summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/testutil
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/testutil')
-rw-r--r--pkg/tcpip/testutil/BUILD5
-rw-r--r--pkg/tcpip/testutil/testutil.go68
-rw-r--r--pkg/tcpip/testutil/testutil_unsafe.go26
3 files changed, 98 insertions, 1 deletions
diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD
index 472545a5d..02ee86ff1 100644
--- a/pkg/tcpip/testutil/BUILD
+++ b/pkg/tcpip/testutil/BUILD
@@ -5,7 +5,10 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
testonly = True,
- srcs = ["testutil.go"],
+ srcs = [
+ "testutil.go",
+ "testutil_unsafe.go",
+ ],
visibility = ["//visibility:public"],
deps = ["//pkg/tcpip"],
)
diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go
index 1aaed590f..f84d399fb 100644
--- a/pkg/tcpip/testutil/testutil.go
+++ b/pkg/tcpip/testutil/testutil.go
@@ -18,6 +18,8 @@ package testutil
import (
"fmt"
"net"
+ "reflect"
+ "strings"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -41,3 +43,69 @@ func MustParse6(addr string) tcpip.Address {
}
return tcpip.Address(ip)
}
+
+func checkFieldCounts(ref, multi reflect.Value) error {
+ refTypeName := ref.Type().Name()
+ multiTypeName := multi.Type().Name()
+ refNumField := ref.NumField()
+ multiNumField := multi.NumField()
+
+ if refNumField != multiNumField {
+ return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
+ }
+
+ return nil
+}
+
+func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
+ s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
+ if !ok {
+ return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
+ }
+
+ // The field names are expected to match (case insensitive).
+ if !strings.EqualFold(refName, multiName) {
+ return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
+ }
+
+ base := (*s).Value()
+ m.Increment()
+ if (*s).Value() != base+1 {
+ return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
+ }
+
+ return nil
+}
+
+// ValidateMultiCounterStats verifies that every counter stored in multi is
+// correctly tracking its counterpart in the given counters.
+func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
+ for _, c := range counters {
+ if err := checkFieldCounts(c, multi); err != nil {
+ return err
+ }
+ }
+
+ for i := 0; i < multi.NumField(); i++ {
+ multiName := multi.Type().Field(i).Name
+ multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
+
+ if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
+ for _, c := range counters {
+ if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
+ return err
+ }
+ }
+ } else {
+ var countersNextField []reflect.Value
+ for _, c := range counters {
+ countersNextField = append(countersNextField, c.Field(i))
+ }
+ if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/testutil/testutil_unsafe.go b/pkg/tcpip/testutil/testutil_unsafe.go
new file mode 100644
index 000000000..5ff764800
--- /dev/null
+++ b/pkg/tcpip/testutil/testutil_unsafe.go
@@ -0,0 +1,26 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafeExposeUnexportedFields takes a Value and returns a version of it in
+// which even unexported fields can be read and written.
+func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value {
+ return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem()
+}