diff options
-rw-r--r-- | pkg/state/tests/register_test.go | 11 | ||||
-rw-r--r-- | pkg/state/tests/struct_test.go | 11 | ||||
-rw-r--r-- | pkg/state/types.go | 71 |
3 files changed, 47 insertions, 46 deletions
diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go index c829753cc..75bdbfc6e 100644 --- a/pkg/state/tests/register_test.go +++ b/pkg/state/tests/register_test.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build race + package tests import ( @@ -165,3 +167,12 @@ func TestRegisterBad(t *testing.T) { } } + +func TestRegisterTypeOnlyStruct(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Register did not panic") + } + }() + state.Register((*typeOnlyEmptyStruct)(nil)) +} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go index c91c2c032..9826f1ee9 100644 --- a/pkg/state/tests/struct_test.go +++ b/pkg/state/tests/struct_test.go @@ -17,8 +17,6 @@ package tests import ( "math/rand" "testing" - - "gvisor.dev/gvisor/pkg/state" ) func TestEmptyStruct(t *testing.T) { @@ -58,15 +56,6 @@ func TestEmptyStruct(t *testing.T) { }) } -func TestRegisterTypeOnlyStruct(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Errorf("Register did not panic") - } - }() - state.Register((*typeOnlyEmptyStruct)(nil)) -} - func TestEmbeddedPointers(t *testing.T) { // Give each int64 a random value to prevent Go from using // runtime.staticuint64s, which confounds tests for struct duplication. diff --git a/pkg/state/types.go b/pkg/state/types.go index 84aed8732..420675880 100644 --- a/pkg/state/types.go +++ b/pkg/state/types.go @@ -329,47 +329,48 @@ var reverseTypeDatabase = map[reflect.Type]string{} // This must be called on init and only done once. func Register(t Type) { name := t.StateTypeName() - fields := t.StateFields() - assertValidType(name, fields) - // Register must always be called on pointers. typ := reflect.TypeOf(t) - if typ.Kind() != reflect.Ptr { - Failf("Register must be called on pointers") + if raceEnabled { + assertValidType(name, t.StateFields()) + // Register must always be called on pointers. + if typ.Kind() != reflect.Ptr { + Failf("Register must be called on pointers") + } } typ = typ.Elem() - if typ.Kind() == reflect.Struct { - // All registered structs must implement SaverLoader. We allow - // the registration is non-struct types with just the Type - // interface, but we need to call StateSave/StateLoad methods - // on aggregate types. - if _, ok := t.(SaverLoader); !ok { - Failf("struct %T does not implement SaverLoader", t) + if raceEnabled { + if typ.Kind() == reflect.Struct { + // All registered structs must implement SaverLoader. We allow + // the registration is non-struct types with just the Type + // interface, but we need to call StateSave/StateLoad methods + // on aggregate types. + if _, ok := t.(SaverLoader); !ok { + Failf("struct %T does not implement SaverLoader", t) + } + } else { + // Non-structs must not have any fields. We don't support + // calling StateSave/StateLoad methods on any non-struct types. + // If custom behavior is required, these types should be + // wrapped in a structure of some kind. + if fields := t.StateFields(); len(fields) != 0 { + Failf("non-struct %T has non-zero fields %v", t, fields) + } + // We don't allow non-structs to implement StateSave/StateLoad + // methods, because they won't be called and it's confusing. + if _, ok := t.(SaverLoader); ok { + Failf("non-struct %T implements SaverLoader", t) + } } - } else { - // Non-structs must not have any fields. We don't support - // calling StateSave/StateLoad methods on any non-struct types. - // If custom behavior is required, these types should be - // wrapped in a structure of some kind. - if len(fields) != 0 { - Failf("non-struct %T has non-zero fields %v", t, fields) + if _, ok := primitiveTypeDatabase[name]; ok { + Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t) } - // We don't allow non-structs to implement StateSave/StateLoad - // methods, because they won't be called and it's confusing. - if _, ok := t.(SaverLoader); ok { - Failf("non-struct %T implements SaverLoader", t) + if _, ok := globalTypeDatabase[name]; ok { + Failf("conflicting globalTypeDatabase entries for %T: name conflict", t) + } + if name == interfaceType { + Failf("conflicting name for %T: matches interfaceType", t) } - } - if _, ok := primitiveTypeDatabase[name]; ok { - Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t) - } - if _, ok := globalTypeDatabase[name]; ok { - Failf("conflicting globalTypeDatabase entries for %T: name conflict", t) - } - if name == interfaceType { - Failf("conflicting name for %T: matches interfaceType", t) - } - globalTypeDatabase[name] = typ - if raceEnabled { reverseTypeDatabase[typ] = name } + globalTypeDatabase[name] = typ } |