diff options
Diffstat (limited to 'pkg/tcpip/header')
-rw-r--r-- | pkg/tcpip/header/BUILD | 7 | ||||
-rw-r--r-- | pkg/tcpip/header/checksum.go | 133 | ||||
-rw-r--r-- | pkg/tcpip/header/checksum_test.go | 62 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv6.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/header/ndp_router_solicit.go | 36 |
5 files changed, 238 insertions, 7 deletions
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index f2061c778..9da0d71f8 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -20,10 +19,10 @@ go_library( "ndp_neighbor_solicit.go", "ndp_options.go", "ndp_router_advert.go", + "ndp_router_solicit.go", "tcp.go", "udp.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/header", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -58,7 +57,7 @@ go_test( "eth_test.go", "ndp_test.go", ], - embed = [":header"], + library = ":header", deps = [ "//pkg/tcpip", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 9749c7f4d..14a4b2b44 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -45,12 +45,139 @@ func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { return ChecksumCombine(uint16(v), uint16(v>>16)), odd } +func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { + v := initial + + if odd { + v += uint32(buf[0]) + buf = buf[1:] + } + + l := len(buf) + odd = l&1 != 0 + if odd { + l-- + v += uint32(buf[l]) << 8 + } + for (l - 64) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + buf = buf[64:] + l = l - 64 + } + if (l - 32) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + buf = buf[32:] + l = l - 32 + } + if (l - 16) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + buf = buf[16:] + l = l - 16 + } + if (l - 8) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + buf = buf[8:] + l = l - 8 + } + if (l - 4) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + buf = buf[4:] + l = l - 4 + } + + // At this point since l was even before we started unrolling + // there can be only two bytes left to add. + if l != 0 { + v += (uint32(buf[0]) << 8) + uint32(buf[1]) + } + + return ChecksumCombine(uint16(v), uint16(v>>16)), odd +} + +// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in +// the given byte array. This function uses a non-optimized implementation. Its +// only retained for reference and to use as a benchmark/test. Most code should +// use the header.Checksum function. +// +// The initial checksum must have been computed on an even number of bytes. +func ChecksumOld(buf []byte, initial uint16) uint16 { + s, _ := calculateChecksum(buf, false, uint32(initial)) + return s +} + // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the -// given byte array. +// given byte array. This function uses an optimized unrolled version of the +// checksum algorithm. // // The initial checksum must have been computed on an even number of bytes. func Checksum(buf []byte, initial uint16) uint16 { - s, _ := calculateChecksum(buf, false, uint32(initial)) + s, _ := unrolledCalculateChecksum(buf, false, uint32(initial)) return s } @@ -86,7 +213,7 @@ func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, siz } v = v[:l] - sum, odd = calculateChecksum(v, odd, uint32(sum)) + sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum)) size -= len(v) if size == 0 { diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 86b466c1c..309403482 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -17,6 +17,8 @@ package header_test import ( + "fmt" + "math/rand" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -107,3 +109,63 @@ func TestChecksumVVWithOffset(t *testing.T) { }) } } + +func TestChecksum(t *testing.T) { + var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024} + type testCase struct { + buf []byte + initial uint16 + csumOrig uint16 + csumNew uint16 + } + testCases := make([]testCase, 100000) + // Ensure same buffer generation for test consistency. + rnd := rand.New(rand.NewSource(42)) + for i := range testCases { + testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) + testCases[i].initial = uint16(rnd.Intn(65536)) + rnd.Read(testCases[i].buf) + } + + for i := range testCases { + testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) + testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) + if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { + t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) + } + } +} + +func BenchmarkChecksum(b *testing.B) { + var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} + + checkSumImpls := []struct { + fn func([]byte, uint16) uint16 + name string + }{ + {header.ChecksumOld, fmt.Sprintf("checksum_old")}, + {header.Checksum, fmt.Sprintf("checksum")}, + } + + for _, csumImpl := range checkSumImpls { + // Ensure same buffer generation for test consistency. + rnd := rand.New(rand.NewSource(42)) + for _, bufSz := range bufSizes { + b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { + tc := struct { + buf []byte + initial uint16 + csum uint16 + }{ + buf: make([]byte, bufSz), + initial: uint16(rnd.Intn(65536)), + } + rnd.Read(tc.buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + tc.csum = csumImpl.fn(tc.buf, tc.initial) + } + }) + } + } +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 83425c614..70e6ce095 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -84,6 +84,13 @@ const ( // The address is ff02::1. IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6AllRoutersMulticastAddress is a link-local multicast group that + // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers on a link. + // + // The address is ff02::2. + IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go new file mode 100644 index 000000000..9e67ba95d --- /dev/null +++ b/pkg/tcpip/header/ndp_router_solicit.go @@ -0,0 +1,36 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.1 for more details. +type NDPRouterSolicit []byte + +const ( + // NDPRSMinimumSize is the minimum size of a valid NDP Router + // Solicitation message (body of an ICMPv6 packet). + NDPRSMinimumSize = 4 + + // ndpRSOptionsOffset is the start of the NDP options in an + // NDPRouterSolicit. + ndpRSOptionsOffset = 4 +) + +// Options returns an NDPOptions of the the options body. +func (b NDPRouterSolicit) Options() NDPOptions { + return NDPOptions(b[ndpRSOptionsOffset:]) +} |