diff options
author | Adin Scannell <ascannell@google.com> | 2021-11-03 22:14:59 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-11-03 22:17:30 -0700 |
commit | 80cba65bd84d6415719b07daeca7188871000242 (patch) | |
tree | 42efeff08ddd8a96e63b06b7d0c9fa8e29a3c545 /tools/checklocks/facts.go | |
parent | 5185548e157be1ec4c8c161d15ca8ee045a31a36 (diff) |
Add automatic lock inference and globals support.
Lock inference will apply annotations to all fields that seem to be
protected. This is currently disabled for all code by default, but it
can be enabled as annotations are applied more broadly.
PiperOrigin-RevId: 407501915
Diffstat (limited to 'tools/checklocks/facts.go')
-rw-r--r-- | tools/checklocks/facts.go | 813 |
1 files changed, 484 insertions, 329 deletions
diff --git a/tools/checklocks/facts.go b/tools/checklocks/facts.go index 17aef5790..f6dfeaec9 100644 --- a/tools/checklocks/facts.go +++ b/tools/checklocks/facts.go @@ -15,6 +15,7 @@ package checklocks import ( + "encoding/gob" "fmt" "go/ast" "go/token" @@ -22,6 +23,7 @@ import ( "regexp" "strings" + "golang.org/x/tools/go/analysis/passes/buildssa" "golang.org/x/tools/go/ssa" ) @@ -46,99 +48,110 @@ const ( atomicRequired ) +// fieldEntry is a single field type. +type fieldEntry interface { + // synthesize produces a string that is compatible with valueAndObject, + // along with the same object that should be produced in that case. + // + // Note that it is called synthesize because this is produced only the + // type information, and not with any ssa.Value objects. + synthesize(s string, typ types.Type) (string, types.Object) +} + +// fieldStruct is a non-pointer struct element. +type fieldStruct int + +// synthesize implements fieldEntry.synthesize. +func (f fieldStruct) synthesize(s string, typ types.Type) (string, types.Object) { + field, ok := findField(typ, int(f)) + if !ok { + // Should not happen as long as fieldList construction is correct. + panic(fmt.Sprintf("unable to resolve field %d in %s", int(f), typ.String())) + } + return fmt.Sprintf("&(%s.%s)", s, field.Name()), field +} + +// fieldStructPtr is a pointer struct element. +type fieldStructPtr int + +// synthesize implements fieldEntry.synthesize. +func (f fieldStructPtr) synthesize(s string, typ types.Type) (string, types.Object) { + field, ok := findField(typ, int(f)) + if !ok { + // See above, this should not happen. + panic(fmt.Sprintf("unable to resolve ptr field %d in %s", int(f), typ.String())) + } + return fmt.Sprintf("*(&(%s.%s))", s, field.Name()), field +} + // fieldList is a simple list of fields, used in two types below. -// -// Note that the integers in this list refer to one of two things: -// - A positive integer refers to a field index in a struct. -// - A negative integer refers to a field index in a struct, where -// that field is a pointer and must be subsequently resolved. -type fieldList []int +type fieldList []fieldEntry // resolvedValue is an ssa.Value with additional fields. // // This can be resolved to a string as part of a lock state. type resolvedValue struct { value ssa.Value - valid bool - fieldList []int -} - -// findExtract finds a relevant extract. This must exist within the referrers -// to the call object. If this doesn't then the object which is locked is never -// consumed, and we should consider this a bug. -func findExtract(v ssa.Value, index int) (ssa.Value, bool) { - if refs := v.Referrers(); refs != nil { - for _, inst := range *refs { - if x, ok := inst.(*ssa.Extract); ok && x.Tuple == v && x.Index == index { - return inst.(ssa.Value), true - } - } - } - return nil, false + fieldList fieldList } -// resolve resolves the given field list. -func (fl fieldList) resolve(v ssa.Value) (rv resolvedValue) { +// makeResolvedValue makes a new resolvedValue. +func makeResolvedValue(v ssa.Value, fl fieldList) resolvedValue { return resolvedValue{ value: v, fieldList: fl, - valid: true, } } -// valueAsString returns a string representing this value. +// valid indicates whether this is a valid resolvedValue. +func (rv *resolvedValue) valid() bool { + return rv.value != nil +} + +// valueAndObject returns a string and object. // -// This must align with how the string is generated in valueAsString. -func (rv resolvedValue) valueAsString(ls *lockState) string { +// This uses the lockState valueAndObject in order to produce a string and +// object for the base ssa.Value, then synthesizes a string representation +// based on the fieldList. +func (rv *resolvedValue) valueAndObject(ls *lockState) (string, types.Object) { + // N.B. obj.Type() and typ should be equal, but a check is omitted + // since, 1) we automatically chase through pointers during field + // resolution, and 2) obj may be nil if there is no source object. + s, obj := ls.valueAndObject(rv.value) typ := rv.value.Type() - s := ls.valueAsString(rv.value) - for i, fieldNumber := range rv.fieldList { - switch { - case fieldNumber > 0: - field, ok := findField(typ, fieldNumber-1) - if !ok { - // This can't be resolved, return for debugging. - return fmt.Sprintf("{%s+%v}", s, rv.fieldList[i:]) - } - s = fmt.Sprintf("&(%s.%s)", s, field.Name()) - typ = field.Type() - case fieldNumber < 1: - field, ok := findField(typ, (-fieldNumber)-1) - if !ok { - // See above. - return fmt.Sprintf("{%s+%v}", s, rv.fieldList[i:]) - } - s = fmt.Sprintf("*(&(%s.%s))", s, field.Name()) - typ = field.Type() - } + for _, entry := range rv.fieldList { + s, obj = entry.synthesize(s, typ) + typ = obj.Type() } - return s + return s, obj } -// lockFieldFacts apply on every struct field. -type lockFieldFacts struct { - // IsMutex is true if the field is of type sync.Mutex. - IsMutex bool - - // IsRWMutex is true if the field is of type sync.RWMutex. - IsRWMutex bool - - // IsPointer indicates if the field is a pointer. - IsPointer bool - - // FieldNumber is the number of this field in the struct. - FieldNumber int +// fieldGuardResolver details a guard for a field. +type fieldGuardResolver interface { + // resolveField is used to resolve a guard during a field access. The + // parent structure is available, as well as the current lock state. + resolveField(pc *passContext, ls *lockState, parent ssa.Value) resolvedValue } -// AFact implements analysis.Fact.AFact. -func (*lockFieldFacts) AFact() {} +// functionGuardResolver details a guard for a function. +type functionGuardResolver interface { + // resolveStatic is used to resolve a guard during static analysis, + // e.g. based on static annotations applied to a method. The function's + // ssa object is available, as well as the return value. + resolveStatic(pc *passContext, ls *lockState, fn *ssa.Function, rv interface{}) resolvedValue + + // resolveCall is used to resolve a guard during a call. The ssa + // return value is available from the instruction context where the + // call occurs, but the target's ssa representation is not available. + resolveCall(pc *passContext, ls *lockState, args []ssa.Value, rv ssa.Value) resolvedValue +} // lockGuardFacts contains guard information. type lockGuardFacts struct { // GuardedBy is the set of locks that are guarding this field. The key // is the original annotation value, and the field list is the object // traversal path. - GuardedBy map[string]fieldList + GuardedBy map[string]fieldGuardResolver // AtomicDisposition is the disposition for this field. Note that this // can affect the interpretation of the GuardedBy field above, see the @@ -149,86 +162,142 @@ type lockGuardFacts struct { // AFact implements analysis.Fact.AFact. func (*lockGuardFacts) AFact() {} -// functionGuard is used by lockFunctionFacts, below. -type functionGuard struct { - // ParameterNumber is the index of the object that contains the - // guarding mutex. From this parameter, a walk is performed - // subsequently using the resolve method. - // - // Note that is ParameterNumber is beyond the size of parameters, then - // it may return to a return value. This applies only for the Acquires - // relation below. - ParameterNumber int +// globalGuard is a global value. +type globalGuard struct { + // Object indicates the object from which resolution should occur. + Object types.Object + + // FieldList is the traversal path from object. + FieldList fieldList +} + +// ssaPackager returns the ssa package. +type ssaPackager interface { + Package() *ssa.Package +} + +// resolveCommon implements resolution for all cases. +func (g *globalGuard) resolveCommon(pc *passContext, ls *lockState) resolvedValue { + state := pc.pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) + v := state.Pkg.Members[g.Object.Name()].(ssa.Value) + return makeResolvedValue(v, g.FieldList) +} + +// resolveStatic implements functionGuardResolver.resolveStatic. +func (g *globalGuard) resolveStatic(pc *passContext, ls *lockState, _ *ssa.Function, v interface{}) resolvedValue { + return g.resolveCommon(pc, ls) +} + +// resolveCall implements functionGuardResolver.resolveCall. +func (g *globalGuard) resolveCall(pc *passContext, ls *lockState, _ []ssa.Value, v ssa.Value) resolvedValue { + return g.resolveCommon(pc, ls) +} + +// resolveField implements fieldGuardResolver.resolveField. +func (g *globalGuard) resolveField(pc *passContext, ls *lockState, parent ssa.Value) resolvedValue { + return g.resolveCommon(pc, ls) +} + +// fieldGuard is a field-based guard. +type fieldGuard struct { + // FieldList is the traversal path from the parent. + FieldList fieldList +} + +// resolveField implements fieldGuardResolver.resolveField. +func (f *fieldGuard) resolveField(_ *passContext, _ *lockState, parent ssa.Value) resolvedValue { + return makeResolvedValue(parent, f.FieldList) +} + +// parameterGuard is a parameter-based guard. +type parameterGuard struct { + // Index is the parameter index of the object that contains the + // guarding mutex. + Index int + + // fieldList is the traversal path from the parameter. + FieldList fieldList +} + +// resolveStatic implements functionGuardResolver.resolveStatic. +func (p *parameterGuard) resolveStatic(_ *passContext, _ *lockState, fn *ssa.Function, _ interface{}) resolvedValue { + return makeResolvedValue(fn.Params[p.Index], p.FieldList) +} + +// resolveCall implements functionGuardResolver.resolveCall. +func (p *parameterGuard) resolveCall(_ *passContext, _ *lockState, args []ssa.Value, _ ssa.Value) resolvedValue { + return makeResolvedValue(args[p.Index], p.FieldList) +} + +// returnGuard is a return-based guard. +type returnGuard struct { + // Index is the index of the return value. + Index int // NeedsExtract is used in the case of a return value, and indicates // that the field must be extracted from a tuple. NeedsExtract bool - // IsAlias indicates that this guard is an alias. - IsAlias bool - - // FieldList is the traversal path to the object. + // FieldList is the traversal path from the return value. FieldList fieldList - - // Exclusive indicates an exclusive lock is required. - Exclusive bool } -// resolveReturn resolves a return value. -// -// Precondition: rv is either an ssa.Value, or an *ssa.Return. -func (fg *functionGuard) resolveReturn(rv interface{}, args int) resolvedValue { +// resolveCommon implements resolution for both cases. +func (r *returnGuard) resolveCommon(rv interface{}) resolvedValue { if rv == nil { // For defers and other objects, this may be nil. This is - // handled in state.go in the actual lock checking logic. - return resolvedValue{ - value: nil, - valid: false, - } + // handled in state.go in the actual lock checking logic. This + // means that there is no resolvedValue available. + return resolvedValue{} } - index := fg.ParameterNumber - args // If this is a *ssa.Return object, i.e. we are analyzing the function // and not the call site, then we can just pull the result directly. - if r, ok := rv.(*ssa.Return); ok { - return fg.FieldList.resolve(r.Results[index]) + if ret, ok := rv.(*ssa.Return); ok { + return makeResolvedValue(ret.Results[r.Index], r.FieldList) } - if fg.NeedsExtract { + if r.NeedsExtract { // Resolve on the extracted field, this is necessary if the // type here is not an explicit return. Note that rv must be an // ssa.Value, since it is not an *ssa.Return. - v, ok := findExtract(rv.(ssa.Value), index) - if !ok { - return resolvedValue{ - value: v, - valid: false, + v := rv.(ssa.Value) + if refs := v.Referrers(); refs != nil { + for _, inst := range *refs { + if x, ok := inst.(*ssa.Extract); ok && x.Tuple == v && x.Index == r.Index { + return makeResolvedValue(x, r.FieldList) + } } } - return fg.FieldList.resolve(v) + // Nothing resolved. + return resolvedValue{} } - if index != 0 { + if r.Index != 0 { // This should not happen, NeedsExtract should always be set. panic("NeedsExtract is false, but return value index is non-zero") } // Resolve on the single return. - return fg.FieldList.resolve(rv.(ssa.Value)) + return makeResolvedValue(rv.(ssa.Value), r.FieldList) } -// resolveStatic returns an ssa.Value representing the given field. -// -// Precondition: per resolveReturn. -func (fg *functionGuard) resolveStatic(fn *ssa.Function, rv interface{}) resolvedValue { - if fg.ParameterNumber >= len(fn.Params) { - return fg.resolveReturn(rv, len(fn.Params)) - } - return fg.FieldList.resolve(fn.Params[fg.ParameterNumber]) +// resolveStatic implements functionGuardResolver.resolveStatic. +func (r *returnGuard) resolveStatic(_ *passContext, _ *lockState, _ *ssa.Function, rv interface{}) resolvedValue { + return r.resolveCommon(rv) } -// resolveCall returns an ssa.Value representing the given field. -func (fg *functionGuard) resolveCall(args []ssa.Value, rv ssa.Value) resolvedValue { - if fg.ParameterNumber >= len(args) { - return fg.resolveReturn(rv, len(args)) - } - return fg.FieldList.resolve(args[fg.ParameterNumber]) +// resolveCall implements functionGuardResolver.resolveCall. +func (r *returnGuard) resolveCall(_ *passContext, _ *lockState, _ []ssa.Value, rv ssa.Value) resolvedValue { + return r.resolveCommon(rv) +} + +// functionGuardInfo is information about a method guard. +type functionGuardInfo struct { + // Resolver is the resolver for this guard. + Resolver functionGuardResolver + + // IsAlias indicates that this guard is an alias. + IsAlias bool + + // Exclusive indicates an exclusive lock is required. + Exclusive bool } // lockFunctionFacts apply on every method. @@ -250,13 +319,11 @@ type lockFunctionFacts struct { // ``` // // '`+checklocks:a.mu' will result in an entry in this map as shown below. - // HeldOnEntry: {"a.mu" => {ParameterNumber: 0, FieldNumbers: {0}} - // - // Unlikely lockFieldFacts, there is no atomic interpretation. - HeldOnEntry map[string]functionGuard + // HeldOnEntry: {"a.mu" => {Resolver: ¶meterGuard{Index: 0}} + HeldOnEntry map[string]functionGuardInfo // HeldOnExit tracks the locks that are expected to be held on exit. - HeldOnExit map[string]functionGuard + HeldOnExit map[string]functionGuardInfo // Ignore means this function has local analysis ignores. // @@ -268,14 +335,14 @@ type lockFunctionFacts struct { func (*lockFunctionFacts) AFact() {} // checkGuard validates the guardName. -func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) { +func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuardInfo, bool) { if _, ok := lff.HeldOnEntry[guardName]; ok { pc.maybeFail(d.Pos(), "annotation %s specified more than once, already required", guardName) - return functionGuard{}, false + return functionGuardInfo{}, false } if _, ok := lff.HeldOnExit[guardName]; ok { pc.maybeFail(d.Pos(), "annotation %s specified more than once, already acquired", guardName) - return functionGuard{}, false + return functionGuardInfo{}, false } fg, ok := pc.findFunctionGuard(d, guardName, exclusive, allowReturn) return fg, ok @@ -285,10 +352,10 @@ func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guard func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok { if lff.HeldOnEntry == nil { - lff.HeldOnEntry = make(map[string]functionGuard) + lff.HeldOnEntry = make(map[string]functionGuardInfo) } if lff.HeldOnExit == nil { - lff.HeldOnExit = make(map[string]functionGuard) + lff.HeldOnExit = make(map[string]functionGuardInfo) } lff.HeldOnEntry[guardName] = fg lff.HeldOnExit[guardName] = fg @@ -299,7 +366,7 @@ func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, gua func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, true /* allowReturn */); ok { if lff.HeldOnExit == nil { - lff.HeldOnExit = make(map[string]functionGuard) + lff.HeldOnExit = make(map[string]functionGuardInfo) } lff.HeldOnExit[guardName] = fg } @@ -309,7 +376,7 @@ func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guar func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok { if lff.HeldOnEntry == nil { - lff.HeldOnEntry = make(map[string]functionGuard) + lff.HeldOnEntry = make(map[string]functionGuardInfo) } lff.HeldOnEntry[guardName] = fg } @@ -345,213 +412,276 @@ func (lff *lockFunctionFacts) addAlias(pc *passContext, d *ast.FuncDecl, guardNa } } -// fieldListFor returns the fieldList for the given object. -func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool, exclusive bool) (int, bool) { - var lff lockFieldFacts - if !pc.pass.ImportObjectFact(fieldObj, &lff) { - // This should not happen: we export facts for all fields. - panic(fmt.Sprintf("no lockFieldFacts available for field %s", fieldName)) - } - // Check that it is indeed a mutex. - if checkMutex && !lff.IsMutex && !lff.IsRWMutex { - pc.maybeFail(pos, "field %s is not a Mutex or an RWMutex", fieldName) - return 0, false - } - if checkMutex && !exclusive && !lff.IsRWMutex { - pc.maybeFail(pos, "field %s must be a RWMutex, but it is not", fieldName) - return 0, false - } +// fieldEntryFor returns the fieldList value for the given object. +func (pc *passContext) fieldEntryFor(fieldObj types.Object, index int) fieldEntry { + // Return the resolution path. - if lff.IsPointer { - return -(index + 1), true + if _, ok := fieldObj.Type().Underlying().(*types.Pointer); ok { + return fieldStructPtr(index) } - return (index + 1), true + if _, ok := fieldObj.Type().Underlying().(*types.Interface); ok { + return fieldStructPtr(index) + } + return fieldStruct(index) } -// resolveOneField resolves a field in a single struct. -func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool, exclusive bool) (fl fieldList, fieldObj types.Object, ok bool) { +// findField resolves a field in a single struct. +func (pc *passContext) findField(structType *types.Struct, fieldName string) (fl fieldList, fieldObj types.Object, ok bool) { // Scan to match the next field. for i := 0; i < structType.NumFields(); i++ { fieldObj := structType.Field(i) if fieldObj.Name() != fieldName { continue } - flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex, exclusive) - if !ok { - return nil, nil, false - } - fl = append(fl, flOne) + fl = append(fl, pc.fieldEntryFor(fieldObj, i)) return fl, fieldObj, true } + // Is this an embed? for i := 0; i < structType.NumFields(); i++ { fieldObj := structType.Field(i) if !fieldObj.Embedded() { continue } + // Is this an embedded struct? structType, ok := resolveStruct(fieldObj.Type()) if !ok { continue } + // Need to check that there is a resolution path. If there is // no resolution path that's not a failure: we just continue // scanning the next embed to find a match. - flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false, exclusive) - flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex, exclusive) - if okEmbed && okCont { - fl = append(fl, flEmbed) - fl = append(fl, flCont...) - return fl, fieldObjCont, true + flEmbed := pc.fieldEntryFor(fieldObj, i) + flNext, fieldObjNext, ok := pc.findField(structType, fieldName) + if !ok { + continue } + + // Found an embedded chain. + fl = append(fl, flEmbed) + fl = append(fl, flNext...) + return fl, fieldObjNext, true } - pc.maybeFail(pos, "field %s does not exist", fieldName) + return nil, nil, false } -// resolveField resolves a set of fields given a string, such a 'a.b.c'. +var ( + mutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineMutex|Mutex)") + rwMutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineRWMutex|RWMutex)") + lockerRE = regexp.MustCompile("((.*/)|^)sync.Locker") +) + +// validateMutex validates the mutex type. +// +// This function returns true iff the object is a valid mutex with an error +// reported at the given position if necessary. +func (pc *passContext) validateMutex(pos token.Pos, obj types.Object, exclusive bool) bool { + // Check that it is indeed a mutex. + s := obj.Type().String() + switch { + case mutexRE.MatchString(s), lockerRE.MatchString(s): + // Safe for exclusive cases. + if !exclusive { + pc.maybeFail(pos, "field %s must be a RWMutex", obj.Name()) + return false + } + return true + case rwMutexRE.MatchString(s): + // Safe for all cases. + return true + default: + // Not a mutex at all? + pc.maybeFail(pos, "field %s is not a Mutex or an RWMutex", obj.Name()) + return false + } +} + +// findFieldList resolves a set of fields given a string, such a 'a.b.c'. // -// Note that this checks that the final element is a mutex of some kind, and -// will fail appropriately. -func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string, exclusive bool) (fl fieldList, ok bool) { - for partNumber, fieldName := range parts { - flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */, exclusive) +// Note that parts must be non-zero in length. If it may be zero, then +// maybeFindFieldList should be used instead with an appropriate object. +func (pc *passContext) findFieldList(pos token.Pos, structType *types.Struct, parts []string, exclusive bool) (fl fieldList, ok bool) { + var obj types.Object + + // This loop requires at least one iteration in order to ensure that + // obj above is non-nil, and the type can be validated. + for i, fieldName := range parts { + flOne, fieldObj, ok := pc.findField(structType, fieldName) if !ok { - // Error already reported. return nil, false } fl = append(fl, flOne...) - if partNumber < len(parts)-1 { - // Traverse to the next type. - structType, ok = resolveStruct(fieldObj.Type()) + obj = fieldObj + if i < len(parts)-1 { + structType, ok = resolveStruct(obj.Type()) if !ok { - pc.maybeFail(pos, "invalid intermediate field %s", fieldName) - return fl, false + // N.B. This is associated with the original position. + pc.maybeFail(pos, "field %s expected to be struct", fieldName) + return nil, false } } } + + // Validate the final field. This reports the field to the caller + // anyways, since the error will be reported only once. + _ = pc.validateMutex(pos, obj, exclusive) return fl, true } -var ( - mutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineMutex|Mutex)") - rwMutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineRWMutex|RWMutex)") - lockerRE = regexp.MustCompile("((.*/)|^)sync.Locker") -) - -// exportLockFieldFacts finds all struct fields that are mutexes, and ensures -// that they are annotated properly. +// maybeFindFieldList resolves the given object. // -// This information is consumed subsequently by exportLockGuardFacts, and this -// function must be called first on all structures. -func (pc *passContext) exportLockFieldFacts(structType *types.Struct, ss *ast.StructType) { - for i, field := range ss.Fields.List { - lff := &lockFieldFacts{ - FieldNumber: i, - } - // We use HasSuffix below because fieldType can be fully - // qualified with the package name eg for the gvisor sync - // package mutex fields have the type: - // "<package path>/sync/sync.Mutex" - fieldObj := structType.Field(i) - s := fieldObj.Type().String() - switch { - case mutexRE.MatchString(s): - lff.IsMutex = true - case rwMutexRE.MatchString(s): - lff.IsRWMutex = true - case lockerRE.MatchString(s): - lff.IsMutex = true - } - // Save whether this is a pointer. - _, lff.IsPointer = fieldObj.Type().Underlying().(*types.Pointer) - if !lff.IsPointer { - _, lff.IsPointer = fieldObj.Type().Underlying().(*types.Interface) - } - // We must always export the lockFieldFacts, since traversal - // can take place along any object in the struct. - pc.pass.ExportObjectFact(fieldObj, lff) - // If this is an anonymous type, then we won't discover it via - // the AST global declarations. We can recurse from here. - if ss, ok := field.Type.(*ast.StructType); ok { - if st, ok := fieldObj.Type().(*types.Struct); ok { - pc.exportLockFieldFacts(st, ss) - } +// Parts may be the empty list, unlike findFieldList. +func (pc *passContext) maybeFindFieldList(pos token.Pos, obj types.Object, parts []string, exclusive bool) (fl fieldList, ok bool) { + if len(parts) > 0 { + structType, ok := resolveStruct(obj.Type()) + if !ok { + // This does not have any fields; the access is not allowed. + pc.maybeFail(pos, "attempted field access on non-struct") + return nil, false } + return pc.findFieldList(pos, structType, parts, exclusive) } + + // See above. + _ = pc.validateMutex(pos, obj, exclusive) + return nil, true } -// exportLockGuardFacts finds all relevant guard information for structures. -// -// This function requires exportLockFieldFacts be called first on all -// structures. -func (pc *passContext) exportLockGuardFacts(structType *types.Struct, ss *ast.StructType) { - for i, field := range ss.Fields.List { - fieldObj := structType.Field(i) - if field.Doc != nil { - var ( - lff lockFieldFacts - lgf lockGuardFacts - ) - pc.pass.ImportObjectFact(structType.Field(i), &lff) - for _, l := range field.Doc.List { - pc.extractAnnotations(l.Text, map[string]func(string){ - checkAtomicAnnotation: func(string) { - switch lgf.AtomicDisposition { - case atomicRequired: - pc.maybeFail(fieldObj.Pos(), "annotation is redundant, already atomic required") - case atomicIgnore: - pc.maybeFail(fieldObj.Pos(), "annotation is contradictory, already atomic ignored") - } - lgf.AtomicDisposition = atomicRequired - }, - checkLocksIgnore: func(string) { - switch lgf.AtomicDisposition { - case atomicIgnore: - pc.maybeFail(fieldObj.Pos(), "annotation is redundant, already atomic ignored") - case atomicRequired: - pc.maybeFail(fieldObj.Pos(), "annotation is contradictory, already atomic required") - } - lgf.AtomicDisposition = atomicIgnore - }, - checkLocksAnnotation: func(guardName string) { - // Check for a duplicate annotation. - if _, ok := lgf.GuardedBy[guardName]; ok { - pc.maybeFail(fieldObj.Pos(), "annotation %s specified more than once", guardName) - return - } - fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, "."), true /* exclusive */) - if ok { - // If we successfully resolved the field, then save it. - if lgf.GuardedBy == nil { - lgf.GuardedBy = make(map[string]fieldList) - } - lgf.GuardedBy[guardName] = fl - } - }, - // N.B. We support only the vanilla - // annotation on individual fields. If - // the field is a read lock, then we - // will allow read access by default. - checkLocksAnnotationRead: func(guardName string) { - pc.maybeFail(fieldObj.Pos(), "annotation %s not legal on fields", guardName) - }, - }) - } - // Save only if there is something meaningful. - if len(lgf.GuardedBy) > 0 || lgf.AtomicDisposition != atomicDisallow { - pc.pass.ExportObjectFact(structType.Field(i), &lgf) - } +// findFieldGuardResolver finds a symbol resolver. +type findFieldGuardResolver func(pos token.Pos, guardName string) (fieldGuardResolver, bool) + +// findFunctionGuardResolver finds a symbol resolver. +type findFunctionGuardResolver func(pos token.Pos, guardName string) (functionGuardResolver, bool) + +// fillLockGuardFacts fills the facts with guard information. +func (pc *passContext) fillLockGuardFacts(obj types.Object, cg *ast.CommentGroup, find findFieldGuardResolver, lgf *lockGuardFacts) { + if cg == nil { + return + } + for _, l := range cg.List { + pc.extractAnnotations(l.Text, map[string]func(string){ + checkAtomicAnnotation: func(string) { + switch lgf.AtomicDisposition { + case atomicRequired: + pc.maybeFail(obj.Pos(), "annotation is redundant, already atomic required") + case atomicIgnore: + pc.maybeFail(obj.Pos(), "annotation is contradictory, already atomic ignored") + } + lgf.AtomicDisposition = atomicRequired + }, + checkLocksIgnore: func(string) { + switch lgf.AtomicDisposition { + case atomicIgnore: + pc.maybeFail(obj.Pos(), "annotation is redundant, already atomic ignored") + case atomicRequired: + pc.maybeFail(obj.Pos(), "annotation is contradictory, already atomic required") + } + lgf.AtomicDisposition = atomicIgnore + }, + checkLocksAnnotation: func(guardName string) { + // Check for a duplicate annotation. + if _, ok := lgf.GuardedBy[guardName]; ok { + pc.maybeFail(obj.Pos(), "annotation %s specified more than once", guardName) + return + } + // Add the item. + if lgf.GuardedBy == nil { + lgf.GuardedBy = make(map[string]fieldGuardResolver) + } + fr, ok := find(obj.Pos(), guardName) + if !ok { + pc.maybeFail(obj.Pos(), "annotation %s cannot be resolved", guardName) + return + } + lgf.GuardedBy[guardName] = fr + }, + // N.B. We support only the vanilla annotation on + // individual fields. If the field is a read lock, then + // we will allow read access by default. + checkLocksAnnotationRead: func(guardName string) { + pc.maybeFail(obj.Pos(), "annotation %s not legal on fields", guardName) + }, + }) + } + // Save only if there is something meaningful. + if len(lgf.GuardedBy) > 0 || lgf.AtomicDisposition != atomicDisallow { + pc.pass.ExportObjectFact(obj, lgf) + } +} + +// findGlobalGuard attempts to resolve a name globally. +func (pc *passContext) findGlobalGuard(pos token.Pos, guardName string) (*globalGuard, bool) { + // Attempt to resolve the object. + parts := strings.Split(guardName, ".") + globalObj := pc.pass.Pkg.Scope().Lookup(parts[0]) + if globalObj == nil { + // No global object. + return nil, false + } + fl, ok := pc.maybeFindFieldList(pos, globalObj, parts[1:], true /* exclusive */) + if !ok { + // Invalid fields. + return nil, false + } + return &globalGuard{ + Object: globalObj, + FieldList: fl, + }, true +} + +// findGlobalFieldGuard is compatible with findFieldGuardResolver. +func (pc *passContext) findGlobalFieldGuard(pos token.Pos, guardName string) (fieldGuardResolver, bool) { + g, ok := pc.findGlobalGuard(pos, guardName) + return g, ok +} + +// findGlobalFunctionGuard is compatible with findFunctionGuardResolver. +func (pc *passContext) findGlobalFunctionGuard(pos token.Pos, guardName string) (functionGuardResolver, bool) { + g, ok := pc.findGlobalGuard(pos, guardName) + return g, ok +} + +// structLockGuardFacts finds all relevant guard information for structures. +func (pc *passContext) structLockGuardFacts(structType *types.Struct, ss *ast.StructType) { + var fieldObj *types.Var + findLocal := func(pos token.Pos, guardName string) (fieldGuardResolver, bool) { + // Try to resolve from the local structure first. + fl, ok := pc.findFieldList(pos, structType, strings.Split(guardName, "."), true /* exclusive */) + if ok { + // Found a valid resolution. + return &fieldGuard{ + FieldList: fl, + }, true } + // Attempt a global resolution. + return pc.findGlobalFieldGuard(pos, guardName) + } + for i, field := range ss.Fields.List { + var lgf lockGuardFacts + fieldObj = structType.Field(i) // N.B. Captured above. + pc.fillLockGuardFacts(fieldObj, field.Doc, findLocal, &lgf) + // See above, for anonymous structure fields. if ss, ok := field.Type.(*ast.StructType); ok { if st, ok := fieldObj.Type().(*types.Struct); ok { - pc.exportLockGuardFacts(st, ss) + pc.structLockGuardFacts(st, ss) } } } } +// globalLockGuardFacts finds all relevant guard information for globals. +// +// Note that the Type is checked in checklocks.go at the top-level. +func (pc *passContext) globalLockGuardFacts(vs *ast.ValueSpec) { + var lgf lockGuardFacts + globalObj := pc.pass.TypesInfo.ObjectOf(vs.Names[0]) + pc.fillLockGuardFacts(globalObj, vs.Doc, pc.findGlobalFieldGuard, &lgf) +} + // countFields gives an accurate field count, according for unnamed arguments // and return values and the compact identifier format. func countFields(fl []*ast.Field) (count int) { @@ -566,89 +696,105 @@ func countFields(fl []*ast.Field) (count int) { } // matchFieldList attempts to match the given field. -func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string, exclusive bool) (functionGuard, bool) { +// +// This function may or may not report an error. This is indicated in the +// reported return value. If reported is true, then the specification is +// ambiguous or not valid, and should be propagated. +func (pc *passContext) matchFieldList(pos token.Pos, fields []*ast.Field, guardName string, exclusive bool) (number int, fl fieldList, reported, ok bool) { parts := strings.Split(guardName, ".") - parameterName := parts[0] - parameterNumber := 0 - for _, field := range fl { + firstName := parts[0] + index := 0 + for _, field := range fields { // See countFields, above. if len(field.Names) == 0 { - parameterNumber++ + index++ continue } for _, name := range field.Names { - if name.Name != parameterName { - parameterNumber++ + if name.Name != firstName { + index++ continue } - ptrType, ok := pc.pass.TypesInfo.TypeOf(field.Type).Underlying().(*types.Pointer) - if !ok { - // Since mutexes cannot be copied we only care - // about parameters that are pointer types when - // checking for guards. - pc.maybeFail(pos, "parameter name %s does not refer to a pointer type", parameterName) - return functionGuard{}, false - } - structType, ok := ptrType.Elem().Underlying().(*types.Struct) + obj := pc.pass.TypesInfo.ObjectOf(name) + fl, ok := pc.maybeFindFieldList(pos, obj, parts[1:], exclusive) if !ok { - // Fields can only be in named structures. - pc.maybeFail(pos, "parameter name %s does not refer to a pointer to a struct", parameterName) - return functionGuard{}, false + // Some intermediate name does not match. The + // resolveField function will not report. + pc.maybeFail(pos, "name %s does not resolve to a field", guardName) + return 0, nil, true, false } - fg := functionGuard{ - ParameterNumber: parameterNumber, - Exclusive: exclusive, - } - fl, ok := pc.resolveField(pos, structType, parts[1:], exclusive) - fg.FieldList = fl - return fg, ok // If ok is false, already failed. + // Successfully found a field. + return index, fl, false, true } } - return functionGuard{}, false + + // Nothing matching. + return 0, nil, false, false } // findFunctionGuard identifies the parameter number and field number for a // particular string of the 'a.b'. // // This function will report any errors directly. -func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) { - var ( - parameterList []*ast.Field - returnList []*ast.Field - ) +func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuardInfo, bool) { + // Match against receiver & parameters. + var parameterList []*ast.Field if d.Recv != nil { parameterList = append(parameterList, d.Recv.List...) } if d.Type.Params != nil { parameterList = append(parameterList, d.Type.Params.List...) } - if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName, exclusive); ok { - return fg, ok + if index, fl, reported, ok := pc.matchFieldList(d.Pos(), parameterList, guardName, exclusive); reported || ok { + if !ok { + return functionGuardInfo{}, false + } + return functionGuardInfo{ + Resolver: ¶meterGuard{ + Index: index, + FieldList: fl, + }, + Exclusive: exclusive, + }, true } + + // Match against return values, if allowed. if allowReturn { + var returnList []*ast.Field if d.Type.Results != nil { returnList = append(returnList, d.Type.Results.List...) } - if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName, exclusive); ok { - // Fix this up to apply to the return value, as noted - // in fg.ParameterNumber. For the ssa analysis, we must - // record whether this has multiple results, since - // *ssa.Call indicates: "The Call instruction yields - // the function result if there is exactly one. - // Otherwise it returns a tuple, the components of - // which are accessed via Extract." - fg.ParameterNumber += countFields(parameterList) - fg.NeedsExtract = countFields(returnList) > 1 - return fg, ok + if index, fl, reported, ok := pc.matchFieldList(d.Pos(), returnList, guardName, exclusive); reported || ok { + if !ok { + return functionGuardInfo{}, false + } + return functionGuardInfo{ + Resolver: &returnGuard{ + Index: index, + FieldList: fl, + NeedsExtract: countFields(returnList) > 1, + }, + Exclusive: exclusive, + }, true } } - // We never saw a matching parameter. - pc.maybeFail(d.Pos(), "annotation %s does not have a matching parameter", guardName) - return functionGuard{}, false + + // Match against globals. + if g, ok := pc.findGlobalFunctionGuard(d.Pos(), guardName); ok { + return functionGuardInfo{ + Resolver: g, + Exclusive: exclusive, + }, true + } + + // No match found. + pc.maybeFail(d.Pos(), "annotation %s does not have a match any parameter, return value or global", guardName) + return functionGuardInfo{}, false } -// exportFunctionFacts exports relevant function findings. -func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) { +// functionFacts exports relevant function findings. +func (pc *passContext) functionFacts(d *ast.FuncDecl) { + // Extract guard information. if d.Doc == nil || d.Doc.List == nil { return } @@ -679,3 +825,12 @@ func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) { pc.pass.ExportObjectFact(funcObj, &lff) } } + +func init() { + gob.Register((*returnGuard)(nil)) + gob.Register((*globalGuard)(nil)) + gob.Register((*parameterGuard)(nil)) + gob.Register((*fieldGuard)(nil)) + gob.Register((*fieldStructPtr)(nil)) + gob.Register((*fieldStruct)(nil)) +} |