summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go2
-rw-r--r--pkg/tcpip/stack/conntrack.go7
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go6
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go3
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--tools/checklocks/analysis.go65
-rw-r--r--tools/checklocks/annotations.go17
-rw-r--r--tools/checklocks/facts.go71
-rw-r--r--tools/checklocks/state.go101
-rw-r--r--tools/checklocks/test/BUILD1
-rw-r--r--tools/checklocks/test/basics.go6
-rw-r--r--tools/checklocks/test/rwmutex.go52
15 files changed, 231 insertions, 108 deletions
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index f7b3446d3..cf6b34cbf 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -170,7 +170,7 @@ func putDentrySlice(ds *[]*dentry) {
// but dentry slices are allocated lazily, and it's much easier to say "defer
// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
-// +checklocksrelease:fs.renameMu
+// +checklocksreleaseread:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **[]*dentry) {
fs.renameMu.RUnlock()
if *dsp == nil {
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 3b3dcf836..044902241 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -86,7 +86,7 @@ func putDentrySlice(ds *[]*dentry) {
// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
//
-// +checklocksrelease:fs.renameMu
+// +checklocksreleaseread:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*dentry) {
fs.renameMu.RUnlock()
if *dsp == nil {
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 52d47994d..8b059aa7d 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -74,7 +74,7 @@ func putDentrySlice(ds *[]*dentry) {
// but dentry slices are allocated lazily, and it's much easier to say "defer
// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
-// +checklocksrelease:fs.renameMu
+// +checklocksreleaseread:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) {
fs.renameMu.RUnlock()
if *ds == nil {
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 046679f76..eee0fc20c 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -146,7 +146,6 @@ func (cn *conn) timedOut(now time.Time) bool {
// update the connection tracking state.
//
-// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
// +checklocks:cn.mu
func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) {
if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
@@ -304,7 +303,7 @@ func (bkt *bucket) connForTID(tid tupleID, now time.Time) *tuple {
return bkt.connForTIDRLocked(tid, now)
}
-// +checklocks:bkt.mu
+// +checklocksread:bkt.mu
func (bkt *bucket) connForTIDRLocked(tid tupleID, now time.Time) *tuple {
for other := bkt.tuples.Front(); other != nil; other = other.Next() {
if tid == other.id() && !other.conn.timedOut(now) {
@@ -591,8 +590,7 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// returns whether the tuple's connection has timed out.
//
// Precondition: ct.mu is read locked and bkt.mu is write locked.
-// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
-// +checklocks:ct.mu
+// +checklocksread:ct.mu
// +checklocks:bkt.mu
func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now time.Time) bool {
if !tuple.conn.timedOut(now) {
@@ -621,7 +619,6 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now t
return true
}
-// TODO(https://gvisor.dev/issue/6590): annotate r/w locking requirements.
// +checklocks:b.mu
func removeConnFromBucket(b *bucket, tuple *tuple) {
if tuple.reply {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 542d9257c..3474c292a 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -71,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
// descending order of match quality. If a call to yield returns false,
// iterEndpointsLocked stops iteration and returns immediately.
//
-// +checklocks:eps.mu
+// +checklocksread:eps.mu
func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
@@ -112,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield
// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
// descending order of match quality.
//
-// +checklocks:eps.mu
+// +checklocksread:eps.mu
func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
var matchedEPs []*endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
@@ -124,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []
// findEndpointLocked returns the endpoint that most closely matches the given id.
//
-// +checklocks:eps.mu
+// +checklocksread:eps.mu
func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
var matchedEP *endpointsByNIC
eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 31579a896..995f58616 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -200,7 +200,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-// +checklocks:e.mu
+// +checklocksread:e.mu
func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.net.State() {
case transport.DatagramEndpointStateInitial:
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index e3094f59f..fb31e5104 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -363,8 +363,7 @@ func (e *Endpoint) Disconnect() {
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
//
-// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement.
-// +checklocks:e.mu
+// +checklocksread:e.mu
func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
localAddr := e.Info().ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 39b1e08c0..077a2325a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -292,7 +292,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-// +checklocks:e.mu
+// +checklocksread:e.mu
func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.net.State() {
case transport.DatagramEndpointStateInitial:
diff --git a/tools/checklocks/analysis.go b/tools/checklocks/analysis.go
index d3fd797d0..ec0cba7f9 100644
--- a/tools/checklocks/analysis.go
+++ b/tools/checklocks/analysis.go
@@ -183,8 +183,8 @@ type instructionWithReferrers interface {
// checkFieldAccess checks the validity of a field access.
//
// This also enforces atomicity constraints for fields that must be accessed
-// atomically. The parameter isWrite indicates whether this field is used
-// downstream for a write operation.
+// atomically. The parameter isWrite indicates whether this field is used for
+// a write operation.
func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj ssa.Value, field int, ls *lockState, isWrite bool) {
var (
lff lockFieldFacts
@@ -200,13 +200,14 @@ func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj
for guardName, fl := range lgf.GuardedBy {
guardsFound++
r := fl.resolve(structObj)
- if _, ok := ls.isHeld(r); ok {
+ if _, ok := ls.isHeld(r, isWrite); ok {
guardsHeld++
continue
}
if _, ok := pc.forced[pc.positionKey(inst.Pos())]; ok {
- // Mark this as locked, since it has been forced.
- ls.lockField(r)
+ // Mark this as locked, since it has been forced. All
+ // forces are treated as an exclusive lock.
+ ls.lockField(r, true /* exclusive */)
guardsHeld++
continue
}
@@ -301,7 +302,7 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction
continue
}
r := fg.resolveCall(call.Common().Args, call.Value())
- if s, ok := ls.unlockField(r); !ok {
+ if s, ok := ls.unlockField(r, fg.Exclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
pc.maybeFail(call.Pos(), "attempt to release %s (%s), but not held (locks: %s)", fieldName, s, ls.String())
}
@@ -315,7 +316,7 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction
}
// Acquire the lock per the annotation.
r := fg.resolveCall(call.Common().Args, call.Value())
- if s, ok := ls.lockField(r); !ok {
+ if s, ok := ls.lockField(r, fg.Exclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
pc.maybeFail(call.Pos(), "attempt to acquire %s (%s), but already held (locks: %s)", fieldName, s, ls.String())
}
@@ -323,6 +324,14 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction
}
}
+// exclusiveStr returns a string describing exclusive requirements.
+func exclusiveStr(exclusive bool) string {
+ if exclusive {
+ return "exclusively"
+ }
+ return "non-exclusively"
+}
+
// checkFunctionCall checks preconditions for function calls, and tracks the
// lock state by recording relevant calls to sync functions. Note that calls to
// atomic functions are tracked by checkFieldAccess by looking directly at the
@@ -332,12 +341,12 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff
// Check all guards required are held.
for fieldName, fg := range lff.HeldOnEntry {
r := fg.resolveCall(call.Common().Args, call.Value())
- if s, ok := ls.isHeld(r); !ok {
+ if s, ok := ls.isHeld(r, fg.Exclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
- pc.maybeFail(call.Pos(), "must hold %s (%s) to call %s, but not held (locks: %s)", fieldName, s, fn.Name(), ls.String())
+ pc.maybeFail(call.Pos(), "must hold %s %s (%s) to call %s, but not held (locks: %s)", fieldName, exclusiveStr(fg.Exclusive), s, fn.Name(), ls.String())
} else {
// Force the lock to be acquired.
- ls.lockField(r)
+ ls.lockField(r, fg.Exclusive)
}
}
}
@@ -348,19 +357,33 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff
// Check if it's a method dispatch for something in the sync package.
// See: https://godoc.org/golang.org/x/tools/go/ssa#Function
if fn.Package() != nil && fn.Package().Pkg.Name() == "sync" && fn.Signature.Recv() != nil {
+ isExclusive := false
switch fn.Name() {
- case "Lock", "RLock":
- if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok {
+ case "Lock":
+ isExclusive = true
+ fallthrough
+ case "RLock":
+ if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
// Double locking a mutex that is already locked.
pc.maybeFail(call.Pos(), "%s already locked (locks: %s)", s, ls.String())
}
}
- case "Unlock", "RUnlock":
- if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok {
+ case "Unlock":
+ isExclusive = true
+ fallthrough
+ case "RUnlock":
+ if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok {
if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
// Unlocking something that is already unlocked.
- pc.maybeFail(call.Pos(), "%s already unlocked (locks: %s)", s, ls.String())
+ pc.maybeFail(call.Pos(), "%s already unlocked or locked differently (locks: %s)", s, ls.String())
+ }
+ }
+ case "DowngradeLock":
+ if s, ok := ls.downgradeField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok {
+ if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
+ // Downgrading something that may not be downgraded.
+ pc.maybeFail(call.Pos(), "%s already unlocked or not exclusive (locks: %s)", s, ls.String())
}
}
}
@@ -531,13 +554,13 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock,
// Validate held locks.
for fieldName, fg := range lff.HeldOnExit {
r := fg.resolveStatic(fn, rv)
- if s, ok := rls.isHeld(r); !ok {
+ if s, ok := rls.isHeld(r, fg.Exclusive); !ok {
if _, ok := pc.forced[pc.positionKey(rv.Pos())]; !ok {
- pc.maybeFail(rv.Pos(), "lock %s (%s) not held (locks: %s)", fieldName, s, rls.String())
+ pc.maybeFail(rv.Pos(), "lock %s (%s) not held %s (locks: %s)", fieldName, s, exclusiveStr(fg.Exclusive), rls.String())
failed = true
} else {
// Force the lock to be acquired.
- rls.lockField(r)
+ rls.lockField(r, fg.Exclusive)
}
}
}
@@ -558,7 +581,7 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock,
if pls := pc.checkBasicBlock(fn, succ, lff, ls, seen); pls != nil {
if rls != nil && !rls.isCompatible(pls) {
if _, ok := pc.forced[pc.positionKey(fn.Pos())]; !ok {
- pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %v)", rls.String(), pls.String())
+ pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %s)", rls.String(), pls.String())
}
}
rls = pls
@@ -601,11 +624,11 @@ func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *loc
// The first is the method object itself so we skip that when looking
// for receiver/function parameters.
r := fg.resolveStatic(fn, call.Value())
- if s, ok := ls.lockField(r); !ok {
+ if s, ok := ls.lockField(r, fg.Exclusive); !ok {
// This can only happen if the same value is declared
// multiple times, and should be caught by the earlier
// fact scanning. Keep it here as a sanity check.
- pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times (locks: %s)", fieldName, s, ls.String())
+ pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times or differently (locks: %s)", fieldName, s, ls.String())
}
}
diff --git a/tools/checklocks/annotations.go b/tools/checklocks/annotations.go
index 371260980..1f679e5be 100644
--- a/tools/checklocks/annotations.go
+++ b/tools/checklocks/annotations.go
@@ -23,13 +23,16 @@ import (
)
const (
- checkLocksAnnotation = "// +checklocks:"
- checkLocksAcquires = "// +checklocksacquire:"
- checkLocksReleases = "// +checklocksrelease:"
- checkLocksIgnore = "// +checklocksignore"
- checkLocksForce = "// +checklocksforce"
- checkLocksFail = "// +checklocksfail"
- checkAtomicAnnotation = "// +checkatomic"
+ checkLocksAnnotation = "// +checklocks:"
+ checkLocksAnnotationRead = "// +checklocksread:"
+ checkLocksAcquires = "// +checklocksacquire:"
+ checkLocksAcquiresRead = "// +checklocksacquireread:"
+ checkLocksReleases = "// +checklocksrelease:"
+ checkLocksReleasesRead = "// +checklocksreleaseread:"
+ checkLocksIgnore = "// +checklocksignore"
+ checkLocksForce = "// +checklocksforce"
+ checkLocksFail = "// +checklocksfail"
+ checkAtomicAnnotation = "// +checkatomic"
)
// failData indicates an expected failure.
diff --git a/tools/checklocks/facts.go b/tools/checklocks/facts.go
index 34c9f5ef1..fd681adc3 100644
--- a/tools/checklocks/facts.go
+++ b/tools/checklocks/facts.go
@@ -166,6 +166,9 @@ type functionGuard struct {
// FieldList is the traversal path to the object.
FieldList fieldList
+
+ // Exclusive indicates an exclusive lock is required.
+ Exclusive bool
}
// resolveReturn resolves a return value.
@@ -262,7 +265,7 @@ type lockFunctionFacts struct {
func (*lockFunctionFacts) AFact() {}
// checkGuard validates the guardName.
-func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) {
+func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) {
if _, ok := lff.HeldOnEntry[guardName]; ok {
pc.maybeFail(d.Pos(), "annotation %s specified more than once, already required", guardName)
return functionGuard{}, false
@@ -271,13 +274,13 @@ func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guard
pc.maybeFail(d.Pos(), "annotation %s specified more than once, already acquired", guardName)
return functionGuard{}, false
}
- fg, ok := pc.findFunctionGuard(d, guardName, allowReturn)
+ fg, ok := pc.findFunctionGuard(d, guardName, exclusive, allowReturn)
return fg, ok
}
// addGuardedBy adds a field to both HeldOnEntry and HeldOnExit.
-func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string) {
- if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok {
+func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) {
+ if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok {
if lff.HeldOnEntry == nil {
lff.HeldOnEntry = make(map[string]functionGuard)
}
@@ -290,8 +293,8 @@ func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, gua
}
// addAcquires adds a field to HeldOnExit.
-func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string) {
- if fg, ok := lff.checkGuard(pc, d, guardName, true /* allowReturn */); ok {
+func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) {
+ if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, true /* allowReturn */); ok {
if lff.HeldOnExit == nil {
lff.HeldOnExit = make(map[string]functionGuard)
}
@@ -300,8 +303,8 @@ func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guar
}
// addReleases adds a field to HeldOnEntry.
-func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string) {
- if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok {
+func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string, exclusive bool) {
+ if fg, ok := lff.checkGuard(pc, d, guardName, exclusive, false /* allowReturn */); ok {
if lff.HeldOnEntry == nil {
lff.HeldOnEntry = make(map[string]functionGuard)
}
@@ -310,7 +313,7 @@ func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guar
}
// fieldListFor returns the fieldList for the given object.
-func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool) (int, bool) {
+func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool, exclusive bool) (int, bool) {
var lff lockFieldFacts
if !pc.pass.ImportObjectFact(fieldObj, &lff) {
// This should not happen: we export facts for all fields.
@@ -318,7 +321,11 @@ func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index
}
// Check that it is indeed a mutex.
if checkMutex && !lff.IsMutex && !lff.IsRWMutex {
- pc.maybeFail(pos, "field %s is not a mutex or an rwmutex", fieldName)
+ pc.maybeFail(pos, "field %s is not a Mutex or an RWMutex", fieldName)
+ return 0, false
+ }
+ if checkMutex && !exclusive && !lff.IsRWMutex {
+ pc.maybeFail(pos, "field %s must be a RWMutex, but it is not", fieldName)
return 0, false
}
// Return the resolution path.
@@ -329,14 +336,14 @@ func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index
}
// resolveOneField resolves a field in a single struct.
-func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool) (fl fieldList, fieldObj types.Object, ok bool) {
+func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool, exclusive bool) (fl fieldList, fieldObj types.Object, ok bool) {
// Scan to match the next field.
for i := 0; i < structType.NumFields(); i++ {
fieldObj := structType.Field(i)
if fieldObj.Name() != fieldName {
continue
}
- flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex)
+ flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex, exclusive)
if !ok {
return nil, nil, false
}
@@ -357,8 +364,8 @@ func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct,
// Need to check that there is a resolution path. If there is
// no resolution path that's not a failure: we just continue
// scanning the next embed to find a match.
- flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false)
- flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex)
+ flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false, exclusive)
+ flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex, exclusive)
if okEmbed && okCont {
fl = append(fl, flEmbed)
fl = append(fl, flCont...)
@@ -373,9 +380,9 @@ func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct,
//
// Note that this checks that the final element is a mutex of some kind, and
// will fail appropriately.
-func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string) (fl fieldList, ok bool) {
+func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string, exclusive bool) (fl fieldList, ok bool) {
for partNumber, fieldName := range parts {
- flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */)
+ flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */, exclusive)
if !ok {
// Error already reported.
return nil, false
@@ -474,16 +481,22 @@ func (pc *passContext) exportLockGuardFacts(structType *types.Struct, ss *ast.St
pc.maybeFail(fieldObj.Pos(), "annotation %s specified more than once", guardName)
return
}
- fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, "."))
+ fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, "."), true /* exclusive */)
if ok {
- // If we successfully resolved
- // the field, then save it.
+ // If we successfully resolved the field, then save it.
if lgf.GuardedBy == nil {
lgf.GuardedBy = make(map[string]fieldList)
}
lgf.GuardedBy[guardName] = fl
}
},
+ // N.B. We support only the vanilla
+ // annotation on individual fields. If
+ // the field is a read lock, then we
+ // will allow read access by default.
+ checkLocksAnnotationRead: func(guardName string) {
+ pc.maybeFail(fieldObj.Pos(), "annotation %s not legal on fields", guardName)
+ },
})
}
// Save only if there is something meaningful.
@@ -514,7 +527,7 @@ func countFields(fl []*ast.Field) (count int) {
}
// matchFieldList attempts to match the given field.
-func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string) (functionGuard, bool) {
+func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string, exclusive bool) (functionGuard, bool) {
parts := strings.Split(guardName, ".")
parameterName := parts[0]
parameterNumber := 0
@@ -545,8 +558,9 @@ func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName
}
fg := functionGuard{
ParameterNumber: parameterNumber,
+ Exclusive: exclusive,
}
- fl, ok := pc.resolveField(pos, structType, parts[1:])
+ fl, ok := pc.resolveField(pos, structType, parts[1:], exclusive)
fg.FieldList = fl
return fg, ok // If ok is false, already failed.
}
@@ -558,7 +572,7 @@ func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName
// particular string of the 'a.b'.
//
// This function will report any errors directly.
-func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) {
+func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, exclusive bool, allowReturn bool) (functionGuard, bool) {
var (
parameterList []*ast.Field
returnList []*ast.Field
@@ -569,14 +583,14 @@ func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allo
if d.Type.Params != nil {
parameterList = append(parameterList, d.Type.Params.List...)
}
- if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName); ok {
+ if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName, exclusive); ok {
return fg, ok
}
if allowReturn {
if d.Type.Results != nil {
returnList = append(returnList, d.Type.Results.List...)
}
- if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName); ok {
+ if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName, exclusive); ok {
// Fix this up to apply to the return value, as noted
// in fg.ParameterNumber. For the ssa analysis, we must
// record whether this has multiple results, since
@@ -610,9 +624,12 @@ func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) {
// extremely rare.
lff.Ignore = true
},
- checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName) },
- checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName) },
- checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName) },
+ checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName, true /* exclusive */) },
+ checkLocksAnnotationRead: func(guardName string) { lff.addGuardedBy(pc, d, guardName, false /* exclusive */) },
+ checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName, true /* exclusive */) },
+ checkLocksAcquiresRead: func(guardName string) { lff.addAcquires(pc, d, guardName, false /* exclusive */) },
+ checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName, true /* exclusive */) },
+ checkLocksReleasesRead: func(guardName string) { lff.addReleases(pc, d, guardName, false /* exclusive */) },
})
}
diff --git a/tools/checklocks/state.go b/tools/checklocks/state.go
index 57061a32e..aaf997d79 100644
--- a/tools/checklocks/state.go
+++ b/tools/checklocks/state.go
@@ -29,7 +29,9 @@ type lockState struct {
// lockedMutexes is used to track which mutexes in a given struct are
// currently locked. Note that most of the heavy lifting is done by
// valueAsString below, which maps to specific structure fields, etc.
- lockedMutexes []string
+ //
+ // The value indicates whether this is an exclusive lock.
+ lockedMutexes map[string]bool
// stored stores values that have been stored in memory, bound to
// FreeVars or passed as Parameterse.
@@ -51,7 +53,7 @@ type lockState struct {
func newLockState() *lockState {
refs := int32(1) // Not shared.
return &lockState{
- lockedMutexes: make([]string, 0),
+ lockedMutexes: make(map[string]bool),
used: make(map[ssa.Value]struct{}),
stored: make(map[ssa.Value]ssa.Value),
defers: make([]*ssa.Defer, 0),
@@ -79,8 +81,10 @@ func (l *lockState) fork() *lockState {
func (l *lockState) modify() {
if atomic.LoadInt32(l.refs) > 1 {
// Copy the lockedMutexes.
- lm := make([]string, len(l.lockedMutexes))
- copy(lm, l.lockedMutexes)
+ lm := make(map[string]bool)
+ for k, v := range l.lockedMutexes {
+ lm[k] = v
+ }
l.lockedMutexes = lm
// Copy the stored values.
@@ -106,55 +110,76 @@ func (l *lockState) modify() {
}
// isHeld indicates whether the field is held is not.
-func (l *lockState) isHeld(rv resolvedValue) (string, bool) {
+func (l *lockState) isHeld(rv resolvedValue, exclusiveRequired bool) (string, bool) {
if !rv.valid {
return rv.valueAsString(l), false
}
s := rv.valueAsString(l)
- for _, k := range l.lockedMutexes {
- if k == s {
- return s, true
- }
+ isExclusive, ok := l.lockedMutexes[s]
+ if !ok {
+ return s, false
+ }
+ // Accept a weaker lock if exclusiveRequired is false.
+ if exclusiveRequired && !isExclusive {
+ return s, false
}
- return s, false
+ return s, true
}
// lockField locks the given field.
//
// If false is returned, the field was already locked.
-func (l *lockState) lockField(rv resolvedValue) (string, bool) {
+func (l *lockState) lockField(rv resolvedValue, exclusive bool) (string, bool) {
if !rv.valid {
return rv.valueAsString(l), false
}
s := rv.valueAsString(l)
- for _, k := range l.lockedMutexes {
- if k == s {
- return s, false
- }
+ if _, ok := l.lockedMutexes[s]; ok {
+ return s, false
}
l.modify()
- l.lockedMutexes = append(l.lockedMutexes, s)
+ l.lockedMutexes[s] = exclusive
return s, true
}
// unlockField unlocks the given field.
//
// If false is returned, the field was not locked.
-func (l *lockState) unlockField(rv resolvedValue) (string, bool) {
+func (l *lockState) unlockField(rv resolvedValue, exclusive bool) (string, bool) {
if !rv.valid {
return rv.valueAsString(l), false
}
s := rv.valueAsString(l)
- for i, k := range l.lockedMutexes {
- if k == s {
- // Copy the last lock in and truncate.
- l.modify()
- l.lockedMutexes[i] = l.lockedMutexes[len(l.lockedMutexes)-1]
- l.lockedMutexes = l.lockedMutexes[:len(l.lockedMutexes)-1]
- return s, true
- }
+ wasExclusive, ok := l.lockedMutexes[s]
+ if !ok {
+ return s, false
}
- return s, false
+ if wasExclusive != exclusive {
+ return s, false
+ }
+ l.modify()
+ delete(l.lockedMutexes, s)
+ return s, true
+}
+
+// downgradeField downgrades the given field.
+//
+// If false was returned, the field was not downgraded.
+func (l *lockState) downgradeField(rv resolvedValue) (string, bool) {
+ if !rv.valid {
+ return rv.valueAsString(l), false
+ }
+ s := rv.valueAsString(l)
+ wasExclusive, ok := l.lockedMutexes[s]
+ if !ok {
+ return s, false
+ }
+ if !wasExclusive {
+ return s, false
+ }
+ l.modify()
+ l.lockedMutexes[s] = false // Downgraded.
+ return s, true
}
// store records an alias.
@@ -165,16 +190,17 @@ func (l *lockState) store(addr ssa.Value, v ssa.Value) {
// isSubset indicates other holds all the locks held by l.
func (l *lockState) isSubset(other *lockState) bool {
- held := 0 // Number in l, held by other.
- for _, k := range l.lockedMutexes {
- for _, ok := range other.lockedMutexes {
- if k == ok {
- held++
- break
- }
+ for k, isExclusive := range l.lockedMutexes {
+ otherExclusive, otherOk := other.lockedMutexes[k]
+ if !otherOk {
+ return false
+ }
+ // Accept weaker locks as a subset.
+ if isExclusive && !otherExclusive {
+ return false
}
}
- return held >= len(l.lockedMutexes)
+ return true
}
// count indicates the number of locks held.
@@ -293,7 +319,12 @@ func (l *lockState) String() string {
if l.count() == 0 {
return "no locks held"
}
- return strings.Join(l.lockedMutexes, ",")
+ keys := make([]string, 0, len(l.lockedMutexes))
+ for k, exclusive := range l.lockedMutexes {
+ // Include the exclusive status of each lock.
+ keys = append(keys, fmt.Sprintf("%s %s", k, exclusiveStr(exclusive)))
+ }
+ return strings.Join(keys, ",")
}
// pushDefer pushes a defer onto the stack.
diff --git a/tools/checklocks/test/BUILD b/tools/checklocks/test/BUILD
index d4d98c256..f2ea6c7c6 100644
--- a/tools/checklocks/test/BUILD
+++ b/tools/checklocks/test/BUILD
@@ -16,6 +16,7 @@ go_library(
"methods.go",
"parameters.go",
"return.go",
+ "rwmutex.go",
"test.go",
],
)
diff --git a/tools/checklocks/test/basics.go b/tools/checklocks/test/basics.go
index 7a773171f..e941fba5b 100644
--- a/tools/checklocks/test/basics.go
+++ b/tools/checklocks/test/basics.go
@@ -108,14 +108,14 @@ type rwGuardStruct struct {
func testRWValidRead(tc *rwGuardStruct) {
tc.rwMu.Lock()
- tc.guardedField = 1
+ _ = tc.guardedField
tc.rwMu.Unlock()
}
func testRWValidWrite(tc *rwGuardStruct) {
- tc.rwMu.RLock()
+ tc.rwMu.Lock()
tc.guardedField = 2
- tc.rwMu.RUnlock()
+ tc.rwMu.Unlock()
}
func testRWInvalidWrite(tc *rwGuardStruct) {
diff --git a/tools/checklocks/test/rwmutex.go b/tools/checklocks/test/rwmutex.go
new file mode 100644
index 000000000..d27ed10e3
--- /dev/null
+++ b/tools/checklocks/test/rwmutex.go
@@ -0,0 +1,52 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "sync"
+)
+
+// oneReadGuardStruct has one read-guarded field.
+type oneReadGuardStruct struct {
+ mu sync.RWMutex
+ // +checklocks:mu
+ guardedField int
+}
+
+func testRWAccessValidRead(tc *oneReadGuardStruct) {
+ tc.mu.Lock()
+ _ = tc.guardedField
+ tc.mu.Unlock()
+ tc.mu.RLock()
+ _ = tc.guardedField
+ tc.mu.RUnlock()
+}
+
+func testRWAccessValidWrite(tc *oneReadGuardStruct) {
+ tc.mu.Lock()
+ tc.guardedField = 1
+ tc.mu.Unlock()
+}
+
+func testRWAccessInvalidWrite(tc *oneReadGuardStruct) {
+ tc.guardedField = 2 // +checklocksfail
+ tc.mu.RLock()
+ tc.guardedField = 2 // +checklocksfail
+ tc.mu.RUnlock()
+}
+
+func testRWAccessInvalidRead(tc *oneReadGuardStruct) {
+ _ = tc.guardedField // +checklocksfail
+}