diff options
Diffstat (limited to 'pkg/tcpip/checker')
-rw-r--r-- | pkg/tcpip/checker/BUILD | 17 | ||||
-rw-r--r-- | pkg/tcpip/checker/checker.go | 1605 |
2 files changed, 0 insertions, 1622 deletions
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD deleted file mode 100644 index c984470e6..000000000 --- a/pkg/tcpip/checker/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "checker", - testonly = 1, - srcs = ["checker.go"], - visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - "@com_github_google_go_cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go deleted file mode 100644 index 91971b687..000000000 --- a/pkg/tcpip/checker/checker.go +++ /dev/null @@ -1,1605 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package checker provides helper functions to check networking packets for -// validity. -package checker - -import ( - "encoding/binary" - "reflect" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" -) - -// NetworkChecker is a function to check a property of a network packet. -type NetworkChecker func(*testing.T, []header.Network) - -// TransportChecker is a function to check a property of a transport packet. -type TransportChecker func(*testing.T, header.Transport) - -// ControlMessagesChecker is a function to check a property of ancillary data. -type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) - -// IPv4 checks the validity and properties of the given IPv4 packet. It is -// expected to be used in conjunction with other network checkers for specific -// properties. For example, to check the source and destination address, one -// would call: -// -// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) -func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv4 := header.IPv4(b) - - if !ipv4.IsValid(len(b)) { - t.Error("Not a valid IPv4 packet") - } - - xsum := ipv4.CalculateChecksum() - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) - } - - for _, f := range checkers { - f(t, []header.Network{ipv4}) - } - if t.Failed() { - t.FailNow() - } -} - -// IPv6 checks the validity and properties of the given IPv6 packet. The usage -// is similar to IPv4. -func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv6 := header.IPv6(b) - if !ipv6.IsValid(len(b)) { - t.Error("Not a valid IPv6 packet") - } - - for _, f := range checkers { - f(t, []header.Network{ipv6}) - } - if t.Failed() { - t.FailNow() - } -} - -// SrcAddr creates a checker that checks the source address. -func SrcAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].SourceAddress(); a != addr { - t.Errorf("Bad source address, got %v, want %v", a, addr) - } - } -} - -// DstAddr creates a checker that checks the destination address. -func DstAddr(addr tcpip.Address) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if a := h[0].DestinationAddress(); a != addr { - t.Errorf("Bad destination address, got %v, want %v", a, addr) - } - } -} - -// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). -func TTL(ttl uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - var v uint8 - switch ip := h[0].(type) { - case header.IPv4: - v = ip.TTL() - case header.IPv6: - v = ip.HopLimit() - case *ipv6HeaderWithExtHdr: - v = ip.HopLimit() - default: - t.Fatalf("unrecognized header type %T for TTL evaluation", ip) - } - if v != ttl { - t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) - } - } -} - -// IPFullLength creates a checker for the full IP packet length. The -// expected size is checked against both the Total Length in the -// header and the number of bytes received. -func IPFullLength(packetLength uint16) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - var v uint16 - var l uint16 - switch ip := h[0].(type) { - case header.IPv4: - v = ip.TotalLength() - l = uint16(len(ip)) - case header.IPv6: - v = ip.PayloadLength() + header.IPv6FixedHeaderSize - l = uint16(len(ip)) - default: - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) - } - if l != packetLength { - t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) - } - if v != packetLength { - t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) - } - } -} - -// IPv4HeaderLength creates a checker that checks the IPv4 Header length. -func IPv4HeaderLength(headerLength int) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - switch ip := h[0].(type) { - case header.IPv4: - if hl := ip.HeaderLength(); hl != uint8(headerLength) { - t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) - } - default: - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) - } - } -} - -// PayloadLen creates a checker that checks the payload length. -func PayloadLen(payloadLength int) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if l := len(h[0].Payload()); l != payloadLength { - t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) - } - } -} - -// IPPayload creates a checker that checks the payload. -func IPPayload(payload []byte) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - got := h[0].Payload() - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(got) == 0 && len(payload) == 0 { - return - } - - if diff := cmp.Diff(payload, got); diff != "" { - t.Errorf("payload mismatch (-want +got):\n%s", diff) - } - } -} - -// IPv4Options returns a checker that checks the options in an IPv4 packet. -func IPv4Options(want header.IPv4Options) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - ip, ok := h[0].(header.IPv4) - if !ok { - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) - } - options := ip.Options() - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(want) == 0 && len(options) == 0 { - return - } - if diff := cmp.Diff(want, options); diff != "" { - t.Errorf("options mismatch (-want +got):\n%s", diff) - } - } -} - -// IPv4RouterAlert returns a checker that checks that the RouterAlert option is -// set in an IPv4 packet. -func IPv4RouterAlert() NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - ip, ok := h[0].(header.IPv4) - if !ok { - t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) - } - iterator := ip.Options().MakeIterator() - for { - opt, done, err := iterator.Next() - if err != nil { - t.Fatalf("error acquiring next IPv4 option %s", err) - } - if done { - break - } - if opt.Type() != header.IPv4OptionRouterAlertType { - continue - } - want := [header.IPv4OptionRouterAlertLength]byte{ - byte(header.IPv4OptionRouterAlertType), - header.IPv4OptionRouterAlertLength, - header.IPv4OptionRouterAlertValue, - header.IPv4OptionRouterAlertValue, - } - if diff := cmp.Diff(want[:], opt.Contents()); diff != "" { - t.Errorf("router alert option mismatch (-want +got):\n%s", diff) - } - return - } - t.Errorf("failed to find router alert option in %v", ip.Options()) - } -} - -// FragmentOffset creates a checker that checks the FragmentOffset field. -func FragmentOffset(offset uint16) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this for IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) - } - } - } -} - -// FragmentFlags creates a checker that checks the fragment flags field. -func FragmentFlags(flags uint8) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // We only do this for IPv4 for now. - switch ip := h[0].(type) { - case header.IPv4: - if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) - } - } - } -} - -// ReceiveTClass creates a checker that checks the TCLASS field in -// ControlMessages. -func ReceiveTClass(want uint32) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTClass { - t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) - } else if got := cm.TClass; got != want { - t.Errorf("got cm.TClass = %d, want %d", got, want) - } - } -} - -// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. -func ReceiveTOS(want uint8) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTOS { - t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) - } else if got := cm.TOS; got != want { - t.Errorf("got cm.TOS = %d, want %d", got, want) - } - } -} - -// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in -// ControlMessages. -func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasIPPacketInfo { - t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) - } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { - t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) - } - } -} - -// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress -// field in ControlMessages. -func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasOriginalDstAddress { - t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) - } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { - t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) - } - } -} - -// TOS creates a checker that checks the TOS field. -func TOS(tos uint8, label uint32) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) - } - } -} - -// Raw creates a checker that checks the bytes of payload. -// The checker always checks the payload of the last network header. -// For instance, in case of IPv6 fragments, the payload that will be checked -// is the one containing the actual data that the packet is carrying, without -// the bytes added by the IPv6 fragmentation. -func Raw(want []byte) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// IPv6Fragment creates a checker that validates an IPv6 fragment. -func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) - } - - ipv6Frag := header.IPv6Fragment(h[0].Payload()) - if !ipv6Frag.IsValid() { - t.Error("Not a valid IPv6 fragment") - } - - for _, f := range checkers { - f(t, []header.Network{h[0], ipv6Frag}) - } - if t.Failed() { - t.FailNow() - } - } -} - -// TCP creates a checker that checks that the transport protocol is TCP and -// potentially additional transport header fields. -func TCP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - first := h[0] - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) - } - - // Verify the checksum. - tcp := header.TCP(last.Payload()) - l := uint16(len(tcp)) - - xsum := header.Checksum([]byte(first.SourceAddress()), 0) - xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) - xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) - xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) - xsum = header.Checksum(tcp, xsum) - - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) - } - - // Run the transport checkers. - for _, f := range checkers { - f(t, tcp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// UDP creates a checker that checks that the transport protocol is UDP and -// potentially additional transport header fields. -func UDP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) - } - - udp := header.UDP(last.Payload()) - for _, f := range checkers { - f(t, udp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// SrcPort creates a checker that checks the source port. -func SrcPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got = %d, want = %d", p, port) - } - } -} - -// DstPort creates a checker that checks the destination port. -func DstPort(port uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got = %d, want = %d", p, port) - } - } -} - -// NoChecksum creates a checker that checks if the checksum is zero. -func NoChecksum(noChecksum bool) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - udp, ok := h.(header.UDP) - if !ok { - t.Fatalf("UDP header not found in h: %T", h) - } - - if b := udp.Checksum() == 0; b != noChecksum { - t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) - } - } -} - -// TCPSeqNum creates a checker that checks the sequence number. -func TCPSeqNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) - } - } -} - -// TCPAckNum creates a checker that checks the ack number. -func TCPAckNum(seq uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got = %d, want = %d", s, seq) - } - } -} - -// TCPWindow creates a checker that checks the tcp window. -func TCPWindow(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in hdr : %T", h) - } - - if w := tcp.WindowSize(); w != window { - t.Errorf("Bad window, got %d, want %d", w, window) - } - } -} - -// TCPWindowGreaterThanEq creates a checker that checks that the TCP window -// is greater than or equal to the provided value. -func TCPWindowGreaterThanEq(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if w := tcp.WindowSize(); w < window { - t.Errorf("Bad window, got %d, want > %d", w, window) - } - } -} - -// TCPWindowLessThanEq creates a checker that checks that the tcp window -// is less than or equal to the provided value. -func TCPWindowLessThanEq(window uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if w := tcp.WindowSize(); w > window { - t.Errorf("Bad window, got %d, want < %d", w, window) - } - } -} - -// TCPFlags creates a checker that checks the tcp flags. -func TCPFlags(flags uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if f := tcp.Flags(); f != flags { - t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) - } - } -} - -// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the -// given mask, match the supplied flags. -func TCPFlagsMatch(flags, mask uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - t.Fatalf("TCP header not found in h: %T", h) - } - - if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) - } - } -} - -// TCPSynOptions creates a checker that checks the presence of TCP options in -// SYN segments. -// -// If wndscale is negative, the window scale option must not be present. -func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := tcp.Options() - limit := len(opts) - foundMSS := false - foundWS := false - foundTS := false - foundSACKPermitted := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionMSS: - v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) - if wantOpts.MSS != v { - t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) - } - foundMSS = true - i += 4 - case header.TCPOptionWS: - if wantOpts.WS < 0 { - t.Error("WS present when it shouldn't be") - } - v := int(opts[i+2]) - if v != wantOpts.WS { - t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) - } - foundWS = true - i += 3 - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = uint32(0) - if tcp.Flags()&header.TCPFlagAck != 0 { - // If the syn is an SYN-ACK then read - // the tsEcr value as well. - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - } - foundTS = true - i += 10 - case header.TCPOptionSACKPermitted: - if i+1 >= limit { - t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) - } - if opts[i+1] != 2 { - t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) - } - foundSACKPermitted = true - i += 2 - - default: - i += int(opts[i+1]) - } - } - - if !foundMSS { - t.Errorf("MSS option not found. Options: %x", opts) - } - - if !foundWS && wantOpts.WS >= 0 { - t.Errorf("WS option not found. Options: %x", opts) - } - if wantOpts.TS && !foundTS { - t.Errorf("TS option not found. Options: %x", opts) - } - if foundTS && tsVal == 0 { - t.Error("TS option specified but the timestamp value is zero") - } - if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) - } - if wantOpts.SACKPermitted && !foundSACKPermitted { - t.Errorf("SACKPermitted option not found. Options: %x", opts) - } - } -} - -// TCPTimestampChecker creates a checker that validates that a TCP segment has a -// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and -// wantTSEcr values with those in the TCP segment (if present). -// -// If wantTSVal or wantTSEcr is zero then the corresponding comparison is -// skipped. -func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - tcp, ok := h.(header.TCP) - if !ok { - return - } - opts := []byte(tcp.Options()) - limit := len(opts) - foundTS := false - tsVal := uint32(0) - tsEcr := uint32(0) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionTS: - if i+9 >= limit { - t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) - } - if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) - } - tsVal = binary.BigEndian.Uint32(opts[i+2:]) - tsEcr = binary.BigEndian.Uint32(opts[i+6:]) - foundTS = true - i += 10 - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - return - } - l := int(opts[i+1]) - if i < 2 || i+l > limit { - return - } - i += l - } - } - - if wantTS != foundTS { - t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) - } - if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) - } - if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) - } - } -} - -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does -// not contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - -// TCPSACKBlockChecker creates a checker that verifies that the segment does -// contain the specified SACK blocks in the TCP options. -func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - tcp, ok := h.(header.TCP) - if !ok { - return - } - var gotSACKBlocks []header.SACKBlock - - opts := []byte(tcp.Options()) - limit := len(opts) - for i := 0; i < limit; { - switch opts[i] { - case header.TCPOptionEOL: - i = limit - case header.TCPOptionNOP: - i++ - case header.TCPOptionSACK: - if i+2 > limit { - // Malformed SACK block. - t.Errorf("malformed SACK option in options: %v", opts) - } - sackOptionLen := int(opts[i+1]) - if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { - // Malformed SACK block. - t.Errorf("malformed SACK option length in options: %v", opts) - } - numBlocks := sackOptionLen / 8 - for j := 0; j < numBlocks; j++ { - start := binary.BigEndian.Uint32(opts[i+2+j*8:]) - end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) - gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ - Start: seqnum.Value(start), - End: seqnum.Value(end), - }) - } - i += sackOptionLen - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - break - } - l := int(opts[i+1]) - if l < 2 || i+l > limit { - break - } - i += l - } - } - - if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) - } - } -} - -// Payload creates a checker that checks the payload. -func Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - if got := h.Payload(); !reflect.DeepEqual(got, want) { - t.Errorf("Wrong payload, got %v, want %v", got, want) - } - } -} - -// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 -// and potentially additional ICMPv4 header fields. -func ICMPv4(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber) - } - - icmp := header.ICMPv4(last.Payload()) - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv4Type creates a checker that checks the ICMPv4 Type field. -func ICMPv4Type(want header.ICMPv4Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Code creates a checker that checks the ICMPv4 Code field. -func ICMPv4Code(want header.ICMPv4Code) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. -func ICMPv4Ident(want uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Ident(); got != want { - t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. -func ICMPv4Seq(want uint16) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Sequence(); got != want { - t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. -func ICMPv4Pointer(want uint8) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - if got := icmpv4.Pointer(); got != want { - t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) - } - } -} - -// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. -// This assumes that the payload exactly makes up the rest of the slice. -func ICMPv4Checksum() TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - heldChecksum := icmpv4.Checksum() - icmpv4.SetChecksum(0) - newChecksum := ^header.Checksum(icmpv4, 0) - icmpv4.SetChecksum(heldChecksum) - if heldChecksum != newChecksum { - t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) - } - } -} - -// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. -func ICMPv4Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv4, ok := h.(header.ICMPv4) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) - } - payload := icmpv4.Payload() - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(want) == 0 && len(payload) == 0 { - return - } - - if diff := cmp.Diff(want, payload); diff != "" { - t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) - } - } -} - -// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and -// potentially additional ICMPv6 header fields. -// -// ICMPv6 will validate the checksum field before calling checkers. -func ICMPv6(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber) - } - - icmp := header.ICMPv6(last.Payload()) - if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { - t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// ICMPv6Type creates a checker that checks the ICMPv6 Type field. -func ICMPv6Type(want header.ICMPv6Type) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) - } - } -} - -// ICMPv6Code creates a checker that checks the ICMPv6 Code field. -func ICMPv6Code(want header.ICMPv6Code) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) - } - } -} - -// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific -// field. -func ICMPv6TypeSpecific(want uint32) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - if got := icmpv6.TypeSpecific(); got != want { - t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) - } - } -} - -// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. -func ICMPv6Payload(want []byte) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmpv6, ok := h.(header.ICMPv6) - if !ok { - t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) - } - payload := icmpv6.Payload() - - // cmp.Diff does not consider nil slices equal to empty slices, but we do. - if len(want) == 0 && len(payload) == 0 { - return - } - - if diff := cmp.Diff(want, payload); diff != "" { - t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) - } - } -} - -// MLD creates a checker that checks that the packet contains a valid MLD -// message for type of mldType, with potentially additional checks specified by -// checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// MLD message as far as the size of the message (minSize) is concerned. The -// values within the message are up to checkers to validate. -func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // Check normal ICMPv6 first. - ICMPv6( - ICMPv6Type(msgType), - ICMPv6Code(0))(t, h) - - last := h[len(h)-1] - - icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.MessageBody()); got < minSize { - t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay -// field of a MLD message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid MLD message as far as the size is concerned. -func MLDMaxRespDelay(want time.Duration) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.MLD(icmp.MessageBody()) - - if got := ns.MaximumResponseDelay(); got != want { - t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want) - } - } -} - -// MLDMulticastAddress creates a checker that checks the Multicast Address -// field of a MLD message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid MLD message as far as the size is concerned. -func MLDMulticastAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.MLD(icmp.MessageBody()) - - if got := ns.MulticastAddress(); got != want { - t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want) - } - } -} - -// NDP creates a checker that checks that the packet contains a valid NDP -// message for type of ty, with potentially additional checks specified by -// checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDP message as far as the size of the message (minSize) is concerned. The -// values within the message are up to checkers to validate. -func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - // Check normal ICMPv6 first. - ICMPv6( - ICMPv6Type(msgType), - ICMPv6Code(0))(t, h) - - last := h[len(h)-1] - - icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.MessageBody()); got < minSize { - t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) - } - - for _, f := range checkers { - f(t, icmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// NDPNS creates a checker that checks that the packet contains a valid NDP -// Neighbor Solicitation message (as per the raw wire format), with potentially -// additional checks specified by checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNS message as far as the size of the message is concerned. The values -// within the message are up to checkers to validate. -func NDPNS(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) -} - -// NDPNSTargetAddress creates a checker that checks the Target Address field of -// a header.NDPNeighborSolicit. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNS message as far as the size is concerned. -func NDPNSTargetAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - - if got := ns.TargetAddress(); got != want { - t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) - } - } -} - -// NDPNA creates a checker that checks that the packet contains a valid NDP -// Neighbor Advertisement message (as per the raw wire format), with potentially -// additional checks specified by checkers. -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPNA message as far as the size of the message is concerned. The values -// within the message are up to checkers to validate. -func NDPNA(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...) -} - -// NDPNATargetAddress creates a checker that checks the Target Address field of -// a header.NDPNeighborAdvert. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNA message as far as the size is concerned. -func NDPNATargetAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - - if got := na.TargetAddress(); got != want { - t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) - } - } -} - -// NDPNASolicitedFlag creates a checker that checks the Solicited field of -// a header.NDPNeighborAdvert. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNA message as far as the size is concerned. -func NDPNASolicitedFlag(want bool) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - - if got := na.SolicitedFlag(); got != want { - t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) - } - } -} - -// ndpOptions checks that optsBuf only contains opts. -func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { - t.Helper() - - it, err := optsBuf.Iter(true) - if err != nil { - t.Errorf("optsBuf.Iter(true): %s", err) - return - } - - i := 0 - for { - opt, done, err := it.Next() - if err != nil { - // This should never happen as Iter(true) above did not return an error. - t.Fatalf("unexpected error when iterating over NDP options: %s", err) - } - if done { - break - } - - if i >= len(opts) { - t.Errorf("got unexpected option: %s", opt) - continue - } - - switch wantOpt := opts[i].(type) { - case header.NDPSourceLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - case header.NDPTargetLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - default: - t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) - } - - i++ - } - - if missing := opts[i:]; len(missing) > 0 { - t.Errorf("missing options: %s", missing) - } -} - -// NDPNAOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Neighbor Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNA message as far as the size is concerned. -func NDPNAOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.MessageBody()) - ndpOptions(t, na.Options(), opts) - } -} - -// NDPNSOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Neighbor Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPNS message as far as the size is concerned. -func NDPNSOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.MessageBody()) - ndpOptions(t, ns.Options(), opts) - } -} - -// NDPRS creates a checker that checks that the packet contains a valid NDP -// Router Solicitation message (as per the raw wire format). -// -// Checkers may assume that a valid ICMPv6 is passed to it containing a valid -// NDPRS as far as the size of the message is concerned. The values within the -// message are up to checkers to validate. -func NDPRS(checkers ...TransportChecker) NetworkChecker { - return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) -} - -// NDPRSOptions creates a checker that checks that the packet contains the -// provided NDP options within an NDP Router Solicitation message. -// -// The returned TransportChecker assumes that a valid ICMPv6 is passed to it -// containing a valid NDPRS message as far as the size is concerned. -func NDPRSOptions(opts []header.NDPOption) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - icmp := h.(header.ICMPv6) - rs := header.NDPRouterSolicit(icmp.MessageBody()) - ndpOptions(t, rs.Options(), opts) - } -} - -// IGMP checks the validity and properties of the given IGMP packet. It is -// expected to be used in conjunction with other IGMP transport checkers for -// specific properties. -func IGMP(checkers ...TransportChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - last := h[len(h)-1] - - if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { - t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) - } - - igmp := header.IGMP(last.Payload()) - for _, f := range checkers { - f(t, igmp) - } - if t.Failed() { - t.FailNow() - } - } -} - -// IGMPType creates a checker that checks the IGMP Type field. -func IGMPType(want header.IGMPType) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - igmp, ok := h.(header.IGMP) - if !ok { - t.Fatalf("got transport header = %T, want = header.IGMP", h) - } - if got := igmp.Type(); got != want { - t.Errorf("got igmp.Type() = %d, want = %d", got, want) - } - } -} - -// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. -func IGMPMaxRespTime(want time.Duration) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - igmp, ok := h.(header.IGMP) - if !ok { - t.Fatalf("got transport header = %T, want = header.IGMP", h) - } - if got := igmp.MaxRespTime(); got != want { - t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want) - } - } -} - -// IGMPGroupAddress creates a checker that checks the IGMP Group Address field. -func IGMPGroupAddress(want tcpip.Address) TransportChecker { - return func(t *testing.T, h header.Transport) { - t.Helper() - - igmp, ok := h.(header.IGMP) - if !ok { - t.Fatalf("got transport header = %T, want = header.IGMP", h) - } - if got := igmp.GroupAddress(); got != want { - t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) - } - } -} - -// IPv6ExtHdrChecker is a function to check an extension header. -type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) - -// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. -func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { - t.Helper() - - ipv6 := header.IPv6(b) - if !ipv6.IsValid(len(b)) { - t.Error("not a valid IPv6 packet") - return - } - - payloadIterator := header.MakeIPv6PayloadIterator( - header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), - buffer.View(ipv6.Payload()).ToVectorisedView(), - ) - - var rawPayloadHeader header.IPv6RawPayloadHeader - for { - h, done, err := payloadIterator.Next() - if err != nil { - t.Errorf("payloadIterator.Next(): %s", err) - return - } - if done { - t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) - return - } - r, ok := h.(header.IPv6RawPayloadHeader) - if ok { - rawPayloadHeader = r - break - } - } - - networkHeader := ipv6HeaderWithExtHdr{ - IPv6: ipv6, - transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), - payload: rawPayloadHeader.Buf.ToView(), - } - - for _, checker := range checkers { - checker(t, []header.Network{&networkHeader}) - } -} - -// IPv6ExtHdr checks for the presence of extension headers. -// -// All the extension headers in headers will be checked exhaustively in the -// order provided. -func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { - return func(t *testing.T, h []header.Network) { - t.Helper() - - extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) - if !ok { - t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) - return - } - - payloadIterator := header.MakeIPv6PayloadIterator( - header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), - buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), - ) - - for _, check := range headers { - h, done, err := payloadIterator.Next() - if err != nil { - t.Errorf("payloadIterator.Next(): %s", err) - return - } - if done { - t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) - return - } - check(t, h) - } - // Validate we consumed all headers. - // - // The next one over should be a raw payload and then iterator should - // terminate. - wantDone := false - for { - h, done, err := payloadIterator.Next() - if err != nil { - t.Errorf("payloadIterator.Next(): %s", err) - return - } - if done != wantDone { - t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) - return - } - if done { - break - } - if _, ok := h.(header.IPv6RawPayloadHeader); !ok { - t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) - continue - } - wantDone = true - } - } -} - -var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) - -// ipv6HeaderWithExtHdr provides a header.Network implementation that takes -// extension headers into consideration, which is not the case with vanilla -// header.IPv6. -type ipv6HeaderWithExtHdr struct { - header.IPv6 - transport tcpip.TransportProtocolNumber - payload []byte -} - -// TransportProtocol implements header.Network. -func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { - return h.transport -} - -// Payload implements header.Network. -func (h *ipv6HeaderWithExtHdr) Payload() []byte { - return h.payload -} - -// IPv6ExtHdrOptionChecker is a function to check an extension header option. -type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) - -// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop -// extension header and validates the containing options with checkers. -// -// checkers must exhaustively contain all the expected options. -func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { - return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { - t.Helper() - - hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) - if !ok { - t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) - return - } - optionsIterator := hbh.Iter() - for _, f := range checkers { - opt, done, err := optionsIterator.Next() - if err != nil { - t.Errorf("optionsIterator.Next(): %s", err) - return - } - if done { - t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) - } - f(t, opt) - } - // Validate all options were consumed. - for { - opt, done, err := optionsIterator.Next() - if err != nil { - t.Errorf("optionsIterator.Next(): %s", err) - return - } - if !done { - t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) - } - if done { - break - } - } - } -} - -// IPv6RouterAlert validates that an extension header option is the RouterAlert -// option and matches on its value. -func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { - return func(t *testing.T, opt header.IPv6ExtHdrOption) { - routerAlert, ok := opt.(*header.IPv6RouterAlertOption) - if !ok { - t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) - return - } - if routerAlert.Value != want { - t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) - } - } -} |