diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/header/checksum_test.go | 94 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv4.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv6.go | 27 |
3 files changed, 108 insertions, 29 deletions
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 309403482..5ab20ee86 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -19,6 +19,7 @@ package header_test import ( "fmt" "math/rand" + "sync" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -169,3 +170,96 @@ func BenchmarkChecksum(b *testing.B) { } } } + +func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) { + // icmpChecksum should not do any modifications of the header to + // calculate its checksum. Let's call it from a few go-routines and the + // race detector will trigger a warning if there are any concurrent + // read/write accesses. + + const concurrency = 5 + start := make(chan int) + ready := make(chan bool, concurrency) + var wg sync.WaitGroup + wg.Add(concurrency) + defer wg.Wait() + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + + ready <- true + <-start + + if got := headerChecksum(); want != got { + t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) + } + if got := icmpChecksum(); want != got { + t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) + } + }() + } + for i := 0; i < concurrency; i++ { + <-ready + } + close(start) +} + +func TestICMPv4Checksum(t *testing.T) { + rnd := rand.New(rand.NewSource(42)) + + h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize)) + if _, err := rnd.Read(h); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + h.SetChecksum(0) + + buf := make([]byte, 13) + if _, err := rnd.Read(buf); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + vv := buffer.NewVectorisedView(len(buf), []buffer.View{ + buffer.NewViewFromBytes(buf[:5]), + buffer.NewViewFromBytes(buf[5:]), + }) + + want := header.Checksum(vv.ToView(), 0) + want = ^header.Checksum(h, want) + h.SetChecksum(want) + + testICMPChecksum(t, h.Checksum, func() uint16 { + return header.ICMPv4Checksum(h, vv) + }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) +} + +func TestICMPv6Checksum(t *testing.T) { + rnd := rand.New(rand.NewSource(42)) + + h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize)) + if _, err := rnd.Read(h); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + h.SetChecksum(0) + + buf := make([]byte, 13) + if _, err := rnd.Read(buf); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + vv := buffer.NewVectorisedView(len(buf), []buffer.View{ + buffer.NewViewFromBytes(buf[:7]), + buffer.NewViewFromBytes(buf[7:10]), + buffer.NewViewFromBytes(buf[10:]), + }) + + dst := header.IPv6Loopback + src := header.IPv6Loopback + + want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) + want = header.Checksum(vv.ToView(), want) + want = ^header.Checksum(h, want) + h.SetChecksum(want) + + testICMPChecksum(t, h.Checksum, func() uint16 { + return header.ICMPv6Checksum(h, src, dst, vv) + }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 1be90d7d5..5f9b8e9e2 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -200,19 +200,13 @@ func (b ICMPv4) SetSequence(sequence uint16) { // ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, // and payload. func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := uint16(0) - for _, v := range vv.Views() { - xsum = Checksum(v, xsum) - } + xsum := ChecksumVV(vv, 0) - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^Checksum(h, xsum) - h[2], h[3] = h2, h3 + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = Checksum(h[:2], xsum) + xsum = Checksum(h[4:], xsum) - return xsum + return ^xsum } // ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 2eef64b4d..eca9750ab 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -265,22 +265,13 @@ func (b ICMPv6) Payload() []byte { // ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := Checksum([]byte(src), 0) - xsum = Checksum([]byte(dst), xsum) - var upperLayerLength [4]byte - binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) - xsum = Checksum(upperLayerLength[:], xsum) - xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum) - for _, v := range vv.Views() { - xsum = Checksum(v, xsum) - } - - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^Checksum(h, xsum) - h[2], h[3] = h2, h3 - - return xsum + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) + + xsum = ChecksumVV(vv, xsum) + + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = Checksum(h[:2], xsum) + xsum = Checksum(h[4:], xsum) + + return ^xsum } |