diff options
Diffstat (limited to 'pkg/tcpip/checker')
-rw-r--r-- | pkg/tcpip/checker/checker.go | 414 |
1 files changed, 407 insertions, 7 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 8868cf4e3..0ac2000ca 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "reflect" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -116,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker { v = ip.TTL() case header.IPv6: v = ip.HopLimit() + case *ipv6HeaderWithExtHdr: + v = ip.HopLimit() + default: + t.Fatalf("unrecognized header type %T for TTL evaluation", ip) } if v != ttl { t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) @@ -216,6 +221,42 @@ func IPv4Options(want header.IPv4Options) NetworkChecker { } } +// IPv4RouterAlert returns a checker that checks that the RouterAlert option is +// set in an IPv4 packet. +func IPv4RouterAlert() NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + ip, ok := h[0].(header.IPv4) + if !ok { + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) + } + iterator := ip.Options().MakeIterator() + for { + opt, done, err := iterator.Next() + if err != nil { + t.Fatalf("error acquiring next IPv4 option %s", err) + } + if done { + break + } + if opt.Type() != header.IPv4OptionRouterAlertType { + continue + } + want := [header.IPv4OptionRouterAlertLength]byte{ + byte(header.IPv4OptionRouterAlertType), + header.IPv4OptionRouterAlertLength, + header.IPv4OptionRouterAlertValue, + header.IPv4OptionRouterAlertValue, + } + if diff := cmp.Diff(want[:], opt.Contents()); diff != "" { + t.Errorf("router alert option mismatch (-want +got):\n%s", diff) + } + return + } + t.Errorf("failed to find router alert option in %v", ip.Options()) + } +} + // FragmentOffset creates a checker that checks the FragmentOffset field. func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -284,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress +// field in ControlMessages. +func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasOriginalDstAddress { + t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) + } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { + t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -904,6 +958,12 @@ func ICMPv4Payload(want []byte) TransportChecker { t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } payload := icmpv4.Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(payload) == 0 { + return + } + if diff := cmp.Diff(want, payload); diff != "" { t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } @@ -994,12 +1054,86 @@ func ICMPv6Payload(want []byte) TransportChecker { t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } payload := icmpv6.Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(payload) == 0 { + return + } + if diff := cmp.Diff(want, payload); diff != "" { t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } } } +// MLD creates a checker that checks that the packet contains a valid MLD +// message for type of mldType, with potentially additional checks specified by +// checkers. +// +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// MLD message as far as the size of the message (minSize) is concerned. The +// values within the message are up to checkers to validate. +func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + // Check normal ICMPv6 first. + ICMPv6( + ICMPv6Type(msgType), + ICMPv6Code(0))(t, h) + + last := h[len(h)-1] + + icmp := header.ICMPv6(last.Payload()) + if got := len(icmp.MessageBody()); got < minSize { + t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) + } + + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay +// field of a MLD message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid MLD message as far as the size is concerned. +func MLDMaxRespDelay(want time.Duration) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.MLD(icmp.MessageBody()) + + if got := ns.MaximumResponseDelay(); got != want { + t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want) + } + } +} + +// MLDMulticastAddress creates a checker that checks the Multicast Address +// field of a MLD message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid MLD message as far as the size is concerned. +func MLDMulticastAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.MLD(icmp.MessageBody()) + + if got := ns.MulticastAddress(); got != want { + t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want) + } + } +} + // NDP creates a checker that checks that the packet contains a valid NDP // message for type of ty, with potentially additional checks specified by // checkers. @@ -1019,7 +1153,7 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N last := h[len(h)-1] icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.NDPPayload()); got < minSize { + if got := len(icmp.MessageBody()); got < minSize { t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) } @@ -1053,7 +1187,7 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) if got := ns.TargetAddress(); got != want { t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) @@ -1082,7 +1216,7 @@ func NDPNATargetAddress(want tcpip.Address) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) if got := na.TargetAddress(); got != want { t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) @@ -1100,7 +1234,7 @@ func NDPNASolicitedFlag(want bool) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) if got := na.SolicitedFlag(); got != want { t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) @@ -1171,7 +1305,7 @@ func NDPNAOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) ndpOptions(t, na.Options(), opts) } } @@ -1186,7 +1320,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ndpOptions(t, ns.Options(), opts) } } @@ -1211,7 +1345,273 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - rs := header.NDPRouterSolicit(icmp.NDPPayload()) + rs := header.NDPRouterSolicit(icmp.MessageBody()) ndpOptions(t, rs.Options(), opts) } } + +// IGMP checks the validity and properties of the given IGMP packet. It is +// expected to be used in conjunction with other IGMP transport checkers for +// specific properties. +func IGMP(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) + } + + igmp := header.IGMP(last.Payload()) + for _, f := range checkers { + f(t, igmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// IGMPType creates a checker that checks the IGMP Type field. +func IGMPType(want header.IGMPType) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.Type(); got != want { + t.Errorf("got igmp.Type() = %d, want = %d", got, want) + } + } +} + +// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. +func IGMPMaxRespTime(want time.Duration) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.MaxRespTime(); got != want { + t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want) + } + } +} + +// IGMPGroupAddress creates a checker that checks the IGMP Group Address field. +func IGMPGroupAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.GroupAddress(); got != want { + t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) + } + } +} + +// IPv6ExtHdrChecker is a function to check an extension header. +type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) + +// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. +func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + + ipv6 := header.IPv6(b) + if !ipv6.IsValid(len(b)) { + t.Error("not a valid IPv6 packet") + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var rawPayloadHeader header.IPv6RawPayloadHeader + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) + return + } + r, ok := h.(header.IPv6RawPayloadHeader) + if ok { + rawPayloadHeader = r + break + } + } + + networkHeader := ipv6HeaderWithExtHdr{ + IPv6: ipv6, + transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), + payload: rawPayloadHeader.Buf.ToView(), + } + + for _, checker := range checkers { + checker(t, []header.Network{&networkHeader}) + } +} + +// IPv6ExtHdr checks for the presence of extension headers. +// +// All the extension headers in headers will be checked exhaustively in the +// order provided. +func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) + if !ok { + t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), + buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), + ) + + for _, check := range headers { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) + return + } + check(t, h) + } + // Validate we consumed all headers. + // + // The next one over should be a raw payload and then iterator should + // terminate. + wantDone := false + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done != wantDone { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) + return + } + if done { + break + } + if _, ok := h.(header.IPv6RawPayloadHeader); !ok { + t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) + continue + } + wantDone = true + } + } +} + +var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) + +// ipv6HeaderWithExtHdr provides a header.Network implementation that takes +// extension headers into consideration, which is not the case with vanilla +// header.IPv6. +type ipv6HeaderWithExtHdr struct { + header.IPv6 + transport tcpip.TransportProtocolNumber + payload []byte +} + +// TransportProtocol implements header.Network. +func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { + return h.transport +} + +// Payload implements header.Network. +func (h *ipv6HeaderWithExtHdr) Payload() []byte { + return h.payload +} + +// IPv6ExtHdrOptionChecker is a function to check an extension header option. +type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) + +// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop +// extension header and validates the containing options with checkers. +// +// checkers must exhaustively contain all the expected options. +func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { + return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { + t.Helper() + + hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) + if !ok { + t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) + return + } + optionsIterator := hbh.Iter() + for _, f := range checkers { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + f(t, opt) + } + // Validate all options were consumed. + for { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if !done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + if done { + break + } + } + } +} + +// IPv6RouterAlert validates that an extension header option is the RouterAlert +// option and matches on its value. +func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { + return func(t *testing.T, opt header.IPv6ExtHdrOption) { + routerAlert, ok := opt.(*header.IPv6RouterAlertOption) + if !ok { + t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) + return + } + if routerAlert.Value != want { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) + } + } +} + +// IgnoreCmpPath returns a cmp.Option that ignores listed field paths. +func IgnoreCmpPath(paths ...string) cmp.Option { + ignores := map[string]struct{}{} + for _, path := range paths { + ignores[path] = struct{}{} + } + return cmp.FilterPath(func(path cmp.Path) bool { + _, ok := ignores[path.String()] + return ok + }, cmp.Ignore()) +} |