summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/header/checksum_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/header/checksum_test.go')
-rw-r--r--pkg/tcpip/header/checksum_test.go94
1 files changed, 94 insertions, 0 deletions
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 309403482..5ab20ee86 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -19,6 +19,7 @@ package header_test
import (
"fmt"
"math/rand"
+ "sync"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -169,3 +170,96 @@ func BenchmarkChecksum(b *testing.B) {
}
}
}
+
+func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) {
+ // icmpChecksum should not do any modifications of the header to
+ // calculate its checksum. Let's call it from a few go-routines and the
+ // race detector will trigger a warning if there are any concurrent
+ // read/write accesses.
+
+ const concurrency = 5
+ start := make(chan int)
+ ready := make(chan bool, concurrency)
+ var wg sync.WaitGroup
+ wg.Add(concurrency)
+ defer wg.Wait()
+
+ for i := 0; i < concurrency; i++ {
+ go func() {
+ defer wg.Done()
+
+ ready <- true
+ <-start
+
+ if got := headerChecksum(); want != got {
+ t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
+ }
+ if got := icmpChecksum(); want != got {
+ t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want)
+ }
+ }()
+ }
+ for i := 0; i < concurrency; i++ {
+ <-ready
+ }
+ close(start)
+}
+
+func TestICMPv4Checksum(t *testing.T) {
+ rnd := rand.New(rand.NewSource(42))
+
+ h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize))
+ if _, err := rnd.Read(h); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ h.SetChecksum(0)
+
+ buf := make([]byte, 13)
+ if _, err := rnd.Read(buf); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ vv := buffer.NewVectorisedView(len(buf), []buffer.View{
+ buffer.NewViewFromBytes(buf[:5]),
+ buffer.NewViewFromBytes(buf[5:]),
+ })
+
+ want := header.Checksum(vv.ToView(), 0)
+ want = ^header.Checksum(h, want)
+ h.SetChecksum(want)
+
+ testICMPChecksum(t, h.Checksum, func() uint16 {
+ return header.ICMPv4Checksum(h, vv)
+ }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
+}
+
+func TestICMPv6Checksum(t *testing.T) {
+ rnd := rand.New(rand.NewSource(42))
+
+ h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize))
+ if _, err := rnd.Read(h); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ h.SetChecksum(0)
+
+ buf := make([]byte, 13)
+ if _, err := rnd.Read(buf); err != nil {
+ t.Fatalf("rnd.Read failed: %v", err)
+ }
+ vv := buffer.NewVectorisedView(len(buf), []buffer.View{
+ buffer.NewViewFromBytes(buf[:7]),
+ buffer.NewViewFromBytes(buf[7:10]),
+ buffer.NewViewFromBytes(buf[10:]),
+ })
+
+ dst := header.IPv6Loopback
+ src := header.IPv6Loopback
+
+ want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+ want = header.Checksum(vv.ToView(), want)
+ want = ^header.Checksum(h, want)
+ h.SetChecksum(want)
+
+ testICMPChecksum(t, h.Checksum, func() uint16 {
+ return header.ICMPv6Checksum(h, src, dst, vv)
+ }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
+}