summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/testutil/testutil.go
diff options
context:
space:
mode:
authorArthur Sfez <asfez@google.com>2021-05-11 10:23:06 -0700
committergVisor bot <gvisor-bot@google.com>2021-05-11 10:25:33 -0700
commit60bdf7ed31bc32bb8413cfdb67cd906ca3d5955a (patch)
treec78188b5336a782abc8d84151269d50fa268d7bc /pkg/tcpip/testutil/testutil.go
parent1daabac237ffb2b7d5711d87bfadc531dc457d08 (diff)
Move multicounter testutil functions out of network/ip
This is in preparation of having aggregated NIC stats at the stack level. These validation functions will be needed outside of the network layer packages to test aggregated NIC stats. PiperOrigin-RevId: 373180565
Diffstat (limited to 'pkg/tcpip/testutil/testutil.go')
-rw-r--r--pkg/tcpip/testutil/testutil.go68
1 files changed, 68 insertions, 0 deletions
diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go
index 1aaed590f..f84d399fb 100644
--- a/pkg/tcpip/testutil/testutil.go
+++ b/pkg/tcpip/testutil/testutil.go
@@ -18,6 +18,8 @@ package testutil
import (
"fmt"
"net"
+ "reflect"
+ "strings"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -41,3 +43,69 @@ func MustParse6(addr string) tcpip.Address {
}
return tcpip.Address(ip)
}
+
+func checkFieldCounts(ref, multi reflect.Value) error {
+ refTypeName := ref.Type().Name()
+ multiTypeName := multi.Type().Name()
+ refNumField := ref.NumField()
+ multiNumField := multi.NumField()
+
+ if refNumField != multiNumField {
+ return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
+ }
+
+ return nil
+}
+
+func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
+ s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
+ if !ok {
+ return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
+ }
+
+ // The field names are expected to match (case insensitive).
+ if !strings.EqualFold(refName, multiName) {
+ return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
+ }
+
+ base := (*s).Value()
+ m.Increment()
+ if (*s).Value() != base+1 {
+ return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
+ }
+
+ return nil
+}
+
+// ValidateMultiCounterStats verifies that every counter stored in multi is
+// correctly tracking its counterpart in the given counters.
+func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
+ for _, c := range counters {
+ if err := checkFieldCounts(c, multi); err != nil {
+ return err
+ }
+ }
+
+ for i := 0; i < multi.NumField(); i++ {
+ multiName := multi.Type().Field(i).Name
+ multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
+
+ if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
+ for _, c := range counters {
+ if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
+ return err
+ }
+ }
+ } else {
+ var countersNextField []reflect.Value
+ for _, c := range counters {
+ countersNextField = append(countersNextField, c.Field(i))
+ }
+ if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}