summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/header/checksum.go15
-rw-r--r--pkg/tcpip/header/checksum_test.go6
2 files changed, 12 insertions, 9 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index ce57b581a..204285576 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -160,20 +160,23 @@ func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bo
return ChecksumCombine(uint16(v), uint16(v>>16)), odd
}
-// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
-// given byte array.
+// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given byte array. This function uses a non-optimized implementation. Its
+// only retained for reference and to use as a benchmark/test. Most code should
+// use the header.Checksum function.
//
// The initial checksum must have been computed on an even number of bytes.
-func Checksum(buf []byte, initial uint16) uint16 {
+func ChecksumOld(buf []byte, initial uint16) uint16 {
s, _ := calculateChecksum(buf, false, uint32(initial))
return s
}
-// UnrolledChecksum calculates the checksum (as defined in RFC 1071) of the
-// bytes in the given byte array.
+// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
+// given byte array. This function uses an optimized unrolled version of the
+// checksum algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
-func UnrolledChecksum(buf []byte, initial uint16) uint16 {
+func Checksum(buf []byte, initial uint16) uint16 {
s, _ := unrolledCalculateChecksum(buf, false, uint32(initial))
return s
}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 2fbd16a65..309403482 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -128,8 +128,8 @@ func TestChecksum(t *testing.T) {
}
for i := range testCases {
- testCases[i].csumOrig = header.Checksum(testCases[i].buf, testCases[i].initial)
- testCases[i].csumNew = header.UnrolledChecksum(testCases[i].buf, testCases[i].initial)
+ testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial)
+ testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial)
if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want {
t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want)
}
@@ -143,8 +143,8 @@ func BenchmarkChecksum(b *testing.B) {
fn func([]byte, uint16) uint16
name string
}{
+ {header.ChecksumOld, fmt.Sprintf("checksum_old")},
{header.Checksum, fmt.Sprintf("checksum")},
- {header.UnrolledChecksum, fmt.Sprintf("unrolled_checksum")},
}
for _, csumImpl := range checkSumImpls {