diff options
-rw-r--r-- | pkg/sentry/fsimpl/gofer/filesystem.go | 2 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/overlay/filesystem.go | 2 | ||||
-rw-r--r-- | pkg/sentry/fsimpl/verity/filesystem.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 2 | ||||
-rw-r--r-- | tools/checklocks/analysis.go | 65 | ||||
-rw-r--r-- | tools/checklocks/annotations.go | 17 | ||||
-rw-r--r-- | tools/checklocks/facts.go | 71 | ||||
-rw-r--r-- | tools/checklocks/state.go | 101 | ||||
-rw-r--r-- | tools/checklocks/test/BUILD | 1 | ||||
-rw-r--r-- | tools/checklocks/test/basics.go | 6 | ||||
-rw-r--r-- | tools/checklocks/test/rwmutex.go | 52 |
15 files changed, 231 insertions, 108 deletions
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index f7b3446d3..cf6b34cbf 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -170,7 +170,7 @@ func putDentrySlice(ds *[]*dentry) { // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. -// +checklocksrelease:fs.renameMu +// +checklocksreleaseread:fs.renameMu func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **[]*dentry) { fs.renameMu.RUnlock() if *dsp == nil { diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 3b3dcf836..044902241 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -86,7 +86,7 @@ func putDentrySlice(ds *[]*dentry) { // fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this. // -// +checklocksrelease:fs.renameMu +// +checklocksreleaseread:fs.renameMu func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*dentry) { fs.renameMu.RUnlock() if *dsp == nil { diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 52d47994d..8b059aa7d 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -74,7 +74,7 @@ func putDentrySlice(ds *[]*dentry) { // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. -// +checklocksrelease:fs.renameMu +// +checklocksreleaseread:fs.renameMu func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { fs.renameMu.RUnlock() if *ds == nil { diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 046679f76..eee0fc20c 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -146,7 +146,6 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // -// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. // +checklocks:cn.mu func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) { if pkt.TransportProtocolNumber != header.TCPProtocolNumber { @@ -304,7 +303,7 @@ func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple { return bkt.connForTIDRLocked(tid, now) } -// +checklocks:bkt.mu +// +checklocksread:bkt.mu func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple { for other := bkt.tuples.Front(); other != nil; other = other.Next() { if tid == other.id() && !other.conn.timedOut(now) { @@ -591,8 +590,7 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // returns whether the tuple's connection has timed out. // // Precondition: ct.mu is read locked and bkt.mu is write locked. -// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. -// +checklocks:ct.mu +// +checklocksread:ct.mu // +checklocks:bkt.mu func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now time.Time) bool { if !tuple.conn.timedOut(now) { @@ -621,7 +619,6 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now t return true } -// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements. // +checklocks:b.mu func removeConnFromBucket(b *bucket, tuple *tuple) { if tuple.reply { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 542d9257c..3474c292a 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -71,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { // descending order of match quality. If a call to yield returns false, // iterEndpointsLocked stops iteration and returns immediately. // -// +checklocks:eps.mu +// +checklocksread:eps.mu func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -112,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in // descending order of match quality. // -// +checklocks:eps.mu +// +checklocksread:eps.mu func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { var matchedEPs []*endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -124,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) [] // findEndpointLocked returns the endpoint that most closely matches the given id. // -// +checklocks:eps.mu +// +checklocksread:eps.mu func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { var matchedEP *endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 31579a896..995f58616 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -200,7 +200,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -// +checklocks:e.mu +// +checklocksread:e.mu func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.net.State() { case transport.DatagramEndpointStateInitial: diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index e3094f59f..fb31e5104 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -363,8 +363,7 @@ func (e *Endpoint) Disconnect() { // configured multicast interface if no interface is specified and the // specified address is a multicast address. // -// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. -// +checklocks:e.mu +// +checklocksread:e.mu func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { localAddr := e.Info().ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 39b1e08c0..077a2325a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -292,7 +292,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -// +checklocks:e.mu +// +checklocksread:e.mu func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.net.State() { case transport.DatagramEndpointStateInitial: diff --git a/tools/checklocks/analysis.go b/tools/checklocks/analysis.go index d3fd797d0..ec0cba7f9 100644 --- a/tools/checklocks/analysis.go +++ b/tools/checklocks/analysis.go @@ -183,8 +183,8 @@ type instructionWithReferrers interface { // checkFieldAccess checks the validity of a field access. // // This also enforces atomicity constraints for fields that must be accessed -// atomically. The parameter isWrite indicates whether this field is used -// downstream for a write operation. +// atomically. The parameter isWrite indicates whether this field is used for +// a write operation. func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj ssa.Value, field int, ls *lockState, isWrite bool) { var ( lff lockFieldFacts @@ -200,13 +200,14 @@ func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj for guardName, fl := range lgf.GuardedBy { guardsFound++ r := fl.resolve(structObj) - if _, ok := ls.isHeld(r); ok { + if _, ok := ls.isHeld(r, isWrite); ok { guardsHeld++ continue } if _, ok := pc.forced[pc.positionKey(inst.Pos())]; ok { - // Mark this as locked, since it has been forced. - ls.lockField(r) + // Mark this as locked, since it has been forced. All + // forces are treated as an exclusive lock. + ls.lockField(r, true /* exclusive */) guardsHeld++ continue } @@ -301,7 +302,7 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction continue } r := fg.resolveCall(call.Common().Args, call.Value()) - if s, ok := ls.unlockField(r); !ok { + if s, ok := ls.unlockField(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { pc.maybeFail(call.Pos(), "attempt to release %s (%s), but not held (locks: %s)", fieldName, s, ls.String()) } @@ -315,7 +316,7 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction } // Acquire the lock per the annotation. r := fg.resolveCall(call.Common().Args, call.Value()) - if s, ok := ls.lockField(r); !ok { + if s, ok := ls.lockField(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { pc.maybeFail(call.Pos(), "attempt to acquire %s (%s), but already held (locks: %s)", fieldName, s, ls.String()) } @@ -323,6 +324,14 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction } } +// exclusiveStr returns a string describing exclusive requirements. +func exclusiveStr(exclusive bool) string { + if exclusive { + return "exclusively" + } + return "non-exclusively" +} + // checkFunctionCall checks preconditions for function calls, and tracks the // lock state by recording relevant calls to sync functions. Note that calls to // atomic functions are tracked by checkFieldAccess by looking directly at the @@ -332,12 +341,12 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff // Check all guards required are held. for fieldName, fg := range lff.HeldOnEntry { r := fg.resolveCall(call.Common().Args, call.Value()) - if s, ok := ls.isHeld(r); !ok { + if s, ok := ls.isHeld(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { - pc.maybeFail(call.Pos(), "must hold %s (%s) to call %s, but not held (locks: %s)", fieldName, s, fn.Name(), ls.String()) + pc.maybeFail(call.Pos(), "must hold %s %s (%s) to call %s, but not held (locks: %s)", fieldName, exclusiveStr(fg.Exclusive), s, fn.Name(), ls.String()) } else { // Force the lock to be acquired. - ls.lockField(r) + ls.lockField(r, fg.Exclusive) } } } @@ -348,19 +357,33 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff // Check if it's a method dispatch for something in the sync package. // See: https://godoc.org/golang.org/x/tools/go/ssa#Function if fn.Package() != nil && fn.Package().Pkg.Name() == "sync" && fn.Signature.Recv() != nil { + isExclusive := false switch fn.Name() { - case "Lock", "RLock": - if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { + case "Lock": + isExclusive = true + fallthrough + case "RLock": + if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { // Double locking a mutex that is already locked. pc.maybeFail(call.Pos(), "%s already locked (locks: %s)", s, ls.String()) } } - case "Unlock", "RUnlock": - if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { + case "Unlock": + isExclusive = true + fallthrough + case "RUnlock": + if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok { if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { // Unlocking something that is already unlocked. - pc.maybeFail(call.Pos(), "%s already unlocked (locks: %s)", s, ls.String()) + pc.maybeFail(call.Pos(), "%s already unlocked or locked differently (locks: %s)", s, ls.String()) + } + } + case "DowngradeLock": + if s, ok := ls.downgradeField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + // Downgrading something that may not be downgraded. + pc.maybeFail(call.Pos(), "%s already unlocked or not exclusive (locks: %s)", s, ls.String()) } } } @@ -531,13 +554,13 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, // Validate held locks. for fieldName, fg := range lff.HeldOnExit { r := fg.resolveStatic(fn, rv) - if s, ok := rls.isHeld(r); !ok { + if s, ok := rls.isHeld(r, fg.Exclusive); !ok { if _, ok := pc.forced[pc.positionKey(rv.Pos())]; !ok { - pc.maybeFail(rv.Pos(), "lock %s (%s) not held (locks: %s)", fieldName, s, rls.String()) + pc.maybeFail(rv.Pos(), "lock %s (%s) not held %s (locks: %s)", fieldName, s, exclusiveStr(fg.Exclusive), rls.String()) failed = true } else { // Force the lock to be acquired. - rls.lockField(r) + rls.lockField(r, fg.Exclusive) } } } @@ -558,7 +581,7 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, if pls := pc.checkBasicBlock(fn, succ, lff, ls, seen); pls != nil { if rls != nil && !rls.isCompatible(pls) { if _, ok := pc.forced[pc.positionKey(fn.Pos())]; !ok { - pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %v)", rls.String(), pls.String()) + pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %s)", rls.String(), pls.String()) } } rls = pls @@ -601,11 +624,11 @@ func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *loc // The first is the method object itself so we skip that when looking // for receiver/function parameters. r := fg.resolveStatic(fn, call.Value()) - if s, ok := ls.lockField(r); !ok { + if s, ok := ls.lockField(r, fg.Exclusive); !ok { // This can only happen if the same value is declared // multiple times, and should be caught by the earlier // fact scanning. Keep it here as a sanity check. - pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times (locks: %s)", fieldName, s, ls.String()) + pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times or differently (locks: %s)", fieldName, s, ls.String()) } } diff --git a/tools/checklocks/annotations.go b/tools/checklocks/annotations.go index 371260980..1f679e5be 100644 --- a/tools/checklocks/annotations.go +++ b/tools/checklocks/annotations.go @@ -23,13 +23,16 @@ import ( ) const ( - checkLocksAnnotation = "// +checklocks:" - checkLocksAcquires = "// +checklocksacquire:" - checkLocksReleases = "// +checklocksrelease:" - checkLocksIgnore = "// +checklocksignore" - checkLocksForce = "// +checklocksforce" - checkLocksFail = "// +checklocksfail" - checkAtomicAnnotation = "// +checkatomic" + checkLocksAnnotation = "// +checklocks:" + checkLocksAnnotationRead = "// +checklocksread:" + checkLocksAcquires = "// +checklocksacquire:" + checkLocksAcquiresRead = "// +checklocksacquireread:" + checkLocksReleases = "// +checklocksrelease:" + checkLocksReleasesRead = "// +checklocksreleaseread:" + checkLocksIgnore = "// +checklocksignore" + checkLocksForce = "// +checklocksforce" + checkLocksFail = "// +checklocksfail" + checkAtomicAnnotation = "// +checkatomic" ) // failData indicates an expected failure. diff --git a/tools/checklocks/facts.go b/tools/checklocks/facts.go index 34c9f5ef1..fd681adc3 100644 --- a/tools/checklocks/facts.go +++ b/tools/checklocks/facts.go @@ -166,6 +166,9 @@ type functionGuard struct { // FieldList is the traversal path to the object. FieldList fieldList + + // Exclusive indicates an exclusive lock is required. + Exclusive bool } // resolveReturn resolves a return value. @@ -262,7 +265,7 @@ type lockFunctionFacts struct { func (*lockFunctionFacts) AFact() {} // checkGuard validates the guardName. -func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) { +func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) { if _, ok := lff.HeldOnEntry[guardName]; ok { pc.maybeFail(d.Pos(), "annotation %s specified more than once, already required", guardName) return functionGuard{}, false @@ -271,13 +274,13 @@ func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guard pc.maybeFail(d.Pos(), "annotation %s specified more than once, already acquired", guardName) return functionGuard{}, false } - fg, ok := pc.findFunctionGuard(d, guardName, allowReturn) + fg, ok := pc.findFunctionGuard(d, guardName, exclusive, allowReturn) return fg, ok } // addGuardedBy adds a field to both HeldOnEntry and HeldOnExit. -func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string) { - if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok { +func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { + if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok { if lff.HeldOnEntry == nil { lff.HeldOnEntry = make(map[string]functionGuard) } @@ -290,8 +293,8 @@ func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, gua } // addAcquires adds a field to HeldOnExit. -func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string) { - if fg, ok := lff.checkGuard(pc, d, guardName, true /* allowReturn */); ok { +func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { + if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, true /* allowReturn */); ok { if lff.HeldOnExit == nil { lff.HeldOnExit = make(map[string]functionGuard) } @@ -300,8 +303,8 @@ func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guar } // addReleases adds a field to HeldOnEntry. -func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string) { - if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok { +func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) { + if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok { if lff.HeldOnEntry == nil { lff.HeldOnEntry = make(map[string]functionGuard) } @@ -310,7 +313,7 @@ func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guar } // fieldListFor returns the fieldList for the given object. -func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool) (int, bool) { +func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool, exclusive bool) (int, bool) { var lff lockFieldFacts if !pc.pass.ImportObjectFact(fieldObj, &lff) { // This should not happen: we export facts for all fields. @@ -318,7 +321,11 @@ func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index } // Check that it is indeed a mutex. if checkMutex && !lff.IsMutex && !lff.IsRWMutex { - pc.maybeFail(pos, "field %s is not a mutex or an rwmutex", fieldName) + pc.maybeFail(pos, "field %s is not a Mutex or an RWMutex", fieldName) + return 0, false + } + if checkMutex && !exclusive && !lff.IsRWMutex { + pc.maybeFail(pos, "field %s must be a RWMutex, but it is not", fieldName) return 0, false } // Return the resolution path. @@ -329,14 +336,14 @@ func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index } // resolveOneField resolves a field in a single struct. -func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool) (fl fieldList, fieldObj types.Object, ok bool) { +func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool, exclusive bool) (fl fieldList, fieldObj types.Object, ok bool) { // Scan to match the next field. for i := 0; i < structType.NumFields(); i++ { fieldObj := structType.Field(i) if fieldObj.Name() != fieldName { continue } - flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex) + flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex, exclusive) if !ok { return nil, nil, false } @@ -357,8 +364,8 @@ func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, // Need to check that there is a resolution path. If there is // no resolution path that's not a failure: we just continue // scanning the next embed to find a match. - flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false) - flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex) + flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false, exclusive) + flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex, exclusive) if okEmbed && okCont { fl = append(fl, flEmbed) fl = append(fl, flCont...) @@ -373,9 +380,9 @@ func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, // // Note that this checks that the final element is a mutex of some kind, and // will fail appropriately. -func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string) (fl fieldList, ok bool) { +func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string, exclusive bool) (fl fieldList, ok bool) { for partNumber, fieldName := range parts { - flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */) + flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */, exclusive) if !ok { // Error already reported. return nil, false @@ -474,16 +481,22 @@ func (pc *passContext) exportLockGuardFacts(structType *types.Struct, ss *ast.St pc.maybeFail(fieldObj.Pos(), "annotation %s specified more than once", guardName) return } - fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, ".")) + fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, "."), true /* exclusive */) if ok { - // If we successfully resolved - // the field, then save it. + // If we successfully resolved the field, then save it. if lgf.GuardedBy == nil { lgf.GuardedBy = make(map[string]fieldList) } lgf.GuardedBy[guardName] = fl } }, + // N.B. We support only the vanilla + // annotation on individual fields. If + // the field is a read lock, then we + // will allow read access by default. + checkLocksAnnotationRead: func(guardName string) { + pc.maybeFail(fieldObj.Pos(), "annotation %s not legal on fields", guardName) + }, }) } // Save only if there is something meaningful. @@ -514,7 +527,7 @@ func countFields(fl []*ast.Field) (count int) { } // matchFieldList attempts to match the given field. -func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string) (functionGuard, bool) { +func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string, exclusive bool) (functionGuard, bool) { parts := strings.Split(guardName, ".") parameterName := parts[0] parameterNumber := 0 @@ -545,8 +558,9 @@ func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName } fg := functionGuard{ ParameterNumber: parameterNumber, + Exclusive: exclusive, } - fl, ok := pc.resolveField(pos, structType, parts[1:]) + fl, ok := pc.resolveField(pos, structType, parts[1:], exclusive) fg.FieldList = fl return fg, ok // If ok is false, already failed. } @@ -558,7 +572,7 @@ func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName // particular string of the 'a.b'. // // This function will report any errors directly. -func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) { +func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) { var ( parameterList []*ast.Field returnList []*ast.Field @@ -569,14 +583,14 @@ func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allo if d.Type.Params != nil { parameterList = append(parameterList, d.Type.Params.List...) } - if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName); ok { + if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName, exclusive); ok { return fg, ok } if allowReturn { if d.Type.Results != nil { returnList = append(returnList, d.Type.Results.List...) } - if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName); ok { + if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName, exclusive); ok { // Fix this up to apply to the return value, as noted // in fg.ParameterNumber. For the ssa analysis, we must // record whether this has multiple results, since @@ -610,9 +624,12 @@ func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) { // extremely rare. lff.Ignore = true }, - checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName) }, - checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName) }, - checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName) }, + checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName, true /* exclusive */) }, + checkLocksAnnotationRead: func(guardName string) { lff.addGuardedBy(pc, d, guardName, false /* exclusive */) }, + checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName, true /* exclusive */) }, + checkLocksAcquiresRead: func(guardName string) { lff.addAcquires(pc, d, guardName, false /* exclusive */) }, + checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName, true /* exclusive */) }, + checkLocksReleasesRead: func(guardName string) { lff.addReleases(pc, d, guardName, false /* exclusive */) }, }) } diff --git a/tools/checklocks/state.go b/tools/checklocks/state.go index 57061a32e..aaf997d79 100644 --- a/tools/checklocks/state.go +++ b/tools/checklocks/state.go @@ -29,7 +29,9 @@ type lockState struct { // lockedMutexes is used to track which mutexes in a given struct are // currently locked. Note that most of the heavy lifting is done by // valueAsString below, which maps to specific structure fields, etc. - lockedMutexes []string + // + // The value indicates whether this is an exclusive lock. + lockedMutexes map[string]bool // stored stores values that have been stored in memory, bound to // FreeVars or passed as Parameterse. @@ -51,7 +53,7 @@ type lockState struct { func newLockState() *lockState { refs := int32(1) // Not shared. return &lockState{ - lockedMutexes: make([]string, 0), + lockedMutexes: make(map[string]bool), used: make(map[ssa.Value]struct{}), stored: make(map[ssa.Value]ssa.Value), defers: make([]*ssa.Defer, 0), @@ -79,8 +81,10 @@ func (l *lockState) fork() *lockState { func (l *lockState) modify() { if atomic.LoadInt32(l.refs) > 1 { // Copy the lockedMutexes. - lm := make([]string, len(l.lockedMutexes)) - copy(lm, l.lockedMutexes) + lm := make(map[string]bool) + for k, v := range l.lockedMutexes { + lm[k] = v + } l.lockedMutexes = lm // Copy the stored values. @@ -106,55 +110,76 @@ func (l *lockState) modify() { } // isHeld indicates whether the field is held is not. -func (l *lockState) isHeld(rv resolvedValue) (string, bool) { +func (l *lockState) isHeld(rv resolvedValue, exclusiveRequired bool) (string, bool) { if !rv.valid { return rv.valueAsString(l), false } s := rv.valueAsString(l) - for _, k := range l.lockedMutexes { - if k == s { - return s, true - } + isExclusive, ok := l.lockedMutexes[s] + if !ok { + return s, false + } + // Accept a weaker lock if exclusiveRequired is false. + if exclusiveRequired && !isExclusive { + return s, false } - return s, false + return s, true } // lockField locks the given field. // // If false is returned, the field was already locked. -func (l *lockState) lockField(rv resolvedValue) (string, bool) { +func (l *lockState) lockField(rv resolvedValue, exclusive bool) (string, bool) { if !rv.valid { return rv.valueAsString(l), false } s := rv.valueAsString(l) - for _, k := range l.lockedMutexes { - if k == s { - return s, false - } + if _, ok := l.lockedMutexes[s]; ok { + return s, false } l.modify() - l.lockedMutexes = append(l.lockedMutexes, s) + l.lockedMutexes[s] = exclusive return s, true } // unlockField unlocks the given field. // // If false is returned, the field was not locked. -func (l *lockState) unlockField(rv resolvedValue) (string, bool) { +func (l *lockState) unlockField(rv resolvedValue, exclusive bool) (string, bool) { if !rv.valid { return rv.valueAsString(l), false } s := rv.valueAsString(l) - for i, k := range l.lockedMutexes { - if k == s { - // Copy the last lock in and truncate. - l.modify() - l.lockedMutexes[i] = l.lockedMutexes[len(l.lockedMutexes)-1] - l.lockedMutexes = l.lockedMutexes[:len(l.lockedMutexes)-1] - return s, true - } + wasExclusive, ok := l.lockedMutexes[s] + if !ok { + return s, false } - return s, false + if wasExclusive != exclusive { + return s, false + } + l.modify() + delete(l.lockedMutexes, s) + return s, true +} + +// downgradeField downgrades the given field. +// +// If false was returned, the field was not downgraded. +func (l *lockState) downgradeField(rv resolvedValue) (string, bool) { + if !rv.valid { + return rv.valueAsString(l), false + } + s := rv.valueAsString(l) + wasExclusive, ok := l.lockedMutexes[s] + if !ok { + return s, false + } + if !wasExclusive { + return s, false + } + l.modify() + l.lockedMutexes[s] = false // Downgraded. + return s, true } // store records an alias. @@ -165,16 +190,17 @@ func (l *lockState) store(addr ssa.Value, v ssa.Value) { // isSubset indicates other holds all the locks held by l. func (l *lockState) isSubset(other *lockState) bool { - held := 0 // Number in l, held by other. - for _, k := range l.lockedMutexes { - for _, ok := range other.lockedMutexes { - if k == ok { - held++ - break - } + for k, isExclusive := range l.lockedMutexes { + otherExclusive, otherOk := other.lockedMutexes[k] + if !otherOk { + return false + } + // Accept weaker locks as a subset. + if isExclusive && !otherExclusive { + return false } } - return held >= len(l.lockedMutexes) + return true } // count indicates the number of locks held. @@ -293,7 +319,12 @@ func (l *lockState) String() string { if l.count() == 0 { return "no locks held" } - return strings.Join(l.lockedMutexes, ",") + keys := make([]string, 0, len(l.lockedMutexes)) + for k, exclusive := range l.lockedMutexes { + // Include the exclusive status of each lock. + keys = append(keys, fmt.Sprintf("%s %s", k, exclusiveStr(exclusive))) + } + return strings.Join(keys, ",") } // pushDefer pushes a defer onto the stack. diff --git a/tools/checklocks/test/BUILD b/tools/checklocks/test/BUILD index d4d98c256..f2ea6c7c6 100644 --- a/tools/checklocks/test/BUILD +++ b/tools/checklocks/test/BUILD @@ -16,6 +16,7 @@ go_library( "methods.go", "parameters.go", "return.go", + "rwmutex.go", "test.go", ], ) diff --git a/tools/checklocks/test/basics.go b/tools/checklocks/test/basics.go index 7a773171f..e941fba5b 100644 --- a/tools/checklocks/test/basics.go +++ b/tools/checklocks/test/basics.go @@ -108,14 +108,14 @@ type rwGuardStruct struct { func testRWValidRead(tc *rwGuardStruct) { tc.rwMu.Lock() - tc.guardedField = 1 + _ = tc.guardedField tc.rwMu.Unlock() } func testRWValidWrite(tc *rwGuardStruct) { - tc.rwMu.RLock() + tc.rwMu.Lock() tc.guardedField = 2 - tc.rwMu.RUnlock() + tc.rwMu.Unlock() } func testRWInvalidWrite(tc *rwGuardStruct) { diff --git a/tools/checklocks/test/rwmutex.go b/tools/checklocks/test/rwmutex.go new file mode 100644 index 000000000..d27ed10e3 --- /dev/null +++ b/tools/checklocks/test/rwmutex.go @@ -0,0 +1,52 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "sync" +) + +// oneReadGuardStruct has one read-guarded field. +type oneReadGuardStruct struct { + mu sync.RWMutex + // +checklocks:mu + guardedField int +} + +func testRWAccessValidRead(tc *oneReadGuardStruct) { + tc.mu.Lock() + _ = tc.guardedField + tc.mu.Unlock() + tc.mu.RLock() + _ = tc.guardedField + tc.mu.RUnlock() +} + +func testRWAccessValidWrite(tc *oneReadGuardStruct) { + tc.mu.Lock() + tc.guardedField = 1 + tc.mu.Unlock() +} + +func testRWAccessInvalidWrite(tc *oneReadGuardStruct) { + tc.guardedField = 2 // +checklocksfail + tc.mu.RLock() + tc.guardedField = 2 // +checklocksfail + tc.mu.RUnlock() +} + +func testRWAccessInvalidRead(tc *oneReadGuardStruct) { + _ = tc.guardedField // +checklocksfail +} |