summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/checker
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/checker')
-rw-r--r--pkg/tcpip/checker/BUILD16
-rw-r--r--pkg/tcpip/checker/checker.go842
2 files changed, 0 insertions, 858 deletions
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
deleted file mode 100644
index ed434807f..000000000
--- a/pkg/tcpip/checker/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "checker",
- testonly = 1,
- srcs = ["checker.go"],
- visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/header",
- "//pkg/tcpip/seqnum",
- ],
-)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
deleted file mode 100644
index c6c160dfc..000000000
--- a/pkg/tcpip/checker/checker.go
+++ /dev/null
@@ -1,842 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package checker provides helper functions to check networking packets for
-// validity.
-package checker
-
-import (
- "encoding/binary"
- "reflect"
- "testing"
-
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
-)
-
-// NetworkChecker is a function to check a property of a network packet.
-type NetworkChecker func(*testing.T, []header.Network)
-
-// TransportChecker is a function to check a property of a transport packet.
-type TransportChecker func(*testing.T, header.Transport)
-
-// ControlMessagesChecker is a function to check a property of ancillary data.
-type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
-
-// IPv4 checks the validity and properties of the given IPv4 packet. It is
-// expected to be used in conjunction with other network checkers for specific
-// properties. For example, to check the source and destination address, one
-// would call:
-//
-// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
-func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv4 := header.IPv4(b)
-
- if !ipv4.IsValid(len(b)) {
- t.Error("Not a valid IPv4 packet")
- }
-
- xsum := ipv4.CalculateChecksum()
- if xsum != 0 && xsum != 0xffff {
- t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
- }
-
- for _, f := range checkers {
- f(t, []header.Network{ipv4})
- }
- if t.Failed() {
- t.FailNow()
- }
-}
-
-// IPv6 checks the validity and properties of the given IPv6 packet. The usage
-// is similar to IPv4.
-func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv6 := header.IPv6(b)
- if !ipv6.IsValid(len(b)) {
- t.Error("Not a valid IPv6 packet")
- }
-
- for _, f := range checkers {
- f(t, []header.Network{ipv6})
- }
- if t.Failed() {
- t.FailNow()
- }
-}
-
-// SrcAddr creates a checker that checks the source address.
-func SrcAddr(addr tcpip.Address) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if a := h[0].SourceAddress(); a != addr {
- t.Errorf("Bad source address, got %v, want %v", a, addr)
- }
- }
-}
-
-// DstAddr creates a checker that checks the destination address.
-func DstAddr(addr tcpip.Address) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if a := h[0].DestinationAddress(); a != addr {
- t.Errorf("Bad destination address, got %v, want %v", a, addr)
- }
- }
-}
-
-// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
-func TTL(ttl uint8) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- var v uint8
- switch ip := h[0].(type) {
- case header.IPv4:
- v = ip.TTL()
- case header.IPv6:
- v = ip.HopLimit()
- }
- if v != ttl {
- t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
- }
- }
-}
-
-// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(plen int) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if l := len(h[0].Payload()); l != plen {
- t.Errorf("Bad payload length, got %v, want %v", l, plen)
- }
- }
-}
-
-// FragmentOffset creates a checker that checks the FragmentOffset field.
-func FragmentOffset(offset uint16) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // We only do this of IPv4 for now.
- switch ip := h[0].(type) {
- case header.IPv4:
- if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
- }
- }
- }
-}
-
-// FragmentFlags creates a checker that checks the fragment flags field.
-func FragmentFlags(flags uint8) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // We only do this of IPv4 for now.
- switch ip := h[0].(type) {
- case header.IPv4:
- if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
- }
- }
- }
-}
-
-// ReceiveTClass creates a checker that checks the TCLASS field in
-// ControlMessages.
-func ReceiveTClass(want uint32) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasTClass {
- t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
- }
- if got := cm.TClass; got != want {
- t.Fatalf("got cm.TClass = %d, want %d", got, want)
- }
- }
-}
-
-// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
-func ReceiveTOS(want uint8) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasTOS {
- t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want)
- }
- if got := cm.TOS; got != want {
- t.Fatalf("got cm.TOS = %d, want %d", got, want)
- }
- }
-}
-
-// TOS creates a checker that checks the TOS field.
-func TOS(tos uint8, label uint32) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
- }
- }
-}
-
-// Raw creates a checker that checks the bytes of payload.
-// The checker always checks the payload of the last network header.
-// For instance, in case of IPv6 fragments, the payload that will be checked
-// is the one containing the actual data that the packet is carrying, without
-// the bytes added by the IPv6 fragmentation.
-func Raw(want []byte) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
- t.Errorf("Wrong payload, got %v, want %v", got, want)
- }
- }
-}
-
-// IPv6Fragment creates a checker that validates an IPv6 fragment.
-func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
- }
-
- ipv6Frag := header.IPv6Fragment(h[0].Payload())
- if !ipv6Frag.IsValid() {
- t.Error("Not a valid IPv6 fragment")
- }
-
- for _, f := range checkers {
- f(t, []header.Network{h[0], ipv6Frag})
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// TCP creates a checker that checks that the transport protocol is TCP and
-// potentially additional transport header fields.
-func TCP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- first := h[0]
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
- }
-
- // Verify the checksum.
- tcp := header.TCP(last.Payload())
- l := uint16(len(tcp))
-
- xsum := header.Checksum([]byte(first.SourceAddress()), 0)
- xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
- xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
- xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
- xsum = header.Checksum(tcp, xsum)
-
- if xsum != 0 && xsum != 0xffff {
- t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
- }
-
- // Run the transport checkers.
- for _, f := range checkers {
- f(t, tcp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// UDP creates a checker that checks that the transport protocol is UDP and
-// potentially additional transport header fields.
-func UDP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
- }
-
- udp := header.UDP(last.Payload())
- for _, f := range checkers {
- f(t, udp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// SrcPort creates a checker that checks the source port.
-func SrcPort(port uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got %v, want %v", p, port)
- }
- }
-}
-
-// DstPort creates a checker that checks the destination port.
-func DstPort(port uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got %v, want %v", p, port)
- }
- }
-}
-
-// SeqNum creates a checker that checks the sequence number.
-func SeqNum(seq uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got %v, want %v", s, seq)
- }
- }
-}
-
-// AckNum creates a checker that checks the ack number.
-func AckNum(seq uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got %v, want %v", s, seq)
- }
- }
-}
-
-// Window creates a checker that checks the tcp window.
-func Window(window uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if w := tcp.WindowSize(); w != window {
- t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
- }
- }
-}
-
-// TCPFlags creates a checker that checks the tcp flags.
-func TCPFlags(flags uint8) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if f := tcp.Flags(); f != flags {
- t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
- }
- }
-}
-
-// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
-// given mask, match the supplied flags.
-func TCPFlagsMatch(flags, mask uint8) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
-
- if f := tcp.Flags(); (f & mask) != (flags & mask) {
- t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
- }
- }
-}
-
-// TCPSynOptions creates a checker that checks the presence of TCP options in
-// SYN segments.
-//
-// If wndscale is negative, the window scale option must not be present.
-func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- opts := tcp.Options()
- limit := len(opts)
- foundMSS := false
- foundWS := false
- foundTS := false
- foundSACKPermitted := false
- tsVal := uint32(0)
- tsEcr := uint32(0)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionMSS:
- v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
- if wantOpts.MSS != v {
- t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
- }
- foundMSS = true
- i += 4
- case header.TCPOptionWS:
- if wantOpts.WS < 0 {
- t.Error("WS present when it shouldn't be")
- }
- v := int(opts[i+2])
- if v != wantOpts.WS {
- t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
- }
- foundWS = true
- i += 3
- case header.TCPOptionTS:
- if i+9 >= limit {
- t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
- }
- if opts[i+1] != 10 {
- t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
- }
- tsVal = binary.BigEndian.Uint32(opts[i+2:])
- tsEcr = uint32(0)
- if tcp.Flags()&header.TCPFlagAck != 0 {
- // If the syn is an SYN-ACK then read
- // the tsEcr value as well.
- tsEcr = binary.BigEndian.Uint32(opts[i+6:])
- }
- foundTS = true
- i += 10
- case header.TCPOptionSACKPermitted:
- if i+1 >= limit {
- t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
- }
- if opts[i+1] != 2 {
- t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
- }
- foundSACKPermitted = true
- i += 2
-
- default:
- i += int(opts[i+1])
- }
- }
-
- if !foundMSS {
- t.Errorf("MSS option not found. Options: %x", opts)
- }
-
- if !foundWS && wantOpts.WS >= 0 {
- t.Errorf("WS option not found. Options: %x", opts)
- }
- if wantOpts.TS && !foundTS {
- t.Errorf("TS option not found. Options: %x", opts)
- }
- if foundTS && tsVal == 0 {
- t.Error("TS option specified but the timestamp value is zero")
- }
- if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
- }
- if wantOpts.SACKPermitted && !foundSACKPermitted {
- t.Errorf("SACKPermitted option not found. Options: %x", opts)
- }
- }
-}
-
-// TCPTimestampChecker creates a checker that validates that a TCP segment has a
-// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
-// wantTSEcr values with those in the TCP segment (if present).
-//
-// If wantTSVal or wantTSEcr is zero then the corresponding comparison is
-// skipped.
-func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- opts := []byte(tcp.Options())
- limit := len(opts)
- foundTS := false
- tsVal := uint32(0)
- tsEcr := uint32(0)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionTS:
- if i+9 >= limit {
- t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
- }
- if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
- }
- tsVal = binary.BigEndian.Uint32(opts[i+2:])
- tsEcr = binary.BigEndian.Uint32(opts[i+6:])
- foundTS = true
- i += 10
- default:
- // We don't recognize this option, just skip over it.
- if i+2 > limit {
- return
- }
- l := int(opts[i+1])
- if i < 2 || i+l > limit {
- return
- }
- i += l
- }
- }
-
- if wantTS != foundTS {
- t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
- }
- if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
- }
- if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
- }
- }
-}
-
-// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
-// contain any SACK blocks in the TCP options.
-func TCPNoSACKBlockChecker() TransportChecker {
- return TCPSACKBlockChecker(nil)
-}
-
-// TCPSACKBlockChecker creates a checker that verifies that the segment does
-// contain the specified SACK blocks in the TCP options.
-func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- var gotSACKBlocks []header.SACKBlock
-
- opts := []byte(tcp.Options())
- limit := len(opts)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionSACK:
- if i+2 > limit {
- // Malformed SACK block.
- t.Errorf("malformed SACK option in options: %v", opts)
- }
- sackOptionLen := int(opts[i+1])
- if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
- // Malformed SACK block.
- t.Errorf("malformed SACK option length in options: %v", opts)
- }
- numBlocks := sackOptionLen / 8
- for j := 0; j < numBlocks; j++ {
- start := binary.BigEndian.Uint32(opts[i+2+j*8:])
- end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
- gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
- Start: seqnum.Value(start),
- End: seqnum.Value(end),
- })
- }
- i += sackOptionLen
- default:
- // We don't recognize this option, just skip over it.
- if i+2 > limit {
- break
- }
- l := int(opts[i+1])
- if l < 2 || i+l > limit {
- break
- }
- i += l
- }
- }
-
- if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
- }
- }
-}
-
-// Payload creates a checker that checks the payload.
-func Payload(want []byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- if got := h.Payload(); !reflect.DeepEqual(got, want) {
- t.Errorf("Wrong payload, got %v, want %v", got, want)
- }
- }
-}
-
-// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
-// potentially additional ICMPv4 header fields.
-func ICMPv4(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
- }
-
- icmp := header.ICMPv4(last.Payload())
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// ICMPv4Type creates a checker that checks the ICMPv4 Type field.
-func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
- }
- if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
- }
- }
-}
-
-// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
-func ICMPv4Code(want byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
- }
- if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
- }
- }
-}
-
-// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
-// potentially additional ICMPv6 header fields.
-//
-// ICMPv6 will validate the checksum field before calling checkers.
-func ICMPv6(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
- }
-
- icmp := header.ICMPv6(last.Payload())
- if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want {
- t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// ICMPv6Type creates a checker that checks the ICMPv6 Type field.
-func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
- }
- if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
- }
- }
-}
-
-// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
-func ICMPv6Code(want byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
- }
- if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
- }
- }
-}
-
-// NDP creates a checker that checks that the packet contains a valid NDP
-// message for type of ty, with potentially additional checks specified by
-// checkers.
-//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDP message as far as the size of the message (minSize) is concerned. The
-// values within the message are up to checkers to validate.
-func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // Check normal ICMPv6 first.
- ICMPv6(
- ICMPv6Type(msgType),
- ICMPv6Code(0))(t, h)
-
- last := h[len(h)-1]
-
- icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.NDPPayload()); got < minSize {
- t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// NDPNS creates a checker that checks that the packet contains a valid NDP
-// Neighbor Solicitation message (as per the raw wire format), with potentially
-// additional checks specified by checkers.
-//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNS message as far as the size of the messages concerned. The values within
-// the message are up to checkers to validate.
-func NDPNS(checkers ...TransportChecker) NetworkChecker {
- return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
-}
-
-// NDPNSTargetAddress creates a checker that checks the Target Address field of
-// a header.NDPNeighborSolicit.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNS message as far as the size is concerned.
-func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
-
- if got := ns.TargetAddress(); got != want {
- t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want)
- }
- }
-}
-
-// NDPNSOptions creates a checker that checks that the packet contains the
-// provided NDP options within an NDP Neighbor Solicitation message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNS message as far as the size is concerned.
-func NDPNSOptions(opts []header.NDPOption) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
- it, err := ns.Options().Iter(true)
- if err != nil {
- t.Errorf("opts.Iter(true): %s", err)
- return
- }
-
- i := 0
- for {
- opt, done, _ := it.Next()
- if done {
- break
- }
-
- if i >= len(opts) {
- t.Errorf("got unexpected option: %s", opt)
- continue
- }
-
- switch wantOpt := opts[i].(type) {
- case header.NDPSourceLinkLayerAddressOption:
- gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
- t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
- }
- default:
- panic("not implemented")
- }
-
- i++
- }
-
- if missing := opts[i:]; len(missing) > 0 {
- t.Errorf("missing options: %s", missing)
- }
- }
-}
-
-// NDPRS creates a checker that checks that the packet contains a valid NDP
-// Router Solicitation message (as per the raw wire format).
-func NDPRS() NetworkChecker {
- return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize)
-}