diff options
Diffstat (limited to 'tools/checklocks')
-rw-r--r-- | tools/checklocks/analysis.go | 65 | ||||
-rw-r--r-- | tools/checklocks/annotations.go | 17 | ||||
-rw-r--r-- | tools/checklocks/facts.go | 71 | ||||
-rw-r--r-- | tools/checklocks/state.go | 101 | ||||
-rw-r--r-- | tools/checklocks/test/BUILD | 1 | ||||
-rw-r--r-- | tools/checklocks/test/basics.go | 6 | ||||
-rw-r--r-- | tools/checklocks/test/rwmutex.go | 52 |
7 files changed, 220 insertions, 93 deletions
diff --git a/tools/checklocks/analysis.go b/tools/checklocks/analysis.go index d3fd797d0..ec0cba7f9 100644 --- a/tools/checklocks/analysis.go +++ b/tools/checklocks/analysis.go @@ -183,8 +183,8 @@ type instructionWithReferrers interface { // checkFieldAccess checks the validity of a field access. // // This also enforces atomicity constraints for fields that must be accessed -// atomically. The parameter isWrite indicates whether this field is used -// downstream for a write operation. +// atomically. The parameter isWrite indicates whether this field is used for +// a write operation. func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj ssa.Value, field int, ls *lockState, isWrite bool) { var ( lff lockFieldFacts @@ -200,13 +200,14 @@ func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj for guardName, fl := range lgf.GuardedBy { guardsFound++ r := fl.resolve(structObj) - if _, ok := ls.isHeld(r); ok { + if _, ok := ls.isHeld(r, isWrite); ok { guardsHeld++ continue } if _, ok := pc.forced[pc.positionKey(inst.Pos())]; ok { - // Mark this as locked, since it has been forced. - ls.lockField(r) + // Mark this as locked, since it has been forced. All + // forces are treated as an exclusive lock. + ls.lockField(r, true /* exclusive */) guardsHeld++ continue } @@ -301,7 +302,7 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction continue } r := fg.resolveCall(call.Common().Args, call.Value()) - if s, ok := ls.unlockField(r); !ok { + if s, ok := ls.unlockField(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { pc.maybeFail(call.Pos(), "attempt to release %s (%s), but not held (locks: %s)", fieldName, s, ls.String()) } @@ -315,7 +316,7 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction } // Acquire the lock per the annotation. r := fg.resolveCall(call.Common().Args, call.Value()) - if s, ok := ls.lockField(r); !ok { + if s, ok := ls.lockField(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { pc.maybeFail(call.Pos(), "attempt to acquire %s (%s), but already held (locks: %s)", fieldName, s, ls.String()) } @@ -323,6 +324,14 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction } } +// exclusiveStr returns a string describing exclusive requirements. +func exclusiveStr(exclusive bool) string { + if exclusive { + return "exclusively" + } + return "non-exclusively" +} + // checkFunctionCall checks preconditions for function calls, and tracks the // lock state by recording relevant calls to sync functions. Note that calls to // atomic functions are tracked by checkFieldAccess by looking directly at the @@ -332,12 +341,12 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff // Check all guards required are held. for fieldName, fg := range lff.HeldOnEntry { r := fg.resolveCall(call.Common().Args, call.Value()) - if s, ok := ls.isHeld(r); !ok { + if s, ok := ls.isHeld(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { - pc.maybeFail(call.Pos(), "must hold %s (%s) to call %s, but not held (locks: %s)", fieldName, s, fn.Name(), ls.String()) + pc.maybeFail(call.Pos(), "must hold %s %s (%s) to call %s, but not held (locks: %s)", fieldName, exclusiveStr(fg.Exclusive), s, fn.Name(), ls.String()) } else { // Force the lock to be acquired. - ls.lockField(r) + ls.lockField(r, fg.Exclusive) } } } @@ -348,19 +357,33 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff // Check if it's a method dispatch for something in the sync package. // See: https://godoc.org/golang.org/x/tools/go/ssa#Function if fn.Package() != nil && fn.Package().Pkg.Name() == "sync" && fn.Signature.Recv() != nil { + isExclusive := false switch fn.Name() { - case "Lock", "RLock": - if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { + case "Lock": + isExclusive = true + fallthrough + case "RLock": + if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { // Double locking a mutex that is already locked. pc.maybeFail(call.Pos(), "%s already locked (locks: %s)", s, ls.String()) } } - case "Unlock", "RUnlock": - if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { + case "Unlock": + isExclusive = true + fallthrough + case "RUnlock": + if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { // Unlocking something that is already unlocked. - pc.maybeFail(call.Pos(), "%s already unlocked (locks: %s)", s, ls.String()) + pc.maybeFail(call.Pos(), "%s already unlocked or locked differently (locks: %s)", s, ls.String()) + } + } + case "DowngradeLock": + if s, ok := ls.downgradeField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + // Downgrading something that may not be downgraded. + pc.maybeFail(call.Pos(), "%s already unlocked or not exclusive (locks: %s)", s, ls.String()) } } } @@ -531,13 +554,13 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, // Validate held locks. for fieldName, fg := range lff.HeldOnExit { r := fg.resolveStatic(fn, rv) - if s, ok := rls.isHeld(r); !ok { + if s, ok := rls.isHeld(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(rv.Pos())]; !ok { - pc.maybeFail(rv.Pos(), "lock %s (%s) not held (locks: %s)", fieldName, s, rls.String()) + pc.maybeFail(rv.Pos(), "lock %s (%s) not held %s (locks: %s)", fieldName, s, exclusiveStr(fg.Exclusive), rls.String()) failed = true } else { // Force the lock to be acquired. - rls.lockField(r) + rls.lockField(r, fg.Exclusive) } } } @@ -558,7 +581,7 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, if pls := pc.checkBasicBlock(fn, succ, lff, ls, seen); pls != nil { if rls != nil && !rls.isCompatible(pls) { if _, ok := pc.forced[pc.positionKey(fn.Pos())]; !ok { - pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %v)", rls.String(), pls.String()) + pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %s)", rls.String(), pls.String()) } } rls = pls @@ -601,11 +624,11 @@ func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *loc // The first is the method object itself so we skip that when looking // for receiver/function parameters. r := fg.resolveStatic(fn, call.Value()) - if s, ok := ls.lockField(r); !ok { + if s, ok := ls.lockField(r, fg.Exclusive); !ok { // This can only happen if the same value is declared // multiple times, and should be caught by the earlier // fact scanning. Keep it here as a sanity check. - pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times (locks: %s)", fieldName, s, ls.String()) + pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times or differently (locks: %s)", fieldName, s, ls.String()) } } diff --git a/tools/checklocks/annotations.go b/tools/checklocks/annotations.go index 371260980..1f679e5be 100644 --- a/tools/checklocks/annotations.go +++ b/tools/checklocks/annotations.go @@ -23,13 +23,16 @@ import ( ) const ( - checkLocksAnnotation = "// +checklocks:" - checkLocksAcquires = "// +checklocksacquire:" - checkLocksReleases = "// +checklocksrelease:" - checkLocksIgnore = "// +checklocksignore" - checkLocksForce = "// +checklocksforce" - checkLocksFail = "// +checklocksfail" - checkAtomicAnnotation = "// +checkatomic" + checkLocksAnnotation = "// +checklocks:" + checkLocksAnnotationRead = "// +checklocksread:" + checkLocksAcquires = "// +checklocksacquire:" + checkLocksAcquiresRead = "// +checklocksacquireread:" + checkLocksReleases = "// +checklocksrelease:" + checkLocksReleasesRead = "// +checklocksreleaseread:" + checkLocksIgnore = "// +checklocksignore" + checkLocksForce = "// +checklocksforce" + checkLocksFail = "// +checklocksfail" + checkAtomicAnnotation = "// +checkatomic" ) // failData indicates an expected failure. diff --git a/tools/checklocks/facts.go b/tools/checklocks/facts.go index 34c9f5ef1..fd681adc3 100644 --- a/tools/checklocks/facts.go +++ b/tools/checklocks/facts.go @@ -166,6 +166,9 @@ type functionGuard struct { // FieldList is the traversal path to the object. FieldList fieldList + + // Exclusive indicates an exclusive lock is required. + Exclusive bool } // resolveReturn resolves a return value. @@ -262,7 +265,7 @@ type lockFunctionFacts struct { func (*lockFunctionFacts) AFact() {} // checkGuard validates the guardName. -func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) { +func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) { if _, ok := lff.HeldOnEntry[guardName]; ok { pc.maybeFail(d.Pos(), "annotation %s specified more than once, already required", guardName) return functionGuard{}, false @@ -271,13 +274,13 @@ func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guard pc.maybeFail(d.Pos(), "annotation %s specified more than once, already acquired", guardName) return functionGuard{}, false } - fg, ok := pc.findFunctionGuard(d, guardName, allowReturn) + fg, ok := pc.findFunctionGuard(d, guardName, exclusive, allowReturn) return fg, ok } // addGuardedBy adds a field to both HeldOnEntry and HeldOnExit. -func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string) { - if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok { +func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { + if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok { if lff.HeldOnEntry == nil { lff.HeldOnEntry = make(map[string]functionGuard) } @@ -290,8 +293,8 @@ func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, gua } // addAcquires adds a field to HeldOnExit. -func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string) { - if fg, ok := lff.checkGuard(pc, d, guardName, true /* allowReturn */); ok { +func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { + if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, true /* allowReturn */); ok { if lff.HeldOnExit == nil { lff.HeldOnExit = make(map[string]functionGuard) } @@ -300,8 +303,8 @@ func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guar } // addReleases adds a field to HeldOnEntry. -func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string) { - if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok { +func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { + if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok { if lff.HeldOnEntry == nil { lff.HeldOnEntry = make(map[string]functionGuard) } @@ -310,7 +313,7 @@ func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guar } // fieldListFor returns the fieldList for the given object. -func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool) (int, bool) { +func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool, exclusive bool) (int, bool) { var lff lockFieldFacts if !pc.pass.ImportObjectFact(fieldObj, &lff) { // This should not happen: we export facts for all fields. @@ -318,7 +321,11 @@ func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index } // Check that it is indeed a mutex. if checkMutex && !lff.IsMutex && !lff.IsRWMutex { - pc.maybeFail(pos, "field %s is not a mutex or an rwmutex", fieldName) + pc.maybeFail(pos, "field %s is not a Mutex or an RWMutex", fieldName) + return 0, false + } + if checkMutex && !exclusive && !lff.IsRWMutex { + pc.maybeFail(pos, "field %s must be a RWMutex, but it is not", fieldName) return 0, false } // Return the resolution path. @@ -329,14 +336,14 @@ func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index } // resolveOneField resolves a field in a single struct. -func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool) (fl fieldList, fieldObj types.Object, ok bool) { +func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool, exclusive bool) (fl fieldList, fieldObj types.Object, ok bool) { // Scan to match the next field. for i := 0; i < structType.NumFields(); i++ { fieldObj := structType.Field(i) if fieldObj.Name() != fieldName { continue } - flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex) + flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex, exclusive) if !ok { return nil, nil, false } @@ -357,8 +364,8 @@ func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, // Need to check that there is a resolution path. If there is // no resolution path that's not a failure: we just continue // scanning the next embed to find a match. - flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false) - flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex) + flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false, exclusive) + flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex, exclusive) if okEmbed && okCont { fl = append(fl, flEmbed) fl = append(fl, flCont...) @@ -373,9 +380,9 @@ func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, // // Note that this checks that the final element is a mutex of some kind, and // will fail appropriately. -func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string) (fl fieldList, ok bool) { +func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string, exclusive bool) (fl fieldList, ok bool) { for partNumber, fieldName := range parts { - flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */) + flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */, exclusive) if !ok { // Error already reported. return nil, false @@ -474,16 +481,22 @@ func (pc *passContext) exportLockGuardFacts(structType *types.Struct, ss *ast.St pc.maybeFail(fieldObj.Pos(), "annotation %s specified more than once", guardName) return } - fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, ".")) + fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, "."), true /* exclusive */) if ok { - // If we successfully resolved - // the field, then save it. + // If we successfully resolved the field, then save it. if lgf.GuardedBy == nil { lgf.GuardedBy = make(map[string]fieldList) } lgf.GuardedBy[guardName] = fl } }, + // N.B. We support only the vanilla + // annotation on individual fields. If + // the field is a read lock, then we + // will allow read access by default. + checkLocksAnnotationRead: func(guardName string) { + pc.maybeFail(fieldObj.Pos(), "annotation %s not legal on fields", guardName) + }, }) } // Save only if there is something meaningful. @@ -514,7 +527,7 @@ func countFields(fl []*ast.Field) (count int) { } // matchFieldList attempts to match the given field. -func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string) (functionGuard, bool) { +func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string, exclusive bool) (functionGuard, bool) { parts := strings.Split(guardName, ".") parameterName := parts[0] parameterNumber := 0 @@ -545,8 +558,9 @@ func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName } fg := functionGuard{ ParameterNumber: parameterNumber, + Exclusive: exclusive, } - fl, ok := pc.resolveField(pos, structType, parts[1:]) + fl, ok := pc.resolveField(pos, structType, parts[1:], exclusive) fg.FieldList = fl return fg, ok // If ok is false, already failed. } @@ -558,7 +572,7 @@ func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName // particular string of the 'a.b'. // // This function will report any errors directly. -func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) { +func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) { var ( parameterList []*ast.Field returnList []*ast.Field @@ -569,14 +583,14 @@ func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allo if d.Type.Params != nil { parameterList = append(parameterList, d.Type.Params.List...) } - if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName); ok { + if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName, exclusive); ok { return fg, ok } if allowReturn { if d.Type.Results != nil { returnList = append(returnList, d.Type.Results.List...) } - if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName); ok { + if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName, exclusive); ok { // Fix this up to apply to the return value, as noted // in fg.ParameterNumber. For the ssa analysis, we must // record whether this has multiple results, since @@ -610,9 +624,12 @@ func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) { // extremely rare. lff.Ignore = true }, - checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName) }, - checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName) }, - checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName) }, + checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName, true /* exclusive */) }, + checkLocksAnnotationRead: func(guardName string) { lff.addGuardedBy(pc, d, guardName, false /* exclusive */) }, + checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName, true /* exclusive */) }, + checkLocksAcquiresRead: func(guardName string) { lff.addAcquires(pc, d, guardName, false /* exclusive */) }, + checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName, true /* exclusive */) }, + checkLocksReleasesRead: func(guardName string) { lff.addReleases(pc, d, guardName, false /* exclusive */) }, }) } diff --git a/tools/checklocks/state.go b/tools/checklocks/state.go index 57061a32e..aaf997d79 100644 --- a/tools/checklocks/state.go +++ b/tools/checklocks/state.go @@ -29,7 +29,9 @@ type lockState struct { // lockedMutexes is used to track which mutexes in a given struct are // currently locked. Note that most of the heavy lifting is done by // valueAsString below, which maps to specific structure fields, etc. - lockedMutexes []string + // + // The value indicates whether this is an exclusive lock. + lockedMutexes map[string]bool // stored stores values that have been stored in memory, bound to // FreeVars or passed as Parameterse. @@ -51,7 +53,7 @@ type lockState struct { func newLockState() *lockState { refs := int32(1) // Not shared. return &lockState{ - lockedMutexes: make([]string, 0), + lockedMutexes: make(map[string]bool), used: make(map[ssa.Value]struct{}), stored: make(map[ssa.Value]ssa.Value), defers: make([]*ssa.Defer, 0), @@ -79,8 +81,10 @@ func (l *lockState) fork() *lockState { func (l *lockState) modify() { if atomic.LoadInt32(l.refs) > 1 { // Copy the lockedMutexes. - lm := make([]string, len(l.lockedMutexes)) - copy(lm, l.lockedMutexes) + lm := make(map[string]bool) + for k, v := range l.lockedMutexes { + lm[k] = v + } l.lockedMutexes = lm // Copy the stored values. @@ -106,55 +110,76 @@ func (l *lockState) modify() { } // isHeld indicates whether the field is held is not. -func (l *lockState) isHeld(rv resolvedValue) (string, bool) { +func (l *lockState) isHeld(rv resolvedValue, exclusiveRequired bool) (string, bool) { if !rv.valid { return rv.valueAsString(l), false } s := rv.valueAsString(l) - for _, k := range l.lockedMutexes { - if k == s { - return s, true - } + isExclusive, ok := l.lockedMutexes[s] + if !ok { + return s, false + } + // Accept a weaker lock if exclusiveRequired is false. + if exclusiveRequired && !isExclusive { + return s, false } - return s, false + return s, true } // lockField locks the given field. // // If false is returned, the field was already locked. -func (l *lockState) lockField(rv resolvedValue) (string, bool) { +func (l *lockState) lockField(rv resolvedValue, exclusive bool) (string, bool) { if !rv.valid { return rv.valueAsString(l), false } s := rv.valueAsString(l) - for _, k := range l.lockedMutexes { - if k == s { - return s, false - } + if _, ok := l.lockedMutexes[s]; ok { + return s, false } l.modify() - l.lockedMutexes = append(l.lockedMutexes, s) + l.lockedMutexes[s] = exclusive return s, true } // unlockField unlocks the given field. // // If false is returned, the field was not locked. -func (l *lockState) unlockField(rv resolvedValue) (string, bool) { +func (l *lockState) unlockField(rv resolvedValue, exclusive bool) (string, bool) { if !rv.valid { return rv.valueAsString(l), false } s := rv.valueAsString(l) - for i, k := range l.lockedMutexes { - if k == s { - // Copy the last lock in and truncate. - l.modify() - l.lockedMutexes[i] = l.lockedMutexes[len(l.lockedMutexes)-1] - l.lockedMutexes = l.lockedMutexes[:len(l.lockedMutexes)-1] - return s, true - } + wasExclusive, ok := l.lockedMutexes[s] + if !ok { + return s, false } - return s, false + if wasExclusive != exclusive { + return s, false + } + l.modify() + delete(l.lockedMutexes, s) + return s, true +} + +// downgradeField downgrades the given field. +// +// If false was returned, the field was not downgraded. +func (l *lockState) downgradeField(rv resolvedValue) (string, bool) { + if !rv.valid { + return rv.valueAsString(l), false + } + s := rv.valueAsString(l) + wasExclusive, ok := l.lockedMutexes[s] + if !ok { + return s, false + } + if !wasExclusive { + return s, false + } + l.modify() + l.lockedMutexes[s] = false // Downgraded. + return s, true } // store records an alias. @@ -165,16 +190,17 @@ func (l *lockState) store(addr ssa.Value, v ssa.Value) { // isSubset indicates other holds all the locks held by l. func (l *lockState) isSubset(other *lockState) bool { - held := 0 // Number in l, held by other. - for _, k := range l.lockedMutexes { - for _, ok := range other.lockedMutexes { - if k == ok { - held++ - break - } + for k, isExclusive := range l.lockedMutexes { + otherExclusive, otherOk := other.lockedMutexes[k] + if !otherOk { + return false + } + // Accept weaker locks as a subset. + if isExclusive && !otherExclusive { + return false } } - return held >= len(l.lockedMutexes) + return true } // count indicates the number of locks held. @@ -293,7 +319,12 @@ func (l *lockState) String() string { if l.count() == 0 { return "no locks held" } - return strings.Join(l.lockedMutexes, ",") + keys := make([]string, 0, len(l.lockedMutexes)) + for k, exclusive := range l.lockedMutexes { + // Include the exclusive status of each lock. + keys = append(keys, fmt.Sprintf("%s %s", k, exclusiveStr(exclusive))) + } + return strings.Join(keys, ",") } // pushDefer pushes a defer onto the stack. diff --git a/tools/checklocks/test/BUILD b/tools/checklocks/test/BUILD index d4d98c256..f2ea6c7c6 100644 --- a/tools/checklocks/test/BUILD +++ b/tools/checklocks/test/BUILD @@ -16,6 +16,7 @@ go_library( "methods.go", "parameters.go", "return.go", + "rwmutex.go", "test.go", ], ) diff --git a/tools/checklocks/test/basics.go b/tools/checklocks/test/basics.go index 7a773171f..e941fba5b 100644 --- a/tools/checklocks/test/basics.go +++ b/tools/checklocks/test/basics.go @@ -108,14 +108,14 @@ type rwGuardStruct struct { func testRWValidRead(tc *rwGuardStruct) { tc.rwMu.Lock() - tc.guardedField = 1 + _ = tc.guardedField tc.rwMu.Unlock() } func testRWValidWrite(tc *rwGuardStruct) { - tc.rwMu.RLock() + tc.rwMu.Lock() tc.guardedField = 2 - tc.rwMu.RUnlock() + tc.rwMu.Unlock() } func testRWInvalidWrite(tc *rwGuardStruct) { diff --git a/tools/checklocks/test/rwmutex.go b/tools/checklocks/test/rwmutex.go new file mode 100644 index 000000000..d27ed10e3 --- /dev/null +++ b/tools/checklocks/test/rwmutex.go @@ -0,0 +1,52 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "sync" +) + +// oneReadGuardStruct has one read-guarded field. +type oneReadGuardStruct struct { + mu sync.RWMutex + // +checklocks:mu + guardedField int +} + +func testRWAccessValidRead(tc *oneReadGuardStruct) { + tc.mu.Lock() + _ = tc.guardedField + tc.mu.Unlock() + tc.mu.RLock() + _ = tc.guardedField + tc.mu.RUnlock() +} + +func testRWAccessValidWrite(tc *oneReadGuardStruct) { + tc.mu.Lock() + tc.guardedField = 1 + tc.mu.Unlock() +} + +func testRWAccessInvalidWrite(tc *oneReadGuardStruct) { + tc.guardedField = 2 // +checklocksfail + tc.mu.RLock() + tc.guardedField = 2 // +checklocksfail + tc.mu.RUnlock() +} + +func testRWAccessInvalidRead(tc *oneReadGuardStruct) { + _ = tc.guardedField // +checklocksfail +} |