diff options
Diffstat (limited to 'pkg/tcpip/header/checksum_test.go')
-rw-r--r-- | pkg/tcpip/header/checksum_test.go | 461 |
1 files changed, 0 insertions, 461 deletions
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go deleted file mode 100644 index 3445511f4..000000000 --- a/pkg/tcpip/header/checksum_test.go +++ /dev/null @@ -1,461 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package header provides the implementation of the encoding and decoding of -// network protocol headers. -package header_test - -import ( - "bytes" - "fmt" - "math/rand" - "sync" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -func TestChecksumer(t *testing.T) { - testCases := []struct { - name string - data [][]byte - want uint16 - }{ - { - name: "empty", - want: 0, - }, - { - name: "OneOddView", - data: [][]byte{ - []byte{1, 9, 0, 5, 4}, - }, - want: 1294, - }, - { - name: "TwoOddViews", - data: [][]byte{ - []byte{1, 9, 0, 5, 4}, - []byte{4, 3, 7, 1, 2, 123}, - }, - want: 33819, - }, - { - name: "OneEvenView", - data: [][]byte{ - []byte{1, 9, 0, 5}, - }, - want: 270, - }, - { - name: "TwoEvenViews", - data: [][]byte{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0}), - buffer.NewViewFromBytes([]byte{9, 0, 5, 4}), - }, - want: 30981, - }, - { - name: "ThreeViews", - data: [][]byte{ - []byte{77, 11, 33, 0, 55, 44}, - []byte{98, 1, 9, 0, 5, 4}, - []byte{4, 3, 7, 1, 2, 123, 99}, - }, - want: 34236, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var all bytes.Buffer - var c header.Checksumer - for _, b := range tc.data { - c.Add(b) - // Append to the buffer. We will check the checksum as a whole later. - if _, err := all.Write(b); err != nil { - t.Fatalf("all.Write(b) = _, %s; want _, nil", err) - } - } - if got, want := c.Checksum(), tc.want; got != want { - t.Errorf("c.Checksum() = %d, want %d", got, want) - } - if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want { - t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want) - } - }) - } -} - -func TestChecksum(t *testing.T) { - var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024} - type testCase struct { - buf []byte - initial uint16 - csumOrig uint16 - csumNew uint16 - } - testCases := make([]testCase, 100000) - // Ensure same buffer generation for test consistency. - rnd := rand.New(rand.NewSource(42)) - for i := range testCases { - testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) - testCases[i].initial = uint16(rnd.Intn(65536)) - rnd.Read(testCases[i].buf) - } - - for i := range testCases { - testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) - testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) - if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { - t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) - } - } -} - -func BenchmarkChecksum(b *testing.B) { - var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} - - checkSumImpls := []struct { - fn func([]byte, uint16) uint16 - name string - }{ - {header.ChecksumOld, fmt.Sprintf("checksum_old")}, - {header.Checksum, fmt.Sprintf("checksum")}, - } - - for _, csumImpl := range checkSumImpls { - // Ensure same buffer generation for test consistency. - rnd := rand.New(rand.NewSource(42)) - for _, bufSz := range bufSizes { - b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { - tc := struct { - buf []byte - initial uint16 - csum uint16 - }{ - buf: make([]byte, bufSz), - initial: uint16(rnd.Intn(65536)), - } - rnd.Read(tc.buf) - b.ResetTimer() - for i := 0; i < b.N; i++ { - tc.csum = csumImpl.fn(tc.buf, tc.initial) - } - }) - } - } -} - -func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) { - // icmpChecksum should not do any modifications of the header to - // calculate its checksum. Let's call it from a few go-routines and the - // race detector will trigger a warning if there are any concurrent - // read/write accesses. - - const concurrency = 5 - start := make(chan int) - ready := make(chan bool, concurrency) - var wg sync.WaitGroup - wg.Add(concurrency) - defer wg.Wait() - - for i := 0; i < concurrency; i++ { - go func() { - defer wg.Done() - - ready <- true - <-start - - if got := headerChecksum(); want != got { - t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) - } - if got := icmpChecksum(); want != got { - t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) - } - }() - } - for i := 0; i < concurrency; i++ { - <-ready - } - close(start) -} - -func TestICMPv4Checksum(t *testing.T) { - rnd := rand.New(rand.NewSource(42)) - - h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize)) - if _, err := rnd.Read(h); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - h.SetChecksum(0) - - buf := make([]byte, 13) - if _, err := rnd.Read(buf); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - vv := buffer.NewVectorisedView(len(buf), []buffer.View{ - buffer.NewViewFromBytes(buf[:5]), - buffer.NewViewFromBytes(buf[5:]), - }) - - want := header.Checksum(vv.ToView(), 0) - want = ^header.Checksum(h, want) - h.SetChecksum(want) - - testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0)) - }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) -} - -func TestICMPv6Checksum(t *testing.T) { - rnd := rand.New(rand.NewSource(42)) - - h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize)) - if _, err := rnd.Read(h); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - h.SetChecksum(0) - - buf := make([]byte, 13) - if _, err := rnd.Read(buf); err != nil { - t.Fatalf("rnd.Read failed: %v", err) - } - vv := buffer.NewVectorisedView(len(buf), []buffer.View{ - buffer.NewViewFromBytes(buf[:7]), - buffer.NewViewFromBytes(buf[7:10]), - buffer.NewViewFromBytes(buf[10:]), - }) - - dst := header.IPv6Loopback - src := header.IPv6Loopback - - want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) - want = header.Checksum(vv.ToView(), want) - want = ^header.Checksum(h, want) - h.SetChecksum(want) - - testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: h, - Src: src, - Dst: dst, - PayloadCsum: header.ChecksumVV(vv, 0), - PayloadLen: vv.Size(), - }) - }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) -} - -func randomAddress(size int) tcpip.Address { - s := make([]byte, size) - for i := 0; i < size; i++ { - s[i] = byte(rand.Uint32()) - } - return tcpip.Address(s) -} - -func TestChecksummableNetworkUpdateAddress(t *testing.T) { - tests := []struct { - name string - update func(header.IPv4, tcpip.Address) - }{ - { - name: "SetSourceAddressWithChecksumUpdate", - update: header.IPv4.SetSourceAddressWithChecksumUpdate, - }, - { - name: "SetDestinationAddressWithChecksumUpdate", - update: header.IPv4.SetDestinationAddressWithChecksumUpdate, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for i := 0; i < 1000; i++ { - var origBytes [header.IPv4MinimumSize]byte - header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{ - TOS: 1, - TotalLength: header.IPv4MinimumSize, - ID: 2, - Flags: 3, - FragmentOffset: 4, - TTL: 5, - Protocol: 6, - Checksum: 0, - SrcAddr: randomAddress(header.IPv4AddressSize), - DstAddr: randomAddress(header.IPv4AddressSize), - }) - - addr := randomAddress(header.IPv4AddressSize) - - bytesCopy := origBytes - h := header.IPv4(bytesCopy[:]) - origXSum := h.CalculateChecksum() - h.SetChecksum(^origXSum) - - test.update(h, addr) - got := ^h.Checksum() - h.SetChecksum(0) - want := h.CalculateChecksum() - if got != want { - t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr) - } - } - }) - } -} - -func TestChecksummableTransportUpdatePort(t *testing.T) { - // The fields in the pseudo header is not tested here so we just use 0. - const pseudoHeaderXSum = 0 - - tests := []struct { - name string - transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16) - proto tcpip.TransportProtocolNumber - }{ - { - name: "TCP", - transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { - h := header.TCP(make([]byte, header.TCPMinimumSize)) - h.Encode(&header.TCPFields{ - SrcPort: src, - DstPort: dst, - SeqNum: 1, - AckNum: 2, - DataOffset: header.TCPMinimumSize, - Flags: 3, - WindowSize: 4, - Checksum: 0, - UrgentPointer: 5, - }) - h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) - return h, h.CalculateChecksum - }, - proto: header.TCPProtocolNumber, - }, - { - name: "UDP", - transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { - h := header.UDP(make([]byte, header.UDPMinimumSize)) - h.Encode(&header.UDPFields{ - SrcPort: src, - DstPort: dst, - Length: 0, - Checksum: 0, - }) - h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) - return h, h.CalculateChecksum - }, - proto: header.UDPProtocolNumber, - }, - } - - for i := 0; i < 1000; i++ { - origSrcPort := uint16(rand.Uint32()) - origDstPort := uint16(rand.Uint32()) - newPort := uint16(rand.Uint32()) - - t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, subTest := range []struct { - name string - update func(header.ChecksummableTransport) - }{ - { - name: "Source port", - update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) }, - }, - { - name: "Destination port", - update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) }, - }, - } { - t.Run(subTest.name, func(t *testing.T) { - h, calcXSum := test.transportHdr(origSrcPort, origDstPort) - subTest.update(h) - // TCP and UDP hold the 1s complement of the fully calculated - // checksum. - got := ^h.Checksum() - h.SetChecksum(0) - - if want := calcXSum(pseudoHeaderXSum); got != want { - h, _ := test.transportHdr(origSrcPort, origDstPort) - t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort) - } - }) - } - }) - } - }) - } -} - -func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) { - const addressSize = 6 - - tests := []struct { - name string - transportHdr func() header.ChecksummableTransport - proto tcpip.TransportProtocolNumber - }{ - { - name: "TCP", - transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) }, - proto: header.TCPProtocolNumber, - }, - { - name: "UDP", - transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) }, - proto: header.UDPProtocolNumber, - }, - } - - for i := 0; i < 1000; i++ { - permanent := randomAddress(addressSize) - old := randomAddress(addressSize) - new := randomAddress(addressSize) - - t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for _, fullChecksum := range []bool{true, false} { - t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) { - initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0) - if fullChecksum { - // TCP and UDP hold the 1s complement of the fully calculated - // checksum. - initialXSum = ^initialXSum - } - - h := test.transportHdr() - h.SetChecksum(initialXSum) - h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum) - - got := h.Checksum() - if fullChecksum { - got = ^got - } - if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want { - t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h) - } - }) - } - }) - } - }) - } -} |