summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/checker
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/checker')
-rw-r--r--pkg/tcpip/checker/BUILD4
-rw-r--r--pkg/tcpip/checker/checker.go246
2 files changed, 241 insertions, 9 deletions
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
index b6fa6fc37..c984470e6 100644
--- a/pkg/tcpip/checker/BUILD
+++ b/pkg/tcpip/checker/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,12 +6,12 @@ go_library(
name = "checker",
testonly = 1,
srcs = ["checker.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/checker",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 2f15bf1f1..b769094dc 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -21,6 +21,7 @@ import (
"reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -33,6 +34,9 @@ type NetworkChecker func(*testing.T, []header.Network)
// TransportChecker is a function to check a property of a transport packet.
type TransportChecker func(*testing.T, header.Transport)
+// ControlMessagesChecker is a function to check a property of ancillary data.
+type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
+
// IPv4 checks the validity and properties of the given IPv4 packet. It is
// expected to be used in conjunction with other network checkers for specific
// properties. For example, to check the source and destination address, one
@@ -104,6 +108,8 @@ func DstAddr(addr tcpip.Address) NetworkChecker {
// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
func TTL(ttl uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
var v uint8
switch ip := h[0].(type) {
case header.IPv4:
@@ -158,6 +164,44 @@ func FragmentFlags(flags uint8) NetworkChecker {
}
}
+// ReceiveTClass creates a checker that checks the TCLASS field in
+// ControlMessages.
+func ReceiveTClass(want uint32) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTClass {
+ t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass)
+ } else if got := cm.TClass; got != want {
+ t.Errorf("got cm.TClass = %d, want %d", got, want)
+ }
+ }
+}
+
+// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
+func ReceiveTOS(want uint8) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTOS {
+ t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS)
+ } else if got := cm.TOS; got != want {
+ t.Errorf("got cm.TOS = %d, want %d", got, want)
+ }
+ }
+}
+
+// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
+// ControlMessages.
+func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPPacketInfo {
+ t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo)
+ } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
+ t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -280,12 +324,30 @@ func SrcPort(port uint16) TransportChecker {
// DstPort creates a checker that checks the destination port.
func DstPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if p := h.DestinationPort(); p != port {
t.Errorf("Bad destination port, got %v, want %v", p, port)
}
}
}
+// NoChecksum creates a checker that checks if the checksum is zero.
+func NoChecksum(noChecksum bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ udp, ok := h.(header.UDP)
+ if !ok {
+ return
+ }
+
+ if b := udp.Checksum() == 0; b != noChecksum {
+ t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
+ }
+ }
+}
+
// SeqNum creates a checker that checks the sequence number.
func SeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
@@ -306,6 +368,7 @@ func SeqNum(seq uint32) TransportChecker {
func AckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -320,6 +383,8 @@ func AckNum(seq uint32) TransportChecker {
// Window creates a checker that checks the tcp window.
func Window(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -351,6 +416,8 @@ func TCPFlags(flags uint8) TransportChecker {
// given mask, match the supplied flags.
func TCPFlagsMatch(flags, mask uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -368,6 +435,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
// If wndscale is negative, the window scale option must not be present.
func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -464,6 +533,8 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
// skipped.
func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -582,6 +653,8 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
// Payload creates a checker that checks the payload.
func Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if got := h.Payload(); !reflect.DeepEqual(got, want) {
t.Errorf("Wrong payload, got %v, want %v", got, want)
}
@@ -614,6 +687,7 @@ func ICMPv4(checkers ...TransportChecker) NetworkChecker {
func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
@@ -625,9 +699,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
}
// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
-func ICMPv4Code(want byte) TransportChecker {
+func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
@@ -670,6 +745,7 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker {
func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
@@ -681,9 +757,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
}
// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
-func ICMPv6Code(want byte) TransportChecker {
+func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
@@ -698,7 +775,7 @@ func ICMPv6Code(want byte) TransportChecker {
// message for type of ty, with potentially additional checks specified by
// checkers.
//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
// NDP message as far as the size of the message (minSize) is concerned. The
// values within the message are up to checkers to validate.
func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
@@ -730,9 +807,9 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N
// Neighbor Solicitation message (as per the raw wire format), with potentially
// additional checks specified by checkers.
//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNS message as far as the size of the messages concerned. The values within
-// the message are up to checkers to validate.
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNS message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
func NDPNS(checkers ...TransportChecker) NetworkChecker {
return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
}
@@ -750,7 +827,162 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
ns := header.NDPNeighborSolicit(icmp.NDPPayload())
if got := ns.TargetAddress(); got != want {
- t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want)
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
+// NDPNA creates a checker that checks that the packet contains a valid NDP
+// Neighbor Advertisement message (as per the raw wire format), with potentially
+// additional checks specified by checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNA message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
+func NDPNA(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
+}
+
+// NDPNATargetAddress creates a checker that checks the Target Address field of
+// a header.NDPNeighborAdvert.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNATargetAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+
+ if got := na.TargetAddress(); got != want {
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
}
}
}
+
+// NDPNASolicitedFlag creates a checker that checks the Solicited field of
+// a header.NDPNeighborAdvert.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNASolicitedFlag(want bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+
+ if got := na.SolicitedFlag(); got != want {
+ t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
+ }
+ }
+}
+
+// ndpOptions checks that optsBuf only contains opts.
+func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
+ t.Helper()
+
+ it, err := optsBuf.Iter(true)
+ if err != nil {
+ t.Errorf("optsBuf.Iter(true): %s", err)
+ return
+ }
+
+ i := 0
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ t.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
+ if done {
+ break
+ }
+
+ if i >= len(opts) {
+ t.Errorf("got unexpected option: %s", opt)
+ continue
+ }
+
+ switch wantOpt := opts[i].(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ case header.NDPTargetLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ default:
+ t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
+ }
+
+ i++
+ }
+
+ if missing := opts[i:]; len(missing) > 0 {
+ t.Errorf("missing options: %s", missing)
+ }
+}
+
+// NDPNAOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNAOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ ndpOptions(t, na.Options(), opts)
+ }
+}
+
+// NDPNSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNS message as far as the size is concerned.
+func NDPNSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ndpOptions(t, ns.Options(), opts)
+ }
+}
+
+// NDPRS creates a checker that checks that the packet contains a valid NDP
+// Router Solicitation message (as per the raw wire format).
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPRS as far as the size of the message is concerned. The values within the
+// message are up to checkers to validate.
+func NDPRS(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
+}
+
+// NDPRSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Router Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPRS message as far as the size is concerned.
+func NDPRSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ ndpOptions(t, rs.Options(), opts)
+ }
+}