summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/checker/checker.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/checker/checker.go')
-rw-r--r--pkg/tcpip/checker/checker.go1625
1 files changed, 0 insertions, 1625 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
deleted file mode 100644
index bab640faf..000000000
--- a/pkg/tcpip/checker/checker.go
+++ /dev/null
@@ -1,1625 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package checker provides helper functions to check networking packets for
-// validity.
-package checker
-
-import (
- "encoding/binary"
- "reflect"
- "testing"
- "time"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
-)
-
-// NetworkChecker is a function to check a property of a network packet.
-type NetworkChecker func(*testing.T, []header.Network)
-
-// TransportChecker is a function to check a property of a transport packet.
-type TransportChecker func(*testing.T, header.Transport)
-
-// ControlMessagesChecker is a function to check a property of ancillary data.
-type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
-
-// IPv4 checks the validity and properties of the given IPv4 packet. It is
-// expected to be used in conjunction with other network checkers for specific
-// properties. For example, to check the source and destination address, one
-// would call:
-//
-// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
-func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv4 := header.IPv4(b)
-
- if !ipv4.IsValid(len(b)) {
- t.Error("Not a valid IPv4 packet")
- }
-
- if !ipv4.IsChecksumValid() {
- t.Errorf("Bad checksum, got = %d", ipv4.Checksum())
- }
-
- for _, f := range checkers {
- f(t, []header.Network{ipv4})
- }
- if t.Failed() {
- t.FailNow()
- }
-}
-
-// IPv6 checks the validity and properties of the given IPv6 packet. The usage
-// is similar to IPv4.
-func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv6 := header.IPv6(b)
- if !ipv6.IsValid(len(b)) {
- t.Error("Not a valid IPv6 packet")
- }
-
- for _, f := range checkers {
- f(t, []header.Network{ipv6})
- }
- if t.Failed() {
- t.FailNow()
- }
-}
-
-// SrcAddr creates a checker that checks the source address.
-func SrcAddr(addr tcpip.Address) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if a := h[0].SourceAddress(); a != addr {
- t.Errorf("Bad source address, got %v, want %v", a, addr)
- }
- }
-}
-
-// DstAddr creates a checker that checks the destination address.
-func DstAddr(addr tcpip.Address) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if a := h[0].DestinationAddress(); a != addr {
- t.Errorf("Bad destination address, got %v, want %v", a, addr)
- }
- }
-}
-
-// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
-func TTL(ttl uint8) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- var v uint8
- switch ip := h[0].(type) {
- case header.IPv4:
- v = ip.TTL()
- case header.IPv6:
- v = ip.HopLimit()
- case *ipv6HeaderWithExtHdr:
- v = ip.HopLimit()
- default:
- t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
- }
- if v != ttl {
- t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
- }
- }
-}
-
-// IPFullLength creates a checker for the full IP packet length. The
-// expected size is checked against both the Total Length in the
-// header and the number of bytes received.
-func IPFullLength(packetLength uint16) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- var v uint16
- var l uint16
- switch ip := h[0].(type) {
- case header.IPv4:
- v = ip.TotalLength()
- l = uint16(len(ip))
- case header.IPv6:
- v = ip.PayloadLength() + header.IPv6FixedHeaderSize
- l = uint16(len(ip))
- default:
- t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
- }
- if l != packetLength {
- t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
- }
- if v != packetLength {
- t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
- }
- }
-}
-
-// IPv4HeaderLength creates a checker that checks the IPv4 Header length.
-func IPv4HeaderLength(headerLength int) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- switch ip := h[0].(type) {
- case header.IPv4:
- if hl := ip.HeaderLength(); hl != uint8(headerLength) {
- t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
- }
- default:
- t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
- }
- }
-}
-
-// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(payloadLength int) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if l := len(h[0].Payload()); l != payloadLength {
- t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
- }
- }
-}
-
-// IPPayload creates a checker that checks the payload.
-func IPPayload(payload []byte) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- got := h[0].Payload()
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(got) == 0 && len(payload) == 0 {
- return
- }
-
- if diff := cmp.Diff(payload, got); diff != "" {
- t.Errorf("payload mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// IPv4Options returns a checker that checks the options in an IPv4 packet.
-func IPv4Options(want header.IPv4Options) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- ip, ok := h[0].(header.IPv4)
- if !ok {
- t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
- }
- options := ip.Options()
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(want) == 0 && len(options) == 0 {
- return
- }
- if diff := cmp.Diff(want, options); diff != "" {
- t.Errorf("options mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// IPv4RouterAlert returns a checker that checks that the RouterAlert option is
-// set in an IPv4 packet.
-func IPv4RouterAlert() NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
- ip, ok := h[0].(header.IPv4)
- if !ok {
- t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
- }
- iterator := ip.Options().MakeIterator()
- for {
- opt, done, err := iterator.Next()
- if err != nil {
- t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer)
- }
- if done {
- break
- }
- if opt.Type() != header.IPv4OptionRouterAlertType {
- continue
- }
- want := [header.IPv4OptionRouterAlertLength]byte{
- byte(header.IPv4OptionRouterAlertType),
- header.IPv4OptionRouterAlertLength,
- header.IPv4OptionRouterAlertValue,
- header.IPv4OptionRouterAlertValue,
- }
- if diff := cmp.Diff(want[:], opt.Contents()); diff != "" {
- t.Errorf("router alert option mismatch (-want +got):\n%s", diff)
- }
- return
- }
- t.Errorf("failed to find router alert option in %v", ip.Options())
- }
-}
-
-// FragmentOffset creates a checker that checks the FragmentOffset field.
-func FragmentOffset(offset uint16) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // We only do this for IPv4 for now.
- switch ip := h[0].(type) {
- case header.IPv4:
- if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
- }
- }
- }
-}
-
-// FragmentFlags creates a checker that checks the fragment flags field.
-func FragmentFlags(flags uint8) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // We only do this for IPv4 for now.
- switch ip := h[0].(type) {
- case header.IPv4:
- if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
- }
- }
- }
-}
-
-// ReceiveTClass creates a checker that checks the TCLASS field in
-// ControlMessages.
-func ReceiveTClass(want uint32) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasTClass {
- t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass)
- } else if got := cm.TClass; got != want {
- t.Errorf("got cm.TClass = %d, want %d", got, want)
- }
- }
-}
-
-// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
-func ReceiveTOS(want uint8) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasTOS {
- t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS)
- } else if got := cm.TOS; got != want {
- t.Errorf("got cm.TOS = %d, want %d", got, want)
- }
- }
-}
-
-// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
-// ControlMessages.
-func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasIPPacketInfo {
- t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo)
- } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
- t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
-// field in ControlMessages.
-func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
- return func(t *testing.T, cm tcpip.ControlMessages) {
- t.Helper()
- if !cm.HasOriginalDstAddress {
- t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress)
- } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
- t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// TOS creates a checker that checks the TOS field.
-func TOS(tos uint8, label uint32) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
- }
- }
-}
-
-// Raw creates a checker that checks the bytes of payload.
-// The checker always checks the payload of the last network header.
-// For instance, in case of IPv6 fragments, the payload that will be checked
-// is the one containing the actual data that the packet is carrying, without
-// the bytes added by the IPv6 fragmentation.
-func Raw(want []byte) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
- t.Errorf("Wrong payload, got %v, want %v", got, want)
- }
- }
-}
-
-// IPv6Fragment creates a checker that validates an IPv6 fragment.
-func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
- }
-
- ipv6Frag := header.IPv6Fragment(h[0].Payload())
- if !ipv6Frag.IsValid() {
- t.Error("Not a valid IPv6 fragment")
- }
-
- for _, f := range checkers {
- f(t, []header.Network{h[0], ipv6Frag})
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// TCP creates a checker that checks that the transport protocol is TCP and
-// potentially additional transport header fields.
-func TCP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- first := h[0]
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
- }
-
- tcp := header.TCP(last.Payload())
- payload := tcp.Payload()
- payloadChecksum := header.Checksum(payload, 0)
- if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) {
- t.Errorf("Bad checksum, got = %d", tcp.Checksum())
- }
-
- // Run the transport checkers.
- for _, f := range checkers {
- f(t, tcp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// UDP creates a checker that checks that the transport protocol is UDP and
-// potentially additional transport header fields.
-func UDP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
- }
-
- udp := header.UDP(last.Payload())
- for _, f := range checkers {
- f(t, udp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// SrcPort creates a checker that checks the source port.
-func SrcPort(port uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got = %d, want = %d", p, port)
- }
- }
-}
-
-// DstPort creates a checker that checks the destination port.
-func DstPort(port uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got = %d, want = %d", p, port)
- }
- }
-}
-
-// NoChecksum creates a checker that checks if the checksum is zero.
-func NoChecksum(noChecksum bool) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- udp, ok := h.(header.UDP)
- if !ok {
- t.Fatalf("UDP header not found in h: %T", h)
- }
-
- if b := udp.Checksum() == 0; b != noChecksum {
- t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
- }
- }
-}
-
-// TCPSeqNum creates a checker that checks the sequence number.
-func TCPSeqNum(seq uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
- }
- }
-}
-
-// TCPAckNum creates a checker that checks the ack number.
-func TCPAckNum(seq uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
- }
- }
-}
-
-// TCPWindow creates a checker that checks the tcp window.
-func TCPWindow(window uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in hdr : %T", h)
- }
-
- if w := tcp.WindowSize(); w != window {
- t.Errorf("Bad window, got %d, want %d", w, window)
- }
- }
-}
-
-// TCPWindowGreaterThanEq creates a checker that checks that the TCP window
-// is greater than or equal to the provided value.
-func TCPWindowGreaterThanEq(window uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if w := tcp.WindowSize(); w < window {
- t.Errorf("Bad window, got %d, want > %d", w, window)
- }
- }
-}
-
-// TCPWindowLessThanEq creates a checker that checks that the tcp window
-// is less than or equal to the provided value.
-func TCPWindowLessThanEq(window uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if w := tcp.WindowSize(); w > window {
- t.Errorf("Bad window, got %d, want < %d", w, window)
- }
- }
-}
-
-// TCPFlags creates a checker that checks the tcp flags.
-func TCPFlags(flags header.TCPFlags) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if got := tcp.Flags(); got != flags {
- t.Errorf("got tcp.Flags() = %s, want %s", got, flags)
- }
- }
-}
-
-// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
-// given mask, match the supplied flags.
-func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- t.Fatalf("TCP header not found in h: %T", h)
- }
-
- if got := tcp.Flags(); (got & mask) != (flags & mask) {
- t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask)
- }
- }
-}
-
-// TCPSynOptions creates a checker that checks the presence of TCP options in
-// SYN segments.
-//
-// If wndscale is negative, the window scale option must not be present.
-func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- opts := tcp.Options()
- limit := len(opts)
- foundMSS := false
- foundWS := false
- foundTS := false
- foundSACKPermitted := false
- tsVal := uint32(0)
- tsEcr := uint32(0)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionMSS:
- v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
- if wantOpts.MSS != v {
- t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
- }
- foundMSS = true
- i += 4
- case header.TCPOptionWS:
- if wantOpts.WS < 0 {
- t.Error("WS present when it shouldn't be")
- }
- v := int(opts[i+2])
- if v != wantOpts.WS {
- t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
- }
- foundWS = true
- i += 3
- case header.TCPOptionTS:
- if i+9 >= limit {
- t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
- }
- if opts[i+1] != 10 {
- t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
- }
- tsVal = binary.BigEndian.Uint32(opts[i+2:])
- tsEcr = uint32(0)
- if tcp.Flags()&header.TCPFlagAck != 0 {
- // If the syn is an SYN-ACK then read
- // the tsEcr value as well.
- tsEcr = binary.BigEndian.Uint32(opts[i+6:])
- }
- foundTS = true
- i += 10
- case header.TCPOptionSACKPermitted:
- if i+1 >= limit {
- t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
- }
- if opts[i+1] != 2 {
- t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
- }
- foundSACKPermitted = true
- i += 2
-
- default:
- i += int(opts[i+1])
- }
- }
-
- if !foundMSS {
- t.Errorf("MSS option not found. Options: %x", opts)
- }
-
- if !foundWS && wantOpts.WS >= 0 {
- t.Errorf("WS option not found. Options: %x", opts)
- }
- if wantOpts.TS && !foundTS {
- t.Errorf("TS option not found. Options: %x", opts)
- }
- if foundTS && tsVal == 0 {
- t.Error("TS option specified but the timestamp value is zero")
- }
- if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
- }
- if wantOpts.SACKPermitted && !foundSACKPermitted {
- t.Errorf("SACKPermitted option not found. Options: %x", opts)
- }
- }
-}
-
-// TCPTimestampChecker creates a checker that validates that a TCP segment has a
-// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
-// wantTSEcr values with those in the TCP segment (if present).
-//
-// If wantTSVal or wantTSEcr is zero then the corresponding comparison is
-// skipped.
-func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- opts := tcp.Options()
- limit := len(opts)
- foundTS := false
- tsVal := uint32(0)
- tsEcr := uint32(0)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionTS:
- if i+9 >= limit {
- t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
- }
- if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
- }
- tsVal = binary.BigEndian.Uint32(opts[i+2:])
- tsEcr = binary.BigEndian.Uint32(opts[i+6:])
- foundTS = true
- i += 10
- default:
- // We don't recognize this option, just skip over it.
- if i+2 > limit {
- return
- }
- l := int(opts[i+1])
- if i < 2 || i+l > limit {
- return
- }
- i += l
- }
- }
-
- if wantTS != foundTS {
- t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
- }
- if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
- }
- if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
- }
- }
-}
-
-// TCPSACKBlockChecker creates a checker that verifies that the segment does
-// contain the specified SACK blocks in the TCP options.
-func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
- tcp, ok := h.(header.TCP)
- if !ok {
- return
- }
- var gotSACKBlocks []header.SACKBlock
-
- opts := tcp.Options()
- limit := len(opts)
- for i := 0; i < limit; {
- switch opts[i] {
- case header.TCPOptionEOL:
- i = limit
- case header.TCPOptionNOP:
- i++
- case header.TCPOptionSACK:
- if i+2 > limit {
- // Malformed SACK block.
- t.Errorf("malformed SACK option in options: %v", opts)
- }
- sackOptionLen := int(opts[i+1])
- if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
- // Malformed SACK block.
- t.Errorf("malformed SACK option length in options: %v", opts)
- }
- numBlocks := sackOptionLen / 8
- for j := 0; j < numBlocks; j++ {
- start := binary.BigEndian.Uint32(opts[i+2+j*8:])
- end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
- gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
- Start: seqnum.Value(start),
- End: seqnum.Value(end),
- })
- }
- i += sackOptionLen
- default:
- // We don't recognize this option, just skip over it.
- if i+2 > limit {
- break
- }
- l := int(opts[i+1])
- if l < 2 || i+l > limit {
- break
- }
- i += l
- }
- }
-
- if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
- }
- }
-}
-
-// Payload creates a checker that checks the payload.
-func Payload(want []byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- if got := h.Payload(); !reflect.DeepEqual(got, want) {
- t.Errorf("Wrong payload, got %v, want %v", got, want)
- }
- }
-}
-
-// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4
-// and potentially additional ICMPv4 header fields.
-func ICMPv4(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
- }
-
- icmp := header.ICMPv4(last.Payload())
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// ICMPv4Type creates a checker that checks the ICMPv4 Type field.
-func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
-func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
-func ICMPv4Ident(want uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Ident(); got != want {
- t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
-func ICMPv4Seq(want uint16) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Sequence(); got != want {
- t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer.
-func ICMPv4Pointer(want uint8) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- if got := icmpv4.Pointer(); got != want {
- t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
-// This assumes that the payload exactly makes up the rest of the slice.
-func ICMPv4Checksum() TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- heldChecksum := icmpv4.Checksum()
- icmpv4.SetChecksum(0)
- newChecksum := ^header.Checksum(icmpv4, 0)
- icmpv4.SetChecksum(heldChecksum)
- if heldChecksum != newChecksum {
- t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
- }
- }
-}
-
-// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
-func ICMPv4Payload(want []byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv4, ok := h.(header.ICMPv4)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
- }
- payload := icmpv4.Payload()
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(want) == 0 && len(payload) == 0 {
- return
- }
-
- if diff := cmp.Diff(want, payload); diff != "" {
- t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
-// potentially additional ICMPv6 header fields.
-//
-// ICMPv6 will validate the checksum field before calling checkers.
-func ICMPv6(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
- }
-
- icmp := header.ICMPv6(last.Payload())
- if got, want := icmp.Checksum(), header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
- Header: icmp,
- Src: last.SourceAddress(),
- Dst: last.DestinationAddress(),
- }); got != want {
- t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// ICMPv6Type creates a checker that checks the ICMPv6 Type field.
-func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
- }
- if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
-func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
- }
- if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific
-// field.
-func ICMPv6TypeSpecific(want uint32) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
- }
- if got := icmpv6.TypeSpecific(); got != want {
- t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want)
- }
- }
-}
-
-// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet.
-func ICMPv6Payload(want []byte) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmpv6, ok := h.(header.ICMPv6)
- if !ok {
- t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
- }
- payload := icmpv6.Payload()
-
- // cmp.Diff does not consider nil slices equal to empty slices, but we do.
- if len(want) == 0 && len(payload) == 0 {
- return
- }
-
- if diff := cmp.Diff(want, payload); diff != "" {
- t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
- }
- }
-}
-
-// MLD creates a checker that checks that the packet contains a valid MLD
-// message for type of mldType, with potentially additional checks specified by
-// checkers.
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// MLD message as far as the size of the message (minSize) is concerned. The
-// values within the message are up to checkers to validate.
-func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // Check normal ICMPv6 first.
- ICMPv6(
- ICMPv6Type(msgType),
- ICMPv6Code(0))(t, h)
-
- last := h[len(h)-1]
-
- icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.MessageBody()); got < minSize {
- t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay
-// field of a MLD message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid MLD message as far as the size is concerned.
-func MLDMaxRespDelay(want time.Duration) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.MLD(icmp.MessageBody())
-
- if got := ns.MaximumResponseDelay(); got != want {
- t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want)
- }
- }
-}
-
-// MLDMulticastAddress creates a checker that checks the Multicast Address
-// field of a MLD message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid MLD message as far as the size is concerned.
-func MLDMulticastAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.MLD(icmp.MessageBody())
-
- if got := ns.MulticastAddress(); got != want {
- t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want)
- }
- }
-}
-
-// NDP creates a checker that checks that the packet contains a valid NDP
-// message for type of ty, with potentially additional checks specified by
-// checkers.
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDP message as far as the size of the message (minSize) is concerned. The
-// values within the message are up to checkers to validate.
-func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- // Check normal ICMPv6 first.
- ICMPv6(
- ICMPv6Type(msgType),
- ICMPv6Code(0))(t, h)
-
- last := h[len(h)-1]
-
- icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.MessageBody()); got < minSize {
- t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
- }
-
- for _, f := range checkers {
- f(t, icmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// NDPNS creates a checker that checks that the packet contains a valid NDP
-// Neighbor Solicitation message (as per the raw wire format), with potentially
-// additional checks specified by checkers.
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNS message as far as the size of the message is concerned. The values
-// within the message are up to checkers to validate.
-func NDPNS(checkers ...TransportChecker) NetworkChecker {
- return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
-}
-
-// NDPNSTargetAddress creates a checker that checks the Target Address field of
-// a header.NDPNeighborSolicit.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNS message as far as the size is concerned.
-func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
-
- if got := ns.TargetAddress(); got != want {
- t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
- }
- }
-}
-
-// NDPNA creates a checker that checks that the packet contains a valid NDP
-// Neighbor Advertisement message (as per the raw wire format), with potentially
-// additional checks specified by checkers.
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNA message as far as the size of the message is concerned. The values
-// within the message are up to checkers to validate.
-func NDPNA(checkers ...TransportChecker) NetworkChecker {
- return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
-}
-
-// NDPNATargetAddress creates a checker that checks the Target Address field of
-// a header.NDPNeighborAdvert.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNA message as far as the size is concerned.
-func NDPNATargetAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
-
- if got := na.TargetAddress(); got != want {
- t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
- }
- }
-}
-
-// NDPNASolicitedFlag creates a checker that checks the Solicited field of
-// a header.NDPNeighborAdvert.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNA message as far as the size is concerned.
-func NDPNASolicitedFlag(want bool) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
-
- if got := na.SolicitedFlag(); got != want {
- t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
- }
- }
-}
-
-// ndpOptions checks that optsBuf only contains opts.
-func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
- t.Helper()
-
- it, err := optsBuf.Iter(true)
- if err != nil {
- t.Errorf("optsBuf.Iter(true): %s", err)
- return
- }
-
- i := 0
- for {
- opt, done, err := it.Next()
- if err != nil {
- // This should never happen as Iter(true) above did not return an error.
- t.Fatalf("unexpected error when iterating over NDP options: %s", err)
- }
- if done {
- break
- }
-
- if i >= len(opts) {
- t.Errorf("got unexpected option: %s", opt)
- continue
- }
-
- switch wantOpt := opts[i].(type) {
- case header.NDPSourceLinkLayerAddressOption:
- gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
- t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
- }
- case header.NDPTargetLinkLayerAddressOption:
- gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
- t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
- }
- case header.NDPNonceOption:
- gotOpt, ok := opt.(header.NDPNonceOption)
- if !ok {
- t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
- } else if diff := cmp.Diff(wantOpt.Nonce(), gotOpt.Nonce()); diff != "" {
- t.Errorf("nonce mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
- }
-
- i++
- }
-
- if missing := opts[i:]; len(missing) > 0 {
- t.Errorf("missing options: %s", missing)
- }
-}
-
-// NDPNAOptions creates a checker that checks that the packet contains the
-// provided NDP options within an NDP Neighbor Solicitation message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNA message as far as the size is concerned.
-func NDPNAOptions(opts []header.NDPOption) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.MessageBody())
- ndpOptions(t, na.Options(), opts)
- }
-}
-
-// NDPNSOptions creates a checker that checks that the packet contains the
-// provided NDP options within an NDP Neighbor Solicitation message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPNS message as far as the size is concerned.
-func NDPNSOptions(opts []header.NDPOption) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.MessageBody())
- ndpOptions(t, ns.Options(), opts)
- }
-}
-
-// NDPRS creates a checker that checks that the packet contains a valid NDP
-// Router Solicitation message (as per the raw wire format).
-//
-// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPRS as far as the size of the message is concerned. The values within the
-// message are up to checkers to validate.
-func NDPRS(checkers ...TransportChecker) NetworkChecker {
- return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
-}
-
-// NDPRSOptions creates a checker that checks that the packet contains the
-// provided NDP options within an NDP Router Solicitation message.
-//
-// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
-// containing a valid NDPRS message as far as the size is concerned.
-func NDPRSOptions(opts []header.NDPOption) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- icmp := h.(header.ICMPv6)
- rs := header.NDPRouterSolicit(icmp.MessageBody())
- ndpOptions(t, rs.Options(), opts)
- }
-}
-
-// IGMP checks the validity and properties of the given IGMP packet. It is
-// expected to be used in conjunction with other IGMP transport checkers for
-// specific properties.
-func IGMP(checkers ...TransportChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- last := h[len(h)-1]
-
- if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
- t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
- }
-
- igmp := header.IGMP(last.Payload())
- for _, f := range checkers {
- f(t, igmp)
- }
- if t.Failed() {
- t.FailNow()
- }
- }
-}
-
-// IGMPType creates a checker that checks the IGMP Type field.
-func IGMPType(want header.IGMPType) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- igmp, ok := h.(header.IGMP)
- if !ok {
- t.Fatalf("got transport header = %T, want = header.IGMP", h)
- }
- if got := igmp.Type(); got != want {
- t.Errorf("got igmp.Type() = %d, want = %d", got, want)
- }
- }
-}
-
-// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field.
-func IGMPMaxRespTime(want time.Duration) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- igmp, ok := h.(header.IGMP)
- if !ok {
- t.Fatalf("got transport header = %T, want = header.IGMP", h)
- }
- if got := igmp.MaxRespTime(); got != want {
- t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want)
- }
- }
-}
-
-// IGMPGroupAddress creates a checker that checks the IGMP Group Address field.
-func IGMPGroupAddress(want tcpip.Address) TransportChecker {
- return func(t *testing.T, h header.Transport) {
- t.Helper()
-
- igmp, ok := h.(header.IGMP)
- if !ok {
- t.Fatalf("got transport header = %T, want = header.IGMP", h)
- }
- if got := igmp.GroupAddress(); got != want {
- t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want)
- }
- }
-}
-
-// IPv6ExtHdrChecker is a function to check an extension header.
-type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
-
-// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
-func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) {
- t.Helper()
-
- ipv6 := header.IPv6(b)
- if !ipv6.IsValid(len(b)) {
- t.Error("not a valid IPv6 packet")
- return
- }
-
- payloadIterator := header.MakeIPv6PayloadIterator(
- header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
- buffer.View(ipv6.Payload()).ToVectorisedView(),
- )
-
- var rawPayloadHeader header.IPv6RawPayloadHeader
- for {
- h, done, err := payloadIterator.Next()
- if err != nil {
- t.Errorf("payloadIterator.Next(): %s", err)
- return
- }
- if done {
- t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
- return
- }
- r, ok := h.(header.IPv6RawPayloadHeader)
- if ok {
- rawPayloadHeader = r
- break
- }
- }
-
- networkHeader := ipv6HeaderWithExtHdr{
- IPv6: ipv6,
- transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
- payload: rawPayloadHeader.Buf.ToView(),
- }
-
- for _, checker := range checkers {
- checker(t, []header.Network{&networkHeader})
- }
-}
-
-// IPv6ExtHdr checks for the presence of extension headers.
-//
-// All the extension headers in headers will be checked exhaustively in the
-// order provided.
-func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
- return func(t *testing.T, h []header.Network) {
- t.Helper()
-
- extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
- if !ok {
- t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
- return
- }
-
- payloadIterator := header.MakeIPv6PayloadIterator(
- header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
- buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(),
- )
-
- for _, check := range headers {
- h, done, err := payloadIterator.Next()
- if err != nil {
- t.Errorf("payloadIterator.Next(): %s", err)
- return
- }
- if done {
- t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
- return
- }
- check(t, h)
- }
- // Validate we consumed all headers.
- //
- // The next one over should be a raw payload and then iterator should
- // terminate.
- wantDone := false
- for {
- h, done, err := payloadIterator.Next()
- if err != nil {
- t.Errorf("payloadIterator.Next(): %s", err)
- return
- }
- if done != wantDone {
- t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
- return
- }
- if done {
- break
- }
- if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
- t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
- continue
- }
- wantDone = true
- }
- }
-}
-
-var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
-
-// ipv6HeaderWithExtHdr provides a header.Network implementation that takes
-// extension headers into consideration, which is not the case with vanilla
-// header.IPv6.
-type ipv6HeaderWithExtHdr struct {
- header.IPv6
- transport tcpip.TransportProtocolNumber
- payload []byte
-}
-
-// TransportProtocol implements header.Network.
-func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
- return h.transport
-}
-
-// Payload implements header.Network.
-func (h *ipv6HeaderWithExtHdr) Payload() []byte {
- return h.payload
-}
-
-// IPv6ExtHdrOptionChecker is a function to check an extension header option.
-type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
-
-// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
-// extension header and validates the containing options with checkers.
-//
-// checkers must exhaustively contain all the expected options.
-func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
- return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
- t.Helper()
-
- hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
- if !ok {
- t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
- return
- }
- optionsIterator := hbh.Iter()
- for _, f := range checkers {
- opt, done, err := optionsIterator.Next()
- if err != nil {
- t.Errorf("optionsIterator.Next(): %s", err)
- return
- }
- if done {
- t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
- }
- f(t, opt)
- }
- // Validate all options were consumed.
- for {
- opt, done, err := optionsIterator.Next()
- if err != nil {
- t.Errorf("optionsIterator.Next(): %s", err)
- return
- }
- if !done {
- t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
- }
- if done {
- break
- }
- }
- }
-}
-
-// IPv6RouterAlert validates that an extension header option is the RouterAlert
-// option and matches on its value.
-func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
- return func(t *testing.T, opt header.IPv6ExtHdrOption) {
- routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
- if !ok {
- t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
- return
- }
- if routerAlert.Value != want {
- t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
- }
- }
-}
-
-// IPv6UnknownOption validates that an extension header option is the
-// unknown header option.
-func IPv6UnknownOption() IPv6ExtHdrOptionChecker {
- return func(t *testing.T, opt header.IPv6ExtHdrOption) {
- _, ok := opt.(*header.IPv6UnknownExtHdrOption)
- if !ok {
- t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt)
- }
- }
-}
-
-// IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
-func IgnoreCmpPath(paths ...string) cmp.Option {
- ignores := map[string]struct{}{}
- for _, path := range paths {
- ignores[path] = struct{}{}
- }
- return cmp.FilterPath(func(path cmp.Path) bool {
- _, ok := ignores[path.String()]
- return ok
- }, cmp.Ignore())
-}