summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/checker
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/checker')
-rw-r--r--pkg/tcpip/checker/checker.go414
1 files changed, 407 insertions, 7 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 8868cf4e3..0ac2000ca 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -20,6 +20,7 @@ import (
"encoding/binary"
"reflect"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -116,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.TTL()
case header.IPv6:
v = ip.HopLimit()
+ case *ipv6HeaderWithExtHdr:
+ v = ip.HopLimit()
+ default:
+ t.Fatalf("unrecognized header type %T for TTL evaluation", ip)
}
if v != ttl {
t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
@@ -216,6 +221,42 @@ func IPv4Options(want header.IPv4Options) NetworkChecker {
}
}
+// IPv4RouterAlert returns a checker that checks that the RouterAlert option is
+// set in an IPv4 packet.
+func IPv4RouterAlert() NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+ ip, ok := h[0].(header.IPv4)
+ if !ok {
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
+ }
+ iterator := ip.Options().MakeIterator()
+ for {
+ opt, done, err := iterator.Next()
+ if err != nil {
+ t.Fatalf("error acquiring next IPv4 option %s", err)
+ }
+ if done {
+ break
+ }
+ if opt.Type() != header.IPv4OptionRouterAlertType {
+ continue
+ }
+ want := [header.IPv4OptionRouterAlertLength]byte{
+ byte(header.IPv4OptionRouterAlertType),
+ header.IPv4OptionRouterAlertLength,
+ header.IPv4OptionRouterAlertValue,
+ header.IPv4OptionRouterAlertValue,
+ }
+ if diff := cmp.Diff(want[:], opt.Contents()); diff != "" {
+ t.Errorf("router alert option mismatch (-want +got):\n%s", diff)
+ }
+ return
+ }
+ t.Errorf("failed to find router alert option in %v", ip.Options())
+ }
+}
+
// FragmentOffset creates a checker that checks the FragmentOffset field.
func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -284,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
}
}
+// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress
+// field in ControlMessages.
+func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasOriginalDstAddress {
+ t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress)
+ } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" {
+ t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -904,6 +958,12 @@ func ICMPv4Payload(want []byte) TransportChecker {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
payload := icmpv4.Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(payload) == 0 {
+ return
+ }
+
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
@@ -994,12 +1054,86 @@ func ICMPv6Payload(want []byte) TransportChecker {
t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
payload := icmpv6.Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(payload) == 0 {
+ return
+ }
+
if diff := cmp.Diff(want, payload); diff != "" {
t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
+// MLD creates a checker that checks that the packet contains a valid MLD
+// message for type of mldType, with potentially additional checks specified by
+// checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// MLD message as far as the size of the message (minSize) is concerned. The
+// values within the message are up to checkers to validate.
+func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // Check normal ICMPv6 first.
+ ICMPv6(
+ ICMPv6Type(msgType),
+ ICMPv6Code(0))(t, h)
+
+ last := h[len(h)-1]
+
+ icmp := header.ICMPv6(last.Payload())
+ if got := len(icmp.MessageBody()); got < minSize {
+ t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
+ }
+
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMaxRespDelay(want time.Duration) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MaximumResponseDelay(); got != want {
+ t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
+// MLDMulticastAddress creates a checker that checks the Multicast Address
+// field of a MLD message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid MLD message as far as the size is concerned.
+func MLDMulticastAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.MLD(icmp.MessageBody())
+
+ if got := ns.MulticastAddress(); got != want {
+ t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
// NDP creates a checker that checks that the packet contains a valid NDP
// message for type of ty, with potentially additional checks specified by
// checkers.
@@ -1019,7 +1153,7 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N
last := h[len(h)-1]
icmp := header.ICMPv6(last.Payload())
- if got := len(icmp.NDPPayload()); got < minSize {
+ if got := len(icmp.MessageBody()); got < minSize {
t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
}
@@ -1053,7 +1187,7 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
if got := ns.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
@@ -1082,7 +1216,7 @@ func NDPNATargetAddress(want tcpip.Address) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.TargetAddress(); got != want {
t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
@@ -1100,7 +1234,7 @@ func NDPNASolicitedFlag(want bool) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
if got := na.SolicitedFlag(); got != want {
t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
@@ -1171,7 +1305,7 @@ func NDPNAOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na := header.NDPNeighborAdvert(icmp.MessageBody())
ndpOptions(t, na.Options(), opts)
}
}
@@ -1186,7 +1320,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns := header.NDPNeighborSolicit(icmp.MessageBody())
ndpOptions(t, ns.Options(), opts)
}
}
@@ -1211,7 +1345,273 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker {
t.Helper()
icmp := h.(header.ICMPv6)
- rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ rs := header.NDPRouterSolicit(icmp.MessageBody())
ndpOptions(t, rs.Options(), opts)
}
}
+
+// IGMP checks the validity and properties of the given IGMP packet. It is
+// expected to be used in conjunction with other IGMP transport checkers for
+// specific properties.
+func IGMP(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.IGMPProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber)
+ }
+
+ igmp := header.IGMP(last.Payload())
+ for _, f := range checkers {
+ f(t, igmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// IGMPType creates a checker that checks the IGMP Type field.
+func IGMPType(want header.IGMPType) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.Type(); got != want {
+ t.Errorf("got igmp.Type() = %d, want = %d", got, want)
+ }
+ }
+}
+
+// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field.
+func IGMPMaxRespTime(want time.Duration) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.MaxRespTime(); got != want {
+ t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want)
+ }
+ }
+}
+
+// IGMPGroupAddress creates a checker that checks the IGMP Group Address field.
+func IGMPGroupAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ igmp, ok := h.(header.IGMP)
+ if !ok {
+ t.Fatalf("got transport header = %T, want = header.IGMP", h)
+ }
+ if got := igmp.GroupAddress(); got != want {
+ t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want)
+ }
+ }
+}
+
+// IPv6ExtHdrChecker is a function to check an extension header.
+type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader)
+
+// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers.
+func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ t.Helper()
+
+ ipv6 := header.IPv6(b)
+ if !ipv6.IsValid(len(b)) {
+ t.Error("not a valid IPv6 packet")
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()),
+ buffer.View(ipv6.Payload()).ToVectorisedView(),
+ )
+
+ var rawPayloadHeader header.IPv6RawPayloadHeader
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done)
+ return
+ }
+ r, ok := h.(header.IPv6RawPayloadHeader)
+ if ok {
+ rawPayloadHeader = r
+ break
+ }
+ }
+
+ networkHeader := ipv6HeaderWithExtHdr{
+ IPv6: ipv6,
+ transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier),
+ payload: rawPayloadHeader.Buf.ToView(),
+ }
+
+ for _, checker := range checkers {
+ checker(t, []header.Network{&networkHeader})
+ }
+}
+
+// IPv6ExtHdr checks for the presence of extension headers.
+//
+// All the extension headers in headers will be checked exhaustively in the
+// order provided.
+func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr)
+ if !ok {
+ t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0])
+ return
+ }
+
+ payloadIterator := header.MakeIPv6PayloadIterator(
+ header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()),
+ buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(),
+ )
+
+ for _, check := range headers {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done)
+ return
+ }
+ check(t, h)
+ }
+ // Validate we consumed all headers.
+ //
+ // The next one over should be a raw payload and then iterator should
+ // terminate.
+ wantDone := false
+ for {
+ h, done, err := payloadIterator.Next()
+ if err != nil {
+ t.Errorf("payloadIterator.Next(): %s", err)
+ return
+ }
+ if done != wantDone {
+ t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone)
+ return
+ }
+ if done {
+ break
+ }
+ if _, ok := h.(header.IPv6RawPayloadHeader); !ok {
+ t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h)
+ continue
+ }
+ wantDone = true
+ }
+ }
+}
+
+var _ header.Network = (*ipv6HeaderWithExtHdr)(nil)
+
+// ipv6HeaderWithExtHdr provides a header.Network implementation that takes
+// extension headers into consideration, which is not the case with vanilla
+// header.IPv6.
+type ipv6HeaderWithExtHdr struct {
+ header.IPv6
+ transport tcpip.TransportProtocolNumber
+ payload []byte
+}
+
+// TransportProtocol implements header.Network.
+func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber {
+ return h.transport
+}
+
+// Payload implements header.Network.
+func (h *ipv6HeaderWithExtHdr) Payload() []byte {
+ return h.payload
+}
+
+// IPv6ExtHdrOptionChecker is a function to check an extension header option.
+type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption)
+
+// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop
+// extension header and validates the containing options with checkers.
+//
+// checkers must exhaustively contain all the expected options.
+func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker {
+ return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) {
+ t.Helper()
+
+ hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr)
+ if !ok {
+ t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader)
+ return
+ }
+ optionsIterator := hbh.Iter()
+ for _, f := range checkers {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done)
+ }
+ f(t, opt)
+ }
+ // Validate all options were consumed.
+ for {
+ opt, done, err := optionsIterator.Next()
+ if err != nil {
+ t.Errorf("optionsIterator.Next(): %s", err)
+ return
+ }
+ if !done {
+ t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done)
+ }
+ if done {
+ break
+ }
+ }
+ }
+}
+
+// IPv6RouterAlert validates that an extension header option is the RouterAlert
+// option and matches on its value.
+func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker {
+ return func(t *testing.T, opt header.IPv6ExtHdrOption) {
+ routerAlert, ok := opt.(*header.IPv6RouterAlertOption)
+ if !ok {
+ t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt)
+ return
+ }
+ if routerAlert.Value != want {
+ t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want)
+ }
+ }
+}
+
+// IgnoreCmpPath returns a cmp.Option that ignores listed field paths.
+func IgnoreCmpPath(paths ...string) cmp.Option {
+ ignores := map[string]struct{}{}
+ for _, path := range paths {
+ ignores[path] = struct{}{}
+ }
+ return cmp.FilterPath(func(path cmp.Path) bool {
+ _, ok := ignores[path.String()]
+ return ok
+ }, cmp.Ignore())
+}