From 80cba65bd84d6415719b07daeca7188871000242 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Wed, 3 Nov 2021 22:14:59 -0700 Subject: Add automatic lock inference and globals support. Lock inference will apply annotations to all fields that seem to be protected. This is currently disabled for all code by default, but it can be enabled as annotations are applied more broadly. PiperOrigin-RevId: 407501915 --- nogo.yaml | 2 + tools/checklocks/README.md | 23 +- tools/checklocks/analysis.go | 192 +++++++-- tools/checklocks/checklocks.go | 77 +++- tools/checklocks/facts.go | 813 +++++++++++++++++++++++--------------- tools/checklocks/state.go | 129 +++--- tools/checklocks/test/BUILD | 6 + tools/checklocks/test/globals.go | 85 ++++ tools/checklocks/test/inferred.go | 35 ++ tools/checklocks/test/methods.go | 2 +- tools/checklocks/test/test.go | 2 +- 11 files changed, 926 insertions(+), 440 deletions(-) create mode 100644 tools/checklocks/test/globals.go create mode 100644 tools/checklocks/test/inferred.go diff --git a/nogo.yaml b/nogo.yaml index 3ef57d29d..29f9569f3 100644 --- a/nogo.yaml +++ b/nogo.yaml @@ -48,6 +48,7 @@ global: - "duplicate import" # These will never be annotated. - "unexpected call to atomic function" + - "may require checklocks annotation for" # Generated proto code creates declarations like 'var start int = iNdEx' - "should omit type .* from declaration; it will be inferred from the right-hand side" internal: @@ -63,6 +64,7 @@ global: - "unexpected call to atomic function.*" - "return with unexpected locks held.*" - "incompatible return states.*" + - "may require checklocks annotation for.*" exclude: # Generated: exempt all. - pkg/shim/runtimeoptions/runtimeoptions_cri.go diff --git a/tools/checklocks/README.md b/tools/checklocks/README.md index eaad69399..7444acfa0 100644 --- a/tools/checklocks/README.md +++ b/tools/checklocks/README.md @@ -1,6 +1,6 @@ # CheckLocks Analyzer - + Checklocks is an analyzer for lock and atomic constraints. The analyzer relies on explicit annotations to identify fields that should be checked for access. @@ -75,7 +75,26 @@ annotation refers either to something that is not a 'sync.Mutex' or 'sync.RWMutex' or where the field does not exist at all. This will prevent the annotations from becoming stale over time as fields are renamed, etc. -# Currently not supported +## Lock suggestions + +Based on locks held during field access, the analyzer will suggest annotations. +These can be ignored with the standard `+checklocksignore` annotation. + +The annotation will be generated when the lock is held the vast majority of the +time the field is accessed. Note that it is possible for this frequency to be +greater than 100%, if the lock is held multiple times. For example: + +```go +func foo(ts1 *testStruct, ts2 *testStruct) { + ts1.Lock() + ts2.Lock() + ts1.gaurdedField = 1 // 200% locks held. + ts1.Unlock() + ts2.Unlock() +} +``` + +## Currently not supported 1. Anonymous functions are not correctly evaluated. The analyzer does not currently support specifying annotations on anonymous functions as a result diff --git a/tools/checklocks/analysis.go b/tools/checklocks/analysis.go index 2def09744..c3216cc0d 100644 --- a/tools/checklocks/analysis.go +++ b/tools/checklocks/analysis.go @@ -168,19 +168,19 @@ func resolveStruct(typ types.Type) (*types.Struct, bool) { func findField(typ types.Type, field int) (types.Object, bool) { structType, ok := resolveStruct(typ) - if !ok { + if !ok || field >= structType.NumFields() { return nil, false } return structType.Field(field), true } -// instructionWithReferrers is a generalization over ssa.Field, ssa.FieldAddr. -type instructionWithReferrers interface { - ssa.Instruction +// almostInst is a generalization over ssa.Field, ssa.FieldAddr, ssa.Global. +type almostInst interface { + Pos() token.Pos Referrers() *[]ssa.Instruction } -// checkFieldAccess checks the validity of a field access. +// checkGuards checks the guards held. // // This also enforces atomicity constraints for fields that must be accessed // atomically. The parameter isWrite indicates whether this field is used @@ -188,41 +188,46 @@ type instructionWithReferrers interface { // // Note that this function is not called if lff.Ignore is true, since it cannot // discover any local anonymous functions or closures. -func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj ssa.Value, field int, ls *lockState, isWrite bool) { +func (pc *passContext) checkGuards(inst almostInst, from ssa.Value, accessObj types.Object, ls *lockState, isWrite bool) { var ( - lff lockFieldFacts lgf lockGuardFacts guardsFound int - guardsHeld int + guardsHeld = make(map[string]struct{}) // Keyed by resolved string. ) - fieldObj, _ := findField(structObj.Type(), field) - pc.pass.ImportObjectFact(fieldObj, &lff) - pc.pass.ImportObjectFact(fieldObj, &lgf) + // Load the facts for the object accessed. + pc.pass.ImportObjectFact(accessObj, &lgf) - for guardName, fl := range lgf.GuardedBy { + // Check guards held. + for guardName, fgr := range lgf.GuardedBy { guardsFound++ - r := fl.resolve(structObj) + r := fgr.resolveField(pc, ls, from) + if !r.valid() { + // See above; this cannot be forced. + pc.maybeFail(inst.Pos(), "field %s cannot be resolved", guardName) + continue + } s, ok := ls.isHeld(r, isWrite) if ok { - guardsHeld++ + guardsHeld[s] = struct{}{} continue } if _, ok := pc.forced[pc.positionKey(inst.Pos())]; ok { // Mark this as locked, since it has been forced. All // forces are treated as an exclusive lock. - ls.lockField(r, true /* exclusive */) - guardsHeld++ + s, _ := ls.lockField(r, true /* exclusive */) + guardsHeld[s] = struct{}{} continue } // Note that we may allow this if the disposition is atomic, // and we are allowing atomic reads only. This will fall into // the atomic disposition check below, which asserts that the - // access is atomic. Further, guardsHeld < guardsFound will be - // true for this case, so we require it to be read-only. + // access is atomic. Further, len(guardsHeld) < guardsFound + // will be true for this case, so we require it to be + // read-only. if lgf.AtomicDisposition != atomicRequired { // There is no force key, no atomic access and no lock held. - pc.maybeFail(inst.Pos(), "invalid field access, must hold %s (%s) when accessing %s (locks: %s)", guardName, s, fieldObj.Name(), ls.String()) + pc.maybeFail(inst.Pos(), "invalid field access, %s (%s) must be locked when accessing %s (locks: %s)", guardName, s, accessObj.Name(), ls.String()) } } @@ -230,25 +235,75 @@ func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj switch lgf.AtomicDisposition { case atomicRequired: // Check that this is used safely as an input. - readOnly := guardsHeld < guardsFound + readOnly := len(guardsHeld) < guardsFound if refs := inst.Referrers(); refs != nil { for _, otherInst := range *refs { - pc.checkAtomicCall(otherInst, fieldObj, true, readOnly) + pc.checkAtomicCall(otherInst, accessObj, true, readOnly) } } // Check that this is not otherwise written non-atomically, // even if we do hold all the locks. if isWrite { - pc.maybeFail(inst.Pos(), "non-atomic write of field %s, writes must still be atomic with locks held (locks: %s)", fieldObj.Name(), ls.String()) + pc.maybeFail(inst.Pos(), "non-atomic write of field %s, writes must still be atomic with locks held (locks: %s)", accessObj.Name(), ls.String()) } case atomicDisallow: // Check that this is *not* used atomically. if refs := inst.Referrers(); refs != nil { for _, otherInst := range *refs { - pc.checkAtomicCall(otherInst, fieldObj, false, false) + pc.checkAtomicCall(otherInst, accessObj, false, false) } } } + + // Check inferred locks. + if accessObj.Pkg() == pc.pass.Pkg { + oo := pc.observationsFor(accessObj) + oo.total++ + for s, info := range ls.lockedMutexes { + // Is this an object for which we have facts? If there + // is no ability to name this object, then we don't + // bother with any inferrence. We also ignore any self + // references (e.g. accessing a mutex while you are + // holding that exact mutex). + if info.object == nil || accessObj == info.object { + continue + } + // Has this already been held? + if _, ok := guardsHeld[s]; ok { + oo.counts[info.object]++ + continue + } + // Is this a global? Record directly. + if _, ok := from.(*ssa.Global); ok { + oo.counts[info.object]++ + continue + } + // Is the object a sibling to the accessObj? We need to + // check all fields and see if they match. We accept + // only siblings and globals for this recommendation. + structType, ok := resolveStruct(from.Type()) + if !ok { + continue + } + for i := 0; i < structType.NumFields(); i++ { + if fieldObj := structType.Field(i); fieldObj == info.object { + // Add to the maybe list. + oo.counts[info.object]++ + } + } + } + } +} + +// checkFieldAccess checks the validity of a field access. +func (pc *passContext) checkFieldAccess(inst almostInst, structObj ssa.Value, field int, ls *lockState, isWrite bool) { + fieldObj, _ := findField(structObj.Type(), field) + pc.checkGuards(inst, structObj, fieldObj, ls, isWrite) +} + +// checkGlobalAccess checks the validity of a global access. +func (pc *passContext) checkGlobalAccess(g *ssa.Global, ls *lockState, isWrite bool) { + pc.checkGuards(g, g, g.Object(), ls, isWrite) } func (pc *passContext) checkCall(call callCommon, lff *lockFunctionFacts, ls *lockState) { @@ -320,8 +375,13 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction if fg.IsAlias && !aliases { continue } - r := fg.resolveCall(call.Common().Args, call.Value()) - if s, ok := ls.unlockField(r, fg.Exclusive); !ok { + r := fg.Resolver.resolveCall(pc, ls, call.Common().Args, call.Value()) + if !r.valid() { + // See above: this cannot be forced. + pc.maybeFail(call.Pos(), "field %s cannot be resolved", fieldName) + continue + } + if s, ok := ls.unlockField(r, fg.Exclusive); !ok && !lff.Ignore { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { pc.maybeFail(call.Pos(), "attempt to release %s (%s), but not held (locks: %s)", fieldName, s, ls.String()) } @@ -337,8 +397,8 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction continue } // Acquire the lock per the annotation. - r := fg.resolveCall(call.Common().Args, call.Value()) - if s, ok := ls.lockField(r, fg.Exclusive); !ok { + r := fg.Resolver.resolveCall(pc, ls, call.Common().Args, call.Value()) + if s, ok := ls.lockField(r, fg.Exclusive); !ok && !lff.Ignore { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { pc.maybeFail(call.Pos(), "attempt to acquire %s (%s), but already held (locks: %s)", fieldName, s, ls.String()) } @@ -361,17 +421,15 @@ func exclusiveStr(exclusive bool) string { // instruction order). func (pc *passContext) checkFunctionCall(call callCommon, fn *types.Func, lff *lockFunctionFacts, ls *lockState) { // Extract the "receiver" properly. - var rcvr ssa.Value + var args []ssa.Value if call.Common().Method != nil { // This is an interface dispatch for sync.Locker. - rcvr = call.Common().Value - } else if args := call.Common().Args; len(args) > 0 && fn.Type().(*types.Signature).Recv() != nil { + args = append([]ssa.Value{call.Common().Value}, call.Common().Args...) + } else { // This matches the signature for the relevant // sync.Lock/sync.Unlock functions below. - rcvr = args[0] + args = call.Common().Args } - // Note that at this point, rcvr may be nil, but it should not match any - // of the function signatures below where rcvr may be used. // Check all guards required are held. Note that this explicitly does // not include aliases, hence false being passed below. @@ -379,7 +437,7 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *types.Func, lff *l if fg.IsAlias { continue } - r := fg.resolveCall(call.Common().Args, call.Value()) + r := fg.Resolver.resolveCall(pc, ls, args, call.Value()) if s, ok := ls.isHeld(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { pc.maybeFail(call.Pos(), "must hold %s %s (%s) to call %s, but not held (locks: %s)", fieldName, exclusiveStr(fg.Exclusive), s, fn.Name(), ls.String()) @@ -395,15 +453,16 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *types.Func, lff *l // Check if it's a method dispatch for something in the sync package. // See: https://godoc.org/golang.org/x/tools/go/ssa#Function - if fn.Pkg() != nil && fn.Pkg().Name() == "sync" { + if fn.Pkg() != nil && fn.Pkg().Name() == "sync" && len(args) > 0 { + rv := makeResolvedValue(args[0], nil) isExclusive := false switch fn.Name() { case "Lock": isExclusive = true fallthrough case "RLock": - if s, ok := ls.lockField(resolvedValue{value: rcvr, valid: true}, isExclusive); !ok { - if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { + if s, ok := ls.lockField(rv, isExclusive); !ok && !lff.Ignore { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { // Double locking a mutex that is already locked. pc.maybeFail(call.Pos(), "%s already locked (locks: %s)", s, ls.String()) } @@ -412,14 +471,14 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *types.Func, lff *l isExclusive = true fallthrough case "RUnlock": - if s, ok := ls.unlockField(resolvedValue{value: rcvr, valid: true}, isExclusive); !ok { - if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { + if s, ok := ls.unlockField(rv, isExclusive); !ok && !lff.Ignore { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { // Unlocking something that is already unlocked. pc.maybeFail(call.Pos(), "%s already unlocked or locked differently (locks: %s)", s, ls.String()) } } case "DowngradeLock": - if s, ok := ls.downgradeField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { + if s, ok := ls.downgradeField(rv); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { // Downgrading something that may not be downgraded. pc.maybeFail(call.Pos(), "%s already unlocked or not exclusive (locks: %s)", s, ls.String()) @@ -497,6 +556,24 @@ type callCommon interface { // checkInstruction checks the legality the single instruction based on the // current lockState. func (pc *passContext) checkInstruction(inst ssa.Instruction, lff *lockFunctionFacts, ls *lockState) (*ssa.Return, *lockState) { + // Record any observed globals, and check for violations. The global + // value is not itself an instruction, but we check all referrers to + // see where they are consumed. + var stackLocal [16]*ssa.Value + ops := inst.Operands(stackLocal[:]) + for _, v := range ops { + if v == nil { + continue + } + g, ok := (*v).(*ssa.Global) + if !ok { + continue + } + _, isWrite := inst.(*ssa.Store) + pc.checkGlobalAccess(g, ls, isWrite) + } + + // Process the instruction. switch x := inst.(type) { case *ssa.Store: // Record that this value is holding this other value. This is @@ -611,7 +688,12 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, failed := false // Validate held locks. for fieldName, fg := range lff.HeldOnExit { - r := fg.resolveStatic(fn, rv) + r := fg.Resolver.resolveStatic(pc, ls, fn, rv) + if !r.valid() { + // This cannot be forced, since we have no reference. + pc.maybeFail(rv.Pos(), "lock %s cannot be resolved", fieldName) + continue + } if s, ok := rls.isHeld(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(rv.Pos())]; !ok && !lff.Ignore { pc.maybeFail(rv.Pos(), "lock %s (%s) not held %s (locks: %s)", fieldName, s, exclusiveStr(fg.Exclusive), rls.String()) @@ -684,7 +766,12 @@ func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *loc for fieldName, fg := range lff.HeldOnEntry { // The first is the method object itself so we skip that when looking // for receiver/function parameters. - r := fg.resolveStatic(fn, call.Value()) + r := fg.Resolver.resolveStatic(pc, ls, fn, call.Value()) + if !r.valid() { + // See above: this cannot be forced. + pc.maybeFail(fn.Pos(), "lock %s cannot be resolved", fieldName) + continue + } if s, ok := ls.lockField(r, fg.Exclusive); !ok && !lff.Ignore { // This can only happen if the same value is declared // multiple times, and should be caught by the earlier @@ -710,3 +797,24 @@ func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *loc pc.postFunctionCallUpdate(call, lff, parent, true /* aliases */) } } + +// checkInferred checks for any inferred lock annotations. +func (pc *passContext) checkInferred() { + for obj, oo := range pc.observations { + var lgf lockGuardFacts + pc.pass.ImportObjectFact(obj, &lgf) + for other, count := range oo.counts { + // Is this already a guard? + if _, ok := lgf.GuardedBy[other.Name()]; ok { + continue + } + // Check to see if this field is used with a given lock + // held above the threshold. If yes, provide a helpful + // hint that this may something you wish to annotate. + const threshold = 0.9 + if usage := float64(count) / float64(oo.total); usage >= threshold { + pc.maybeFail(obj.Pos(), "may require checklocks annotation for %s, used with lock held %2.0f%% of the time", other.Name(), usage*100) + } + } + } +} diff --git a/tools/checklocks/checklocks.go b/tools/checklocks/checklocks.go index ae8db1a36..939af4239 100644 --- a/tools/checklocks/checklocks.go +++ b/tools/checklocks/checklocks.go @@ -30,20 +30,61 @@ import ( // Analyzer is the main entrypoint. var Analyzer = &analysis.Analyzer{ - Name: "checklocks", - Doc: "checks lock preconditions on functions and fields", - Run: run, - Requires: []*analysis.Analyzer{buildssa.Analyzer}, - FactTypes: []analysis.Fact{(*atomicAlignment)(nil), (*lockFieldFacts)(nil), (*lockGuardFacts)(nil), (*lockFunctionFacts)(nil)}, + Name: "checklocks", + Doc: "checks lock preconditions on functions and fields", + Run: run, + Requires: []*analysis.Analyzer{buildssa.Analyzer}, + FactTypes: []analysis.Fact{ + (*atomicAlignment)(nil), + (*lockGuardFacts)(nil), + (*lockFunctionFacts)(nil), + }, +} + +// objectObservations tracks lock correlations. +type objectObservations struct { + counts map[types.Object]int + total int } // passContext is a pass with additional expected failures. type passContext struct { - pass *analysis.Pass - failures map[positionKey]*failData - exemptions map[positionKey]struct{} - forced map[positionKey]struct{} - functions map[*ssa.Function]struct{} + pass *analysis.Pass + failures map[positionKey]*failData + exemptions map[positionKey]struct{} + forced map[positionKey]struct{} + functions map[*ssa.Function]struct{} + observations map[types.Object]*objectObservations +} + +// observationsFor retrieves observations for the given object. +func (pc *passContext) observationsFor(obj types.Object) *objectObservations { + if pc.observations == nil { + pc.observations = make(map[types.Object]*objectObservations) + } + oo, ok := pc.observations[obj] + if !ok { + oo = &objectObservations{ + counts: make(map[types.Object]int), + } + pc.observations[obj] = oo + } + return oo +} + +// forAllGlobals applies the given function to all globals. +func (pc *passContext) forAllGlobals(fn func(ts *ast.ValueSpec)) { + for _, f := range pc.pass.Files { + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != token.VAR { + continue + } + for _, gs := range d.Specs { + fn(gs.(*ast.ValueSpec)) + } + } + } } // forAllTypes applies the given function over all types. @@ -88,16 +129,17 @@ func run(pass *analysis.Pass) (interface{}, error) { pc.extractLineFailures() // Find all struct declarations and export relevant facts. - pc.forAllTypes(func(ts *ast.TypeSpec) { - if ss, ok := ts.Type.(*ast.StructType); ok { - structType := pc.pass.TypesInfo.TypeOf(ts.Name).Underlying().(*types.Struct) - pc.exportLockFieldFacts(structType, ss) + pc.forAllGlobals(func(vs *ast.ValueSpec) { + if ss, ok := vs.Type.(*ast.StructType); ok { + structType := pc.pass.TypesInfo.TypeOf(vs.Type).Underlying().(*types.Struct) + pc.structLockGuardFacts(structType, ss) } + pc.globalLockGuardFacts(vs) }) pc.forAllTypes(func(ts *ast.TypeSpec) { if ss, ok := ts.Type.(*ast.StructType); ok { structType := pc.pass.TypesInfo.TypeOf(ts.Name).Underlying().(*types.Struct) - pc.exportLockGuardFacts(structType, ss) + pc.structLockGuardFacts(structType, ss) } }) @@ -112,7 +154,7 @@ func run(pass *analysis.Pass) (interface{}, error) { // Find all function declarations and export relevant facts. pc.forAllFunctions(func(fn *ast.FuncDecl) { - pc.exportFunctionFacts(fn) + pc.functionFacts(fn) }) // Scan all code looking for invalid accesses. @@ -144,6 +186,9 @@ func run(pass *analysis.Pass) (interface{}, error) { pc.checkFunction(nil, fn, &nolff, nil, false /* force */) } + // Check for inferred checklocks annotations. + pc.checkInferred() + // Check for expected failures. pc.checkFailures() diff --git a/tools/checklocks/facts.go b/tools/checklocks/facts.go index 17aef5790..f6dfeaec9 100644 --- a/tools/checklocks/facts.go +++ b/tools/checklocks/facts.go @@ -15,6 +15,7 @@ package checklocks import ( + "encoding/gob" "fmt" "go/ast" "go/token" @@ -22,6 +23,7 @@ import ( "regexp" "strings" + "golang.org/x/tools/go/analysis/passes/buildssa" "golang.org/x/tools/go/ssa" ) @@ -46,99 +48,110 @@ const ( atomicRequired ) +// fieldEntry is a single field type. +type fieldEntry interface { + // synthesize produces a string that is compatible with valueAndObject, + // along with the same object that should be produced in that case. + // + // Note that it is called synthesize because this is produced only the + // type information, and not with any ssa.Value objects. + synthesize(s string, typ types.Type) (string, types.Object) +} + +// fieldStruct is a non-pointer struct element. +type fieldStruct int + +// synthesize implements fieldEntry.synthesize. +func (f fieldStruct) synthesize(s string, typ types.Type) (string, types.Object) { + field, ok := findField(typ, int(f)) + if !ok { + // Should not happen as long as fieldList construction is correct. + panic(fmt.Sprintf("unable to resolve field %d in %s", int(f), typ.String())) + } + return fmt.Sprintf("&(%s.%s)", s, field.Name()), field +} + +// fieldStructPtr is a pointer struct element. +type fieldStructPtr int + +// synthesize implements fieldEntry.synthesize. +func (f fieldStructPtr) synthesize(s string, typ types.Type) (string, types.Object) { + field, ok := findField(typ, int(f)) + if !ok { + // See above, this should not happen. + panic(fmt.Sprintf("unable to resolve ptr field %d in %s", int(f), typ.String())) + } + return fmt.Sprintf("*(&(%s.%s))", s, field.Name()), field +} + // fieldList is a simple list of fields, used in two types below. -// -// Note that the integers in this list refer to one of two things: -// - A positive integer refers to a field index in a struct. -// - A negative integer refers to a field index in a struct, where -// that field is a pointer and must be subsequently resolved. -type fieldList []int +type fieldList []fieldEntry // resolvedValue is an ssa.Value with additional fields. // // This can be resolved to a string as part of a lock state. type resolvedValue struct { value ssa.Value - valid bool - fieldList []int -} - -// findExtract finds a relevant extract. This must exist within the referrers -// to the call object. If this doesn't then the object which is locked is never -// consumed, and we should consider this a bug. -func findExtract(v ssa.Value, index int) (ssa.Value, bool) { - if refs := v.Referrers(); refs != nil { - for _, inst := range *refs { - if x, ok := inst.(*ssa.Extract); ok && x.Tuple == v && x.Index == index { - return inst.(ssa.Value), true - } - } - } - return nil, false + fieldList fieldList } -// resolve resolves the given field list. -func (fl fieldList) resolve(v ssa.Value) (rv resolvedValue) { +// makeResolvedValue makes a new resolvedValue. +func makeResolvedValue(v ssa.Value, fl fieldList) resolvedValue { return resolvedValue{ value: v, fieldList: fl, - valid: true, } } -// valueAsString returns a string representing this value. +// valid indicates whether this is a valid resolvedValue. +func (rv *resolvedValue) valid() bool { + return rv.value != nil +} + +// valueAndObject returns a string and object. // -// This must align with how the string is generated in valueAsString. -func (rv resolvedValue) valueAsString(ls *lockState) string { +// This uses the lockState valueAndObject in order to produce a string and +// object for the base ssa.Value, then synthesizes a string representation +// based on the fieldList. +func (rv *resolvedValue) valueAndObject(ls *lockState) (string, types.Object) { + // N.B. obj.Type() and typ should be equal, but a check is omitted + // since, 1) we automatically chase through pointers during field + // resolution, and 2) obj may be nil if there is no source object. + s, obj := ls.valueAndObject(rv.value) typ := rv.value.Type() - s := ls.valueAsString(rv.value) - for i, fieldNumber := range rv.fieldList { - switch { - case fieldNumber > 0: - field, ok := findField(typ, fieldNumber-1) - if !ok { - // This can't be resolved, return for debugging. - return fmt.Sprintf("{%s+%v}", s, rv.fieldList[i:]) - } - s = fmt.Sprintf("&(%s.%s)", s, field.Name()) - typ = field.Type() - case fieldNumber < 1: - field, ok := findField(typ, (-fieldNumber)-1) - if !ok { - // See above. - return fmt.Sprintf("{%s+%v}", s, rv.fieldList[i:]) - } - s = fmt.Sprintf("*(&(%s.%s))", s, field.Name()) - typ = field.Type() - } + for _, entry := range rv.fieldList { + s, obj = entry.synthesize(s, typ) + typ = obj.Type() } - return s + return s, obj } -// lockFieldFacts apply on every struct field. -type lockFieldFacts struct { - // IsMutex is true if the field is of type sync.Mutex. - IsMutex bool - - // IsRWMutex is true if the field is of type sync.RWMutex. - IsRWMutex bool - - // IsPointer indicates if the field is a pointer. - IsPointer bool - - // FieldNumber is the number of this field in the struct. - FieldNumber int +// fieldGuardResolver details a guard for a field. +type fieldGuardResolver interface { + // resolveField is used to resolve a guard during a field access. The + // parent structure is available, as well as the current lock state. + resolveField(pc *passContext, ls *lockState, parent ssa.Value) resolvedValue } -// AFact implements analysis.Fact.AFact. -func (*lockFieldFacts) AFact() {} +// functionGuardResolver details a guard for a function. +type functionGuardResolver interface { + // resolveStatic is used to resolve a guard during static analysis, + // e.g. based on static annotations applied to a method. The function's + // ssa object is available, as well as the return value. + resolveStatic(pc *passContext, ls *lockState, fn *ssa.Function, rv interface{}) resolvedValue + + // resolveCall is used to resolve a guard during a call. The ssa + // return value is available from the instruction context where the + // call occurs, but the target's ssa representation is not available. + resolveCall(pc *passContext, ls *lockState, args []ssa.Value, rv ssa.Value) resolvedValue +} // lockGuardFacts contains guard information. type lockGuardFacts struct { // GuardedBy is the set of locks that are guarding this field. The key // is the original annotation value, and the field list is the object // traversal path. - GuardedBy map[string]fieldList + GuardedBy map[string]fieldGuardResolver // AtomicDisposition is the disposition for this field. Note that this // can affect the interpretation of the GuardedBy field above, see the @@ -149,86 +162,142 @@ type lockGuardFacts struct { // AFact implements analysis.Fact.AFact. func (*lockGuardFacts) AFact() {} -// functionGuard is used by lockFunctionFacts, below. -type functionGuard struct { - // ParameterNumber is the index of the object that contains the - // guarding mutex. From this parameter, a walk is performed - // subsequently using the resolve method. - // - // Note that is ParameterNumber is beyond the size of parameters, then - // it may return to a return value. This applies only for the Acquires - // relation below. - ParameterNumber int +// globalGuard is a global value. +type globalGuard struct { + // Object indicates the object from which resolution should occur. + Object types.Object + + // FieldList is the traversal path from object. + FieldList fieldList +} + +// ssaPackager returns the ssa package. +type ssaPackager interface { + Package() *ssa.Package +} + +// resolveCommon implements resolution for all cases. +func (g *globalGuard) resolveCommon(pc *passContext, ls *lockState) resolvedValue { + state := pc.pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) + v := state.Pkg.Members[g.Object.Name()].(ssa.Value) + return makeResolvedValue(v, g.FieldList) +} + +// resolveStatic implements functionGuardResolver.resolveStatic. +func (g *globalGuard) resolveStatic(pc *passContext, ls *lockState, _ *ssa.Function, v interface{}) resolvedValue { + return g.resolveCommon(pc, ls) +} + +// resolveCall implements functionGuardResolver.resolveCall. +func (g *globalGuard) resolveCall(pc *passContext, ls *lockState, _ []ssa.Value, v ssa.Value) resolvedValue { + return g.resolveCommon(pc, ls) +} + +// resolveField implements fieldGuardResolver.resolveField. +func (g *globalGuard) resolveField(pc *passContext, ls *lockState, parent ssa.Value) resolvedValue { + return g.resolveCommon(pc, ls) +} + +// fieldGuard is a field-based guard. +type fieldGuard struct { + // FieldList is the traversal path from the parent. + FieldList fieldList +} + +// resolveField implements fieldGuardResolver.resolveField. +func (f *fieldGuard) resolveField(_ *passContext, _ *lockState, parent ssa.Value) resolvedValue { + return makeResolvedValue(parent, f.FieldList) +} + +// parameterGuard is a parameter-based guard. +type parameterGuard struct { + // Index is the parameter index of the object that contains the + // guarding mutex. + Index int + + // fieldList is the traversal path from the parameter. + FieldList fieldList +} + +// resolveStatic implements functionGuardResolver.resolveStatic. +func (p *parameterGuard) resolveStatic(_ *passContext, _ *lockState, fn *ssa.Function, _ interface{}) resolvedValue { + return makeResolvedValue(fn.Params[p.Index], p.FieldList) +} + +// resolveCall implements functionGuardResolver.resolveCall. +func (p *parameterGuard) resolveCall(_ *passContext, _ *lockState, args []ssa.Value, _ ssa.Value) resolvedValue { + return makeResolvedValue(args[p.Index], p.FieldList) +} + +// returnGuard is a return-based guard. +type returnGuard struct { + // Index is the index of the return value. + Index int // NeedsExtract is used in the case of a return value, and indicates // that the field must be extracted from a tuple. NeedsExtract bool - // IsAlias indicates that this guard is an alias. - IsAlias bool - - // FieldList is the traversal path to the object. + // FieldList is the traversal path from the return value. FieldList fieldList - - // Exclusive indicates an exclusive lock is required. - Exclusive bool } -// resolveReturn resolves a return value. -// -// Precondition: rv is either an ssa.Value, or an *ssa.Return. -func (fg *functionGuard) resolveReturn(rv interface{}, args int) resolvedValue { +// resolveCommon implements resolution for both cases. +func (r *returnGuard) resolveCommon(rv interface{}) resolvedValue { if rv == nil { // For defers and other objects, this may be nil. This is - // handled in state.go in the actual lock checking logic. - return resolvedValue{ - value: nil, - valid: false, - } + // handled in state.go in the actual lock checking logic. This + // means that there is no resolvedValue available. + return resolvedValue{} } - index := fg.ParameterNumber - args // If this is a *ssa.Return object, i.e. we are analyzing the function // and not the call site, then we can just pull the result directly. - if r, ok := rv.(*ssa.Return); ok { - return fg.FieldList.resolve(r.Results[index]) + if ret, ok := rv.(*ssa.Return); ok { + return makeResolvedValue(ret.Results[r.Index], r.FieldList) } - if fg.NeedsExtract { + if r.NeedsExtract { // Resolve on the extracted field, this is necessary if the // type here is not an explicit return. Note that rv must be an // ssa.Value, since it is not an *ssa.Return. - v, ok := findExtract(rv.(ssa.Value), index) - if !ok { - return resolvedValue{ - value: v, - valid: false, + v := rv.(ssa.Value) + if refs := v.Referrers(); refs != nil { + for _, inst := range *refs { + if x, ok := inst.(*ssa.Extract); ok && x.Tuple == v && x.Index == r.Index { + return makeResolvedValue(x, r.FieldList) + } } } - return fg.FieldList.resolve(v) + // Nothing resolved. + return resolvedValue{} } - if index != 0 { + if r.Index != 0 { // This should not happen, NeedsExtract should always be set. panic("NeedsExtract is false, but return value index is non-zero") } // Resolve on the single return. - return fg.FieldList.resolve(rv.(ssa.Value)) + return makeResolvedValue(rv.(ssa.Value), r.FieldList) } -// resolveStatic returns an ssa.Value representing the given field. -// -// Precondition: per resolveReturn. -func (fg *functionGuard) resolveStatic(fn *ssa.Function, rv interface{}) resolvedValue { - if fg.ParameterNumber >= len(fn.Params) { - return fg.resolveReturn(rv, len(fn.Params)) - } - return fg.FieldList.resolve(fn.Params[fg.ParameterNumber]) +// resolveStatic implements functionGuardResolver.resolveStatic. +func (r *returnGuard) resolveStatic(_ *passContext, _ *lockState, _ *ssa.Function, rv interface{}) resolvedValue { + return r.resolveCommon(rv) } -// resolveCall returns an ssa.Value representing the given field. -func (fg *functionGuard) resolveCall(args []ssa.Value, rv ssa.Value) resolvedValue { - if fg.ParameterNumber >= len(args) { - return fg.resolveReturn(rv, len(args)) - } - return fg.FieldList.resolve(args[fg.ParameterNumber]) +// resolveCall implements functionGuardResolver.resolveCall. +func (r *returnGuard) resolveCall(_ *passContext, _ *lockState, _ []ssa.Value, rv ssa.Value) resolvedValue { + return r.resolveCommon(rv) +} + +// functionGuardInfo is information about a method guard. +type functionGuardInfo struct { + // Resolver is the resolver for this guard. + Resolver functionGuardResolver + + // IsAlias indicates that this guard is an alias. + IsAlias bool + + // Exclusive indicates an exclusive lock is required. + Exclusive bool } // lockFunctionFacts apply on every method. @@ -250,13 +319,11 @@ type lockFunctionFacts struct { // ``` // // '`+checklocks:a.mu' will result in an entry in this map as shown below. - // HeldOnEntry: {"a.mu" => {ParameterNumber: 0, FieldNumbers: {0}} - // - // Unlikely lockFieldFacts, there is no atomic interpretation. - HeldOnEntry map[string]functionGuard + // HeldOnEntry: {"a.mu" => {Resolver: ¶meterGuard{Index: 0}} + HeldOnEntry map[string]functionGuardInfo // HeldOnExit tracks the locks that are expected to be held on exit. - HeldOnExit map[string]functionGuard + HeldOnExit map[string]functionGuardInfo // Ignore means this function has local analysis ignores. // @@ -268,14 +335,14 @@ type lockFunctionFacts struct { func (*lockFunctionFacts) AFact() {} // checkGuard validates the guardName. -func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) { +func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuardInfo, bool) { if _, ok := lff.HeldOnEntry[guardName]; ok { pc.maybeFail(d.Pos(), "annotation %s specified more than once, already required", guardName) - return functionGuard{}, false + return functionGuardInfo{}, false } if _, ok := lff.HeldOnExit[guardName]; ok { pc.maybeFail(d.Pos(), "annotation %s specified more than once, already acquired", guardName) - return functionGuard{}, false + return functionGuardInfo{}, false } fg, ok := pc.findFunctionGuard(d, guardName, exclusive, allowReturn) return fg, ok @@ -285,10 +352,10 @@ func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guard func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok { if lff.HeldOnEntry == nil { - lff.HeldOnEntry = make(map[string]functionGuard) + lff.HeldOnEntry = make(map[string]functionGuardInfo) } if lff.HeldOnExit == nil { - lff.HeldOnExit = make(map[string]functionGuard) + lff.HeldOnExit = make(map[string]functionGuardInfo) } lff.HeldOnEntry[guardName] = fg lff.HeldOnExit[guardName] = fg @@ -299,7 +366,7 @@ func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, gua func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, true /* allowReturn */); ok { if lff.HeldOnExit == nil { - lff.HeldOnExit = make(map[string]functionGuard) + lff.HeldOnExit = make(map[string]functionGuardInfo) } lff.HeldOnExit[guardName] = fg } @@ -309,7 +376,7 @@ func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guar func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok { if lff.HeldOnEntry == nil { - lff.HeldOnEntry = make(map[string]functionGuard) + lff.HeldOnEntry = make(map[string]functionGuardInfo) } lff.HeldOnEntry[guardName] = fg } @@ -345,213 +412,276 @@ func (lff *lockFunctionFacts) addAlias(pc *passContext, d *ast.FuncDecl, guardNa } } -// fieldListFor returns the fieldList for the given object. -func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool, exclusive bool) (int, bool) { - var lff lockFieldFacts - if !pc.pass.ImportObjectFact(fieldObj, &lff) { - // This should not happen: we export facts for all fields. - panic(fmt.Sprintf("no lockFieldFacts available for field %s", fieldName)) - } - // Check that it is indeed a mutex. - if checkMutex && !lff.IsMutex && !lff.IsRWMutex { - pc.maybeFail(pos, "field %s is not a Mutex or an RWMutex", fieldName) - return 0, false - } - if checkMutex && !exclusive && !lff.IsRWMutex { - pc.maybeFail(pos, "field %s must be a RWMutex, but it is not", fieldName) - return 0, false - } +// fieldEntryFor returns the fieldList value for the given object. +func (pc *passContext) fieldEntryFor(fieldObj types.Object, index int) fieldEntry { + // Return the resolution path. - if lff.IsPointer { - return -(index + 1), true + if _, ok := fieldObj.Type().Underlying().(*types.Pointer); ok { + return fieldStructPtr(index) } - return (index + 1), true + if _, ok := fieldObj.Type().Underlying().(*types.Interface); ok { + return fieldStructPtr(index) + } + return fieldStruct(index) } -// resolveOneField resolves a field in a single struct. -func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool, exclusive bool) (fl fieldList, fieldObj types.Object, ok bool) { +// findField resolves a field in a single struct. +func (pc *passContext) findField(structType *types.Struct, fieldName string) (fl fieldList, fieldObj types.Object, ok bool) { // Scan to match the next field. for i := 0; i < structType.NumFields(); i++ { fieldObj := structType.Field(i) if fieldObj.Name() != fieldName { continue } - flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex, exclusive) - if !ok { - return nil, nil, false - } - fl = append(fl, flOne) + fl = append(fl, pc.fieldEntryFor(fieldObj, i)) return fl, fieldObj, true } + // Is this an embed? for i := 0; i < structType.NumFields(); i++ { fieldObj := structType.Field(i) if !fieldObj.Embedded() { continue } + // Is this an embedded struct? structType, ok := resolveStruct(fieldObj.Type()) if !ok { continue } + // Need to check that there is a resolution path. If there is // no resolution path that's not a failure: we just continue // scanning the next embed to find a match. - flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false, exclusive) - flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex, exclusive) - if okEmbed && okCont { - fl = append(fl, flEmbed) - fl = append(fl, flCont...) - return fl, fieldObjCont, true + flEmbed := pc.fieldEntryFor(fieldObj, i) + flNext, fieldObjNext, ok := pc.findField(structType, fieldName) + if !ok { + continue } + + // Found an embedded chain. + fl = append(fl, flEmbed) + fl = append(fl, flNext...) + return fl, fieldObjNext, true } - pc.maybeFail(pos, "field %s does not exist", fieldName) + return nil, nil, false } -// resolveField resolves a set of fields given a string, such a 'a.b.c'. +var ( + mutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineMutex|Mutex)") + rwMutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineRWMutex|RWMutex)") + lockerRE = regexp.MustCompile("((.*/)|^)sync.Locker") +) + +// validateMutex validates the mutex type. +// +// This function returns true iff the object is a valid mutex with an error +// reported at the given position if necessary. +func (pc *passContext) validateMutex(pos token.Pos, obj types.Object, exclusive bool) bool { + // Check that it is indeed a mutex. + s := obj.Type().String() + switch { + case mutexRE.MatchString(s), lockerRE.MatchString(s): + // Safe for exclusive cases. + if !exclusive { + pc.maybeFail(pos, "field %s must be a RWMutex", obj.Name()) + return false + } + return true + case rwMutexRE.MatchString(s): + // Safe for all cases. + return true + default: + // Not a mutex at all? + pc.maybeFail(pos, "field %s is not a Mutex or an RWMutex", obj.Name()) + return false + } +} + +// findFieldList resolves a set of fields given a string, such a 'a.b.c'. // -// Note that this checks that the final element is a mutex of some kind, and -// will fail appropriately. -func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string, exclusive bool) (fl fieldList, ok bool) { - for partNumber, fieldName := range parts { - flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */, exclusive) +// Note that parts must be non-zero in length. If it may be zero, then +// maybeFindFieldList should be used instead with an appropriate object. +func (pc *passContext) findFieldList(pos token.Pos, structType *types.Struct, parts []string, exclusive bool) (fl fieldList, ok bool) { + var obj types.Object + + // This loop requires at least one iteration in order to ensure that + // obj above is non-nil, and the type can be validated. + for i, fieldName := range parts { + flOne, fieldObj, ok := pc.findField(structType, fieldName) if !ok { - // Error already reported. return nil, false } fl = append(fl, flOne...) - if partNumber < len(parts)-1 { - // Traverse to the next type. - structType, ok = resolveStruct(fieldObj.Type()) + obj = fieldObj + if i < len(parts)-1 { + structType, ok = resolveStruct(obj.Type()) if !ok { - pc.maybeFail(pos, "invalid intermediate field %s", fieldName) - return fl, false + // N.B. This is associated with the original position. + pc.maybeFail(pos, "field %s expected to be struct", fieldName) + return nil, false } } } + + // Validate the final field. This reports the field to the caller + // anyways, since the error will be reported only once. + _ = pc.validateMutex(pos, obj, exclusive) return fl, true } -var ( - mutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineMutex|Mutex)") - rwMutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineRWMutex|RWMutex)") - lockerRE = regexp.MustCompile("((.*/)|^)sync.Locker") -) - -// exportLockFieldFacts finds all struct fields that are mutexes, and ensures -// that they are annotated properly. +// maybeFindFieldList resolves the given object. // -// This information is consumed subsequently by exportLockGuardFacts, and this -// function must be called first on all structures. -func (pc *passContext) exportLockFieldFacts(structType *types.Struct, ss *ast.StructType) { - for i, field := range ss.Fields.List { - lff := &lockFieldFacts{ - FieldNumber: i, - } - // We use HasSuffix below because fieldType can be fully - // qualified with the package name eg for the gvisor sync - // package mutex fields have the type: - // "/sync/sync.Mutex" - fieldObj := structType.Field(i) - s := fieldObj.Type().String() - switch { - case mutexRE.MatchString(s): - lff.IsMutex = true - case rwMutexRE.MatchString(s): - lff.IsRWMutex = true - case lockerRE.MatchString(s): - lff.IsMutex = true - } - // Save whether this is a pointer. - _, lff.IsPointer = fieldObj.Type().Underlying().(*types.Pointer) - if !lff.IsPointer { - _, lff.IsPointer = fieldObj.Type().Underlying().(*types.Interface) - } - // We must always export the lockFieldFacts, since traversal - // can take place along any object in the struct. - pc.pass.ExportObjectFact(fieldObj, lff) - // If this is an anonymous type, then we won't discover it via - // the AST global declarations. We can recurse from here. - if ss, ok := field.Type.(*ast.StructType); ok { - if st, ok := fieldObj.Type().(*types.Struct); ok { - pc.exportLockFieldFacts(st, ss) - } +// Parts may be the empty list, unlike findFieldList. +func (pc *passContext) maybeFindFieldList(pos token.Pos, obj types.Object, parts []string, exclusive bool) (fl fieldList, ok bool) { + if len(parts) > 0 { + structType, ok := resolveStruct(obj.Type()) + if !ok { + // This does not have any fields; the access is not allowed. + pc.maybeFail(pos, "attempted field access on non-struct") + return nil, false } + return pc.findFieldList(pos, structType, parts, exclusive) } + + // See above. + _ = pc.validateMutex(pos, obj, exclusive) + return nil, true } -// exportLockGuardFacts finds all relevant guard information for structures. -// -// This function requires exportLockFieldFacts be called first on all -// structures. -func (pc *passContext) exportLockGuardFacts(structType *types.Struct, ss *ast.StructType) { - for i, field := range ss.Fields.List { - fieldObj := structType.Field(i) - if field.Doc != nil { - var ( - lff lockFieldFacts - lgf lockGuardFacts - ) - pc.pass.ImportObjectFact(structType.Field(i), &lff) - for _, l := range field.Doc.List { - pc.extractAnnotations(l.Text, map[string]func(string){ - checkAtomicAnnotation: func(string) { - switch lgf.AtomicDisposition { - case atomicRequired: - pc.maybeFail(fieldObj.Pos(), "annotation is redundant, already atomic required") - case atomicIgnore: - pc.maybeFail(fieldObj.Pos(), "annotation is contradictory, already atomic ignored") - } - lgf.AtomicDisposition = atomicRequired - }, - checkLocksIgnore: func(string) { - switch lgf.AtomicDisposition { - case atomicIgnore: - pc.maybeFail(fieldObj.Pos(), "annotation is redundant, already atomic ignored") - case atomicRequired: - pc.maybeFail(fieldObj.Pos(), "annotation is contradictory, already atomic required") - } - lgf.AtomicDisposition = atomicIgnore - }, - checkLocksAnnotation: func(guardName string) { - // Check for a duplicate annotation. - if _, ok := lgf.GuardedBy[guardName]; ok { - pc.maybeFail(fieldObj.Pos(), "annotation %s specified more than once", guardName) - return - } - fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, "."), true /* exclusive */) - if ok { - // If we successfully resolved the field, then save it. - if lgf.GuardedBy == nil { - lgf.GuardedBy = make(map[string]fieldList) - } - lgf.GuardedBy[guardName] = fl - } - }, - // N.B. We support only the vanilla - // annotation on individual fields. If - // the field is a read lock, then we - // will allow read access by default. - checkLocksAnnotationRead: func(guardName string) { - pc.maybeFail(fieldObj.Pos(), "annotation %s not legal on fields", guardName) - }, - }) - } - // Save only if there is something meaningful. - if len(lgf.GuardedBy) > 0 || lgf.AtomicDisposition != atomicDisallow { - pc.pass.ExportObjectFact(structType.Field(i), &lgf) - } +// findFieldGuardResolver finds a symbol resolver. +type findFieldGuardResolver func(pos token.Pos, guardName string) (fieldGuardResolver, bool) + +// findFunctionGuardResolver finds a symbol resolver. +type findFunctionGuardResolver func(pos token.Pos, guardName string) (functionGuardResolver, bool) + +// fillLockGuardFacts fills the facts with guard information. +func (pc *passContext) fillLockGuardFacts(obj types.Object, cg *ast.CommentGroup, find findFieldGuardResolver, lgf *lockGuardFacts) { + if cg == nil { + return + } + for _, l := range cg.List { + pc.extractAnnotations(l.Text, map[string]func(string){ + checkAtomicAnnotation: func(string) { + switch lgf.AtomicDisposition { + case atomicRequired: + pc.maybeFail(obj.Pos(), "annotation is redundant, already atomic required") + case atomicIgnore: + pc.maybeFail(obj.Pos(), "annotation is contradictory, already atomic ignored") + } + lgf.AtomicDisposition = atomicRequired + }, + checkLocksIgnore: func(string) { + switch lgf.AtomicDisposition { + case atomicIgnore: + pc.maybeFail(obj.Pos(), "annotation is redundant, already atomic ignored") + case atomicRequired: + pc.maybeFail(obj.Pos(), "annotation is contradictory, already atomic required") + } + lgf.AtomicDisposition = atomicIgnore + }, + checkLocksAnnotation: func(guardName string) { + // Check for a duplicate annotation. + if _, ok := lgf.GuardedBy[guardName]; ok { + pc.maybeFail(obj.Pos(), "annotation %s specified more than once", guardName) + return + } + // Add the item. + if lgf.GuardedBy == nil { + lgf.GuardedBy = make(map[string]fieldGuardResolver) + } + fr, ok := find(obj.Pos(), guardName) + if !ok { + pc.maybeFail(obj.Pos(), "annotation %s cannot be resolved", guardName) + return + } + lgf.GuardedBy[guardName] = fr + }, + // N.B. We support only the vanilla annotation on + // individual fields. If the field is a read lock, then + // we will allow read access by default. + checkLocksAnnotationRead: func(guardName string) { + pc.maybeFail(obj.Pos(), "annotation %s not legal on fields", guardName) + }, + }) + } + // Save only if there is something meaningful. + if len(lgf.GuardedBy) > 0 || lgf.AtomicDisposition != atomicDisallow { + pc.pass.ExportObjectFact(obj, lgf) + } +} + +// findGlobalGuard attempts to resolve a name globally. +func (pc *passContext) findGlobalGuard(pos token.Pos, guardName string) (*globalGuard, bool) { + // Attempt to resolve the object. + parts := strings.Split(guardName, ".") + globalObj := pc.pass.Pkg.Scope().Lookup(parts[0]) + if globalObj == nil { + // No global object. + return nil, false + } + fl, ok := pc.maybeFindFieldList(pos, globalObj, parts[1:], true /* exclusive */) + if !ok { + // Invalid fields. + return nil, false + } + return &globalGuard{ + Object: globalObj, + FieldList: fl, + }, true +} + +// findGlobalFieldGuard is compatible with findFieldGuardResolver. +func (pc *passContext) findGlobalFieldGuard(pos token.Pos, guardName string) (fieldGuardResolver, bool) { + g, ok := pc.findGlobalGuard(pos, guardName) + return g, ok +} + +// findGlobalFunctionGuard is compatible with findFunctionGuardResolver. +func (pc *passContext) findGlobalFunctionGuard(pos token.Pos, guardName string) (functionGuardResolver, bool) { + g, ok := pc.findGlobalGuard(pos, guardName) + return g, ok +} + +// structLockGuardFacts finds all relevant guard information for structures. +func (pc *passContext) structLockGuardFacts(structType *types.Struct, ss *ast.StructType) { + var fieldObj *types.Var + findLocal := func(pos token.Pos, guardName string) (fieldGuardResolver, bool) { + // Try to resolve from the local structure first. + fl, ok := pc.findFieldList(pos, structType, strings.Split(guardName, "."), true /* exclusive */) + if ok { + // Found a valid resolution. + return &fieldGuard{ + FieldList: fl, + }, true } + // Attempt a global resolution. + return pc.findGlobalFieldGuard(pos, guardName) + } + for i, field := range ss.Fields.List { + var lgf lockGuardFacts + fieldObj = structType.Field(i) // N.B. Captured above. + pc.fillLockGuardFacts(fieldObj, field.Doc, findLocal, &lgf) + // See above, for anonymous structure fields. if ss, ok := field.Type.(*ast.StructType); ok { if st, ok := fieldObj.Type().(*types.Struct); ok { - pc.exportLockGuardFacts(st, ss) + pc.structLockGuardFacts(st, ss) } } } } +// globalLockGuardFacts finds all relevant guard information for globals. +// +// Note that the Type is checked in checklocks.go at the top-level. +func (pc *passContext) globalLockGuardFacts(vs *ast.ValueSpec) { + var lgf lockGuardFacts + globalObj := pc.pass.TypesInfo.ObjectOf(vs.Names[0]) + pc.fillLockGuardFacts(globalObj, vs.Doc, pc.findGlobalFieldGuard, &lgf) +} + // countFields gives an accurate field count, according for unnamed arguments // and return values and the compact identifier format. func countFields(fl []*ast.Field) (count int) { @@ -566,89 +696,105 @@ func countFields(fl []*ast.Field) (count int) { } // matchFieldList attempts to match the given field. -func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string, exclusive bool) (functionGuard, bool) { +// +// This function may or may not report an error. This is indicated in the +// reported return value. If reported is true, then the specification is +// ambiguous or not valid, and should be propagated. +func (pc *passContext) matchFieldList(pos token.Pos, fields []*ast.Field, guardName string, exclusive bool) (number int, fl fieldList, reported, ok bool) { parts := strings.Split(guardName, ".") - parameterName := parts[0] - parameterNumber := 0 - for _, field := range fl { + firstName := parts[0] + index := 0 + for _, field := range fields { // See countFields, above. if len(field.Names) == 0 { - parameterNumber++ + index++ continue } for _, name := range field.Names { - if name.Name != parameterName { - parameterNumber++ + if name.Name != firstName { + index++ continue } - ptrType, ok := pc.pass.TypesInfo.TypeOf(field.Type).Underlying().(*types.Pointer) - if !ok { - // Since mutexes cannot be copied we only care - // about parameters that are pointer types when - // checking for guards. - pc.maybeFail(pos, "parameter name %s does not refer to a pointer type", parameterName) - return functionGuard{}, false - } - structType, ok := ptrType.Elem().Underlying().(*types.Struct) + obj := pc.pass.TypesInfo.ObjectOf(name) + fl, ok := pc.maybeFindFieldList(pos, obj, parts[1:], exclusive) if !ok { - // Fields can only be in named structures. - pc.maybeFail(pos, "parameter name %s does not refer to a pointer to a struct", parameterName) - return functionGuard{}, false + // Some intermediate name does not match. The + // resolveField function will not report. + pc.maybeFail(pos, "name %s does not resolve to a field", guardName) + return 0, nil, true, false } - fg := functionGuard{ - ParameterNumber: parameterNumber, - Exclusive: exclusive, - } - fl, ok := pc.resolveField(pos, structType, parts[1:], exclusive) - fg.FieldList = fl - return fg, ok // If ok is false, already failed. + // Successfully found a field. + return index, fl, false, true } } - return functionGuard{}, false + + // Nothing matching. + return 0, nil, false, false } // findFunctionGuard identifies the parameter number and field number for a // particular string of the 'a.b'. // // This function will report any errors directly. -func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) { - var ( - parameterList []*ast.Field - returnList []*ast.Field - ) +func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuardInfo, bool) { + // Match against receiver & parameters. + var parameterList []*ast.Field if d.Recv != nil { parameterList = append(parameterList, d.Recv.List...) } if d.Type.Params != nil { parameterList = append(parameterList, d.Type.Params.List...) } - if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName, exclusive); ok { - return fg, ok + if index, fl, reported, ok := pc.matchFieldList(d.Pos(), parameterList, guardName, exclusive); reported || ok { + if !ok { + return functionGuardInfo{}, false + } + return functionGuardInfo{ + Resolver: ¶meterGuard{ + Index: index, + FieldList: fl, + }, + Exclusive: exclusive, + }, true } + + // Match against return values, if allowed. if allowReturn { + var returnList []*ast.Field if d.Type.Results != nil { returnList = append(returnList, d.Type.Results.List...) } - if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName, exclusive); ok { - // Fix this up to apply to the return value, as noted - // in fg.ParameterNumber. For the ssa analysis, we must - // record whether this has multiple results, since - // *ssa.Call indicates: "The Call instruction yields - // the function result if there is exactly one. - // Otherwise it returns a tuple, the components of - // which are accessed via Extract." - fg.ParameterNumber += countFields(parameterList) - fg.NeedsExtract = countFields(returnList) > 1 - return fg, ok + if index, fl, reported, ok := pc.matchFieldList(d.Pos(), returnList, guardName, exclusive); reported || ok { + if !ok { + return functionGuardInfo{}, false + } + return functionGuardInfo{ + Resolver: &returnGuard{ + Index: index, + FieldList: fl, + NeedsExtract: countFields(returnList) > 1, + }, + Exclusive: exclusive, + }, true } } - // We never saw a matching parameter. - pc.maybeFail(d.Pos(), "annotation %s does not have a matching parameter", guardName) - return functionGuard{}, false + + // Match against globals. + if g, ok := pc.findGlobalFunctionGuard(d.Pos(), guardName); ok { + return functionGuardInfo{ + Resolver: g, + Exclusive: exclusive, + }, true + } + + // No match found. + pc.maybeFail(d.Pos(), "annotation %s does not have a match any parameter, return value or global", guardName) + return functionGuardInfo{}, false } -// exportFunctionFacts exports relevant function findings. -func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) { +// functionFacts exports relevant function findings. +func (pc *passContext) functionFacts(d *ast.FuncDecl) { + // Extract guard information. if d.Doc == nil || d.Doc.List == nil { return } @@ -679,3 +825,12 @@ func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) { pc.pass.ExportObjectFact(funcObj, &lff) } } + +func init() { + gob.Register((*returnGuard)(nil)) + gob.Register((*globalGuard)(nil)) + gob.Register((*parameterGuard)(nil)) + gob.Register((*fieldGuard)(nil)) + gob.Register((*fieldStructPtr)(nil)) + gob.Register((*fieldStruct)(nil)) +} diff --git a/tools/checklocks/state.go b/tools/checklocks/state.go index aaf997d79..2de373b27 100644 --- a/tools/checklocks/state.go +++ b/tools/checklocks/state.go @@ -24,20 +24,26 @@ import ( "golang.org/x/tools/go/ssa" ) +// lockInfo describes a held lock. +type lockInfo struct { + exclusive bool + object types.Object +} + // lockState tracks the locking state and aliases. type lockState struct { // lockedMutexes is used to track which mutexes in a given struct are // currently locked. Note that most of the heavy lifting is done by - // valueAsString below, which maps to specific structure fields, etc. + // valueAndObject below, which maps to specific structure fields, etc. // // The value indicates whether this is an exclusive lock. - lockedMutexes map[string]bool + lockedMutexes map[string]lockInfo // stored stores values that have been stored in memory, bound to // FreeVars or passed as Parameterse. stored map[ssa.Value]ssa.Value - // used is a temporary map, used only for valueAsString. It prevents + // used is a temporary map, used only for valueAndObject. It prevents // multiple use of the same memory location. used map[ssa.Value]struct{} @@ -53,7 +59,7 @@ type lockState struct { func newLockState() *lockState { refs := int32(1) // Not shared. return &lockState{ - lockedMutexes: make(map[string]bool), + lockedMutexes: make(map[string]lockInfo), used: make(map[ssa.Value]struct{}), stored: make(map[ssa.Value]ssa.Value), defers: make([]*ssa.Defer, 0), @@ -81,7 +87,7 @@ func (l *lockState) fork() *lockState { func (l *lockState) modify() { if atomic.LoadInt32(l.refs) > 1 { // Copy the lockedMutexes. - lm := make(map[string]bool) + lm := make(map[string]lockInfo) for k, v := range l.lockedMutexes { lm[k] = v } @@ -110,17 +116,19 @@ func (l *lockState) modify() { } // isHeld indicates whether the field is held is not. +// +// Precondition: rv must be valid. func (l *lockState) isHeld(rv resolvedValue, exclusiveRequired bool) (string, bool) { - if !rv.valid { - return rv.valueAsString(l), false + if !rv.valid() { + panic("invalid resolvedValue passed to isHeld") } - s := rv.valueAsString(l) - isExclusive, ok := l.lockedMutexes[s] + s, _ := rv.valueAndObject(l) + info, ok := l.lockedMutexes[s] if !ok { return s, false } // Accept a weaker lock if exclusiveRequired is false. - if exclusiveRequired && !isExclusive { + if exclusiveRequired && !info.exclusive { return s, false } return s, true @@ -129,32 +137,39 @@ func (l *lockState) isHeld(rv resolvedValue, exclusiveRequired bool) (string, bo // lockField locks the given field. // // If false is returned, the field was already locked. +// +// Precondition: rv must be valid. func (l *lockState) lockField(rv resolvedValue, exclusive bool) (string, bool) { - if !rv.valid { - return rv.valueAsString(l), false + if !rv.valid() { + panic("invalid resolvedValue passed to isHeld") } - s := rv.valueAsString(l) + s, obj := rv.valueAndObject(l) if _, ok := l.lockedMutexes[s]; ok { return s, false } l.modify() - l.lockedMutexes[s] = exclusive + l.lockedMutexes[s] = lockInfo{ + exclusive: exclusive, + object: obj, + } return s, true } // unlockField unlocks the given field. // // If false is returned, the field was not locked. +// +// Precondition: rv must be valid. func (l *lockState) unlockField(rv resolvedValue, exclusive bool) (string, bool) { - if !rv.valid { - return rv.valueAsString(l), false + if !rv.valid() { + panic("invalid resolvedValue passed to isHeld") } - s := rv.valueAsString(l) - wasExclusive, ok := l.lockedMutexes[s] + s, _ := rv.valueAndObject(l) + info, ok := l.lockedMutexes[s] if !ok { return s, false } - if wasExclusive != exclusive { + if info.exclusive != exclusive { return s, false } l.modify() @@ -165,20 +180,23 @@ func (l *lockState) unlockField(rv resolvedValue, exclusive bool) (string, bool) // downgradeField downgrades the given field. // // If false was returned, the field was not downgraded. +// +// Precondition: rv must be valid. func (l *lockState) downgradeField(rv resolvedValue) (string, bool) { - if !rv.valid { - return rv.valueAsString(l), false + if !rv.valid() { + panic("invalid resolvedValue passed to isHeld") } - s := rv.valueAsString(l) - wasExclusive, ok := l.lockedMutexes[s] + s, _ := rv.valueAndObject(l) + info, ok := l.lockedMutexes[s] if !ok { return s, false } - if !wasExclusive { + if !info.exclusive { return s, false } l.modify() - l.lockedMutexes[s] = false // Downgraded. + info.exclusive = false + l.lockedMutexes[s] = info // Downgraded. return s, true } @@ -190,13 +208,13 @@ func (l *lockState) store(addr ssa.Value, v ssa.Value) { // isSubset indicates other holds all the locks held by l. func (l *lockState) isSubset(other *lockState) bool { - for k, isExclusive := range l.lockedMutexes { - otherExclusive, otherOk := other.lockedMutexes[k] + for k, info := range l.lockedMutexes { + otherInfo, otherOk := other.lockedMutexes[k] if !otherOk { return false } // Accept weaker locks as a subset. - if isExclusive && !otherExclusive { + if info.exclusive && !otherInfo.exclusive { return false } } @@ -218,25 +236,26 @@ type elemType interface { Elem() types.Type } -// valueAsString returns a string for a given value. +// valueAndObject returns a string for a given value, along with a source level +// object (if available and relevant). // // This decomposes the value into the simplest possible representation in terms // of parameters, free variables and globals. During resolution, stored values // may be transferred, as well as bound free variables. // // Nil may not be passed here. -func (l *lockState) valueAsString(v ssa.Value) string { +func (l *lockState) valueAndObject(v ssa.Value) (string, types.Object) { switch x := v.(type) { case *ssa.Parameter: // Was this provided as a paramter for a local anonymous // function invocation? v, ok := l.stored[x] if ok { - return l.valueAsString(v) + return l.valueAndObject(v) } - return fmt.Sprintf("{param:%s}", x.Name()) + return fmt.Sprintf("{param:%s}", x.Name()), x.Object() case *ssa.Global: - return fmt.Sprintf("{global:%s}", x.Name()) + return fmt.Sprintf("{global:%s}", x.Name()), x.Object() case *ssa.FreeVar: // Attempt to resolve this, in case we are being invoked in a // scope where all the variables are bound. @@ -247,16 +266,18 @@ func (l *lockState) valueAsString(v ssa.Value) string { // may map to the same FreeVar, which we can check. stored, ok := l.stored[v] if ok { - return l.valueAsString(stored) + return l.valueAndObject(stored) } } - return fmt.Sprintf("{freevar:%s}", x.Name()) + // FreeVar does not have a corresponding source-level object + // that we can return here. + return fmt.Sprintf("{freevar:%s}", x.Name()), nil case *ssa.Convert: // Just disregard conversion. - return l.valueAsString(x.X) + return l.valueAndObject(x.X) case *ssa.ChangeType: // Ditto, disregard. - return l.valueAsString(x.X) + return l.valueAndObject(x.X) case *ssa.UnOp: if x.Op != token.MUL { break @@ -264,7 +285,7 @@ func (l *lockState) valueAsString(v ssa.Value) string { // Is this loading a free variable? If yes, then this can be // resolved in the original isAlias function. if fv, ok := x.X.(*ssa.FreeVar); ok { - return l.valueAsString(fv) + return l.valueAndObject(fv) } // Should be try to resolve via a memory address? This needs to // be done since a memory location can hold its own value. @@ -275,12 +296,13 @@ func (l *lockState) valueAsString(v ssa.Value) string { if ok { l.used[x.X] = struct{}{} defer func() { delete(l.used, x.X) }() - return l.valueAsString(v) + return l.valueAndObject(v) } } // x.X.Type is pointer. We must construct this type // dynamically, since the ssa.Value could be synthetic. - return fmt.Sprintf("*(%s)", l.valueAsString(x.X)) + s, obj := l.valueAndObject(x.X) + return fmt.Sprintf("*(%s)", s), obj case *ssa.Field: structType, ok := resolveStruct(x.X.Type()) if !ok { @@ -288,7 +310,8 @@ func (l *lockState) valueAsString(v ssa.Value) string { panic(fmt.Sprintf("structType not available for struct: %#v", x.X)) } fieldObj := structType.Field(x.Field) - return fmt.Sprintf("%s.%s", l.valueAsString(x.X), fieldObj.Name()) + s, _ := l.valueAndObject(x.X) + return fmt.Sprintf("%s.%s", s, fieldObj.Name()), fieldObj case *ssa.FieldAddr: structType, ok := resolveStruct(x.X.Type()) if !ok { @@ -296,22 +319,30 @@ func (l *lockState) valueAsString(v ssa.Value) string { panic(fmt.Sprintf("structType not available for struct: %#v", x.X)) } fieldObj := structType.Field(x.Field) - return fmt.Sprintf("&(%s.%s)", l.valueAsString(x.X), fieldObj.Name()) + s, _ := l.valueAndObject(x.X) + return fmt.Sprintf("&(%s.%s)", s, fieldObj.Name()), fieldObj case *ssa.Index: - return fmt.Sprintf("%s[%s]", l.valueAsString(x.X), l.valueAsString(x.Index)) + s, _ := l.valueAndObject(x.X) + i, _ := l.valueAndObject(x.Index) + return fmt.Sprintf("%s[%s]", s, i), nil case *ssa.IndexAddr: - return fmt.Sprintf("&(%s[%s])", l.valueAsString(x.X), l.valueAsString(x.Index)) + s, _ := l.valueAndObject(x.X) + i, _ := l.valueAndObject(x.Index) + return fmt.Sprintf("&(%s[%s])", s, i), nil case *ssa.Lookup: - return fmt.Sprintf("%s[%s]", l.valueAsString(x.X), l.valueAsString(x.Index)) + s, _ := l.valueAndObject(x.X) + i, _ := l.valueAndObject(x.Index) + return fmt.Sprintf("%s[%s]", s, i), nil case *ssa.Extract: - return fmt.Sprintf("%s[%d]", l.valueAsString(x.Tuple), x.Index) + s, _ := l.valueAndObject(x.Tuple) + return fmt.Sprintf("%s[%d]", s, x.Index), nil } // In the case of any other type (e.g. this may be an alloc, a return // value, etc.), just return the literal pointer value to the Value. // This will be unique within the ssa graph, and so if two values are // equal, they are from the same type. - return fmt.Sprintf("{%T:%p}", v, v) + return fmt.Sprintf("{%T:%p}", v, v), nil } // String returns the full lock state. @@ -320,9 +351,9 @@ func (l *lockState) String() string { return "no locks held" } keys := make([]string, 0, len(l.lockedMutexes)) - for k, exclusive := range l.lockedMutexes { + for k, info := range l.lockedMutexes { // Include the exclusive status of each lock. - keys = append(keys, fmt.Sprintf("%s %s", k, exclusiveStr(exclusive))) + keys = append(keys, fmt.Sprintf("%s %s", k, exclusiveStr(info.exclusive))) } return strings.Join(keys, ",") } diff --git a/tools/checklocks/test/BUILD b/tools/checklocks/test/BUILD index 4b90731f5..21a68fbdf 100644 --- a/tools/checklocks/test/BUILD +++ b/tools/checklocks/test/BUILD @@ -13,7 +13,9 @@ go_library( "branches.go", "closures.go", "defer.go", + "globals.go", "incompat.go", + "inferred.go", "locker.go", "methods.go", "parameters.go", @@ -21,4 +23,8 @@ go_library( "rwmutex.go", "test.go", ], + # This ensures that there are no dependencies, since we want to explicitly + # control expected failures for analysis. + marshal = False, + stateify = False, ) diff --git a/tools/checklocks/test/globals.go b/tools/checklocks/test/globals.go new file mode 100644 index 000000000..656b0c9a3 --- /dev/null +++ b/tools/checklocks/test/globals.go @@ -0,0 +1,85 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "sync" +) + +var ( + globalMu sync.Mutex + globalRWMu sync.RWMutex +) + +var globalStruct struct { + mu sync.Mutex + // +checklocks:mu + guardedField int +} + +var otherStruct struct { + // +checklocks:globalMu + guardedField1 int + // +checklocks:globalRWMu + guardedField2 int + // +checklocks:globalStruct.mu + guardedField3 int +} + +func testGlobalValid() { + globalMu.Lock() + otherStruct.guardedField1 = 1 + globalMu.Unlock() + + globalRWMu.Lock() + otherStruct.guardedField2 = 1 + globalRWMu.Unlock() + + globalRWMu.RLock() + _ = otherStruct.guardedField2 + globalRWMu.RUnlock() + + globalStruct.mu.Lock() + globalStruct.guardedField = 1 + otherStruct.guardedField3 = 1 + globalStruct.mu.Unlock() +} + +// +checklocks:globalStruct.mu +func testGlobalValidPreconditions0() { + globalStruct.guardedField = 1 +} + +// +checklocks:globalMu +func testGlobalValidPreconditions1() { + otherStruct.guardedField1 = 1 +} + +// +checklocks:globalRWMu +func testGlobalValidPreconditions2() { + otherStruct.guardedField2 = 1 +} + +// +checklocks:globalStruct.mu +func testGlobalValidPreconditions3() { + otherStruct.guardedField3 = 1 +} + +func testGlobalInvalid() { + globalStruct.guardedField = 1 // +checklocksfail + otherStruct.guardedField1 = 1 // +checklocksfail + otherStruct.guardedField2 = 1 // +checklocksfail + otherStruct.guardedField3 = 1 // +checklocksfail +} diff --git a/tools/checklocks/test/inferred.go b/tools/checklocks/test/inferred.go new file mode 100644 index 000000000..5495bdb2a --- /dev/null +++ b/tools/checklocks/test/inferred.go @@ -0,0 +1,35 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "sync" +) + +type inferredStruct struct { + mu sync.Mutex + guardedField int // +checklocksfail + unguardedField int +} + +func testInferredPositive(tc *inferredStruct) { + tc.mu.Lock() + tc.guardedField = 1 + tc.mu.Unlock() +} + +func testInferredNegative(tc *inferredStruct) { + tc.unguardedField = 1 +} diff --git a/tools/checklocks/test/methods.go b/tools/checklocks/test/methods.go index 72e26fca6..b67657b61 100644 --- a/tools/checklocks/test/methods.go +++ b/tools/checklocks/test/methods.go @@ -103,7 +103,7 @@ type testMethodsWithEmbedded struct { // +checklocks:mu guardedField int - p *testMethodsWithParameters + p *testMethodsWithParameters // +checklocksignore: Inferred as protected by mu. } // +checklocks:t.mu diff --git a/tools/checklocks/test/test.go b/tools/checklocks/test/test.go index cbf6b1635..d1a9992fb 100644 --- a/tools/checklocks/test/test.go +++ b/tools/checklocks/test/test.go @@ -51,7 +51,7 @@ type twoLocksStruct struct { // twoLocksDoubleGuardStruct has two locks and a single field with two guards. type twoLocksDoubleGuardStruct struct { mu sync.Mutex - secondMu sync.Mutex + secondMu sync.Mutex // +checklocksignore: mu is inferred as requisite. // +checklocks:mu // +checklocks:secondMu doubleGuardedField int -- cgit v1.2.3