summaryrefslogtreecommitdiffhomepage
path: root/pkg/state
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/state')
-rw-r--r--pkg/state/tests/integer_test.go26
-rw-r--r--pkg/state/tests/register_test.go11
-rw-r--r--pkg/state/tests/struct_test.go11
-rw-r--r--pkg/state/types.go71
4 files changed, 60 insertions, 59 deletions
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
index d3931c952..2b1609af0 100644
--- a/pkg/state/tests/integer_test.go
+++ b/pkg/state/tests/integer_test.go
@@ -20,21 +20,21 @@ import (
)
var (
- allIntTs = []int{-1, 0, 1}
- allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
- allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
- allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
- allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
- allUintTs = []uint{0, 1}
- allUintptrs = []uintptr{0, 1, ^uintptr(0)}
- allUint8s = []uint8{0, 1, math.MaxUint8}
- allUint16s = []uint16{0, 1, math.MaxUint16}
- allUint32s = []uint32{0, 1, math.MaxUint32}
- allUint64s = []uint64{0, 1, math.MaxUint64}
+ allBasicInts = []int{-1, 0, 1}
+ allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
+ allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
+ allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
+ allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
+ allBasicUints = []uint{0, 1}
+ allUintptrs = []uintptr{0, 1, ^uintptr(0)}
+ allUint8s = []uint8{0, 1, math.MaxUint8}
+ allUint16s = []uint16{0, 1, math.MaxUint16}
+ allUint32s = []uint32{0, 1, math.MaxUint32}
+ allUint64s = []uint64{0, 1, math.MaxUint64}
)
var allInts = flatten(
- allIntTs,
+ allBasicInts,
allInt8s,
allInt16s,
allInt32s,
@@ -42,7 +42,7 @@ var allInts = flatten(
)
var allUints = flatten(
- allUintTs,
+ allBasicUints,
allUintptrs,
allUint8s,
allUint16s,
diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go
index c829753cc..75bdbfc6e 100644
--- a/pkg/state/tests/register_test.go
+++ b/pkg/state/tests/register_test.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build race
+
package tests
import (
@@ -165,3 +167,12 @@ func TestRegisterBad(t *testing.T) {
}
}
+
+func TestRegisterTypeOnlyStruct(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Register did not panic")
+ }
+ }()
+ state.Register((*typeOnlyEmptyStruct)(nil))
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
index c91c2c032..9826f1ee9 100644
--- a/pkg/state/tests/struct_test.go
+++ b/pkg/state/tests/struct_test.go
@@ -17,8 +17,6 @@ package tests
import (
"math/rand"
"testing"
-
- "gvisor.dev/gvisor/pkg/state"
)
func TestEmptyStruct(t *testing.T) {
@@ -58,15 +56,6 @@ func TestEmptyStruct(t *testing.T) {
})
}
-func TestRegisterTypeOnlyStruct(t *testing.T) {
- defer func() {
- if r := recover(); r == nil {
- t.Errorf("Register did not panic")
- }
- }()
- state.Register((*typeOnlyEmptyStruct)(nil))
-}
-
func TestEmbeddedPointers(t *testing.T) {
// Give each int64 a random value to prevent Go from using
// runtime.staticuint64s, which confounds tests for struct duplication.
diff --git a/pkg/state/types.go b/pkg/state/types.go
index 84aed8732..420675880 100644
--- a/pkg/state/types.go
+++ b/pkg/state/types.go
@@ -329,47 +329,48 @@ var reverseTypeDatabase = map[reflect.Type]string{}
// This must be called on init and only done once.
func Register(t Type) {
name := t.StateTypeName()
- fields := t.StateFields()
- assertValidType(name, fields)
- // Register must always be called on pointers.
typ := reflect.TypeOf(t)
- if typ.Kind() != reflect.Ptr {
- Failf("Register must be called on pointers")
+ if raceEnabled {
+ assertValidType(name, t.StateFields())
+ // Register must always be called on pointers.
+ if typ.Kind() != reflect.Ptr {
+ Failf("Register must be called on pointers")
+ }
}
typ = typ.Elem()
- if typ.Kind() == reflect.Struct {
- // All registered structs must implement SaverLoader. We allow
- // the registration is non-struct types with just the Type
- // interface, but we need to call StateSave/StateLoad methods
- // on aggregate types.
- if _, ok := t.(SaverLoader); !ok {
- Failf("struct %T does not implement SaverLoader", t)
+ if raceEnabled {
+ if typ.Kind() == reflect.Struct {
+ // All registered structs must implement SaverLoader. We allow
+ // the registration is non-struct types with just the Type
+ // interface, but we need to call StateSave/StateLoad methods
+ // on aggregate types.
+ if _, ok := t.(SaverLoader); !ok {
+ Failf("struct %T does not implement SaverLoader", t)
+ }
+ } else {
+ // Non-structs must not have any fields. We don't support
+ // calling StateSave/StateLoad methods on any non-struct types.
+ // If custom behavior is required, these types should be
+ // wrapped in a structure of some kind.
+ if fields := t.StateFields(); len(fields) != 0 {
+ Failf("non-struct %T has non-zero fields %v", t, fields)
+ }
+ // We don't allow non-structs to implement StateSave/StateLoad
+ // methods, because they won't be called and it's confusing.
+ if _, ok := t.(SaverLoader); ok {
+ Failf("non-struct %T implements SaverLoader", t)
+ }
}
- } else {
- // Non-structs must not have any fields. We don't support
- // calling StateSave/StateLoad methods on any non-struct types.
- // If custom behavior is required, these types should be
- // wrapped in a structure of some kind.
- if len(fields) != 0 {
- Failf("non-struct %T has non-zero fields %v", t, fields)
+ if _, ok := primitiveTypeDatabase[name]; ok {
+ Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
}
- // We don't allow non-structs to implement StateSave/StateLoad
- // methods, because they won't be called and it's confusing.
- if _, ok := t.(SaverLoader); ok {
- Failf("non-struct %T implements SaverLoader", t)
+ if _, ok := globalTypeDatabase[name]; ok {
+ Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
+ }
+ if name == interfaceType {
+ Failf("conflicting name for %T: matches interfaceType", t)
}
- }
- if _, ok := primitiveTypeDatabase[name]; ok {
- Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
- }
- if _, ok := globalTypeDatabase[name]; ok {
- Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
- }
- if name == interfaceType {
- Failf("conflicting name for %T: matches interfaceType", t)
- }
- globalTypeDatabase[name] = typ
- if raceEnabled {
reverseTypeDatabase[typ] = name
}
+ globalTypeDatabase[name] = typ
}