summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAdin Scannell <ascannell@google.com>2021-07-01 15:05:28 -0700
committergVisor bot <gvisor-bot@google.com>2021-07-01 15:07:56 -0700
commit16b751b6c610ec2c5a913cb8a818e9239ee7da71 (patch)
tree5596ea010c6afbbe79d1196197cd4bfc5d517e79
parent570ca571805d6939c4c24b6a88660eefaf558ae7 (diff)
Mix checklocks and atomic analyzers.
This change makes the checklocks analyzer considerable more powerful, adding: * The ability to traverse complex structures, e.g. to have multiple nested fields as part of the annotation. * The ability to resolve simple anonymous functions and closures, and perform lock analysis across these invocations. This does not apply to closures that are passed elsewhere, since it is not possible to know the context in which they might be invoked. * The ability to annotate return values in addition to receivers and other parameters, with the same complex structures noted above. * Ignoring locking semantics for "fresh" objects, i.e. objects that are allocated in the local frame (typically a new-style function). * Sanity checking of locking state across block transitions and returns, to ensure that no unexpected locks are held. Note that initially, most of these findings are excluded by a comprehensive nogo.yaml. The findings that are included are fundamental lock violations. The changes here should be relatively low risk, minor refactorings to either include necessary annotations to simplify the code structure (in general removing closures in favor of methods) so that the analyzer can be easily track the lock state. This change additional includes two changes to nogo itself: * Sanity checking of all types to ensure that the binary and ast-derived types have a consistent objectpath, to prevent the bug above from occurring silently (and causing much confusion). This also requires a trick in order to ensure that serialized facts are consumable downstream. This can be removed with https://go-review.googlesource.com/c/tools/+/331789 merged. * A minor refactoring to isolation the objdump settings in its own package. This was originally used to implement the sanity check above, but this information is now being passed another way. The minor refactor is preserved however, since it cleans up the code slightly and is minimal risk. PiperOrigin-RevId: 382613300
-rw-r--r--nogo.yaml50
-rw-r--r--pkg/sentry/fs/dirent.go84
-rw-r--r--pkg/sentry/fs/fs.go26
-rw-r--r--pkg/sentry/fs/gofer/inode_state.go10
-rw-r--r--pkg/sentry/fs/gofer/path.go80
-rw-r--r--pkg/sentry/fs/gofer/session.go17
-rw-r--r--pkg/sentry/fs/gofer/socket.go9
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go16
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go18
-rw-r--r--pkg/sentry/fsimpl/gofer/revalidate.go10
-rw-r--r--pkg/sentry/fsimpl/gofer/symlink.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go4
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go4
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go2
-rw-r--r--pkg/sentry/kernel/futex/futex.go41
-rw-r--r--pkg/sentry/kernel/pipe/pipe_unsafe.go2
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go1
-rw-r--r--pkg/sentry/kernel/ptrace.go1
-rw-r--r--pkg/sentry/kernel/sessions.go5
-rw-r--r--pkg/sentry/mm/syscalls.go1
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go2
-rw-r--r--pkg/sentry/time/calibrated_clock_test.go1
-rw-r--r--pkg/sentry/vfs/dentry.go27
-rw-r--r--pkg/sync/mutex_test.go2
-rw-r--r--pkg/sync/mutex_unsafe.go14
-rw-r--r--pkg/sync/rwmutex_test.go2
-rw-r--r--pkg/sync/rwmutex_unsafe.go8
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go10
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go2
-rw-r--r--pkg/tcpip/stack/conntrack.go4
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go9
-rw-r--r--pkg/tcpip/transport/tcp/connect.go74
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go140
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go150
-rw-r--r--runsc/cmd/boot.go12
-rw-r--r--runsc/cmd/gofer.go4
-rw-r--r--tools/bazeldefs/BUILD5
-rw-r--r--tools/checkescape/BUILD1
-rw-r--r--tools/checkescape/checkescape.go178
-rw-r--r--tools/checklocks/BUILD9
-rw-r--r--tools/checklocks/README.md83
-rw-r--r--tools/checklocks/analysis.go628
-rw-r--r--tools/checklocks/annotations.go129
-rw-r--r--tools/checklocks/checklocks.go758
-rw-r--r--tools/checklocks/facts.go614
-rw-r--r--tools/checklocks/state.go315
-rw-r--r--tools/checklocks/test/BUILD14
-rw-r--r--tools/checklocks/test/alignment.go51
-rw-r--r--tools/checklocks/test/atomics.go91
-rw-r--r--tools/checklocks/test/basics.go145
-rw-r--r--tools/checklocks/test/branches.go56
-rw-r--r--tools/checklocks/test/closures.go100
-rw-r--r--tools/checklocks/test/defer.go38
-rw-r--r--tools/checklocks/test/incompat.go54
-rw-r--r--tools/checklocks/test/methods.go117
-rw-r--r--tools/checklocks/test/parameters.go48
-rw-r--r--tools/checklocks/test/return.go61
-rw-r--r--tools/checklocks/test/test.go328
-rw-r--r--tools/nogo/BUILD2
-rw-r--r--tools/nogo/check/main.go18
-rw-r--r--tools/nogo/defs.bzl39
-rw-r--r--tools/nogo/nogo.go148
-rw-r--r--tools/nogo/objdump/BUILD10
-rw-r--r--tools/nogo/objdump/objdump.go96
67 files changed, 3427 insertions, 1565 deletions
diff --git a/nogo.yaml b/nogo.yaml
index d9b6a5ffe..9b7fc5c8f 100644
--- a/nogo.yaml
+++ b/nogo.yaml
@@ -46,6 +46,8 @@ global:
- "(field|method|struct|type) .* should be .*"
# Generated proto code sometimes duplicates imports with aliases.
- "duplicate import"
+ # These will never be annotated.
+ - "unexpected call to atomic function"
internal:
suppress:
# We use ALL_CAPS for system definitions,
@@ -55,6 +57,10 @@ global:
# Same story for underscores.
- "should not use ALL_CAPS in Go names"
- "should not use underscores in Go names"
+ # These need to be annotated.
+ - "unexpected call to atomic function.*"
+ - "return with unexpected locks held.*"
+ - "incompatible return states.*"
exclude:
# Generated: exempt all.
- pkg/shim/runtimeoptions/runtimeoptions_cri.go
@@ -76,49 +82,7 @@ analyzers:
checklocks:
internal:
exclude:
- - "^-$" # b/181776900: analyzer fails on buildkite
- - pkg/sentry/fs/dirent.go # unsupported usage.
- - pkg/sentry/fs/fsutil/inode_cached.go # unsupported usage.
- - pkg/sentry/fs/gofer/inode_state.go # unsupported usage.
- - pkg/sentry/fs/gofer/session.go # unsupported usage.
- - pkg/sentry/fs/ramfs/dir.go # unsupported usage.
- - pkg/sentry/fsimpl/fuse/connection.go # unsupported usage.
- - pkg/sentry/fsimpl/kernfs/filesystem.go # unsupported usage.
- - pkg/sentry/fsimpl/kernfs/inode_impl_util.go # unsupported usage.
- - pkg/sentry/fsimpl/fuse/dev_test.go # unsupported usage.
- - pkg/sentry/fsimpl/gofer/filesystem.go # unsupported usage.
- - pkg/sentry/fsimpl/gofer/gofer.go # unsupported usage.
- - pkg/sentry/fsimpl/gofer/regular_file.go # unsupported usage.
- - pkg/sentry/fsimpl/gofer/revalidate.go # unsupported usage.
- - pkg/sentry/fsimpl/gofer/special_file.go # unsupported usage.
- - pkg/sentry/fsimpl/gofer/symlink.go # unsupported usage.
- - pkg/sentry/fsimpl/overlay/copy_up.go # unsupported usage.
- - pkg/sentry/fsimpl/overlay/filesystem.go # unsupported usage.
- - pkg/sentry/fsimpl/tmpfs/filesystem.go # unsupported usage.
- - pkg/sentry/fsimpl/verity/filesystem.go # unsupported usage.
- - pkg/sentry/kernel/futex/futex.go # unsupported usage.
- - pkg/sentry/kernel/pipe/vfs.go # unsupported usage.
- - pkg/sentry/mm/syscalls.go # unsupported usage.
- - pkg/sentry/kernel/fd_table.go # unsupported usage.
- - pkg/sentry/kernel/ptrace.go # unsupported usage.
- - pkg/sentry/time/calibrated_clock_test.go # unsupported usage.
- - pkg/sentry/kernel/task_context.go # unsupported usage.
- - pkg/sentry/pgalloc/pgalloc.go # unsupported usage.
- - pkg/sentry/socket/unix/transport/connectioned.go # unsupported usage.
- - pkg/sentry/vfs/dentry.go # unsupported usage.
- - pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go # unsupported usage.
- - pkg/tcpip/stack/conntrack.go # unsupported usage.
- - pkg/tcpip/transport/packet/endpoint_state.go # unsupported usage.
- - pkg/tcpip/transport/raw/endpoint_state.go # unsupported usage.
- - pkg/tcpip/transport/icmp/endpoint.go # unsupported usage.
- - pkg/tcpip/transport/icmp/endpoint_state.go # unsupported usage.
- - pkg/tcpip/transport/tcp/accept.go # unsupported usage.
- - pkg/tcpip/transport/tcp/connect.go # unsupported usage.
- - pkg/tcpip/transport/tcp/dispatcher.go # unsupported usage (TryLock)
- - pkg/tcpip/transport/tcp/endpoint.go # unsupported usage.
- - pkg/tcpip/transport/tcp/endpoint_state.go # unsupported usage.
- - pkg/tcpip/transport/udp/endpoint.go # unsupported usage (defer unlock in anonymous function)
- - pkg/tcpip/transport/udp/endpoint_state.go # unsupported usage (missing nested mutex annotation support)
+ - "^-$" # b/181776900: analyzer fails on buildkite.
shadow: # Disable for now.
generated:
exclude: [".*"]
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index 3a45e9041..8d7660e79 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -488,11 +488,11 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
// Slow path: load the InodeOperations into memory. Since this is a hot path and the lookup may be
// expensive, if possible release the lock and re-acquire it.
if walkMayUnlock {
- d.mu.Unlock()
+ d.mu.Unlock() // +checklocksforce: results in an inconsistent block.
}
c, err := d.Inode.Lookup(ctx, name)
if walkMayUnlock {
- d.mu.Lock()
+ d.mu.Lock() // +checklocksforce: see above.
}
// No dice.
if err != nil {
@@ -594,21 +594,27 @@ func (d *Dirent) exists(ctx context.Context, root *Dirent, name string) bool {
// lockDirectory should be called for any operation that changes this `d`s
// children (creating or removing them).
-func (d *Dirent) lockDirectory() func() {
+// +checklocksacquire:d.dirMu
+// +checklocksacquire:d.mu
+func (d *Dirent) lockDirectory() {
renameMu.RLock()
d.dirMu.Lock()
d.mu.Lock()
- return func() {
- d.mu.Unlock()
- d.dirMu.Unlock()
- renameMu.RUnlock()
- }
+}
+
+// unlockDirectory is the reverse of lockDirectory.
+// +checklocksrelease:d.dirMu
+// +checklocksrelease:d.mu
+func (d *Dirent) unlockDirectory() {
+ d.mu.Unlock()
+ d.dirMu.Unlock()
+ renameMu.RUnlock() // +checklocksforce: see lockDirectory.
}
// Create creates a new regular file in this directory.
func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags FileFlags, perms FilePermissions) (*File, error) {
- unlock := d.lockDirectory()
- defer unlock()
+ d.lockDirectory()
+ defer d.unlockDirectory()
// Does something already exist?
if d.exists(ctx, root, name) {
@@ -670,8 +676,8 @@ func (d *Dirent) finishCreate(ctx context.Context, child *Dirent, name string) {
// genericCreate executes create if name does not exist. Removes a negative Dirent at name if
// create succeeds.
func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, create func() error) error {
- unlock := d.lockDirectory()
- defer unlock()
+ d.lockDirectory()
+ defer d.unlockDirectory()
// Does something already exist?
if d.exists(ctx, root, name) {
@@ -1021,8 +1027,8 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath
panic("Dirent.Remove: root must not be nil")
}
- unlock := d.lockDirectory()
- defer unlock()
+ d.lockDirectory()
+ defer d.unlockDirectory()
// Try to walk to the node.
child, err := d.walk(ctx, root, name, false /* may unlock */)
@@ -1082,8 +1088,8 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
panic("Dirent.Remove: root must not be nil")
}
- unlock := d.lockDirectory()
- defer unlock()
+ d.lockDirectory()
+ defer d.unlockDirectory()
// Check for dots.
if name == "." {
@@ -1259,17 +1265,15 @@ func (d *Dirent) dropExtendedReference() {
d.Inode.MountSource.fscache.Remove(d)
}
-// lockForRename takes locks on oldParent and newParent as required by Rename
-// and returns a function that will unlock the locks taken. The returned
-// function must be called even if a non-nil error is returned.
-func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName string) (func(), error) {
+// lockForRename takes locks on oldParent and newParent as required by Rename.
+// On return, unlockForRename must always be called, even with an error.
+// +checklocksacquire:oldParent.mu
+// +checklocksacquire:newParent.mu
+func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName string) error {
renameMu.Lock()
if oldParent == newParent {
oldParent.mu.Lock()
- return func() {
- oldParent.mu.Unlock()
- renameMu.Unlock()
- }, nil
+ return nil // +checklocksforce: only one lock exists.
}
// Renaming between directories is a bit subtle:
@@ -1297,11 +1301,7 @@ func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName
// itself.
err = unix.EINVAL
}
- return func() {
- newParent.mu.Unlock()
- oldParent.mu.Unlock()
- renameMu.Unlock()
- }, err
+ return err
}
child = p
}
@@ -1310,11 +1310,21 @@ func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName
// have no relationship; in either case we can do this:
newParent.mu.Lock()
oldParent.mu.Lock()
- return func() {
+ return nil
+}
+
+// unlockForRename is the opposite of lockForRename.
+// +checklocksrelease:oldParent.mu
+// +checklocksrelease:newParent.mu
+func unlockForRename(oldParent, newParent *Dirent) {
+ if oldParent == newParent {
oldParent.mu.Unlock()
- newParent.mu.Unlock()
- renameMu.Unlock()
- }, nil
+ renameMu.Unlock() // +checklocksforce: only one lock exists.
+ return
+ }
+ newParent.mu.Unlock()
+ oldParent.mu.Unlock()
+ renameMu.Unlock() // +checklocksforce: not tracked.
}
func (d *Dirent) checkSticky(ctx context.Context, victim *Dirent) error {
@@ -1353,8 +1363,8 @@ func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error
return err
}
- unlock := d.lockDirectory()
- defer unlock()
+ d.lockDirectory()
+ defer d.unlockDirectory()
victim, err := d.walk(ctx, root, name, true /* may unlock */)
if err != nil {
@@ -1392,8 +1402,8 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
}
// Acquire global renameMu lock, and mu locks on oldParent/newParent.
- unlock, err := lockForRename(oldParent, oldName, newParent, newName)
- defer unlock()
+ err := lockForRename(oldParent, oldName, newParent, newName)
+ defer unlockForRename(oldParent, newParent)
if err != nil {
return err
}
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index 44587bb37..a346c316b 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -80,23 +80,33 @@ func AsyncBarrier() {
// Async executes a function asynchronously.
//
// Async must not be called recursively.
+// +checklocksignore
func Async(f func()) {
workMu.RLock()
- go func() { // S/R-SAFE: AsyncBarrier must be called.
- defer workMu.RUnlock() // Ensure RUnlock in case of panic.
- f()
- }()
+ go asyncWork(f) // S/R-SAFE: AsyncBarrier must be called.
+}
+
+// +checklocksignore
+func asyncWork(f func()) {
+ // Ensure RUnlock in case of panic.
+ defer workMu.RUnlock()
+ f()
}
// AsyncWithContext is just like Async, except that it calls the asynchronous
// function with the given context as argument. This function exists to avoid
// needing to allocate an extra function on the heap in a hot path.
+// +checklocksignore
func AsyncWithContext(ctx context.Context, f func(context.Context)) {
workMu.RLock()
- go func() { // S/R-SAFE: AsyncBarrier must be called.
- defer workMu.RUnlock() // Ensure RUnlock in case of panic.
- f(ctx)
- }()
+ go asyncWorkWithContext(ctx, f)
+}
+
+// +checklocksignore
+func asyncWorkWithContext(ctx context.Context, f func(context.Context)) {
+ // Ensure RUnlock in case of panic.
+ defer workMu.RUnlock()
+ f(ctx)
}
// AsyncErrorBarrier waits for all outstanding asynchronous work to complete, or
diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go
index e2af1d2ae..19f91f010 100644
--- a/pkg/sentry/fs/gofer/inode_state.go
+++ b/pkg/sentry/fs/gofer/inode_state.go
@@ -112,13 +112,6 @@ func (i *inodeFileState) loadLoading(_ struct{}) {
// +checklocks:i.loading
func (i *inodeFileState) afterLoad() {
load := func() (err error) {
- // See comment on i.loading().
- defer func() {
- if err == nil {
- i.loading.Unlock()
- }
- }()
-
// Manually restore the p9.File.
name, ok := i.s.inodeMappings[i.sattr.InodeID]
if !ok {
@@ -167,6 +160,9 @@ func (i *inodeFileState) afterLoad() {
i.savedUAttr = nil
}
+ // See comment on i.loading(). This only unlocks on the
+ // non-error path.
+ i.loading.Unlock() // +checklocksforce: per comment.
return nil
}
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index aa2405f68..958f46bd6 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -47,7 +47,8 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string
return nil, linuxerr.ENAMETOOLONG
}
- cp := i.session().cachePolicy
+ s := i.session()
+ cp := s.cachePolicy
if cp.cacheReaddir() {
// Check to see if we have readdirCache that indicates the
// child does not exist. Avoid holding readdirMu longer than
@@ -78,7 +79,7 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string
return nil, err
}
- if i.session().overrides != nil {
+ if s.overrides != nil {
// Check if file belongs to a internal named pipe. Note that it doesn't need
// to check for sockets because it's done in newInodeOperations below.
deviceKey := device.MultiDeviceKey{
@@ -86,13 +87,13 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string
SecondaryDevice: i.session().connID,
Inode: qids[0].Path,
}
- unlock := i.session().overrides.lock()
- if pipeInode := i.session().overrides.getPipe(deviceKey); pipeInode != nil {
- unlock()
+ s.overrides.lock()
+ if pipeInode := s.overrides.getPipe(deviceKey); pipeInode != nil {
+ s.overrides.unlock()
pipeInode.IncRef()
return fs.NewDirent(ctx, pipeInode, name), nil
}
- unlock()
+ s.overrides.unlock()
}
// Construct the Inode operations.
@@ -221,17 +222,20 @@ func (i *inodeOperations) CreateHardLink(ctx context.Context, inode *fs.Inode, t
if err := i.fileState.file.link(ctx, &targetOpts.fileState.file, newName); err != nil {
return err
}
- if i.session().cachePolicy.cacheUAttrs(inode) {
+
+ s := i.session()
+ if s.cachePolicy.cacheUAttrs(inode) {
// Increase link count.
targetOpts.cachingInodeOps.IncLinks(ctx)
}
+
i.touchModificationAndStatusChangeTime(ctx, inode)
return nil
}
// CreateDirectory uses Create to create a directory named s under inodeOperations.
-func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, s string, perm fs.FilePermissions) error {
- if len(s) > maxFilenameLen {
+func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ if len(name) > maxFilenameLen {
return linuxerr.ENAMETOOLONG
}
@@ -247,16 +251,18 @@ func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, s
perm.SetGID = true
}
- if _, err := i.fileState.file.mkdir(ctx, s, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
+ if _, err := i.fileState.file.mkdir(ctx, name, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
return err
}
- if i.session().cachePolicy.cacheUAttrs(dir) {
+
+ s := i.session()
+ if s.cachePolicy.cacheUAttrs(dir) {
// Increase link count.
//
// N.B. This will update the modification time.
i.cachingInodeOps.IncLinks(ctx)
}
- if i.session().cachePolicy.cacheReaddir() {
+ if s.cachePolicy.cacheReaddir() {
// Invalidate readdir cache.
i.markDirectoryDirty()
}
@@ -269,13 +275,14 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string,
return nil, linuxerr.ENAMETOOLONG
}
- if i.session().overrides == nil {
+ s := i.session()
+ if s.overrides == nil {
return nil, syserror.EOPNOTSUPP
}
// Stabilize the override map while creation is in progress.
- unlock := i.session().overrides.lock()
- defer unlock()
+ s.overrides.lock()
+ defer s.overrides.unlock()
sattr, iops, err := i.createEndpointFile(ctx, dir, name, perm, p9.ModeSocket)
if err != nil {
@@ -284,7 +291,7 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string,
// Construct the positive Dirent.
childDir := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
- i.session().overrides.addBoundEndpoint(iops.fileState.key, childDir, ep)
+ s.overrides.addBoundEndpoint(iops.fileState.key, childDir, ep)
return childDir, nil
}
@@ -298,8 +305,9 @@ func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name st
mode := p9.FileMode(perm.LinuxMode()) | p9.ModeNamedPipe
// N.B. FIFOs use major/minor numbers 0.
+ s := i.session()
if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
- if i.session().overrides == nil || !linuxerr.Equals(linuxerr.EPERM, err) {
+ if s.overrides == nil || !linuxerr.Equals(linuxerr.EPERM, err) {
return err
}
// If gofer doesn't support mknod, check if we can create an internal fifo.
@@ -311,13 +319,14 @@ func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name st
}
func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode, name string, owner fs.FileOwner, perm fs.FilePermissions) error {
- if i.session().overrides == nil {
+ s := i.session()
+ if s.overrides == nil {
return linuxerr.EPERM
}
// Stabilize the override map while creation is in progress.
- unlock := i.session().overrides.lock()
- defer unlock()
+ s.overrides.lock()
+ defer s.overrides.unlock()
sattr, fileOps, err := i.createEndpointFile(ctx, dir, name, perm, p9.ModeNamedPipe)
if err != nil {
@@ -336,7 +345,7 @@ func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode,
// Construct the positive Dirent.
childDir := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
- i.session().overrides.addPipe(fileOps.fileState.key, childDir, inode)
+ s.overrides.addPipe(fileOps.fileState.key, childDir, inode)
return nil
}
@@ -386,8 +395,9 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string
return linuxerr.ENAMETOOLONG
}
+ s := i.session()
var key *device.MultiDeviceKey
- if i.session().overrides != nil {
+ if s.overrides != nil {
// Find out if file being deleted is a socket or pipe that needs to be
// removed from endpoint map.
if d, err := i.Lookup(ctx, dir, name); err == nil {
@@ -402,8 +412,8 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string
}
// Stabilize the override map while deletion is in progress.
- unlock := i.session().overrides.lock()
- defer unlock()
+ s.overrides.lock()
+ defer s.overrides.unlock()
}
}
}
@@ -412,7 +422,7 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string
return err
}
if key != nil {
- i.session().overrides.remove(ctx, *key)
+ s.overrides.remove(ctx, *key)
}
i.touchModificationAndStatusChangeTime(ctx, dir)
@@ -429,11 +439,13 @@ func (i *inodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, na
if err := i.fileState.file.unlinkAt(ctx, name, 0x200); err != nil {
return err
}
- if i.session().cachePolicy.cacheUAttrs(dir) {
+
+ s := i.session()
+ if s.cachePolicy.cacheUAttrs(dir) {
// Decrease link count and updates atime.
i.cachingInodeOps.DecLinks(ctx)
}
- if i.session().cachePolicy.cacheReaddir() {
+ if s.cachePolicy.cacheReaddir() {
// Invalidate readdir cache.
i.markDirectoryDirty()
}
@@ -463,12 +475,13 @@ func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent
}
// Is the renamed entity a directory? Fix link counts.
+ s := i.session()
if fs.IsDir(i.fileState.sattr) {
// Update cached state.
- if i.session().cachePolicy.cacheUAttrs(oldParent) {
+ if s.cachePolicy.cacheUAttrs(oldParent) {
oldParentInodeOperations.cachingInodeOps.DecLinks(ctx)
}
- if i.session().cachePolicy.cacheUAttrs(newParent) {
+ if s.cachePolicy.cacheUAttrs(newParent) {
// Only IncLinks if there is a new addition to
// newParent. If this is replacement, then the total
// count remains the same.
@@ -477,7 +490,7 @@ func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent
}
}
}
- if i.session().cachePolicy.cacheReaddir() {
+ if s.cachePolicy.cacheReaddir() {
// Mark old directory dirty.
oldParentInodeOperations.markDirectoryDirty()
if oldParent != newParent {
@@ -487,17 +500,18 @@ func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent
}
// Rename always updates ctime.
- if i.session().cachePolicy.cacheUAttrs(inode) {
+ if s.cachePolicy.cacheUAttrs(inode) {
i.cachingInodeOps.TouchStatusChangeTime(ctx)
}
return nil
}
func (i *inodeOperations) touchModificationAndStatusChangeTime(ctx context.Context, inode *fs.Inode) {
- if i.session().cachePolicy.cacheUAttrs(inode) {
+ s := i.session()
+ if s.cachePolicy.cacheUAttrs(inode) {
i.cachingInodeOps.TouchModificationAndStatusChangeTime(ctx)
}
- if i.session().cachePolicy.cacheReaddir() {
+ if s.cachePolicy.cacheReaddir() {
// Invalidate readdir cache.
i.markDirectoryDirty()
}
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 7cf3522ff..b7debeecb 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -98,9 +98,14 @@ func (e *overrideMaps) remove(ctx context.Context, key device.MultiDeviceKey) {
// lock blocks other addition and removal operations from happening while
// the backing file is being created or deleted. Returns a function that unlocks
// the endpoint map.
-func (e *overrideMaps) lock() func() {
+// +checklocksacquire:e.mu
+func (e *overrideMaps) lock() {
e.mu.Lock()
- return func() { e.mu.Unlock() }
+}
+
+// +checklocksrelease:e.mu
+func (e *overrideMaps) unlock() {
+ e.mu.Unlock()
}
// getBoundEndpoint returns the bound endpoint mapped to the given key.
@@ -366,8 +371,8 @@ func newOverrideMaps() *overrideMaps {
// fillKeyMap populates key and dirent maps upon restore from saved pathmap.
func (s *session) fillKeyMap(ctx context.Context) error {
- unlock := s.overrides.lock()
- defer unlock()
+ s.overrides.lock()
+ defer s.overrides.unlock()
for ep, dirPath := range s.overrides.pathMap {
_, file, err := s.attach.walk(ctx, splitAbsolutePath(dirPath))
@@ -394,8 +399,8 @@ func (s *session) fillKeyMap(ctx context.Context) error {
// fillPathMap populates paths for overrides from dirents in direntMap
// before save.
func (s *session) fillPathMap(ctx context.Context) error {
- unlock := s.overrides.lock()
- defer unlock()
+ s.overrides.lock()
+ defer s.overrides.unlock()
for _, endpoint := range s.overrides.keyMap {
mountRoot := endpoint.dirent.MountRoot()
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go
index 8a1c69ac2..1fd8a0910 100644
--- a/pkg/sentry/fs/gofer/socket.go
+++ b/pkg/sentry/fs/gofer/socket.go
@@ -32,10 +32,11 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.
return nil
}
- if i.session().overrides != nil {
- unlock := i.session().overrides.lock()
- defer unlock()
- ep := i.session().overrides.getBoundEndpoint(i.fileState.key)
+ s := i.session()
+ if s.overrides != nil {
+ s.overrides.lock()
+ defer s.overrides.unlock()
+ ep := s.overrides.getBoundEndpoint(i.fileState.key)
if ep != nil {
return ep
}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 237d17921..652e5fe77 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -147,6 +147,7 @@ func putDentrySlice(ds *[]*dentry) {
// but dentry slices are allocated lazily, and it's much easier to say "defer
// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
+// +checklocksrelease:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **[]*dentry) {
fs.renameMu.RUnlock()
if *dsp == nil {
@@ -159,6 +160,7 @@ func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **
putDentrySlice(*dsp)
}
+// +checklocksrelease:fs.renameMu
func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) {
if *ds == nil {
fs.renameMu.Unlock()
@@ -540,7 +542,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
if child.syntheticChildren != 0 {
// This is definitely not an empty directory, irrespective of
// fs.opts.interop.
- vfsObj.AbortDeleteDentry(&child.vfsd)
+ vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: PrepareDeleteDentry called if child != nil.
return linuxerr.ENOTEMPTY
}
// If InteropModeShared is in effect and the first call to
@@ -550,12 +552,12 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
// still exist) would be a waste of time.
if child.cachedMetadataAuthoritative() {
if !child.isDir() {
- vfsObj.AbortDeleteDentry(&child.vfsd)
+ vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above.
return syserror.ENOTDIR
}
for _, grandchild := range child.children {
if grandchild != nil {
- vfsObj.AbortDeleteDentry(&child.vfsd)
+ vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above.
return linuxerr.ENOTEMPTY
}
}
@@ -565,12 +567,12 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
} else {
// child must be a non-directory file.
if child != nil && child.isDir() {
- vfsObj.AbortDeleteDentry(&child.vfsd)
+ vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above.
return syserror.EISDIR
}
if rp.MustBeDir() {
if child != nil {
- vfsObj.AbortDeleteDentry(&child.vfsd)
+ vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above.
}
return syserror.ENOTDIR
}
@@ -583,7 +585,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
err = parent.file.unlinkAt(ctx, name, flags)
if err != nil {
if child != nil {
- vfsObj.AbortDeleteDentry(&child.vfsd)
+ vfsObj.AbortDeleteDentry(&child.vfsd) // +checklocksforce: see above.
}
return err
}
@@ -601,7 +603,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
}
if child != nil {
- vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd) // +checklocksforce: see above.
child.setDeleted()
if child.isSynthetic() {
parent.syntheticChildren--
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index fe4c2e0e1..2f85215d9 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -947,10 +947,10 @@ func (d *dentry) cachedMetadataAuthoritative() bool {
// updateFromP9Attrs is called to update d's metadata after an update from the
// remote filesystem.
// Precondition: d.metadataMu must be locked.
+// +checklocks:d.metadataMu
func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
if mask.Mode {
if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want {
- d.metadataMu.Unlock()
panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got))
}
atomic.StoreUint32(&d.mode, uint32(attr.Mode))
@@ -989,6 +989,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
// Preconditions: !d.isSynthetic().
// Preconditions: d.metadataMu is locked.
+// +checklocks:d.metadataMu
func (d *dentry) refreshSizeLocked(ctx context.Context) error {
d.handleMu.RLock()
@@ -1020,6 +1021,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
// Preconditions:
// * !d.isSynthetic().
// * d.metadataMu is locked.
+// +checklocks:d.metadataMu
func (d *dentry) updateFromGetattrLocked(ctx context.Context) error {
// Use d.readFile or d.writeFile, which represent 9P FIDs that have been
// opened, in preference to d.file, which represents a 9P fid that has not.
@@ -1044,7 +1046,8 @@ func (d *dentry) updateFromGetattrLocked(ctx context.Context) error {
_, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask())
if handleMuRLocked {
- d.handleMu.RUnlock() // must be released before updateFromP9AttrsLocked()
+ // handleMu must be released before updateFromP9AttrsLocked().
+ d.handleMu.RUnlock() // +checklocksforce: complex case.
}
if err != nil {
return err
@@ -1470,7 +1473,7 @@ func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked boo
if d.isDeleted() {
d.watches.HandleDeletion(ctx)
}
- d.destroyLocked(ctx)
+ d.destroyLocked(ctx) // +checklocksforce: renameMu must be acquired at this point.
return
}
// If d still has inotify watches and it is not deleted or invalidated, it
@@ -1498,7 +1501,7 @@ func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked boo
delete(d.parent.children, d.name)
d.parent.dirMu.Unlock()
}
- d.destroyLocked(ctx)
+ d.destroyLocked(ctx) // +checklocksforce: see above.
return
}
@@ -1527,7 +1530,7 @@ func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked boo
d.fs.renameMu.Lock()
defer d.fs.renameMu.Unlock()
}
- d.fs.evictCachedDentryLocked(ctx)
+ d.fs.evictCachedDentryLocked(ctx) // +checklocksforce: see above.
}
}
@@ -1544,6 +1547,7 @@ func (d *dentry) removeFromCacheLocked() {
// Precondition: fs.renameMu must be locked for writing; it may be temporarily
// unlocked.
+// +checklocks:fs.renameMu
func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
for fs.cachedDentriesLen != 0 {
fs.evictCachedDentryLocked(ctx)
@@ -1552,6 +1556,7 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
// Preconditions:
// * fs.renameMu must be locked for writing; it may be temporarily unlocked.
+// +checklocks:fs.renameMu
func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
fs.cacheMu.Lock()
victim := fs.cachedDentries.Back()
@@ -1588,7 +1593,7 @@ func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
// will try to acquire fs.renameMu (which we have already acquired). Hence,
// fs.renameMu will synchronize the destroy attempts.
victim.cachingMu.Unlock()
- victim.destroyLocked(ctx)
+ victim.destroyLocked(ctx) // +checklocksforce: owned as precondition, victim.fs == fs.
}
// destroyLocked destroys the dentry.
@@ -1598,6 +1603,7 @@ func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
// * d.refs == 0.
// * d.parent.children[d.name] != d, i.e. d is not reachable by path traversal
// from its former parent dentry.
+// +checklocks:d.fs.renameMu
func (d *dentry) destroyLocked(ctx context.Context) {
switch atomic.LoadInt64(&d.refs) {
case 0:
diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go
index 8f81f0822..226790a11 100644
--- a/pkg/sentry/fsimpl/gofer/revalidate.go
+++ b/pkg/sentry/fsimpl/gofer/revalidate.go
@@ -247,16 +247,16 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF
if found && !d.isSynthetic() {
// First dentry is where the search is starting, just update attributes
// since it cannot be replaced.
- d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr)
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: acquired by lockAllMetadata.
}
- d.metadataMu.Unlock()
+ d.metadataMu.Unlock() // +checklocksforce: see above.
continue
}
// Note that synthetic dentries will always fails the comparison check
// below.
if !found || d.qidPath != stats[i].QID.Path {
- d.metadataMu.Unlock()
+ d.metadataMu.Unlock() // +checklocksforce: see above.
if !found && d.isSynthetic() {
// We have a synthetic file, and no remote file has arisen to replace
// it.
@@ -298,7 +298,7 @@ func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualF
}
// The file at this path hasn't changed. Just update cached metadata.
- d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr)
+ d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) // +checklocksforce: see above.
d.metadataMu.Unlock()
}
@@ -354,6 +354,7 @@ func (r *revalidateState) add(name string, d *dentry) {
r.dentries = append(r.dentries, d)
}
+// +checklocksignore
func (r *revalidateState) lockAllMetadata() {
for _, d := range r.dentries {
d.metadataMu.Lock()
@@ -372,6 +373,7 @@ func (r *revalidateState) popFront() *dentry {
// reset releases all metadata locks and resets all fields to allow this
// instance to be reused.
+// +checklocksignore
func (r *revalidateState) reset() {
if r.locked {
// Unlock any remaining dentries.
diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go
index 2ec819f86..dbd834c67 100644
--- a/pkg/sentry/fsimpl/gofer/symlink.go
+++ b/pkg/sentry/fsimpl/gofer/symlink.go
@@ -41,7 +41,7 @@ func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
d.haveTarget = true
d.target = target
}
- d.dataMu.Unlock()
+ d.dataMu.Unlock() // +checklocksforce: guaranteed locked from above.
}
return target, err
}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 38c2b6df1..20d2526ad 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -752,7 +752,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
fs.deferDecRef(replaced)
replaceVFSD = replaced.VFSDentry()
}
- virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaceVFSD)
+ virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaceVFSD) // +checklocksforce: to may be nil, that's okay.
return nil
}
@@ -788,7 +788,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
defer mntns.DecRef(ctx)
vfsd := d.VFSDentry()
if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
- return err
+ return err // +checklocksforce: vfsd is not locked.
}
if err := parentDentry.inode.RmDir(ctx, d.name, d.inode); err != nil {
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 41207211a..77f9affc1 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -87,7 +87,7 @@ func putDentrySlice(ds *[]*dentry) {
// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
//
-// +checklocks:fs.renameMu
+// +checklocksrelease:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*dentry) {
fs.renameMu.RUnlock()
if *dsp == nil {
@@ -113,7 +113,7 @@ func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*
putDentrySlice(*dsp)
}
-// +checklocks:fs.renameMu
+// +checklocksrelease:fs.renameMu
func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
if *ds == nil {
fs.renameMu.Unlock()
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index e4bfbd3c9..358a66072 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -75,6 +75,7 @@ func putDentrySlice(ds *[]*dentry) {
// but dentry slices are allocated lazily, and it's much easier to say "defer
// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
+// +checklocksrelease:fs.renameMu
func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
fs.renameMu.RUnlock()
if *ds == nil {
@@ -90,6 +91,7 @@ func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*d
putDentrySlice(*ds)
}
+// +checklocksrelease:fs.renameMu
func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
if *ds == nil {
fs.renameMu.Unlock()
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index 6377abb94..f5c364c96 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -398,8 +398,8 @@ func (m *Manager) Fork() *Manager {
}
// lockBucket returns a locked bucket for the given key.
-func (m *Manager) lockBucket(k *Key) *bucket {
- var b *bucket
+// +checklocksacquire:b.mu
+func (m *Manager) lockBucket(k *Key) (b *bucket) {
if k.Kind == KindSharedMappable {
b = m.sharedBucket
} else {
@@ -410,7 +410,9 @@ func (m *Manager) lockBucket(k *Key) *bucket {
}
// lockBuckets returns locked buckets for the given keys.
-func (m *Manager) lockBuckets(k1, k2 *Key) (*bucket, *bucket) {
+// +checklocksacquire:b1.mu
+// +checklocksacquire:b2.mu
+func (m *Manager) lockBuckets(k1, k2 *Key) (b1 *bucket, b2 *bucket) {
// Buckets must be consistently ordered to avoid circular lock
// dependencies. We order buckets in m.privateBuckets by index (lowest
// index first), and all buckets in m.privateBuckets precede
@@ -420,8 +422,8 @@ func (m *Manager) lockBuckets(k1, k2 *Key) (*bucket, *bucket) {
if k1.Kind != KindSharedMappable && k2.Kind != KindSharedMappable {
i1 := bucketIndexForAddr(k1.addr())
i2 := bucketIndexForAddr(k2.addr())
- b1 := &m.privateBuckets[i1]
- b2 := &m.privateBuckets[i2]
+ b1 = &m.privateBuckets[i1]
+ b2 = &m.privateBuckets[i2]
switch {
case i1 < i2:
b1.mu.Lock()
@@ -432,19 +434,30 @@ func (m *Manager) lockBuckets(k1, k2 *Key) (*bucket, *bucket) {
default:
b1.mu.Lock()
}
- return b1, b2
+ return b1, b2 // +checklocksforce
}
// At least one of b1 or b2 should be m.sharedBucket.
- b1 := m.sharedBucket
- b2 := m.sharedBucket
+ b1 = m.sharedBucket
+ b2 = m.sharedBucket
if k1.Kind != KindSharedMappable {
b1 = m.lockBucket(k1)
} else if k2.Kind != KindSharedMappable {
b2 = m.lockBucket(k2)
}
m.sharedBucket.mu.Lock()
- return b1, b2
+ return b1, b2 // +checklocksforce
+}
+
+// unlockBuckets unlocks two buckets.
+// +checklocksrelease:b1.mu
+// +checklocksrelease:b2.mu
+func (m *Manager) unlockBuckets(b1, b2 *bucket) {
+ b1.mu.Unlock()
+ if b1 != b2 {
+ b2.mu.Unlock()
+ }
+ return // +checklocksforce
}
// Wake wakes up to n waiters matching the bitmask on the given addr.
@@ -477,10 +490,7 @@ func (m *Manager) doRequeue(t Target, addr, naddr hostarch.Addr, private bool, c
defer k2.release(t)
b1, b2 := m.lockBuckets(&k1, &k2)
- defer b1.mu.Unlock()
- if b2 != b1 {
- defer b2.mu.Unlock()
- }
+ defer m.unlockBuckets(b1, b2)
if checkval {
if err := check(t, addr, val); err != nil {
@@ -527,10 +537,7 @@ func (m *Manager) WakeOp(t Target, addr1, addr2 hostarch.Addr, private bool, nwa
defer k2.release(t)
b1, b2 := m.lockBuckets(&k1, &k2)
- defer b1.mu.Unlock()
- if b2 != b1 {
- defer b2.mu.Unlock()
- }
+ defer m.unlockBuckets(b1, b2)
done := 0
cond, err := atomicOp(t, addr2, op)
diff --git a/pkg/sentry/kernel/pipe/pipe_unsafe.go b/pkg/sentry/kernel/pipe/pipe_unsafe.go
index dd60cba24..077c5d596 100644
--- a/pkg/sentry/kernel/pipe/pipe_unsafe.go
+++ b/pkg/sentry/kernel/pipe/pipe_unsafe.go
@@ -23,6 +23,8 @@ import (
// concurrent calls cannot deadlock.
//
// Preconditions: x != y.
+// +checklocksacquire:x.mu
+// +checklocksacquire:y.mu
func lockTwoPipes(x, y *Pipe) {
// Lock the two pipes in order of increasing address.
if uintptr(unsafe.Pointer(x)) < uintptr(unsafe.Pointer(y)) {
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index 84f9f6234..c883a9014 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -157,6 +157,7 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume
//
// mu must be held by the caller. waitFor returns with mu held, but it will
// drop mu before blocking for any reader/writers.
+// +checklocks:mu
func waitFor(mu *sync.Mutex, wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool {
// Ideally this function would simply use a condition variable. However, the
// wait needs to be interruptible via 'sleeper', so we must sychronize via a
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index cdaee5d7f..161140980 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -652,6 +652,7 @@ func (t *Task) forgetTracerLocked() {
// Preconditions:
// * The signal mutex must be locked.
// * The caller must be running on the task goroutine.
+// +checklocks:t.tg.signalHandlers.mu
func (t *Task) ptraceSignalLocked(info *linux.SignalInfo) bool {
if linux.Signal(info.Signo) == linux.SIGKILL {
return false
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index c0c1f1f13..ae21a55da 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -121,8 +121,9 @@ func (pg *ProcessGroup) Originator() *ThreadGroup {
// IsOrphan returns true if this process group is an orphan.
func (pg *ProcessGroup) IsOrphan() bool {
- pg.originator.TaskSet().mu.RLock()
- defer pg.originator.TaskSet().mu.RUnlock()
+ ts := pg.originator.TaskSet()
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
return pg.ancestors == 0
}
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index 125fd855b..b51ec6aa7 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -204,6 +204,7 @@ func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar h
// * vseg.Range().IsSupersetOf(ar).
//
// Postconditions: mm.mappingMu will be unlocked.
+// +checklocksrelease:mm.mappingMu
func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaIterator, ar hostarch.AddrRange, precommit bool) {
// See populateVMA above for commentary.
if !vseg.ValuePtr().effectivePerms.Any() {
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index f7d5a1800..0c8542485 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -945,7 +945,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
// NOTE(b/165896008): mincore (which is passed as checkCommitted)
// by f.UpdateUsage() might take a really long time. So unlock f.mu
// while checkCommitted runs.
- f.mu.Unlock()
+ f.mu.Unlock() // +checklocksforce
err := checkCommitted(s, buf)
f.mu.Lock()
if err != nil {
diff --git a/pkg/sentry/time/calibrated_clock_test.go b/pkg/sentry/time/calibrated_clock_test.go
index d6622bfe2..0a4b1f1bf 100644
--- a/pkg/sentry/time/calibrated_clock_test.go
+++ b/pkg/sentry/time/calibrated_clock_test.go
@@ -50,6 +50,7 @@ func TestConstantFrequency(t *testing.T) {
if !c.ready {
c.mu.RUnlock()
t.Fatalf("clock not ready")
+ return // For checklocks consistency.
}
// A bit after the last sample.
now, ok := c.params.ComputeTime(750000)
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 242eb5ecb..cb92b6eee 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -196,11 +196,12 @@ func (d *Dentry) OnZeroWatches(ctx context.Context) {
// PrepareDeleteDentry must be called before attempting to delete the file
// represented by d. If PrepareDeleteDentry succeeds, the caller must call
// AbortDeleteDentry or CommitDeleteDentry depending on the deletion's outcome.
+// +checklocksacquire:d.mu
func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dentry) error {
vfs.mountMu.Lock()
if mntns.mountpoints[d] != 0 {
vfs.mountMu.Unlock()
- return linuxerr.EBUSY
+ return linuxerr.EBUSY // +checklocksforce: inconsistent return.
}
d.mu.Lock()
vfs.mountMu.Unlock()
@@ -211,14 +212,14 @@ func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dent
// AbortDeleteDentry must be called after PrepareDeleteDentry if the deletion
// fails.
-// +checklocks:d.mu
+// +checklocksrelease:d.mu
func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) {
d.mu.Unlock()
}
// CommitDeleteDentry must be called after PrepareDeleteDentry if the deletion
// succeeds.
-// +checklocks:d.mu
+// +checklocksrelease:d.mu
func (vfs *VirtualFilesystem) CommitDeleteDentry(ctx context.Context, d *Dentry) {
d.dead = true
d.mu.Unlock()
@@ -249,16 +250,18 @@ func (vfs *VirtualFilesystem) InvalidateDentry(ctx context.Context, d *Dentry) {
// Preconditions:
// * If to is not nil, it must be a child Dentry from the same Filesystem.
// * from != to.
+// +checklocksacquire:from.mu
+// +checklocksacquire:to.mu
func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, to *Dentry) error {
vfs.mountMu.Lock()
if mntns.mountpoints[from] != 0 {
vfs.mountMu.Unlock()
- return linuxerr.EBUSY
+ return linuxerr.EBUSY // +checklocksforce: no locks acquired.
}
if to != nil {
if mntns.mountpoints[to] != 0 {
vfs.mountMu.Unlock()
- return linuxerr.EBUSY
+ return linuxerr.EBUSY // +checklocksforce: no locks acquired.
}
to.mu.Lock()
}
@@ -267,13 +270,13 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t
// Return with from.mu and to.mu locked, which will be unlocked by
// AbortRenameDentry, CommitRenameReplaceDentry, or
// CommitRenameExchangeDentry.
- return nil
+ return nil // +checklocksforce: to may not be acquired.
}
// AbortRenameDentry must be called after PrepareRenameDentry if the rename
// fails.
-// +checklocks:from.mu
-// +checklocks:to.mu
+// +checklocksrelease:from.mu
+// +checklocksrelease:to.mu
func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) {
from.mu.Unlock()
if to != nil {
@@ -286,8 +289,8 @@ func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) {
// that was replaced by from.
//
// Preconditions: PrepareRenameDentry was previously called on from and to.
-// +checklocks:from.mu
-// +checklocks:to.mu
+// +checklocksrelease:from.mu
+// +checklocksrelease:to.mu
func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, from, to *Dentry) {
from.mu.Unlock()
if to != nil {
@@ -303,8 +306,8 @@ func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, fro
// from and to are exchanged by rename(RENAME_EXCHANGE).
//
// Preconditions: PrepareRenameDentry was previously called on from and to.
-// +checklocks:from.mu
-// +checklocks:to.mu
+// +checklocksrelease:from.mu
+// +checklocksrelease:to.mu
func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) {
from.mu.Unlock()
to.mu.Unlock()
diff --git a/pkg/sync/mutex_test.go b/pkg/sync/mutex_test.go
index 4fb51a8ab..9e4e3f0b2 100644
--- a/pkg/sync/mutex_test.go
+++ b/pkg/sync/mutex_test.go
@@ -64,7 +64,7 @@ func TestTryLockUnlock(t *testing.T) {
if !m.TryLock() {
t.Fatal("failed to aquire lock")
}
- m.Unlock()
+ m.Unlock() // +checklocksforce
if !m.TryLock() {
t.Fatal("failed to aquire lock after unlock")
}
diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go
index 411a80a8a..b829765d9 100644
--- a/pkg/sync/mutex_unsafe.go
+++ b/pkg/sync/mutex_unsafe.go
@@ -32,6 +32,18 @@ func (m *CrossGoroutineMutex) state() *int32 {
return &(*syncMutex)(unsafe.Pointer(&m.Mutex)).state
}
+// Lock locks the underlying Mutex.
+// +checklocksignore
+func (m *CrossGoroutineMutex) Lock() {
+ m.Mutex.Lock()
+}
+
+// Unlock unlocks the underlying Mutex.
+// +checklocksignore
+func (m *CrossGoroutineMutex) Unlock() {
+ m.Mutex.Unlock()
+}
+
const (
mutexUnlocked = 0
mutexLocked = 1
@@ -62,6 +74,7 @@ type Mutex struct {
// Lock locks m. If the lock is already in use, the calling goroutine blocks
// until the mutex is available.
+// +checklocksignore
func (m *Mutex) Lock() {
noteLock(unsafe.Pointer(m))
m.m.Lock()
@@ -80,6 +93,7 @@ func (m *Mutex) Unlock() {
// TryLock tries to acquire the mutex. It returns true if it succeeds and false
// otherwise. TryLock does not block.
+// +checklocksignore
func (m *Mutex) TryLock() bool {
// Note lock first to enforce proper locking even if unsuccessful.
noteLock(unsafe.Pointer(m))
diff --git a/pkg/sync/rwmutex_test.go b/pkg/sync/rwmutex_test.go
index 5ca96d12b..56a88e712 100644
--- a/pkg/sync/rwmutex_test.go
+++ b/pkg/sync/rwmutex_test.go
@@ -172,7 +172,7 @@ func TestRWTryLockUnlock(t *testing.T) {
if !rwm.TryLock() {
t.Fatal("failed to aquire lock")
}
- rwm.Unlock()
+ rwm.Unlock() // +checklocksforce
if !rwm.TryLock() {
t.Fatal("failed to aquire lock after unlock")
}
diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
index 892d3e641..7829b06db 100644
--- a/pkg/sync/rwmutex_unsafe.go
+++ b/pkg/sync/rwmutex_unsafe.go
@@ -37,6 +37,7 @@ const rwmutexMaxReaders = 1 << 30
// TryRLock locks rw for reading. It returns true if it succeeds and false
// otherwise. It does not block.
+// +checklocksignore
func (rw *CrossGoroutineRWMutex) TryRLock() bool {
if RaceEnabled {
RaceDisable()
@@ -65,6 +66,7 @@ func (rw *CrossGoroutineRWMutex) TryRLock() bool {
// It should not be used for recursive read locking; a blocked Lock call
// excludes new readers from acquiring the lock. See the documentation on the
// RWMutex type.
+// +checklocksignore
func (rw *CrossGoroutineRWMutex) RLock() {
if RaceEnabled {
RaceDisable()
@@ -83,6 +85,7 @@ func (rw *CrossGoroutineRWMutex) RLock() {
//
// Preconditions:
// * rw is locked for reading.
+// +checklocksignore
func (rw *CrossGoroutineRWMutex) RUnlock() {
if RaceEnabled {
RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
@@ -134,6 +137,7 @@ func (rw *CrossGoroutineRWMutex) TryLock() bool {
// Lock locks rw for writing. If the lock is already locked for reading or
// writing, Lock blocks until the lock is available.
+// +checklocksignore
func (rw *CrossGoroutineRWMutex) Lock() {
if RaceEnabled {
RaceDisable()
@@ -228,6 +232,7 @@ type RWMutex struct {
// TryRLock locks rw for reading. It returns true if it succeeds and false
// otherwise. It does not block.
+// +checklocksignore
func (rw *RWMutex) TryRLock() bool {
// Note lock first to enforce proper locking even if unsuccessful.
noteLock(unsafe.Pointer(rw))
@@ -243,6 +248,7 @@ func (rw *RWMutex) TryRLock() bool {
// It should not be used for recursive read locking; a blocked Lock call
// excludes new readers from acquiring the lock. See the documentation on the
// RWMutex type.
+// +checklocksignore
func (rw *RWMutex) RLock() {
noteLock(unsafe.Pointer(rw))
rw.m.RLock()
@@ -261,6 +267,7 @@ func (rw *RWMutex) RUnlock() {
// TryLock locks rw for writing. It returns true if it succeeds and false
// otherwise. It does not block.
+// +checklocksignore
func (rw *RWMutex) TryLock() bool {
// Note lock first to enforce proper locking even if unsuccessful.
noteLock(unsafe.Pointer(rw))
@@ -273,6 +280,7 @@ func (rw *RWMutex) TryLock() bool {
// Lock locks rw for writing. If the lock is already locked for reading or
// writing, Lock blocks until the lock is available.
+// +checklocksignore
func (rw *RWMutex) Lock() {
noteLock(unsafe.Pointer(rw))
rw.m.Lock()
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
index 0b51563cd..1261ad414 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
@@ -126,7 +126,7 @@ func (m *mockMulticastGroupProtocol) sendQueuedReports() {
// Precondition: m.mu must be read locked.
func (m *mockMulticastGroupProtocol) Enabled() bool {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
}
@@ -138,11 +138,11 @@ func (m *mockMulticastGroupProtocol) Enabled() bool {
// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
}
if m.mu.TryRLock() {
- m.mu.RUnlock()
+ m.mu.RUnlock() // +checklocksforce: TryLock.
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
}
@@ -155,11 +155,11 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo
// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
}
if m.mu.TryRLock() {
- m.mu.RUnlock()
+ m.mu.RUnlock() // +checklocksforce: TryLock.
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index ce9cebdaa..ae0bb4ace 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -249,7 +249,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// or we are adding a new temporary or permanent address.
//
// The address MUST be write locked at this point.
- defer addrState.mu.Unlock()
+ defer addrState.mu.Unlock() // +checklocksforce
if permanent {
if addrState.mu.kind.IsPermanent() {
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 782e74b24..068dab7ce 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -363,7 +363,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// Unlocking can happen in any order.
ct.buckets[tupleBucket].mu.Unlock()
if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock()
+ ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
}
}
@@ -626,7 +626,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
// Don't re-unlock if both tuples are in the same bucket.
if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock()
+ ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
}
return true
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index cb316d27a..f9a15efb2 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -213,6 +213,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
+// +checklocks:e.mu
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.state {
case stateInitial:
@@ -229,10 +230,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
}
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
+ defer e.mu.DowngradeLock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d807b13b7..aa413ad05 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -330,7 +330,9 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions,
}
ep := h.ep
- if err := h.complete(); err != nil {
+ // N.B. the endpoint is generated above by startHandshake, and will be
+ // returned locked. This first call is forced.
+ if err := h.complete(); err != nil { // +checklocksforce
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.stats.FailedConnectionAttempts.Increment()
l.cleanupFailedHandshake(h)
@@ -364,6 +366,7 @@ func (l *listenContext) closeAllPendingEndpoints() {
}
// Precondition: h.ep.mu must be held.
+// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
@@ -504,7 +507,9 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
go func() {
- if err := h.complete(); err != nil {
+ // Note that startHandshake returns a locked endpoint. The
+ // force call here just makes it so.
+ if err := h.complete(); err != nil { // +checklocksforce
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index e39d1623d..93ed161f9 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -511,6 +511,7 @@ func (h *handshake) start() {
}
// complete completes the TCP 3-way handshake initiated by h.start().
+// +checklocks:h.ep.mu
func (h *handshake) complete() tcpip.Error {
// Set up the wakers.
var s sleep.Sleeper
@@ -1283,42 +1284,45 @@ func (e *endpoint) disableKeepaliveTimer() {
e.keepalive.Unlock()
}
-// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
-// goroutine and is responsible for sending segments and handling received
-// segments.
-func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
- e.mu.Lock()
- var closeTimer tcpip.Timer
- var closeWaker sleep.Waker
-
- epilogue := func() {
- // e.mu is expected to be hold upon entering this section.
- if e.snd != nil {
- e.snd.resendTimer.cleanup()
- e.snd.probeTimer.cleanup()
- e.snd.reorderTimer.cleanup()
- }
+// protocolMainLoopDone is called at the end of protocolMainLoop.
+// +checklocksrelease:e.mu
+func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) {
+ if e.snd != nil {
+ e.snd.resendTimer.cleanup()
+ e.snd.probeTimer.cleanup()
+ e.snd.reorderTimer.cleanup()
+ }
- if closeTimer != nil {
- closeTimer.Stop()
- }
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
- e.completeWorkerLocked()
+ e.completeWorkerLocked()
- if e.drainDone != nil {
- close(e.drainDone)
- }
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
- e.mu.Unlock()
+ e.mu.Unlock()
- e.drainClosingSegmentQueue()
+ e.drainClosingSegmentQueue()
- // When the protocol loop exits we should wake up our waiters.
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
- }
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+}
+// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
+// goroutine and is responsible for sending segments and handling received
+// segments.
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
+ var (
+ closeTimer tcpip.Timer
+ closeWaker sleep.Waker
+ )
+
+ e.mu.Lock()
if handshake {
- if err := e.h.complete(); err != nil {
+ if err := e.h.complete(); err != nil { // +checklocksforce
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
@@ -1327,8 +1331,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.hardError = err
e.workerCleanup = true
- // Lock released below.
- epilogue()
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return err
}
}
@@ -1472,7 +1475,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
- e.mu.Unlock()
+ e.mu.Unlock() // +checklocksforce
<-e.undrain
e.mu.Lock()
}
@@ -1533,8 +1536,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if err != nil {
e.resetConnectionLocked(err)
}
- // Lock released below.
- epilogue()
}
loop:
@@ -1558,6 +1559,7 @@ loop:
// just want to terminate the loop and cleanup the
// endpoint.
cleanupOnError(nil)
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return nil
case StateTimeWait:
fallthrough
@@ -1566,6 +1568,7 @@ loop:
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return nil
}
}
@@ -1589,13 +1592,13 @@ loop:
// Handle any StateError transition from StateTimeWait.
if e.EndpointState() == StateError {
cleanupOnError(nil)
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return nil
}
e.transitionToStateCloseLocked()
- // Lock released below.
- epilogue()
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
@@ -1665,6 +1668,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
// should be executed after releasing the endpoint registrations. This is
// done in cases where a new SYN is received during TIME_WAIT that carries
// a sequence number larger than one see on the connection.
+// +checklocks:e.mu
func (e *endpoint) doTimeWait() (twReuse func()) {
// Trigger a 2 * MSL time wait state. During this period
// we will drop all incoming segments.
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index dff7cb89c..7d110516b 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -127,7 +127,7 @@ func (p *processor) start(wg *sync.WaitGroup) {
case !ep.segmentQueue.empty():
p.epQ.enqueue(ep)
}
- ep.mu.Unlock()
+ ep.mu.Unlock() // +checklocksforce
} else {
ep.newSegmentWaker.Assert()
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 4acddc959..1ed4ba419 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -664,6 +664,7 @@ func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The assumption behind spinning here being that background packet processing
// should not be holding the lock for long and spinning reduces latency as we
// avoid an expensive sleep/wakeup of of the syscall goroutine).
+// +checklocksacquire:e.mu
func (e *endpoint) LockUser() {
for {
// Try first if the sock is locked then check if it's owned
@@ -683,7 +684,7 @@ func (e *endpoint) LockUser() {
continue
}
atomic.StoreUint32(&e.ownedByUser, 1)
- return
+ return // +checklocksforce
}
}
@@ -700,7 +701,7 @@ func (e *endpoint) LockUser() {
// protocol goroutine altogether.
//
// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
-// +checklocks:e.mu
+// +checklocksrelease:e.mu
func (e *endpoint) UnlockUser() {
// Lock segment queue before checking so that we avoid a race where
// segments can be queued between the time we check if queue is empty
@@ -736,12 +737,13 @@ func (e *endpoint) UnlockUser() {
}
// StopWork halts packet processing. Only to be used in tests.
+// +checklocksacquire:e.mu
func (e *endpoint) StopWork() {
e.mu.Lock()
}
// ResumeWork resumes packet processing. Only to be used in tests.
-// +checklocks:e.mu
+// +checklocksrelease:e.mu
func (e *endpoint) ResumeWork() {
e.mu.Unlock()
}
@@ -1480,86 +1482,95 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
return avail, nil
}
-// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
- // and opts.EndOfRecord are also ignored.
+// readFromPayloader reads a slice from the Payloader.
+// +checklocks:e.mu
+// +checklocks:e.sndQueueInfo.sndQueueMu
+func (e *endpoint) readFromPayloader(p tcpip.Payloader, opts tcpip.WriteOptions, avail int) ([]byte, tcpip.Error) {
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndQueueInfo.sndQueueMu.Unlock()
+ defer e.sndQueueInfo.sndQueueMu.Lock()
- e.LockUser()
- defer e.UnlockUser()
+ e.UnlockUser()
+ defer e.LockUser()
+ }
- nextSeg, n, err := func() (*segment, int, tcpip.Error) {
- e.sndQueueInfo.sndQueueMu.Lock()
- defer e.sndQueueInfo.sndQueueMu.Unlock()
+ // Fetch data.
+ if l := p.Len(); l < avail {
+ avail = l
+ }
+ if avail == 0 {
+ return nil, nil
+ }
+ v := make([]byte, avail)
+ n, err := p.Read(v)
+ if err != nil && err != io.EOF {
+ return nil, &tcpip.ErrBadBuffer{}
+ }
+ return v[:n], nil
+}
+
+// queueSegment reads data from the payloader and returns a segment to be sent.
+// +checklocks:e.mu
+func (e *endpoint) queueSegment(p tcpip.Payloader, opts tcpip.WriteOptions) (*segment, int, tcpip.Error) {
+ e.sndQueueInfo.sndQueueMu.Lock()
+ defer e.sndQueueInfo.sndQueueMu.Unlock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return nil, 0, err
+ }
+ v, err := e.readFromPayloader(p, opts, avail)
+ if err != nil {
+ return nil, 0, err
+ }
+ if !opts.Atomic {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
avail, err := e.isEndpointWritableLocked()
if err != nil {
e.stats.WriteErrors.WriteClosed.Increment()
return nil, 0, err
}
- v, err := func() ([]byte, tcpip.Error) {
- // We can release locks while copying data.
- //
- // This is not possible if atomic is set, because we can't allow the
- // available buffer space to be consumed by some other caller while we
- // are copying data in.
- if !opts.Atomic {
- e.sndQueueInfo.sndQueueMu.Unlock()
- defer e.sndQueueInfo.sndQueueMu.Lock()
-
- e.UnlockUser()
- defer e.LockUser()
- }
-
- // Fetch data.
- if l := p.Len(); l < avail {
- avail = l
- }
- if avail == 0 {
- return nil, nil
- }
- v := make([]byte, avail)
- n, err := p.Read(v)
- if err != nil && err != io.EOF {
- return nil, &tcpip.ErrBadBuffer{}
- }
- return v[:n], nil
- }()
- if len(v) == 0 || err != nil {
- return nil, 0, err
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
}
+ }
- if !opts.Atomic {
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- avail, err := e.isEndpointWritableLocked()
- if err != nil {
- e.stats.WriteErrors.WriteClosed.Increment()
- return nil, 0, err
- }
+ // Add data to the send queue.
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
+ e.sndQueueInfo.SndBufUsed += len(v)
+ e.snd.writeList.PushBack(s)
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- }
+ return s, len(v), nil
+}
- // Add data to the send queue.
- s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
- e.sndQueueInfo.SndBufUsed += len(v)
- e.snd.writeList.PushBack(s)
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.LockUser()
+ defer e.UnlockUser()
- return s, len(v), nil
- }()
// Return if either we didn't queue anything or if an error occurred while
// attempting to queue data.
+ nextSeg, n, err := e.queueSegment(p, opts)
if n == 0 || err != nil {
return 0, err
}
+
e.sendData(nextSeg)
return int64(n), nil
}
@@ -2504,6 +2515,7 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
+// +checklocksrelease:e.mu
func (e *endpoint) startAcceptedLoop() {
e.workerRunning = true
e.mu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 65c86823a..2e709ed78 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -164,8 +164,9 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
return nil, err
}
- // Start the protocol goroutine.
- ep.startAcceptedLoop()
+ // Start the protocol goroutine. Note that the endpoint is returned
+ // from performHandshake locked.
+ ep.startAcceptedLoop() // +checklocksforce
return ep, nil
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index def9d7186..82a3f2287 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -364,6 +364,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
+// +checklocks:e.mu
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.EndpointState() {
case StateInitial:
@@ -380,10 +381,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
}
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
+ defer e.mu.DowngradeLock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
@@ -449,37 +448,20 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- if err := e.LastError(); err != nil {
- return 0, err
- }
-
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
- to := opts.To
-
+func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
e.mu.RLock()
- lockReleased := false
- defer func() {
- if lockReleased {
- return
- }
- e.mu.RUnlock()
- }()
+ defer e.mu.RUnlock()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, &tcpip.ErrClosedForSend{}
+ return udpPacketInfo{}, &tcpip.ErrClosedForSend{}
}
// Prepare for write.
for {
- retry, err := e.prepareForWrite(to)
+ retry, err := e.prepareForWrite(opts.To)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
if !retry {
@@ -489,34 +471,34 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
route := e.route
dstPort := e.dstPort
- if to != nil {
+ if opts.To != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
- nicID := to.NIC
+ nicID := opts.To.NIC
if nicID == 0 {
nicID = tcpip.NICID(e.ops.GetBindToDevice())
}
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
- return 0, &tcpip.ErrNoRoute{}
+ return udpPacketInfo{}, &tcpip.ErrNoRoute{}
}
nicID = e.BindNICID
}
- if to.Port == 0 {
+ if opts.To.Port == 0 {
// Port 0 is an invalid port to send to.
- return 0, &tcpip.ErrInvalidEndpointState{}
+ return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
}
- dst, netProto, err := e.checkV4MappedLocked(*to)
+ dst, netProto, err := e.checkV4MappedLocked(*opts.To)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
r, _, err := e.connectRoute(nicID, dst, netProto)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
defer r.Release()
@@ -525,12 +507,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return 0, &tcpip.ErrBroadcastDisabled{}
+ return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{}
}
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
- return 0, &tcpip.ErrBadBuffer{}
+ return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
@@ -548,24 +530,39 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
v,
)
}
- return 0, &tcpip.ErrMessageTooLong{}
+ return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}
ttl := e.ttl
useDefaultTTL := ttl == 0
-
if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
ttl = e.multicastTTL
// Multicast allows a 0 TTL.
useDefaultTTL = false
}
- localPort := e.ID.LocalPort
- sendTOS := e.sendTOS
- owner := e.owner
- noChecksum := e.SocketOptions().GetNoChecksum()
- lockReleased = true
- e.mu.RUnlock()
+ return udpPacketInfo{
+ route: route,
+ data: buffer.View(v),
+ localPort: e.ID.LocalPort,
+ remotePort: dstPort,
+ ttl: ttl,
+ useDefaultTTL: useDefaultTTL,
+ tos: e.sendTOS,
+ owner: e.owner,
+ noChecksum: e.SocketOptions().GetNoChecksum(),
+ }, nil
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ if err := e.LastError(); err != nil {
+ return 0, err
+ }
+
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, &tcpip.ErrInvalidOptionValue{}
+ }
// Do not hold lock when sending as loopback is synchronous and if the UDP
// datagram ends up generating an ICMP response then it can result in a
@@ -577,10 +574,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
//
// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
// locking is prohibited.
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
+ u, err := e.buildUDPPacketInfo(p, opts)
+ if err != nil {
return 0, err
}
- return int64(len(v)), nil
+ n, err := u.send()
+ if err != nil {
+ return 0, err
+ }
+ return int64(n), nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
@@ -817,14 +819,30 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
return nil
}
-// sendUDP sends a UDP segment via the provided network endpoint and under the
-// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error {
+// udpPacketInfo contains all information required to send a UDP packet.
+//
+// This should be used as a value-only type, which exists in order to simplify
+// return value syntax. It should not be exported or extended.
+type udpPacketInfo struct {
+ route *stack.Route
+ data buffer.View
+ localPort uint16
+ remotePort uint16
+ ttl uint8
+ useDefaultTTL bool
+ tos uint8
+ owner tcpip.PacketOwner
+ noChecksum bool
+}
+
+// send sends the given packet.
+func (u *udpPacketInfo) send() (int, tcpip.Error) {
+ vv := u.data.ToVectorisedView()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
- Data: data,
+ ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()),
+ Data: vv,
})
- pkt.Owner = owner
+ pkt.Owner = u.owner
// Initialize the UDP header.
udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
@@ -832,8 +850,8 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
- SrcPort: localPort,
- DstPort: remotePort,
+ SrcPort: u.localPort,
+ DstPort: u.remotePort,
Length: length,
})
@@ -841,30 +859,30 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if r.RequiresTXTransportChecksum() &&
- (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
- for _, v := range data.Views() {
+ if u.route.RequiresTXTransportChecksum() &&
+ (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) {
+ xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length)
+ for _, v := range vv.Views() {
xsum = header.Checksum(v, xsum)
}
udp.SetChecksum(^udp.CalculateChecksum(xsum))
}
- if useDefaultTTL {
- ttl = r.DefaultTTL()
+ if u.useDefaultTTL {
+ u.ttl = u.route.DefaultTTL()
}
- if err := r.WritePacket(stack.NetworkHeaderParams{
+ if err := u.route.WritePacket(stack.NetworkHeaderParams{
Protocol: ProtocolNumber,
- TTL: ttl,
- TOS: tos,
+ TTL: u.ttl,
+ TOS: u.tos,
}, pkt); err != nil {
- r.Stats().UDP.PacketSendErrors.Increment()
- return err
+ u.route.Stats().UDP.PacketSendErrors.Increment()
+ return 0, err
}
// Track count of packets sent.
- r.Stats().UDP.PacketsSent.Increment()
- return nil
+ u.route.Stats().UDP.PacketsSent.Increment()
+ return len(u.data), nil
}
// checkV4MappedLocked determines the effective network protocol and converts
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index a14249641..42c66fbcf 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -157,10 +157,8 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// we will read it again after the exec call. This works
// because the ReadSpecFromFile function seeks to the beginning
// of the file before reading.
- if err := callSelfAsNobody(args); err != nil {
- Fatalf("%v", err)
- }
- panic("callSelfAsNobody must never return success")
+ Fatalf("callSelfAsNobody(%v): %v", args, callSelfAsNobody(args))
+ panic("unreachable")
}
}
@@ -199,10 +197,8 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// we will read it again after the exec call. This works
// because the ReadSpecFromFile function seeks to the beginning
// of the file before reading.
- if err := setCapsAndCallSelf(args, caps); err != nil {
- Fatalf("%v", err)
- }
- panic("setCapsAndCallSelf must never return success")
+ Fatalf("setCapsAndCallSelf(%v, %v): %v", args, caps, setCapsAndCallSelf(args, caps))
+ panic("unreachable")
}
// Read resolved mount list and replace the original one from the spec.
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 5ded7b946..80da9c9a2 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -116,9 +116,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Note: minimal argument handling for the default case to keep it simple.
args := os.Args
args = append(args, "--apply-caps=false", "--setup-root=false")
- if err := setCapsAndCallSelf(args, goferCaps); err != nil {
- Fatalf("Unable to apply caps: %v", err)
- }
+ Fatalf("setCapsAndCallSelf(%v, %v): %v", args, goferCaps, setCapsAndCallSelf(args, goferCaps))
panic("unreachable")
}
diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD
index 24e6f8a94..5295f4a85 100644
--- a/tools/bazeldefs/BUILD
+++ b/tools/bazeldefs/BUILD
@@ -46,6 +46,11 @@ genrule(
outs = ["version.txt"],
cmd = "cat bazel-out/stable-status.txt | grep STABLE_VERSION | cut -d' ' -f2- | sed 's/^[^[:digit:]]*//g' >$@",
stamp = True,
+ tags = [
+ "manual",
+ "nobuilder",
+ "notap",
+ ],
visibility = ["//:sandbox"],
)
diff --git a/tools/checkescape/BUILD b/tools/checkescape/BUILD
index 940538b9e..109b5410c 100644
--- a/tools/checkescape/BUILD
+++ b/tools/checkescape/BUILD
@@ -8,6 +8,7 @@ go_library(
nogo = False,
visibility = ["//tools/nogo:__subpackages__"],
deps = [
+ "//tools/nogo/objdump",
"@org_golang_x_tools//go/analysis:go_default_library",
"@org_golang_x_tools//go/analysis/passes/buildssa:go_default_library",
"@org_golang_x_tools//go/ssa:go_default_library",
diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go
index c788654a8..ddd1212d7 100644
--- a/tools/checkescape/checkescape.go
+++ b/tools/checkescape/checkescape.go
@@ -61,21 +61,19 @@ package checkescape
import (
"bufio"
"bytes"
- "flag"
"fmt"
"go/ast"
"go/token"
"go/types"
"io"
"log"
- "os"
- "os/exec"
"path/filepath"
"strings"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/buildssa"
"golang.org/x/tools/go/ssa"
+ "gvisor.dev/gvisor/tools/nogo/objdump"
)
const (
@@ -92,21 +90,6 @@ const (
exempt = "// escapes"
)
-var (
- // Binary is the binary under analysis.
- //
- // See Reader, below.
- binary = flag.String("binary", "", "binary under analysis")
-
- // Reader is the input stream.
- //
- // This may be set instead of Binary.
- Reader io.Reader
-
- // objdumpTool is the tool used to dump a binary.
- objdumpTool = flag.String("objdump_tool", "", "tool used to dump a binary")
-)
-
// EscapeReason is an escape reason.
//
// This is a simple enum.
@@ -374,31 +357,6 @@ func MergeAll(others []Escapes) (es Escapes) {
// Note that the map uses <basename.go>:<line> because that is all that is
// provided in the objdump format. Since this is all local, it is sufficient.
func loadObjdump() (map[string][]string, error) {
- var (
- args []string
- stdin io.Reader
- )
- if *binary != "" {
- args = append(args, *binary)
- } else if Reader != nil {
- stdin = Reader
- } else {
- // We have no input stream or binary.
- return nil, fmt.Errorf("no binary or reader provided")
- }
-
- // Construct our command.
- cmd := exec.Command(*objdumpTool, args...)
- cmd.Stdin = stdin
- cmd.Stderr = os.Stderr
- out, err := cmd.StdoutPipe()
- if err != nil {
- return nil, err
- }
- if err := cmd.Start(); err != nil {
- return nil, err
- }
-
// Identify calls by address or name. Note that this is also
// constructed dynamically below, as we encounted the addresses.
// This is because some of the functions (duffzero) may have
@@ -431,78 +389,83 @@ func loadObjdump() (map[string][]string, error) {
// Build the map.
nextFunc := "" // For funcsAllowed.
m := make(map[string][]string)
- r := bufio.NewReader(out)
-NextLine:
- for {
- line, err := r.ReadString('\n')
- if err != nil && err != io.EOF {
- return nil, err
- }
- fields := strings.Fields(line)
-
- // Is this an "allowed" function definition?
- if len(fields) >= 2 && fields[0] == "TEXT" {
- nextFunc = strings.TrimSuffix(fields[1], "(SB)")
- if _, ok := funcsAllowed[nextFunc]; !ok {
- nextFunc = "" // Don't record addresses.
- }
- }
- if nextFunc != "" && len(fields) > 2 {
- // Save the given address (in hex form, as it appears).
- addrsAllowed[fields[1]] = struct{}{}
- }
-
- // We recognize lines corresponding to actual code (not the
- // symbol name or other metadata) and annotate them if they
- // correspond to an explicit CALL instruction. We assume that
- // the lack of a CALL for a given line is evidence that escape
- // analysis has eliminated an allocation.
- //
- // Lines look like this (including the first space):
- // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX
- if len(fields) >= 5 && line[0] == ' ' {
- if !strings.Contains(fields[3], "CALL") {
- continue
+ if err := objdump.Load(func(origR io.Reader) error {
+ r := bufio.NewReader(origR)
+ NextLine:
+ for {
+ line, err := r.ReadString('\n')
+ if err != nil && err != io.EOF {
+ return err
}
- site := fields[0]
- target := strings.TrimSuffix(fields[4], "(SB)")
+ fields := strings.Fields(line)
- // Ignore strings containing allowed functions.
- if _, ok := funcsAllowed[target]; ok {
- continue
+ // Is this an "allowed" function definition?
+ if len(fields) >= 2 && fields[0] == "TEXT" {
+ nextFunc = strings.TrimSuffix(fields[1], "(SB)")
+ if _, ok := funcsAllowed[nextFunc]; !ok {
+ nextFunc = "" // Don't record addresses.
+ }
}
- if _, ok := addrsAllowed[target]; ok {
- continue
+ if nextFunc != "" && len(fields) > 2 {
+ // Save the given address (in hex form, as it appears).
+ addrsAllowed[fields[1]] = struct{}{}
}
- if len(fields) > 5 {
- // This may be a future relocation. Some
- // objdump versions describe this differently.
- // If it contains any of the functions allowed
- // above as a string, we let it go.
- softTarget := strings.Join(fields[5:], " ")
- for name := range funcsAllowed {
- if strings.Contains(softTarget, name) {
- continue NextLine
+
+ // We recognize lines corresponding to actual code (not the
+ // symbol name or other metadata) and annotate them if they
+ // correspond to an explicit CALL instruction. We assume that
+ // the lack of a CALL for a given line is evidence that escape
+ // analysis has eliminated an allocation.
+ //
+ // Lines look like this (including the first space):
+ // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX
+ if len(fields) >= 5 && line[0] == ' ' {
+ if !strings.Contains(fields[3], "CALL") {
+ continue
+ }
+ site := fields[0]
+ target := strings.TrimSuffix(fields[4], "(SB)")
+
+ // Ignore strings containing allowed functions.
+ if _, ok := funcsAllowed[target]; ok {
+ continue
+ }
+ if _, ok := addrsAllowed[target]; ok {
+ continue
+ }
+ if len(fields) > 5 {
+ // This may be a future relocation. Some
+ // objdump versions describe this differently.
+ // If it contains any of the functions allowed
+ // above as a string, we let it go.
+ softTarget := strings.Join(fields[5:], " ")
+ for name := range funcsAllowed {
+ if strings.Contains(softTarget, name) {
+ continue NextLine
+ }
}
}
- }
- // Does this exist already?
- existing, ok := m[site]
- if !ok {
- existing = make([]string, 0, 1)
- }
- for _, other := range existing {
- if target == other {
- continue NextLine
+ // Does this exist already?
+ existing, ok := m[site]
+ if !ok {
+ existing = make([]string, 0, 1)
+ }
+ for _, other := range existing {
+ if target == other {
+ continue NextLine
+ }
}
+ existing = append(existing, target)
+ m[site] = existing // Update.
+ }
+ if err == io.EOF {
+ break
}
- existing = append(existing, target)
- m[site] = existing // Update.
- }
- if err == io.EOF {
- break
}
+ return nil
+ }); err != nil {
+ return nil, err
}
// Zap any accidental false positives.
@@ -518,11 +481,6 @@ NextLine:
final[site] = filteredCalls
}
- // Wait for the dump to finish.
- if err := cmd.Wait(); err != nil {
- return nil, err
- }
-
return final, nil
}
diff --git a/tools/checklocks/BUILD b/tools/checklocks/BUILD
index 7d4c63dc7..d23b7cde6 100644
--- a/tools/checklocks/BUILD
+++ b/tools/checklocks/BUILD
@@ -4,11 +4,16 @@ package(licenses = ["notice"])
go_library(
name = "checklocks",
- srcs = ["checklocks.go"],
+ srcs = [
+ "analysis.go",
+ "annotations.go",
+ "checklocks.go",
+ "facts.go",
+ "state.go",
+ ],
nogo = False,
visibility = ["//tools/nogo:__subpackages__"],
deps = [
- "//pkg/log",
"@org_golang_x_tools//go/analysis:go_default_library",
"@org_golang_x_tools//go/analysis/passes/buildssa:go_default_library",
"@org_golang_x_tools//go/ssa:go_default_library",
diff --git a/tools/checklocks/README.md b/tools/checklocks/README.md
index dfb0275ab..bd4beb649 100644
--- a/tools/checklocks/README.md
+++ b/tools/checklocks/README.md
@@ -1,16 +1,29 @@
# CheckLocks Analyzer
-<!--* freshness: { owner: 'gvisor-eng' reviewed: '2020-10-05' } *-->
+<!--* freshness: { owner: 'gvisor-eng' reviewed: '2021-03-21' } *-->
-Checklocks is a nogo analyzer that at compile time uses Go's static analysis
-tools to identify and flag cases where a field that is guarded by a mutex in the
-same struct is accessed outside of a mutex lock.
+Checklocks is an analyzer for lock and atomic constraints. The analyzer relies
+on explicit annotations to identify fields that should be checked for access.
-The analyzer relies on explicit '// +checklocks:<mutex-name>' kind of
-annotations to identify fields that should be checked for access.
+## Atomic annotations
-Individual struct members may be protected by annotations that indicate how they
-must be accessed. These annotations are of the form:
+Individual struct members may be noted as requiring atomic access. These
+annotations are of the form:
+
+```go
+type foo struct {
+ // +checkatomic
+ bar int32
+}
+```
+
+This will ensure that all accesses to bar are atomic, with the exception of
+operations on newly allocated objects.
+
+## Lock annotations
+
+Individual struct members may be protected by annotations that indicate locking
+requirements for accessing members. These annotations are of the form:
```go
type foo struct {
@@ -64,30 +77,6 @@ annotations from becoming stale over time as fields are renamed, etc.
# Currently not supported
-1. The analyzer does not correctly handle deferred functions. e.g The following
- code is not correctly checked by the analyzer. The defer call is never
- evaluated. As a result if the lock was to be say unlocked twice via deferred
- functions it would not be caught by the analyzer.
-
- Similarly deferred anonymous functions are not evaluated either.
-
-```go
-type A struct {
- mu sync.Mutex
-
- // +checklocks:mu
- x int
-}
-
-func abc() {
- var a A
- a.mu.Lock()
- defer a.mu.Unlock()
- defer a.mu.Unlock()
- a.x = 1
-}
-```
-
1. Anonymous functions are not correctly evaluated. The analyzer does not
currently support specifying annotations on anonymous functions as a result
evaluation of a function that accesses protected fields will fail.
@@ -107,10 +96,9 @@ func abc() {
f()
a.mu.Unlock()
}
-
```
-# Explicitly Not Supported
+### Explicitly Not Supported
1. Checking for embedded mutexes as sync.Locker rather than directly as
'sync.Mutex'. In other words, the checker will not track mutex Lock and
@@ -140,3 +128,30 @@ func abc() {
checklocks. Only struct members can be used.
2. The checker will not support checking for lock ordering violations.
+
+## Mixed mode
+
+Some members may allow read-only atomic access, but be protected against writes
+by a mutex. Generally, this imposes the following requirements:
+
+For a read, one of the following must be true:
+
+1. A lock held be held.
+1. The access is atomic.
+
+For a write, both of the following must be true:
+
+1. The lock must be held.
+1. The write must be atomic.
+
+In order to annotate a relevant field, simply apply *both* annotations from
+above. For example:
+
+```go
+type foo struct {
+ mu sync.Mutex
+ // +checklocks:mu
+ // +checkatomic
+ bar int32
+}
+```
diff --git a/tools/checklocks/analysis.go b/tools/checklocks/analysis.go
new file mode 100644
index 000000000..d3fd797d0
--- /dev/null
+++ b/tools/checklocks/analysis.go
@@ -0,0 +1,628 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package checklocks
+
+import (
+ "go/token"
+ "go/types"
+ "strings"
+
+ "golang.org/x/tools/go/ssa"
+)
+
+func gcd(a, b atomicAlignment) atomicAlignment {
+ for b != 0 {
+ a, b = b, a%b
+ }
+ return a
+}
+
+// typeAlignment returns the type alignment for the given type.
+func (pc *passContext) typeAlignment(pkg *types.Package, obj types.Object) atomicAlignment {
+ requiredOffset := atomicAlignment(1)
+ if pc.pass.ImportObjectFact(obj, &requiredOffset) {
+ return requiredOffset
+ }
+
+ switch x := obj.Type().Underlying().(type) {
+ case *types.Struct:
+ fields := make([]*types.Var, x.NumFields())
+ for i := 0; i < x.NumFields(); i++ {
+ fields[i] = x.Field(i)
+ }
+ offsets := pc.pass.TypesSizes.Offsetsof(fields)
+ for i := 0; i < x.NumFields(); i++ {
+ // Check the offset, and then assuming that this offset
+ // aligns with the offset for the broader type.
+ fieldRequired := pc.typeAlignment(pkg, fields[i])
+ if offsets[i]%int64(fieldRequired) != 0 {
+ // The offset of this field is not compatible.
+ pc.maybeFail(fields[i].Pos(), "have alignment %d, need %d", offsets[i], fieldRequired)
+ }
+ // Ensure the requiredOffset is the LCM of the offset.
+ requiredOffset *= fieldRequired / gcd(requiredOffset, fieldRequired)
+ }
+ case *types.Array:
+ // Export direct alignment requirements.
+ if named, ok := x.Elem().(*types.Named); ok {
+ requiredOffset = pc.typeAlignment(pkg, named.Obj())
+ }
+ default:
+ // Use the compiler's underlying alignment.
+ requiredOffset = atomicAlignment(pc.pass.TypesSizes.Alignof(obj.Type().Underlying()))
+ }
+
+ if pkg == obj.Pkg() {
+ // Cache as an object fact, to subsequent calls. Note that we
+ // can only export object facts for the package that we are
+ // currently analyzing. There may be no exported facts for
+ // array types or alias types, for example.
+ pc.pass.ExportObjectFact(obj, &requiredOffset)
+ }
+
+ return requiredOffset
+}
+
+// checkTypeAlignment checks the alignment of the given type.
+//
+// This calls typeAlignment, which resolves all types recursively. This method
+// should be called for all types individual to ensure full coverage.
+func (pc *passContext) checkTypeAlignment(pkg *types.Package, typ *types.Named) {
+ _ = pc.typeAlignment(pkg, typ.Obj())
+}
+
+// checkAtomicCall checks for an atomic access.
+//
+// inst is the instruction analyzed, obj is used only for maybeFail.
+//
+// If mustBeAtomic is true, then we assert that the instruction *is* an atomic
+// fucnction call. If it is false, then we assert that it is *not* an atomic
+// dispatch.
+//
+// If readOnly is true, then only atomic read access are allowed. Note that
+// readOnly is only meaningful if mustBeAtomic is set.
+func (pc *passContext) checkAtomicCall(inst ssa.Instruction, obj types.Object, mustBeAtomic, readOnly bool) {
+ switch x := inst.(type) {
+ case *ssa.Call:
+ if x.Common().IsInvoke() {
+ if mustBeAtomic {
+ // This is an illegal interface dispatch.
+ pc.maybeFail(inst.Pos(), "dynamic dispatch with atomic-only field")
+ }
+ return
+ }
+ fn, ok := x.Common().Value.(*ssa.Function)
+ if !ok {
+ if mustBeAtomic {
+ // This is an illegal call to a non-static function.
+ pc.maybeFail(inst.Pos(), "dispatch to non-static function with atomic-only field")
+ }
+ return
+ }
+ pkg := fn.Package()
+ if pkg == nil {
+ if mustBeAtomic {
+ // This is a call to some shared wrapper function.
+ pc.maybeFail(inst.Pos(), "dispatch to shared function or wrapper")
+ }
+ return
+ }
+ var lff lockFunctionFacts // Check for exemption.
+ if obj := fn.Object(); obj != nil && pc.pass.ImportObjectFact(obj, &lff) && lff.Ignore {
+ return
+ }
+ if name := pkg.Pkg.Name(); name != "atomic" && name != "atomicbitops" {
+ if mustBeAtomic {
+ // This is an illegal call to a non-atomic package function.
+ pc.maybeFail(inst.Pos(), "dispatch to non-atomic function with atomic-only field")
+ }
+ return
+ }
+ if !mustBeAtomic {
+ // We are *not* expecting an atomic dispatch.
+ if _, ok := pc.forced[pc.positionKey(inst.Pos())]; !ok {
+ pc.maybeFail(inst.Pos(), "unexpected call to atomic function")
+ }
+ }
+ if !strings.HasPrefix(fn.Name(), "Load") && readOnly {
+ // We are not allowing any reads in this context.
+ if _, ok := pc.forced[pc.positionKey(inst.Pos())]; !ok {
+ pc.maybeFail(inst.Pos(), "unexpected call to atomic write function, is a lock missing?")
+ }
+ return
+ }
+ default:
+ if mustBeAtomic {
+ // This is something else entirely.
+ if _, ok := pc.forced[pc.positionKey(inst.Pos())]; !ok {
+ pc.maybeFail(inst.Pos(), "illegal use of atomic-only field by %T instruction", inst)
+ }
+ return
+ }
+ }
+}
+
+func resolveStruct(typ types.Type) (*types.Struct, bool) {
+ structType, ok := typ.Underlying().(*types.Struct)
+ if ok {
+ return structType, true
+ }
+ ptrType, ok := typ.Underlying().(*types.Pointer)
+ if ok {
+ return resolveStruct(ptrType.Elem())
+ }
+ return nil, false
+}
+
+func findField(typ types.Type, field int) (types.Object, bool) {
+ structType, ok := resolveStruct(typ)
+ if !ok {
+ return nil, false
+ }
+ return structType.Field(field), true
+}
+
+// instructionWithReferrers is a generalization over ssa.Field, ssa.FieldAddr.
+type instructionWithReferrers interface {
+ ssa.Instruction
+ Referrers() *[]ssa.Instruction
+}
+
+// checkFieldAccess checks the validity of a field access.
+//
+// This also enforces atomicity constraints for fields that must be accessed
+// atomically. The parameter isWrite indicates whether this field is used
+// downstream for a write operation.
+func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj ssa.Value, field int, ls *lockState, isWrite bool) {
+ var (
+ lff lockFieldFacts
+ lgf lockGuardFacts
+ guardsFound int
+ guardsHeld int
+ )
+
+ fieldObj, _ := findField(structObj.Type(), field)
+ pc.pass.ImportObjectFact(fieldObj, &lff)
+ pc.pass.ImportObjectFact(fieldObj, &lgf)
+
+ for guardName, fl := range lgf.GuardedBy {
+ guardsFound++
+ r := fl.resolve(structObj)
+ if _, ok := ls.isHeld(r); ok {
+ guardsHeld++
+ continue
+ }
+ if _, ok := pc.forced[pc.positionKey(inst.Pos())]; ok {
+ // Mark this as locked, since it has been forced.
+ ls.lockField(r)
+ guardsHeld++
+ continue
+ }
+ // Note that we may allow this if the disposition is atomic,
+ // and we are allowing atomic reads only. This will fall into
+ // the atomic disposition check below, which asserts that the
+ // access is atomic. Further, guardsHeld < guardsFound will be
+ // true for this case, so we require it to be read-only.
+ if lgf.AtomicDisposition != atomicRequired {
+ // There is no force key, no atomic access and no lock held.
+ pc.maybeFail(inst.Pos(), "invalid field access, %s must be locked when accessing %s (locks: %s)", guardName, fieldObj.Name(), ls.String())
+ }
+ }
+
+ // Check the atomic access for this field.
+ switch lgf.AtomicDisposition {
+ case atomicRequired:
+ // Check that this is used safely as an input.
+ readOnly := guardsHeld < guardsFound
+ if refs := inst.Referrers(); refs != nil {
+ for _, otherInst := range *refs {
+ pc.checkAtomicCall(otherInst, fieldObj, true, readOnly)
+ }
+ }
+ // Check that this is not otherwise written non-atomically,
+ // even if we do hold all the locks.
+ if isWrite {
+ pc.maybeFail(inst.Pos(), "non-atomic write of field %s, writes must still be atomic with locks held (locks: %s)", fieldObj.Name(), ls.String())
+ }
+ case atomicDisallow:
+ // Check that this is *not* used atomically.
+ if refs := inst.Referrers(); refs != nil {
+ for _, otherInst := range *refs {
+ pc.checkAtomicCall(otherInst, fieldObj, false, false)
+ }
+ }
+ }
+}
+
+func (pc *passContext) checkCall(call callCommon, ls *lockState) {
+ // See: https://godoc.org/golang.org/x/tools/go/ssa#CallCommon
+ //
+ // 1. "call" mode: when Method is nil (!IsInvoke), a CallCommon represents an ordinary
+ // function call of the value in Value, which may be a *Builtin, a *Function or any
+ // other value of kind 'func'.
+ //
+ // Value may be one of:
+ // (a) a *Function, indicating a statically dispatched call
+ // to a package-level function, an anonymous function, or
+ // a method of a named type.
+ //
+ // (b) a *MakeClosure, indicating an immediately applied
+ // function literal with free variables.
+ //
+ // (c) a *Builtin, indicating a statically dispatched call
+ // to a built-in function.
+ //
+ // (d) any other value, indicating a dynamically dispatched
+ // function call.
+ switch fn := call.Common().Value.(type) {
+ case *ssa.Function:
+ var lff lockFunctionFacts
+ if fn.Object() != nil {
+ pc.pass.ImportObjectFact(fn.Object(), &lff)
+ pc.checkFunctionCall(call, fn, &lff, ls)
+ } else {
+ // Anonymous functions have no facts, and cannot be
+ // annotated. We don't check for violations using the
+ // function facts, since they cannot exist. Instead, we
+ // do a fresh analysis using the current lock state.
+ fnls := ls.fork()
+ for i, arg := range call.Common().Args {
+ fnls.store(fn.Params[i], arg)
+ }
+ pc.checkFunction(call, fn, &lff, fnls, true /* force */)
+ }
+ case *ssa.MakeClosure:
+ // Note that creating and then invoking closures locally is
+ // allowed, but analysis of passing closures is done when
+ // checking individual instructions.
+ pc.checkClosure(call, fn, ls)
+ default:
+ return
+ }
+}
+
+// postFunctionCallUpdate updates all conditions.
+func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunctionFacts, ls *lockState) {
+ // Release all locks not still held.
+ for fieldName, fg := range lff.HeldOnEntry {
+ if _, ok := lff.HeldOnExit[fieldName]; ok {
+ continue
+ }
+ r := fg.resolveCall(call.Common().Args, call.Value())
+ if s, ok := ls.unlockField(r); !ok {
+ if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
+ pc.maybeFail(call.Pos(), "attempt to release %s (%s), but not held (locks: %s)", fieldName, s, ls.String())
+ }
+ }
+ }
+
+ // Update all held locks if acquired.
+ for fieldName, fg := range lff.HeldOnExit {
+ if _, ok := lff.HeldOnEntry[fieldName]; ok {
+ continue
+ }
+ // Acquire the lock per the annotation.
+ r := fg.resolveCall(call.Common().Args, call.Value())
+ if s, ok := ls.lockField(r); !ok {
+ if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
+ pc.maybeFail(call.Pos(), "attempt to acquire %s (%s), but already held (locks: %s)", fieldName, s, ls.String())
+ }
+ }
+ }
+}
+
+// checkFunctionCall checks preconditions for function calls, and tracks the
+// lock state by recording relevant calls to sync functions. Note that calls to
+// atomic functions are tracked by checkFieldAccess by looking directly at the
+// referrers (because ordering doesn't matter there, so we need not scan in
+// instruction order).
+func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff *lockFunctionFacts, ls *lockState) {
+ // Check all guards required are held.
+ for fieldName, fg := range lff.HeldOnEntry {
+ r := fg.resolveCall(call.Common().Args, call.Value())
+ if s, ok := ls.isHeld(r); !ok {
+ if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
+ pc.maybeFail(call.Pos(), "must hold %s (%s) to call %s, but not held (locks: %s)", fieldName, s, fn.Name(), ls.String())
+ } else {
+ // Force the lock to be acquired.
+ ls.lockField(r)
+ }
+ }
+ }
+
+ // Update all lock state accordingly.
+ pc.postFunctionCallUpdate(call, lff, ls)
+
+ // Check if it's a method dispatch for something in the sync package.
+ // See: https://godoc.org/golang.org/x/tools/go/ssa#Function
+ if fn.Package() != nil && fn.Package().Pkg.Name() == "sync" && fn.Signature.Recv() != nil {
+ switch fn.Name() {
+ case "Lock", "RLock":
+ if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok {
+ if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
+ // Double locking a mutex that is already locked.
+ pc.maybeFail(call.Pos(), "%s already locked (locks: %s)", s, ls.String())
+ }
+ }
+ case "Unlock", "RUnlock":
+ if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok {
+ if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok {
+ // Unlocking something that is already unlocked.
+ pc.maybeFail(call.Pos(), "%s already unlocked (locks: %s)", s, ls.String())
+ }
+ }
+ }
+ }
+}
+
+// checkClosure forks the lock state, and creates a binding for the FreeVars of
+// the closure. This allows the analysis to resolve the closure.
+func (pc *passContext) checkClosure(call callCommon, fn *ssa.MakeClosure, ls *lockState) {
+ clls := ls.fork()
+ clfn := fn.Fn.(*ssa.Function)
+ for i, fv := range clfn.FreeVars {
+ clls.store(fv, fn.Bindings[i])
+ }
+
+ // Note that this is *not* a call to check function call, which checks
+ // against the function preconditions. Instead, this does a fresh
+ // analysis of the function from source code with a different state.
+ var nolff lockFunctionFacts
+ pc.checkFunction(call, clfn, &nolff, clls, true /* force */)
+}
+
+// freshAlloc indicates that v has been allocated within the local scope. There
+// is no lock checking done on objects that are freshly allocated.
+func freshAlloc(v ssa.Value) bool {
+ switch x := v.(type) {
+ case *ssa.Alloc:
+ return true
+ case *ssa.FieldAddr:
+ return freshAlloc(x.X)
+ case *ssa.Field:
+ return freshAlloc(x.X)
+ case *ssa.IndexAddr:
+ return freshAlloc(x.X)
+ case *ssa.Index:
+ return freshAlloc(x.X)
+ case *ssa.Convert:
+ return freshAlloc(x.X)
+ case *ssa.ChangeType:
+ return freshAlloc(x.X)
+ default:
+ return false
+ }
+}
+
+// isWrite indicates that this value is used as the addr field in a store.
+//
+// Note that this may still be used for a write. The return here is optimistic
+// but sufficient for basic analysis.
+func isWrite(v ssa.Value) bool {
+ refs := v.Referrers()
+ if refs == nil {
+ return false
+ }
+ for _, ref := range *refs {
+ if s, ok := ref.(*ssa.Store); ok && s.Addr == v {
+ return true
+ }
+ }
+ return false
+}
+
+// callCommon is an ssa.Value that also implements Common.
+type callCommon interface {
+ Pos() token.Pos
+ Common() *ssa.CallCommon
+ Value() *ssa.Call
+}
+
+// checkInstruction checks the legality the single instruction based on the
+// current lockState.
+func (pc *passContext) checkInstruction(inst ssa.Instruction, ls *lockState) (*ssa.Return, *lockState) {
+ switch x := inst.(type) {
+ case *ssa.Store:
+ // Record that this value is holding this other value. This is
+ // because at the beginning of each ssa execution, there is a
+ // series of assignments of parameter values to alloc objects.
+ // This allows us to trace these back to the original
+ // parameters as aliases above.
+ //
+ // Note that this may overwrite an existing value in the lock
+ // state, but this is intentional.
+ ls.store(x.Addr, x.Val)
+ case *ssa.Field:
+ if !freshAlloc(x.X) {
+ pc.checkFieldAccess(x, x.X, x.Field, ls, false)
+ }
+ case *ssa.FieldAddr:
+ if !freshAlloc(x.X) {
+ pc.checkFieldAccess(x, x.X, x.Field, ls, isWrite(x))
+ }
+ case *ssa.Call:
+ pc.checkCall(x, ls)
+ case *ssa.Defer:
+ ls.pushDefer(x)
+ case *ssa.RunDefers:
+ for d := ls.popDefer(); d != nil; d = ls.popDefer() {
+ pc.checkCall(d, ls)
+ }
+ case *ssa.MakeClosure:
+ refs := x.Referrers()
+ if refs == nil {
+ // This is strange, it's not used? Ignore this case,
+ // since it will probably be optimized away.
+ return nil, nil
+ }
+ hasNonCall := false
+ for _, ref := range *refs {
+ switch ref.(type) {
+ case *ssa.Call, *ssa.Defer:
+ // Analysis will be done on the call itself
+ // subsequently, including the lock state at
+ // the time of the call.
+ default:
+ // We need to analyze separately. Per below,
+ // this means that we'll analyze at closure
+ // construction time no zero assumptions about
+ // when it will be called.
+ hasNonCall = true
+ }
+ }
+ if !hasNonCall {
+ return nil, nil
+ }
+ // Analyze the closure without bindings. This means that we
+ // assume no lock facts or have any existing lock state. Only
+ // trivial closures are acceptable in this case.
+ clfn := x.Fn.(*ssa.Function)
+ var nolff lockFunctionFacts
+ pc.checkFunction(nil, clfn, &nolff, nil, false /* force */)
+ case *ssa.Return:
+ return x, ls // Valid return state.
+ }
+ return nil, nil
+}
+
+// checkBasicBlock traverses the control flow graph starting at a set of given
+// block and checks each instruction for allowed operations.
+func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, lff *lockFunctionFacts, parent *lockState, seen map[*ssa.BasicBlock]*lockState) *lockState {
+ if oldLS, ok := seen[block]; ok && oldLS.isCompatible(parent) {
+ return nil
+ }
+
+ // If the lock state is not compatible, then we need to do the
+ // recursive analysis to ensure that it is still sane. For example, the
+ // following is guaranteed to generate incompatible locking states:
+ //
+ // if foo {
+ // mu.Lock()
+ // }
+ // other stuff ...
+ // if foo {
+ // mu.Unlock()
+ // }
+
+ var (
+ rv *ssa.Return
+ rls *lockState
+ )
+
+ // Analyze this block.
+ seen[block] = parent
+ ls := parent.fork()
+ for _, inst := range block.Instrs {
+ rv, rls = pc.checkInstruction(inst, ls)
+ if rls != nil {
+ failed := false
+ // Validate held locks.
+ for fieldName, fg := range lff.HeldOnExit {
+ r := fg.resolveStatic(fn, rv)
+ if s, ok := rls.isHeld(r); !ok {
+ if _, ok := pc.forced[pc.positionKey(rv.Pos())]; !ok {
+ pc.maybeFail(rv.Pos(), "lock %s (%s) not held (locks: %s)", fieldName, s, rls.String())
+ failed = true
+ } else {
+ // Force the lock to be acquired.
+ rls.lockField(r)
+ }
+ }
+ }
+ // Check for other locks, but only if the above didn't trip.
+ if !failed && rls.count() != len(lff.HeldOnExit) {
+ pc.maybeFail(rv.Pos(), "return with unexpected locks held (locks: %s)", rls.String())
+ }
+ }
+ }
+
+ // Analyze all successors.
+ for _, succ := range block.Succs {
+ // Collect possible return values, and make sure that the lock
+ // state aligns with any return value that we may have found
+ // above. Note that checkBasicBlock will recursively analyze
+ // the lock state to ensure that Releases and Acquires are
+ // respected.
+ if pls := pc.checkBasicBlock(fn, succ, lff, ls, seen); pls != nil {
+ if rls != nil && !rls.isCompatible(pls) {
+ if _, ok := pc.forced[pc.positionKey(fn.Pos())]; !ok {
+ pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %v)", rls.String(), pls.String())
+ }
+ }
+ rls = pls
+ }
+ }
+ return rls
+}
+
+// checkFunction checks a function invocation, typically starting with nil lockState.
+func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *lockFunctionFacts, parent *lockState, force bool) {
+ defer func() {
+ // Mark this function as checked. This is used by the top-level
+ // loop to ensure that all anonymous functions are scanned, if
+ // they are not explicitly invoked here. Note that this can
+ // happen if the anonymous functions are e.g. passed only as
+ // parameters or used to initialize some structure.
+ pc.functions[fn] = struct{}{}
+ }()
+ if _, ok := pc.functions[fn]; !force && ok {
+ // This function has already been analyzed at least once.
+ // That's all we permit for each function, although this may
+ // cause some anonymous functions to be analyzed in only one
+ // context.
+ return
+ }
+
+ // If no return value is provided, then synthesize one. This is used
+ // below only to check against the locks preconditions, which may
+ // include return values.
+ if call == nil {
+ call = &ssa.Call{Call: ssa.CallCommon{Value: fn}}
+ }
+
+ // Initialize ls with any preconditions that require locks to be held
+ // for the method to be invoked. Note that in the overwhleming majority
+ // of cases, parent will be nil. However, in the case of closures and
+ // anonymous functions, we may start with a non-nil lock state.
+ ls := parent.fork()
+ for fieldName, fg := range lff.HeldOnEntry {
+ // The first is the method object itself so we skip that when looking
+ // for receiver/function parameters.
+ r := fg.resolveStatic(fn, call.Value())
+ if s, ok := ls.lockField(r); !ok {
+ // This can only happen if the same value is declared
+ // multiple times, and should be caught by the earlier
+ // fact scanning. Keep it here as a sanity check.
+ pc.maybeFail(fn.Pos(), "lock %s (%s) acquired multiple times (locks: %s)", fieldName, s, ls.String())
+ }
+ }
+
+ // Scan the blocks.
+ seen := make(map[*ssa.BasicBlock]*lockState)
+ if len(fn.Blocks) > 0 {
+ pc.checkBasicBlock(fn, fn.Blocks[0], lff, ls, seen)
+ }
+
+ // Scan the recover block.
+ if fn.Recover != nil {
+ pc.checkBasicBlock(fn, fn.Recover, lff, ls, seen)
+ }
+
+ // Update all lock state accordingly. This will be called only if we
+ // are doing inline analysis for e.g. an anonymous function.
+ if call != nil && parent != nil {
+ pc.postFunctionCallUpdate(call, lff, parent)
+ }
+}
diff --git a/tools/checklocks/annotations.go b/tools/checklocks/annotations.go
new file mode 100644
index 000000000..371260980
--- /dev/null
+++ b/tools/checklocks/annotations.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package checklocks
+
+import (
+ "fmt"
+
+ "go/token"
+ "strconv"
+ "strings"
+)
+
+const (
+ checkLocksAnnotation = "// +checklocks:"
+ checkLocksAcquires = "// +checklocksacquire:"
+ checkLocksReleases = "// +checklocksrelease:"
+ checkLocksIgnore = "// +checklocksignore"
+ checkLocksForce = "// +checklocksforce"
+ checkLocksFail = "// +checklocksfail"
+ checkAtomicAnnotation = "// +checkatomic"
+)
+
+// failData indicates an expected failure.
+type failData struct {
+ pos token.Pos
+ count int
+ seen int
+}
+
+// positionKey is a simple position string.
+type positionKey string
+
+// positionKey converts from a token.Pos to a key we can use to track failures
+// as the position of the failure annotation is not the same as the position of
+// the actual failure (different column/offsets). Hence we ignore these fields
+// and only use the file/line numbers to track failures.
+func (pc *passContext) positionKey(pos token.Pos) positionKey {
+ position := pc.pass.Fset.Position(pos)
+ return positionKey(fmt.Sprintf("%s:%d", position.Filename, position.Line))
+}
+
+// addFailures adds an expected failure.
+func (pc *passContext) addFailures(pos token.Pos, s string) {
+ count := 1
+ if len(s) > 0 && s[0] == ':' {
+ parsedCount, err := strconv.Atoi(s[1:])
+ if err != nil {
+ pc.pass.Reportf(pos, "unable to parse failure annotation %q: %v", s[1:], err)
+ return
+ }
+ count = parsedCount
+ }
+ pc.failures[pc.positionKey(pos)] = &failData{
+ pos: pos,
+ count: count,
+ }
+}
+
+// addExemption adds an exemption.
+func (pc *passContext) addExemption(pos token.Pos) {
+ pc.exemptions[pc.positionKey(pos)] = struct{}{}
+}
+
+// addForce adds a force annotation.
+func (pc *passContext) addForce(pos token.Pos) {
+ pc.forced[pc.positionKey(pos)] = struct{}{}
+}
+
+// maybeFail checks a potential failure against a specific failure map.
+func (pc *passContext) maybeFail(pos token.Pos, fmtStr string, args ...interface{}) {
+ if fd, ok := pc.failures[pc.positionKey(pos)]; ok {
+ fd.seen++
+ return
+ }
+ if _, ok := pc.exemptions[pc.positionKey(pos)]; ok {
+ return // Ignored, not counted.
+ }
+ pc.pass.Reportf(pos, fmtStr, args...)
+}
+
+// checkFailure checks for the expected failure counts.
+func (pc *passContext) checkFailures() {
+ for _, fd := range pc.failures {
+ if fd.count != fd.seen {
+ // We are missing expect failures, report as much as possible.
+ pc.pass.Reportf(fd.pos, "got %d failures, want %d failures", fd.seen, fd.count)
+ }
+ }
+}
+
+// extractAnnotations extracts annotations from text.
+func (pc *passContext) extractAnnotations(s string, fns map[string]func(p string)) {
+ for prefix, fn := range fns {
+ if strings.HasPrefix(s, prefix) {
+ fn(s[len(prefix):])
+ }
+ }
+}
+
+// extractLineFailures extracts all line-based exceptions.
+//
+// Note that this applies only to individual line exemptions, and does not
+// consider function-wide exemptions, or specific field exemptions, which are
+// extracted separately as part of the saved facts for those objects.
+func (pc *passContext) extractLineFailures() {
+ for _, f := range pc.pass.Files {
+ for _, cg := range f.Comments {
+ for _, c := range cg.List {
+ pc.extractAnnotations(c.Text, map[string]func(string){
+ checkLocksFail: func(p string) { pc.addFailures(c.Pos(), p) },
+ checkLocksIgnore: func(string) { pc.addExemption(c.Pos()) },
+ checkLocksForce: func(string) { pc.addForce(c.Pos()) },
+ })
+ }
+ }
+ }
+}
diff --git a/tools/checklocks/checklocks.go b/tools/checklocks/checklocks.go
index 1e877d394..180f8873f 100644
--- a/tools/checklocks/checklocks.go
+++ b/tools/checklocks/checklocks.go
@@ -13,32 +13,19 @@
// limitations under the License.
// Package checklocks performs lock analysis to identify and flag unprotected
-// access to field annotated with a '// +checklocks:<mutex-name>' annotation.
+// access to annotated fields.
//
-// For detailed ussage refer to README.md in the same directory.
+// For detailed usage refer to README.md in the same directory.
package checklocks
import (
- "bytes"
- "fmt"
"go/ast"
"go/token"
"go/types"
- "reflect"
- "regexp"
- "strconv"
- "strings"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/buildssa"
"golang.org/x/tools/go/ssa"
- "gvisor.dev/gvisor/pkg/log"
-)
-
-const (
- checkLocksAnnotation = "// +checklocks:"
- checkLocksIgnore = "// +checklocksignore"
- checkLocksFail = "// +checklocksfail"
)
// Analyzer is the main entrypoint.
@@ -47,712 +34,121 @@ var Analyzer = &analysis.Analyzer{
Doc: "checks lock preconditions on functions and fields",
Run: run,
Requires: []*analysis.Analyzer{buildssa.Analyzer},
- FactTypes: []analysis.Fact{(*lockFieldFacts)(nil), (*lockFunctionFacts)(nil)},
-}
-
-// lockFieldFacts apply on every struct field protected by a lock or that is a
-// lock.
-type lockFieldFacts struct {
- // GuardedBy tracks the names and field numbers that guard this field.
- GuardedBy map[string]int
-
- // IsMutex is true if the field is of type sync.Mutex.
- IsMutex bool
-
- // IsRWMutex is true if the field is of type sync.RWMutex.
- IsRWMutex bool
-
- // FieldNumber is the number of this field in the struct.
- FieldNumber int
-}
-
-// AFact implements analysis.Fact.AFact.
-func (*lockFieldFacts) AFact() {}
-
-type functionGuard struct {
- // ParameterNumber is the index of the object that contains the guarding mutex.
- // This is required during SSA analysis as field names and parameters names are
- // not available in SSA. For example, from the example below ParameterNumber would
- // be 1 and FieldNumber would correspond to the field number of 'mu' within b's type.
- //
- // //+checklocks:b.mu
- // func (a *A) method(b *B, c *C) {
- // ...
- // }
- ParameterNumber int
-
- // FieldNumber is the field index of the mutex in the parameter's struct
- // type. Refer to example above for more details.
- FieldNumber int
-}
-
-// lockFunctionFacts apply on every method.
-type lockFunctionFacts struct {
- // GuardedBy tracks the names and number of parameter (including receiver)
- // lockFuncfields that guard calls to this function.
- // The key is the name specified in the checklocks annotation. e.g given
- // the following code.
- // ```
- // type A struct {
- // mu sync.Mutex
- // a int
- // }
- //
- // // +checklocks:a.mu
- // func xyz(a *A) {..}
- // ```
- //
- // '`+checklocks:a.mu' will result in an entry in this map as shown below.
- // GuardedBy: {"a.mu" => {ParameterNumber: 0, FieldNumber: 0}
- GuardedBy map[string]functionGuard
-}
-
-// AFact implements analysis.Fact.AFact.
-func (*lockFunctionFacts) AFact() {}
-
-type positionKey string
-
-// toPositionKey converts from a token.Position to a key we can use to track
-// failures as the position of the failure annotation is not the same as the
-// position of the actual failure (different column/offsets). Hence we ignore
-// these fields and only use the file/line numbers to track failures.
-func toPositionKey(position token.Position) positionKey {
- return positionKey(fmt.Sprintf("%s:%d", position.Filename, position.Line))
-}
-
-type failData struct {
- pos token.Pos
- count int
-}
-
-func (f failData) String() string {
- return fmt.Sprintf("pos: %d, count: %d", f.pos, f.count)
+ FactTypes: []analysis.Fact{(*atomicAlignment)(nil), (*lockFieldFacts)(nil), (*lockGuardFacts)(nil), (*lockFunctionFacts)(nil)},
}
+// passContext is a pass with additional expected failures.
type passContext struct {
- pass *analysis.Pass
-
- // exemptions tracks functions that should be exempted from lock checking due
- // to '// +checklocksignore' annotation.
- exemptions map[types.Object]struct{}
-
- failures map[positionKey]*failData
+ pass *analysis.Pass
+ failures map[positionKey]*failData
+ exemptions map[positionKey]struct{}
+ forced map[positionKey]struct{}
+ functions map[*ssa.Function]struct{}
}
-var (
- mutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineMutex|Mutex)")
- rwMutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineRWMutex|RWMutex)")
-)
-
-func (pc *passContext) extractFieldAnnotations(field *ast.Field, fieldType *types.Var) *lockFieldFacts {
- s := fieldType.Type().String()
- // We use HasSuffix below because fieldType can be fully qualified with the
- // package name eg for the gvisor sync package mutex fields have the type:
- // "<package path>/sync/sync.Mutex"
- switch {
- case mutexRE.Match([]byte(s)):
- return &lockFieldFacts{IsMutex: true}
- case rwMutexRE.Match([]byte(s)):
- return &lockFieldFacts{IsRWMutex: true}
- default:
- }
- if field.Doc == nil {
- return nil
- }
- fieldFacts := &lockFieldFacts{GuardedBy: make(map[string]int)}
- for _, l := range field.Doc.List {
- if strings.HasPrefix(l.Text, checkLocksAnnotation) {
- guardName := strings.TrimPrefix(l.Text, checkLocksAnnotation)
- if _, ok := fieldFacts.GuardedBy[guardName]; ok {
- pc.pass.Reportf(field.Pos(), "annotation %s specified more than once.", l.Text)
- continue
- }
- fieldFacts.GuardedBy[guardName] = -1
- }
- }
-
- return fieldFacts
-}
-
-func (pc *passContext) findField(v ssa.Value, fieldNumber int) types.Object {
- structType, ok := v.Type().Underlying().(*types.Struct)
- if !ok {
- structType = v.Type().Underlying().(*types.Pointer).Elem().Underlying().(*types.Struct)
- }
- return structType.Field(fieldNumber)
-}
-
-// findAndExportStructFacts finds any struct fields that are annotated with the
-// "// +checklocks:" annotation and exports relevant facts about the fields to
-// be used in later analysis.
-func (pc *passContext) findAndExportStructFacts(ss *ast.StructType, structType *types.Struct) {
- type fieldRef struct {
- fieldObj *types.Var
- facts *lockFieldFacts
- }
- mutexes := make(map[string]*fieldRef)
- rwMutexes := make(map[string]*fieldRef)
- guardedFields := make(map[string]*fieldRef)
- for i, field := range ss.Fields.List {
- fieldObj := structType.Field(i)
- fieldFacts := pc.extractFieldAnnotations(field, fieldObj)
- if fieldFacts == nil {
- continue
- }
- fieldFacts.FieldNumber = i
-
- ref := &fieldRef{fieldObj, fieldFacts}
- if fieldFacts.IsMutex {
- mutexes[fieldObj.Name()] = ref
- }
- if fieldFacts.IsRWMutex {
- rwMutexes[fieldObj.Name()] = ref
- }
- if len(fieldFacts.GuardedBy) != 0 {
- guardedFields[fieldObj.Name()] = ref
- }
- }
-
- // Export facts about all mutexes.
- for _, f := range mutexes {
- pc.pass.ExportObjectFact(f.fieldObj, f.facts)
- }
- // Export facts about all rwMutexes.
- for _, f := range rwMutexes {
- pc.pass.ExportObjectFact(f.fieldObj, f.facts)
- }
-
- // Validate that guarded fields annotations refer to actual mutexes or
- // rwMutexes in the struct.
- for _, gf := range guardedFields {
- for g := range gf.facts.GuardedBy {
- if f, ok := mutexes[g]; ok {
- gf.facts.GuardedBy[g] = f.facts.FieldNumber
- } else if f, ok := rwMutexes[g]; ok {
- gf.facts.GuardedBy[g] = f.facts.FieldNumber
- } else {
- pc.maybeFail(gf.fieldObj.Pos(), false /* isExempted */, "invalid mutex guard, no such mutex %s in struct %s", g, structType.String())
- continue
- }
- // Export guarded field fact.
- pc.pass.ExportObjectFact(gf.fieldObj, gf.facts)
- }
- }
-}
-
-func (pc *passContext) findAndExportFuncFacts(d *ast.FuncDecl) {
- log.Debugf("finding and exporting function facts\n")
- // for each function definition, check for +checklocks:mu annotation, which
- // means that the function must be called with that lock held.
- fnObj := pc.pass.TypesInfo.ObjectOf(d.Name)
- funcFacts := lockFunctionFacts{GuardedBy: make(map[string]functionGuard)}
- var (
- ignore bool
- ignorePos token.Pos
- )
-
-outerLoop:
- for _, l := range d.Doc.List {
- if strings.HasPrefix(l.Text, checkLocksIgnore) {
- pc.exemptions[fnObj] = struct{}{}
- ignore = true
- ignorePos = l.Pos()
- continue
- }
- if strings.HasPrefix(l.Text, checkLocksAnnotation) {
- guardName := strings.TrimPrefix(l.Text, checkLocksAnnotation)
- if _, ok := funcFacts.GuardedBy[guardName]; ok {
- pc.pass.Reportf(l.Pos(), "annotation %s specified more than once.", l.Text)
- continue
- }
-
- found := false
- x := strings.Split(guardName, ".")
- if len(x) != 2 {
- pc.pass.Reportf(l.Pos(), "checklocks mutex annotation should be of the form 'a.b'")
+// forAllTypes applies the given function over all types.
+func (pc *passContext) forAllTypes(fn func(ts *ast.TypeSpec)) {
+ for _, f := range pc.pass.Files {
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.GenDecl)
+ if !ok || d.Tok != token.TYPE {
continue
}
- paramName, fieldName := x[0], x[1]
- log.Debugf("paramName: %s, fieldName: %s", paramName, fieldName)
- var paramList []*ast.Field
- if d.Recv != nil {
- paramList = append(paramList, d.Recv.List...)
- }
- if d.Type.Params != nil {
- paramList = append(paramList, d.Type.Params.List...)
- }
- for paramNum, field := range paramList {
- log.Debugf("field names: %+v", field.Names)
- if len(field.Names) == 0 {
- log.Debugf("skipping because parameter is unnamed", paramName)
- continue
- }
- nameExists := false
- for _, name := range field.Names {
- if name.Name == paramName {
- nameExists = true
- }
- }
- if !nameExists {
- log.Debugf("skipping because parameter name(s) does not match : %s", paramName)
- continue
- }
- ptrType, ok := pc.pass.TypesInfo.TypeOf(field.Type).Underlying().(*types.Pointer)
- if !ok {
- // Since mutexes cannot be copied we only care about parameters that
- // are pointer types when checking for guards.
- pc.pass.Reportf(l.Pos(), "annotation %s incorrectly specified, parameter name does not refer to a pointer type", l.Text)
- continue outerLoop
- }
-
- structType, ok := ptrType.Elem().Underlying().(*types.Struct)
- if !ok {
- pc.pass.Reportf(l.Pos(), "annotation %s incorrectly specified, parameter name does not refer to a pointer to a struct", l.Text)
- continue outerLoop
- }
-
- for i := 0; i < structType.NumFields(); i++ {
- if structType.Field(i).Name() == fieldName {
- var fieldFacts lockFieldFacts
- pc.pass.ImportObjectFact(structType.Field(i), &fieldFacts)
- if !fieldFacts.IsMutex && !fieldFacts.IsRWMutex {
- pc.pass.Reportf(l.Pos(), "field %s of param %s is not a mutex or an rwmutex", paramName, structType.Field(i))
- continue outerLoop
- }
- funcFacts.GuardedBy[guardName] = functionGuard{ParameterNumber: paramNum, FieldNumber: i}
- found = true
- continue outerLoop
- }
- }
- if !found {
- pc.pass.Reportf(l.Pos(), "annotation refers to a non-existent field %s in %s", guardName, structType)
- continue outerLoop
- }
- }
- if !found {
- pc.pass.Reportf(l.Pos(), "annotation refers to a non-existent parameter %s", paramName)
- }
- }
- }
-
- if len(funcFacts.GuardedBy) == 0 {
- return
- }
- if ignore {
- pc.pass.Reportf(ignorePos, "//+checklocksignore cannot be specified with other annotations on the function")
- }
- funcObj, ok := pc.pass.TypesInfo.Defs[d.Name].(*types.Func)
- if !ok {
- panic(fmt.Sprintf("function type information missing for %+v", d))
- }
- log.Debugf("export fact for d: %+v, funcObj: %+v, funcFacts: %+v\n", d, funcObj, funcFacts)
- pc.pass.ExportObjectFact(funcObj, &funcFacts)
-}
-
-type mutexState struct {
- // lockedMutexes is used to track which mutexes in a given struct are
- // currently locked using the field number of the mutex as the key.
- lockedMutexes map[int]struct{}
-}
-
-// locksHeld tracks all currently held locks.
-type locksHeld struct {
- locks map[ssa.Value]mutexState
-}
-
-// Same returns true if the locks held by other and l are the same.
-func (l *locksHeld) Same(other *locksHeld) bool {
- return reflect.DeepEqual(l.locks, other.locks)
-}
-
-// Copy creates a copy of all the lock state held by l.
-func (l *locksHeld) Copy() *locksHeld {
- out := &locksHeld{locks: make(map[ssa.Value]mutexState)}
- for ssaVal, mState := range l.locks {
- newLM := make(map[int]struct{})
- for k, v := range mState.lockedMutexes {
- newLM[k] = v
- }
- out.locks[ssaVal] = mutexState{lockedMutexes: newLM}
- }
- return out
-}
-
-func isAlias(first, second ssa.Value) bool {
- if first == second {
- return true
- }
- switch x := first.(type) {
- case *ssa.Field:
- if y, ok := second.(*ssa.Field); ok {
- return x.Field == y.Field && isAlias(x.X, y.X)
- }
- case *ssa.FieldAddr:
- if y, ok := second.(*ssa.FieldAddr); ok {
- return x.Field == y.Field && isAlias(x.X, y.X)
- }
- case *ssa.Index:
- if y, ok := second.(*ssa.Index); ok {
- return isAlias(x.Index, y.Index) && isAlias(x.X, y.X)
- }
- case *ssa.IndexAddr:
- if y, ok := second.(*ssa.IndexAddr); ok {
- return isAlias(x.Index, y.Index) && isAlias(x.X, y.X)
- }
- case *ssa.UnOp:
- if y, ok := second.(*ssa.UnOp); ok {
- return isAlias(x.X, y.X)
- }
- }
- return false
-}
-
-// checkBasicBlocks traverses the control flow graph starting at a set of given
-// block and checks each instruction for allowed operations.
-//
-// funcFact are the exported facts for the enclosing function for these basic
-// blocks.
-func (pc *passContext) checkBasicBlocks(blocks []*ssa.BasicBlock, recoverBlock *ssa.BasicBlock, fn *ssa.Function, funcFact lockFunctionFacts) {
- if len(blocks) == 0 {
- return
- }
-
- // mutexes is used to track currently locked sync.Mutexes/sync.RWMutexes for a
- // given *struct identified by ssa.Value.
- seen := make(map[*ssa.BasicBlock]*locksHeld)
- var scan func(block *ssa.BasicBlock, parent *locksHeld)
- scan = func(block *ssa.BasicBlock, parent *locksHeld) {
- _, isExempted := pc.exemptions[block.Parent().Object()]
- if oldLocksHeld, ok := seen[block]; ok {
- if oldLocksHeld.Same(parent) {
- return
- }
- pc.maybeFail(block.Instrs[0].Pos(), isExempted, "failure entering a block %+v with different sets of lock held, oldLocks: %+v, parentLocks: %+v", block, oldLocksHeld, parent)
- return
- }
- seen[block] = parent
- var lh = parent.Copy()
- for _, inst := range block.Instrs {
- pc.checkInstruction(inst, isExempted, lh)
- }
- for _, b := range block.Succs {
- scan(b, lh)
- }
- }
-
- // Initialize lh with any preconditions that require locks to be held for the
- // method to be invoked.
- lh := &locksHeld{locks: make(map[ssa.Value]mutexState)}
- for _, fg := range funcFact.GuardedBy {
- // The first is the method object itself so we skip that when looking
- // for receiver/function parameters.
- log.Debugf("fn: %s, fn.Operands() == %+v", fn, fn.Operands(nil))
- r := fn.Params[fg.ParameterNumber]
- guardObj := findField(r, fg.FieldNumber)
- var fieldFacts lockFieldFacts
- pc.pass.ImportObjectFact(guardObj, &fieldFacts)
- if fieldFacts.IsMutex || fieldFacts.IsRWMutex {
- m, ok := lh.locks[r]
- if !ok {
- m = mutexState{lockedMutexes: make(map[int]struct{})}
- lh.locks[r] = m
+ for _, gs := range d.Specs {
+ fn(gs.(*ast.TypeSpec))
}
- m.lockedMutexes[fieldFacts.FieldNumber] = struct{}{}
- } else {
- panic(fmt.Sprintf("function: %+v has an invalid guard that is not a mutex: %+v", fn, guardObj))
- }
- }
-
- // Start scanning from the first basic block.
- scan(blocks[0], lh)
-
- // Validate that all blocks were touched.
- for _, b := range blocks {
- if _, ok := seen[b]; !ok && b != recoverBlock {
- panic(fmt.Sprintf("block %+v was not visited during checkBasicBlocks", b))
- }
- }
-}
-
-func (pc *passContext) checkInstruction(inst ssa.Instruction, isExempted bool, lh *locksHeld) {
- log.Debugf("checking instruction: %s, isExempted: %t", inst, isExempted)
- switch x := inst.(type) {
- case *ssa.Field:
- pc.checkFieldAccess(inst, x.X, x.Field, isExempted, lh)
- case *ssa.FieldAddr:
- pc.checkFieldAccess(inst, x.X, x.Field, isExempted, lh)
- case *ssa.Call:
- pc.checkFunctionCall(x, isExempted, lh)
- }
-}
-
-func findField(v ssa.Value, field int) types.Object {
- structType, ok := v.Type().Underlying().(*types.Struct)
- if !ok {
- ptrType, ok := v.Type().Underlying().(*types.Pointer)
- if !ok {
- return nil
- }
- structType = ptrType.Elem().Underlying().(*types.Struct)
- }
- return structType.Field(field)
-}
-
-func (pc *passContext) maybeFail(pos token.Pos, isExempted bool, fmtStr string, args ...interface{}) {
- posKey := toPositionKey(pc.pass.Fset.Position(pos))
- log.Debugf("maybeFail: pos: %d, positionKey: %s", pos, posKey)
- if fData, ok := pc.failures[posKey]; ok {
- fData.count--
- if fData.count == 0 {
- delete(pc.failures, posKey)
}
- return
- }
- if !isExempted {
- pc.pass.Reportf(pos, fmt.Sprintf(fmtStr, args...))
}
}
-func (pc *passContext) checkFieldAccess(inst ssa.Instruction, structObj ssa.Value, field int, isExempted bool, lh *locksHeld) {
- var fieldFacts lockFieldFacts
- fieldObj := findField(structObj, field)
- pc.pass.ImportObjectFact(fieldObj, &fieldFacts)
- log.Debugf("fieldObj: %s, fieldFacts: %+v", fieldObj, fieldFacts)
- for _, guardFieldNumber := range fieldFacts.GuardedBy {
- guardObj := findField(structObj, guardFieldNumber)
- var guardfieldFacts lockFieldFacts
- pc.pass.ImportObjectFact(guardObj, &guardfieldFacts)
- log.Debugf("guardObj: %s, guardFieldFacts: %+v", guardObj, guardfieldFacts)
- if guardfieldFacts.IsMutex || guardfieldFacts.IsRWMutex {
- log.Debugf("guard is a mutex")
- m, ok := lh.locks[structObj]
+// forAllFunctions applies the given function over all functions.
+func (pc *passContext) forAllFunctions(fn func(fn *ast.FuncDecl)) {
+ for _, f := range pc.pass.Files {
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.FuncDecl)
if !ok {
- pc.maybeFail(inst.Pos(), isExempted, "invalid field access, %s must be locked when accessing %s", guardObj.Name(), fieldObj.Name())
- continue
- }
- if _, ok := m.lockedMutexes[guardfieldFacts.FieldNumber]; !ok {
- pc.maybeFail(inst.Pos(), isExempted, "invalid field access, %s must be locked when accessing %s", guardObj.Name(), fieldObj.Name())
- }
- } else {
- panic("incorrect guard that is not a mutex or an RWMutex")
- }
- }
-}
-
-func (pc *passContext) checkFunctionCall(call *ssa.Call, isExempted bool, lh *locksHeld) {
- // See: https://godoc.org/golang.org/x/tools/go/ssa#CallCommon
- //
- // 1. "call" mode: when Method is nil (!IsInvoke), a CallCommon represents an ordinary
- // function call of the value in Value, which may be a *Builtin, a *Function or any
- // other value of kind 'func'.
- //
- // Value may be one of:
- // (a) a *Function, indicating a statically dispatched call
- // to a package-level function, an anonymous function, or
- // a method of a named type.
- //
- // (b) a *MakeClosure, indicating an immediately applied
- // function literal with free variables.
- //
- // (c) a *Builtin, indicating a statically dispatched call
- // to a built-in function.
- //
- // (d) any other value, indicating a dynamically dispatched
- // function call.
- fn, ok := call.Common().Value.(*ssa.Function)
- if !ok {
- return
- }
- if fn.Object() == nil {
- return
- }
-
- // Check if the function should be called with any locks held.
- var funcFact lockFunctionFacts
- pc.pass.ImportObjectFact(fn.Object(), &funcFact)
- if len(funcFact.GuardedBy) > 0 {
- for _, fg := range funcFact.GuardedBy {
- // The first is the method object itself so we skip that when looking
- // for receiver/function parameters.
- r := (*call.Value().Operands(nil)[fg.ParameterNumber+1])
- guardObj := findField(r, fg.FieldNumber)
- if guardObj == nil {
continue
}
- var fieldFacts lockFieldFacts
- pc.pass.ImportObjectFact(guardObj, &fieldFacts)
- if fieldFacts.IsMutex || fieldFacts.IsRWMutex {
- heldMutexes, ok := lh.locks[r]
- if !ok {
- log.Debugf("fn: %s, funcFact: %+v", fn, funcFact)
- pc.maybeFail(call.Pos(), isExempted, "invalid function call %s must be held", guardObj.Name())
- continue
- }
- if _, ok := heldMutexes.lockedMutexes[fg.FieldNumber]; !ok {
- log.Debugf("fn: %s, funcFact: %+v", fn, funcFact)
- pc.maybeFail(call.Pos(), isExempted, "invalid function call %s must be held", guardObj.Name())
- }
- } else {
- panic(fmt.Sprintf("function: %+v has an invalid guard that is not a mutex: %+v", fn, guardObj))
- }
- }
- }
-
- // Check if it's a method dispatch for something in the sync package.
- // See: https://godoc.org/golang.org/x/tools/go/ssa#Function
- if fn.Package() != nil && fn.Package().Pkg.Name() == "sync" && fn.Signature.Recv() != nil {
- r, ok := call.Common().Args[0].(*ssa.FieldAddr)
- if !ok {
- return
- }
- guardObj := findField(r.X, r.Field)
- var fieldFacts lockFieldFacts
- pc.pass.ImportObjectFact(guardObj, &fieldFacts)
- if fieldFacts.IsMutex || fieldFacts.IsRWMutex {
- switch fn.Name() {
- case "Lock", "RLock":
- obj := r.X
- m := mutexState{lockedMutexes: make(map[int]struct{})}
- for k, v := range lh.locks {
- if isAlias(r.X, k) {
- obj = k
- m = v
- }
- }
- if _, ok := m.lockedMutexes[r.Field]; ok {
- // Double locking a mutex that is already locked.
- pc.maybeFail(call.Pos(), isExempted, "trying to a lock %s when already locked", guardObj.Name())
- return
- }
- m.lockedMutexes[r.Field] = struct{}{}
- lh.locks[obj] = m
- case "Unlock", "RUnlock":
- // Find the associated locker object.
- var (
- obj ssa.Value
- m mutexState
- )
- for k, v := range lh.locks {
- if isAlias(r.X, k) {
- obj = k
- m = v
- break
- }
- }
- if _, ok := m.lockedMutexes[r.Field]; !ok {
- pc.maybeFail(call.Pos(), isExempted, "trying to unlock a mutex %s that is already unlocked", guardObj.Name())
- return
- }
- delete(m.lockedMutexes, r.Field)
- if len(m.lockedMutexes) == 0 {
- delete(lh.locks, obj)
- }
- case "RLocker", "DowngradeLock", "TryLock", "TryRLock":
- // we explicitly ignore this for now.
- default:
- panic(fmt.Sprintf("unexpected mutex/rwmutex method invoked: %s", fn.Name()))
- }
+ fn(d)
}
}
}
+// run is the main entrypoint.
func run(pass *analysis.Pass) (interface{}, error) {
pc := &passContext{
pass: pass,
- exemptions: make(map[types.Object]struct{}),
failures: make(map[positionKey]*failData),
+ exemptions: make(map[positionKey]struct{}),
+ forced: make(map[positionKey]struct{}),
+ functions: make(map[*ssa.Function]struct{}),
}
// Find all line failure annotations.
- for _, f := range pass.Files {
- for _, cg := range f.Comments {
- for _, c := range cg.List {
- if strings.Contains(c.Text, checkLocksFail) {
- cnt := 1
- if strings.Contains(c.Text, checkLocksFail+":") {
- parts := strings.SplitAfter(c.Text, checkLocksFail+":")
- parsedCount, err := strconv.Atoi(parts[1])
- if err != nil {
- pc.pass.Reportf(c.Pos(), "invalid checklocks annotation : %s", err)
- continue
- }
- cnt = parsedCount
- }
- position := toPositionKey(pass.Fset.Position(c.Pos()))
- pc.failures[position] = &failData{pos: c.Pos(), count: cnt}
- }
- }
- }
- }
-
- // Find all struct declarations and export any relevant facts.
- for _, f := range pass.Files {
- for _, decl := range f.Decls {
- d, ok := decl.(*ast.GenDecl)
- // A GenDecl node (generic declaration node) represents an import,
- // constant, type or variable declaration. We only care about struct
- // declarations so skip any declaration that doesn't declare a new type.
- if !ok || d.Tok != token.TYPE {
- continue
- }
+ pc.extractLineFailures()
- for _, gs := range d.Specs {
- ts := gs.(*ast.TypeSpec)
- ss, ok := ts.Type.(*ast.StructType)
- if !ok {
- continue
- }
- structType := pass.TypesInfo.TypeOf(ts.Name).Underlying().(*types.Struct)
- pc.findAndExportStructFacts(ss, structType)
- }
+ // Find all struct declarations and export relevant facts.
+ pc.forAllTypes(func(ts *ast.TypeSpec) {
+ if ss, ok := ts.Type.(*ast.StructType); ok {
+ pc.exportLockFieldFacts(ts, ss)
}
- }
+ })
+ pc.forAllTypes(func(ts *ast.TypeSpec) {
+ if ss, ok := ts.Type.(*ast.StructType); ok {
+ pc.exportLockGuardFacts(ts, ss)
+ }
+ })
- // Find all method calls and export any relevant facts.
- for _, f := range pass.Files {
- for _, decl := range f.Decls {
- d, ok := decl.(*ast.FuncDecl)
- // Ignore any non function declarations and any functions that do not have
- // any comments.
- if !ok || d.Doc == nil {
- continue
- }
- pc.findAndExportFuncFacts(d)
+ // Check all alignments.
+ pc.forAllTypes(func(ts *ast.TypeSpec) {
+ typ, ok := pass.TypesInfo.TypeOf(ts.Name).(*types.Named)
+ if !ok {
+ return
}
- }
+ pc.checkTypeAlignment(pass.Pkg, typ)
+ })
- // log all known facts and all failures if debug logging is enabled.
- allFacts := pass.AllObjectFacts()
- for i := range allFacts {
- log.Debugf("fact.object: %+v, fact.Fact: %+v", allFacts[i].Object, allFacts[i].Fact)
- }
- log.Debugf("all expected failures: %+v", pc.failures)
+ // Find all function declarations and export relevant facts.
+ pc.forAllFunctions(func(fn *ast.FuncDecl) {
+ pc.exportFunctionFacts(fn)
+ })
// Scan all code looking for invalid accesses.
state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
for _, fn := range state.SrcFuncs {
- var funcFact lockFunctionFacts
- // Anonymous(closures) functions do not have an object() but do show up in
- // the SSA.
- if obj := fn.Object(); obj != nil {
- pc.pass.ImportObjectFact(fn.Object(), &funcFact)
+ // Import function facts generated above.
+ //
+ // Note that anonymous(closures) functions do not have an
+ // object but do show up in the SSA. They can only be invoked
+ // by named functions in the package, and they are analyzing
+ // inline on every call. Thus we skip the analysis here. They
+ // will be hit on calls, or picked up in the pass below.
+ if obj := fn.Object(); obj == nil {
+ continue
}
+ var lff lockFunctionFacts
+ pc.pass.ImportObjectFact(fn.Object(), &lff)
- log.Debugf("checking function: %s", fn)
- var b bytes.Buffer
- ssa.WriteFunction(&b, fn)
- log.Debugf("function SSA: %s", b.String())
- if fn.Recover != nil {
- pc.checkBasicBlocks([]*ssa.BasicBlock{fn.Recover}, nil, fn, funcFact)
+ // Do we ignore this?
+ if lff.Ignore {
+ continue
}
- pc.checkBasicBlocks(fn.Blocks, fn.Recover, fn, funcFact)
- }
- // Scan for remaining failures we expect.
- for _, failure := range pc.failures {
- // We are missing expect failures, report as much as possible.
- pass.Reportf(failure.pos, "expected %d failures", failure.count)
+ // Check the basic blocks in the function.
+ pc.checkFunction(nil, fn, &lff, nil, false /* force */)
}
+ for _, fn := range state.SrcFuncs {
+ // Ensure all anonymous functions are hit. They are not
+ // permitted to have any lock preconditions.
+ if obj := fn.Object(); obj != nil {
+ continue
+ }
+ var nolff lockFunctionFacts
+ pc.checkFunction(nil, fn, &nolff, nil, false /* force */)
+ }
+
+ // Check for expected failures.
+ pc.checkFailures()
return nil, nil
}
diff --git a/tools/checklocks/facts.go b/tools/checklocks/facts.go
new file mode 100644
index 000000000..1a43dbbe6
--- /dev/null
+++ b/tools/checklocks/facts.go
@@ -0,0 +1,614 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package checklocks
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "go/types"
+ "regexp"
+ "strings"
+
+ "golang.org/x/tools/go/ssa"
+)
+
+// atomicAlignment is saved per type.
+//
+// This represents the alignment required for the type, which may
+// be implied and imposed by other types within the aggregate type.
+type atomicAlignment int
+
+// AFact implements analysis.Fact.AFact.
+func (*atomicAlignment) AFact() {}
+
+// atomicDisposition is saved per field.
+//
+// This represents how the field must be accessed. It must either
+// be non-atomic (default), atomic or ignored.
+type atomicDisposition int
+
+const (
+ atomicDisallow atomicDisposition = iota
+ atomicIgnore
+ atomicRequired
+)
+
+// fieldList is a simple list of fields, used in two types below.
+//
+// Note that the integers in this list refer to one of two things:
+// - A positive integer refers to a field index in a struct.
+// - A negative integer refers to a field index in a struct, where
+// that field is a pointer and must be subsequently resolved.
+type fieldList []int
+
+// resolvedValue is an ssa.Value with additional fields.
+//
+// This can be resolved to a string as part of a lock state.
+type resolvedValue struct {
+ value ssa.Value
+ valid bool
+ fieldList []int
+}
+
+// findExtract finds a relevant extract. This must exist within the referrers
+// to the call object. If this doesn't then the object which is locked is never
+// consumed, and we should consider this a bug.
+func findExtract(v ssa.Value, index int) (ssa.Value, bool) {
+ if refs := v.Referrers(); refs != nil {
+ for _, inst := range *refs {
+ if x, ok := inst.(*ssa.Extract); ok && x.Tuple == v && x.Index == index {
+ return inst.(ssa.Value), true
+ }
+ }
+ }
+ return nil, false
+}
+
+// resolve resolves the given field list.
+func (fl fieldList) resolve(v ssa.Value) (rv resolvedValue) {
+ return resolvedValue{
+ value: v,
+ fieldList: fl,
+ valid: true,
+ }
+}
+
+// valueAsString returns a string representing this value.
+//
+// This must align with how the string is generated in valueAsString.
+func (rv resolvedValue) valueAsString(ls *lockState) string {
+ typ := rv.value.Type()
+ s := ls.valueAsString(rv.value)
+ for i, fieldNumber := range rv.fieldList {
+ switch {
+ case fieldNumber > 0:
+ field, ok := findField(typ, fieldNumber-1)
+ if !ok {
+ // This can't be resolved, return for debugging.
+ return fmt.Sprintf("{%s+%v}", s, rv.fieldList[i:])
+ }
+ s = fmt.Sprintf("&(%s.%s)", s, field.Name())
+ typ = field.Type()
+ case fieldNumber < 1:
+ field, ok := findField(typ, (-fieldNumber)-1)
+ if !ok {
+ // See above.
+ return fmt.Sprintf("{%s+%v}", s, rv.fieldList[i:])
+ }
+ s = fmt.Sprintf("*(&(%s.%s))", s, field.Name())
+ typ = field.Type()
+ }
+ }
+ return s
+}
+
+// lockFieldFacts apply on every struct field.
+type lockFieldFacts struct {
+ // IsMutex is true if the field is of type sync.Mutex.
+ IsMutex bool
+
+ // IsRWMutex is true if the field is of type sync.RWMutex.
+ IsRWMutex bool
+
+ // IsPointer indicates if the field is a pointer.
+ IsPointer bool
+
+ // FieldNumber is the number of this field in the struct.
+ FieldNumber int
+}
+
+// AFact implements analysis.Fact.AFact.
+func (*lockFieldFacts) AFact() {}
+
+// lockGuardFacts contains guard information.
+type lockGuardFacts struct {
+ // GuardedBy is the set of locks that are guarding this field. The key
+ // is the original annotation value, and the field list is the object
+ // traversal path.
+ GuardedBy map[string]fieldList
+
+ // AtomicDisposition is the disposition for this field. Note that this
+ // can affect the interpretation of the GuardedBy field above, see the
+ // relevant comment.
+ AtomicDisposition atomicDisposition
+}
+
+// AFact implements analysis.Fact.AFact.
+func (*lockGuardFacts) AFact() {}
+
+// functionGuard is used by lockFunctionFacts, below.
+type functionGuard struct {
+ // ParameterNumber is the index of the object that contains the
+ // guarding mutex. From this parameter, a walk is performed
+ // subsequently using the resolve method.
+ //
+ // Note that is ParameterNumber is beyond the size of parameters, then
+ // it may return to a return value. This applies only for the Acquires
+ // relation below.
+ ParameterNumber int
+
+ // NeedsExtract is used in the case of a return value, and indicates
+ // that the field must be extracted from a tuple.
+ NeedsExtract bool
+
+ // FieldList is the traversal path to the object.
+ FieldList fieldList
+}
+
+// resolveReturn resolves a return value.
+//
+// Precondition: rv is either an ssa.Value, or an *ssa.Return.
+func (fg *functionGuard) resolveReturn(rv interface{}, args int) resolvedValue {
+ if rv == nil {
+ // For defers and other objects, this may be nil. This is
+ // handled in state.go in the actual lock checking logic.
+ return resolvedValue{
+ value: nil,
+ valid: false,
+ }
+ }
+ index := fg.ParameterNumber - args
+ // If this is a *ssa.Return object, i.e. we are analyzing the function
+ // and not the call site, then we can just pull the result directly.
+ if r, ok := rv.(*ssa.Return); ok {
+ return fg.FieldList.resolve(r.Results[index])
+ }
+ if fg.NeedsExtract {
+ // Resolve on the extracted field, this is necessary if the
+ // type here is not an explicit return. Note that rv must be an
+ // ssa.Value, since it is not an *ssa.Return.
+ v, ok := findExtract(rv.(ssa.Value), index)
+ if !ok {
+ return resolvedValue{
+ value: v,
+ valid: false,
+ }
+ }
+ return fg.FieldList.resolve(v)
+ }
+ if index != 0 {
+ // This should not happen, NeedsExtract should always be set.
+ panic("NeedsExtract is false, but return value index is non-zero")
+ }
+ // Resolve on the single return.
+ return fg.FieldList.resolve(rv.(ssa.Value))
+}
+
+// resolveStatic returns an ssa.Value representing the given field.
+//
+// Precondition: per resolveReturn.
+func (fg *functionGuard) resolveStatic(fn *ssa.Function, rv interface{}) resolvedValue {
+ if fg.ParameterNumber >= len(fn.Params) {
+ return fg.resolveReturn(rv, len(fn.Params))
+ }
+ return fg.FieldList.resolve(fn.Params[fg.ParameterNumber])
+}
+
+// resolveCall returns an ssa.Value representing the given field.
+func (fg *functionGuard) resolveCall(args []ssa.Value, rv ssa.Value) resolvedValue {
+ if fg.ParameterNumber >= len(args) {
+ return fg.resolveReturn(rv, len(args))
+ }
+ return fg.FieldList.resolve(args[fg.ParameterNumber])
+}
+
+// lockFunctionFacts apply on every method.
+type lockFunctionFacts struct {
+ // HeldOnEntry tracks the names and number of parameter (including receiver)
+ // lockFuncfields that guard calls to this function.
+ //
+ // The key is the name specified in the checklocks annotation. e.g given
+ // the following code:
+ //
+ // ```
+ // type A struct {
+ // mu sync.Mutex
+ // a int
+ // }
+ //
+ // // +checklocks:a.mu
+ // func xyz(a *A) {..}
+ // ```
+ //
+ // '`+checklocks:a.mu' will result in an entry in this map as shown below.
+ // HeldOnEntry: {"a.mu" => {ParameterNumber: 0, FieldNumbers: {0}}
+ //
+ // Unlikely lockFieldFacts, there is no atomic interpretation.
+ HeldOnEntry map[string]functionGuard
+
+ // HeldOnExit tracks the locks that are expected to be held on exit.
+ HeldOnExit map[string]functionGuard
+
+ // Ignore means this function has local analysis ignores.
+ //
+ // This is not used outside the local package.
+ Ignore bool
+}
+
+// AFact implements analysis.Fact.AFact.
+func (*lockFunctionFacts) AFact() {}
+
+// checkGuard validates the guardName.
+func (lff *lockFunctionFacts) checkGuard(pc *passContext, d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) {
+ if _, ok := lff.HeldOnEntry[guardName]; ok {
+ pc.maybeFail(d.Pos(), "annotation %s specified more than once, already required", guardName)
+ return functionGuard{}, false
+ }
+ if _, ok := lff.HeldOnExit[guardName]; ok {
+ pc.maybeFail(d.Pos(), "annotation %s specified more than once, already acquired", guardName)
+ return functionGuard{}, false
+ }
+ fg, ok := pc.findFunctionGuard(d, guardName, allowReturn)
+ return fg, ok
+}
+
+// addGuardedBy adds a field to both HeldOnEntry and HeldOnExit.
+func (lff *lockFunctionFacts) addGuardedBy(pc *passContext, d *ast.FuncDecl, guardName string) {
+ if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok {
+ if lff.HeldOnEntry == nil {
+ lff.HeldOnEntry = make(map[string]functionGuard)
+ }
+ if lff.HeldOnExit == nil {
+ lff.HeldOnExit = make(map[string]functionGuard)
+ }
+ lff.HeldOnEntry[guardName] = fg
+ lff.HeldOnExit[guardName] = fg
+ }
+}
+
+// addAcquires adds a field to HeldOnExit.
+func (lff *lockFunctionFacts) addAcquires(pc *passContext, d *ast.FuncDecl, guardName string) {
+ if fg, ok := lff.checkGuard(pc, d, guardName, true /* allowReturn */); ok {
+ if lff.HeldOnExit == nil {
+ lff.HeldOnExit = make(map[string]functionGuard)
+ }
+ lff.HeldOnExit[guardName] = fg
+ }
+}
+
+// addReleases adds a field to HeldOnEntry.
+func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guardName string) {
+ if fg, ok := lff.checkGuard(pc, d, guardName, false /* allowReturn */); ok {
+ if lff.HeldOnEntry == nil {
+ lff.HeldOnEntry = make(map[string]functionGuard)
+ }
+ lff.HeldOnEntry[guardName] = fg
+ }
+}
+
+// fieldListFor returns the fieldList for the given object.
+func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool) (int, bool) {
+ var lff lockFieldFacts
+ if !pc.pass.ImportObjectFact(fieldObj, &lff) {
+ // This should not happen: we export facts for all fields.
+ panic(fmt.Sprintf("no lockFieldFacts available for field %s", fieldName))
+ }
+ // Check that it is indeed a mutex.
+ if checkMutex && !lff.IsMutex && !lff.IsRWMutex {
+ pc.maybeFail(pos, "field %s is not a mutex or an rwmutex", fieldName)
+ return 0, false
+ }
+ // Return the resolution path.
+ if lff.IsPointer {
+ return -(index + 1), true
+ }
+ return (index + 1), true
+}
+
+// resolveOneField resolves a field in a single struct.
+func (pc *passContext) resolveOneField(pos token.Pos, structType *types.Struct, fieldName string, checkMutex bool) (fl fieldList, fieldObj types.Object, ok bool) {
+ // Scan to match the next field.
+ for i := 0; i < structType.NumFields(); i++ {
+ fieldObj := structType.Field(i)
+ if fieldObj.Name() != fieldName {
+ continue
+ }
+ flOne, ok := pc.fieldListFor(pos, fieldObj, i, fieldName, checkMutex)
+ if !ok {
+ return nil, nil, false
+ }
+ fl = append(fl, flOne)
+ return fl, fieldObj, true
+ }
+ // Is this an embed?
+ for i := 0; i < structType.NumFields(); i++ {
+ fieldObj := structType.Field(i)
+ if !fieldObj.Embedded() {
+ continue
+ }
+ // Is this an embedded struct?
+ structType, ok := resolveStruct(fieldObj.Type())
+ if !ok {
+ continue
+ }
+ // Need to check that there is a resolution path. If there is
+ // no resolution path that's not a failure: we just continue
+ // scanning the next embed to find a match.
+ flEmbed, okEmbed := pc.fieldListFor(pos, fieldObj, i, fieldName, false)
+ flCont, fieldObjCont, okCont := pc.resolveOneField(pos, structType, fieldName, checkMutex)
+ if okEmbed && okCont {
+ fl = append(fl, flEmbed)
+ fl = append(fl, flCont...)
+ return fl, fieldObjCont, true
+ }
+ }
+ pc.maybeFail(pos, "field %s does not exist", fieldName)
+ return nil, nil, false
+}
+
+// resolveField resolves a set of fields given a string, such a 'a.b.c'.
+//
+// Note that this checks that the final element is a mutex of some kind, and
+// will fail appropriately.
+func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, parts []string) (fl fieldList, ok bool) {
+ for partNumber, fieldName := range parts {
+ flOne, fieldObj, ok := pc.resolveOneField(pos, structType, fieldName, partNumber >= len(parts)-1 /* checkMutex */)
+ if !ok {
+ // Error already reported.
+ return nil, false
+ }
+ fl = append(fl, flOne...)
+ if partNumber < len(parts)-1 {
+ // Traverse to the next type.
+ structType, ok = resolveStruct(fieldObj.Type())
+ if !ok {
+ pc.maybeFail(pos, "invalid intermediate field %s", fieldName)
+ return fl, false
+ }
+ }
+ }
+ return fl, true
+}
+
+var (
+ mutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineMutex|Mutex)")
+ rwMutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineRWMutex|RWMutex)")
+)
+
+// exportLockFieldFacts finds all struct fields that are mutexes, and ensures
+// that they are annotated approperly.
+//
+// This information is consumed subsequently by exportLockGuardFacts, and this
+// function must be called first on all structures.
+func (pc *passContext) exportLockFieldFacts(ts *ast.TypeSpec, ss *ast.StructType) {
+ structType := pc.pass.TypesInfo.TypeOf(ts.Name).Underlying().(*types.Struct)
+ for i := range ss.Fields.List {
+ lff := &lockFieldFacts{
+ FieldNumber: i,
+ }
+ // We use HasSuffix below because fieldType can be fully
+ // qualified with the package name eg for the gvisor sync
+ // package mutex fields have the type:
+ // "<package path>/sync/sync.Mutex"
+ fieldObj := structType.Field(i)
+ s := fieldObj.Type().String()
+ switch {
+ case mutexRE.MatchString(s):
+ lff.IsMutex = true
+ case rwMutexRE.MatchString(s):
+ lff.IsRWMutex = true
+ }
+ // Save whether this is a pointer.
+ _, lff.IsPointer = fieldObj.Type().Underlying().(*types.Pointer)
+ // We must always export the lockFieldFacts, since traversal
+ // can take place along any object in the struct.
+ pc.pass.ExportObjectFact(fieldObj, lff)
+ }
+}
+
+// exportLockGuardFacts finds all relevant guard information for structures.
+//
+// This function requires exportLockFieldFacts be called first on all
+// structures.
+func (pc *passContext) exportLockGuardFacts(ts *ast.TypeSpec, ss *ast.StructType) {
+ structType := pc.pass.TypesInfo.TypeOf(ts.Name).Underlying().(*types.Struct)
+ for i, field := range ss.Fields.List {
+ if field.Doc == nil {
+ continue
+ }
+ var (
+ lff lockFieldFacts
+ lgf lockGuardFacts
+ )
+ pc.pass.ImportObjectFact(structType.Field(i), &lff)
+ fieldObj := structType.Field(i)
+ for _, l := range field.Doc.List {
+ pc.extractAnnotations(l.Text, map[string]func(string){
+ checkAtomicAnnotation: func(string) {
+ switch lgf.AtomicDisposition {
+ case atomicRequired:
+ pc.maybeFail(fieldObj.Pos(), "annotation is redundant, already atomic required")
+ case atomicIgnore:
+ pc.maybeFail(fieldObj.Pos(), "annotation is contradictory, already atomic ignored")
+ }
+ lgf.AtomicDisposition = atomicRequired
+ },
+ checkLocksIgnore: func(string) {
+ switch lgf.AtomicDisposition {
+ case atomicIgnore:
+ pc.maybeFail(fieldObj.Pos(), "annotation is redundant, already atomic ignored")
+ case atomicRequired:
+ pc.maybeFail(fieldObj.Pos(), "annotation is contradictory, already atomic required")
+ }
+ lgf.AtomicDisposition = atomicIgnore
+ },
+ checkLocksAnnotation: func(guardName string) {
+ // Check for a duplicate annotation.
+ if _, ok := lgf.GuardedBy[guardName]; ok {
+ pc.maybeFail(fieldObj.Pos(), "annotation %s specified more than once", guardName)
+ return
+ }
+ fl, ok := pc.resolveField(fieldObj.Pos(), structType, strings.Split(guardName, "."))
+ if ok {
+ // If we successfully resolved
+ // the field, then save it.
+ if lgf.GuardedBy == nil {
+ lgf.GuardedBy = make(map[string]fieldList)
+ }
+ lgf.GuardedBy[guardName] = fl
+ }
+ },
+ })
+ }
+ // Save only if there is something meaningful.
+ if len(lgf.GuardedBy) > 0 || lgf.AtomicDisposition != atomicDisallow {
+ pc.pass.ExportObjectFact(structType.Field(i), &lgf)
+ }
+ }
+}
+
+// countFields gives an accurate field count, according for unnamed arguments
+// and return values and the compact identifier format.
+func countFields(fl []*ast.Field) (count int) {
+ for _, field := range fl {
+ if len(field.Names) == 0 {
+ count++
+ continue
+ }
+ count += len(field.Names)
+ }
+ return
+}
+
+// matchFieldList attempts to match the given field.
+func (pc *passContext) matchFieldList(pos token.Pos, fl []*ast.Field, guardName string) (functionGuard, bool) {
+ parts := strings.Split(guardName, ".")
+ parameterName := parts[0]
+ parameterNumber := 0
+ for _, field := range fl {
+ // See countFields, above.
+ if len(field.Names) == 0 {
+ parameterNumber++
+ continue
+ }
+ for _, name := range field.Names {
+ if name.Name != parameterName {
+ parameterNumber++
+ continue
+ }
+ ptrType, ok := pc.pass.TypesInfo.TypeOf(field.Type).Underlying().(*types.Pointer)
+ if !ok {
+ // Since mutexes cannot be copied we only care
+ // about parameters that are pointer types when
+ // checking for guards.
+ pc.maybeFail(pos, "parameter name %s does not refer to a pointer type", parameterName)
+ return functionGuard{}, false
+ }
+ structType, ok := ptrType.Elem().Underlying().(*types.Struct)
+ if !ok {
+ // Fields can only be in named structures.
+ pc.maybeFail(pos, "parameter name %s does not refer to a pointer to a struct", parameterName)
+ return functionGuard{}, false
+ }
+ fg := functionGuard{
+ ParameterNumber: parameterNumber,
+ }
+ fl, ok := pc.resolveField(pos, structType, parts[1:])
+ fg.FieldList = fl
+ return fg, ok // If ok is false, already failed.
+ }
+ }
+ return functionGuard{}, false
+}
+
+// findFunctionGuard identifies the parameter number and field number for a
+// particular string of the 'a.b'.
+//
+// This function will report any errors directly.
+func (pc *passContext) findFunctionGuard(d *ast.FuncDecl, guardName string, allowReturn bool) (functionGuard, bool) {
+ var (
+ parameterList []*ast.Field
+ returnList []*ast.Field
+ )
+ if d.Recv != nil {
+ parameterList = append(parameterList, d.Recv.List...)
+ }
+ if d.Type.Params != nil {
+ parameterList = append(parameterList, d.Type.Params.List...)
+ }
+ if fg, ok := pc.matchFieldList(d.Pos(), parameterList, guardName); ok {
+ return fg, ok
+ }
+ if allowReturn {
+ if d.Type.Results != nil {
+ returnList = append(returnList, d.Type.Results.List...)
+ }
+ if fg, ok := pc.matchFieldList(d.Pos(), returnList, guardName); ok {
+ // Fix this up to apply to the return value, as noted
+ // in fg.ParameterNumber. For the ssa analysis, we must
+ // record whether this has multiple results, since
+ // *ssa.Call indicates: "The Call instruction yields
+ // the function result if there is exactly one.
+ // Otherwise it returns a tuple, the components of
+ // which are accessed via Extract."
+ fg.ParameterNumber += countFields(parameterList)
+ fg.NeedsExtract = countFields(returnList) > 1
+ return fg, ok
+ }
+ }
+ // We never saw a matching parameter.
+ pc.maybeFail(d.Pos(), "annotation %s does not have a matching parameter", guardName)
+ return functionGuard{}, false
+}
+
+// exportFunctionFacts exports relevant function findings.
+func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) {
+ if d.Doc == nil || d.Doc.List == nil {
+ return
+ }
+ var lff lockFunctionFacts
+ for _, l := range d.Doc.List {
+ pc.extractAnnotations(l.Text, map[string]func(string){
+ checkLocksIgnore: func(string) {
+ // Note that this applies to all atomic
+ // analysis as well. There is no provided way
+ // to selectively ignore only lock analysis or
+ // atomic analysis, as we expect this use to be
+ // extremely rare.
+ lff.Ignore = true
+ },
+ checkLocksAnnotation: func(guardName string) { lff.addGuardedBy(pc, d, guardName) },
+ checkLocksAcquires: func(guardName string) { lff.addAcquires(pc, d, guardName) },
+ checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName) },
+ })
+ }
+
+ // Export the function facts if there is anything to save.
+ if lff.Ignore || len(lff.HeldOnEntry) > 0 || len(lff.HeldOnExit) > 0 {
+ funcObj := pc.pass.TypesInfo.Defs[d.Name].(*types.Func)
+ pc.pass.ExportObjectFact(funcObj, &lff)
+ }
+}
diff --git a/tools/checklocks/state.go b/tools/checklocks/state.go
new file mode 100644
index 000000000..57061a32e
--- /dev/null
+++ b/tools/checklocks/state.go
@@ -0,0 +1,315 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package checklocks
+
+import (
+ "fmt"
+ "go/token"
+ "go/types"
+ "strings"
+ "sync/atomic"
+
+ "golang.org/x/tools/go/ssa"
+)
+
+// lockState tracks the locking state and aliases.
+type lockState struct {
+ // lockedMutexes is used to track which mutexes in a given struct are
+ // currently locked. Note that most of the heavy lifting is done by
+ // valueAsString below, which maps to specific structure fields, etc.
+ lockedMutexes []string
+
+ // stored stores values that have been stored in memory, bound to
+ // FreeVars or passed as Parameterse.
+ stored map[ssa.Value]ssa.Value
+
+ // used is a temporary map, used only for valueAsString. It prevents
+ // multiple use of the same memory location.
+ used map[ssa.Value]struct{}
+
+ // defers are the stack of defers that have been pushed.
+ defers []*ssa.Defer
+
+ // refs indicates the number of references on this structure. If it's
+ // greater than one, we will do copy-on-write.
+ refs *int32
+}
+
+// newLockState makes a new lockState.
+func newLockState() *lockState {
+ refs := int32(1) // Not shared.
+ return &lockState{
+ lockedMutexes: make([]string, 0),
+ used: make(map[ssa.Value]struct{}),
+ stored: make(map[ssa.Value]ssa.Value),
+ defers: make([]*ssa.Defer, 0),
+ refs: &refs,
+ }
+}
+
+// fork forks the locking state. When a lockState is forked, any modifications
+// will cause maps to be copied.
+func (l *lockState) fork() *lockState {
+ if l == nil {
+ return newLockState()
+ }
+ atomic.AddInt32(l.refs, 1)
+ return &lockState{
+ lockedMutexes: l.lockedMutexes,
+ used: make(map[ssa.Value]struct{}),
+ stored: l.stored,
+ defers: l.defers,
+ refs: l.refs,
+ }
+}
+
+// modify indicates that this state will be modified.
+func (l *lockState) modify() {
+ if atomic.LoadInt32(l.refs) > 1 {
+ // Copy the lockedMutexes.
+ lm := make([]string, len(l.lockedMutexes))
+ copy(lm, l.lockedMutexes)
+ l.lockedMutexes = lm
+
+ // Copy the stored values.
+ s := make(map[ssa.Value]ssa.Value)
+ for k, v := range l.stored {
+ s[k] = v
+ }
+ l.stored = s
+
+ // Reset the used values.
+ l.used = make(map[ssa.Value]struct{})
+
+ // Copy the defers.
+ ds := make([]*ssa.Defer, len(l.defers))
+ copy(ds, l.defers)
+ l.defers = ds
+
+ // Drop our reference.
+ atomic.AddInt32(l.refs, -1)
+ newRefs := int32(1) // Not shared.
+ l.refs = &newRefs
+ }
+}
+
+// isHeld indicates whether the field is held is not.
+func (l *lockState) isHeld(rv resolvedValue) (string, bool) {
+ if !rv.valid {
+ return rv.valueAsString(l), false
+ }
+ s := rv.valueAsString(l)
+ for _, k := range l.lockedMutexes {
+ if k == s {
+ return s, true
+ }
+ }
+ return s, false
+}
+
+// lockField locks the given field.
+//
+// If false is returned, the field was already locked.
+func (l *lockState) lockField(rv resolvedValue) (string, bool) {
+ if !rv.valid {
+ return rv.valueAsString(l), false
+ }
+ s := rv.valueAsString(l)
+ for _, k := range l.lockedMutexes {
+ if k == s {
+ return s, false
+ }
+ }
+ l.modify()
+ l.lockedMutexes = append(l.lockedMutexes, s)
+ return s, true
+}
+
+// unlockField unlocks the given field.
+//
+// If false is returned, the field was not locked.
+func (l *lockState) unlockField(rv resolvedValue) (string, bool) {
+ if !rv.valid {
+ return rv.valueAsString(l), false
+ }
+ s := rv.valueAsString(l)
+ for i, k := range l.lockedMutexes {
+ if k == s {
+ // Copy the last lock in and truncate.
+ l.modify()
+ l.lockedMutexes[i] = l.lockedMutexes[len(l.lockedMutexes)-1]
+ l.lockedMutexes = l.lockedMutexes[:len(l.lockedMutexes)-1]
+ return s, true
+ }
+ }
+ return s, false
+}
+
+// store records an alias.
+func (l *lockState) store(addr ssa.Value, v ssa.Value) {
+ l.modify()
+ l.stored[addr] = v
+}
+
+// isSubset indicates other holds all the locks held by l.
+func (l *lockState) isSubset(other *lockState) bool {
+ held := 0 // Number in l, held by other.
+ for _, k := range l.lockedMutexes {
+ for _, ok := range other.lockedMutexes {
+ if k == ok {
+ held++
+ break
+ }
+ }
+ }
+ return held >= len(l.lockedMutexes)
+}
+
+// count indicates the number of locks held.
+func (l *lockState) count() int {
+ return len(l.lockedMutexes)
+}
+
+// isCompatible returns true if the states are compatible.
+func (l *lockState) isCompatible(other *lockState) bool {
+ return l.isSubset(other) && other.isSubset(l)
+}
+
+// elemType is a type that implements the Elem function.
+type elemType interface {
+ Elem() types.Type
+}
+
+// valueAsString returns a string for a given value.
+//
+// This decomposes the value into the simplest possible representation in terms
+// of parameters, free variables and globals. During resolution, stored values
+// may be transferred, as well as bound free variables.
+//
+// Nil may not be passed here.
+func (l *lockState) valueAsString(v ssa.Value) string {
+ switch x := v.(type) {
+ case *ssa.Parameter:
+ // Was this provided as a paramter for a local anonymous
+ // function invocation?
+ v, ok := l.stored[x]
+ if ok {
+ return l.valueAsString(v)
+ }
+ return fmt.Sprintf("{param:%s}", x.Name())
+ case *ssa.Global:
+ return fmt.Sprintf("{global:%s}", x.Name())
+ case *ssa.FreeVar:
+ // Attempt to resolve this, in case we are being invoked in a
+ // scope where all the variables are bound.
+ v, ok := l.stored[x]
+ if ok {
+ // The FreeVar is typically bound to a location, so we
+ // check what's been stored there. Note that the second
+ // may map to the same FreeVar, which we can check.
+ stored, ok := l.stored[v]
+ if ok {
+ return l.valueAsString(stored)
+ }
+ }
+ return fmt.Sprintf("{freevar:%s}", x.Name())
+ case *ssa.Convert:
+ // Just disregard conversion.
+ return l.valueAsString(x.X)
+ case *ssa.ChangeType:
+ // Ditto, disregard.
+ return l.valueAsString(x.X)
+ case *ssa.UnOp:
+ if x.Op != token.MUL {
+ break
+ }
+ // Is this loading a free variable? If yes, then this can be
+ // resolved in the original isAlias function.
+ if fv, ok := x.X.(*ssa.FreeVar); ok {
+ return l.valueAsString(fv)
+ }
+ // Should be try to resolve via a memory address? This needs to
+ // be done since a memory location can hold its own value.
+ if _, ok := l.used[x.X]; !ok {
+ // Check if we know what the accessed location holds.
+ // This is used to disambiguate memory locations.
+ v, ok := l.stored[x.X]
+ if ok {
+ l.used[x.X] = struct{}{}
+ defer func() { delete(l.used, x.X) }()
+ return l.valueAsString(v)
+ }
+ }
+ // x.X.Type is pointer. We must construct this type
+ // dynamically, since the ssa.Value could be synthetic.
+ return fmt.Sprintf("*(%s)", l.valueAsString(x.X))
+ case *ssa.Field:
+ structType, ok := resolveStruct(x.X.Type())
+ if !ok {
+ // This should not happen.
+ panic(fmt.Sprintf("structType not available for struct: %#v", x.X))
+ }
+ fieldObj := structType.Field(x.Field)
+ return fmt.Sprintf("%s.%s", l.valueAsString(x.X), fieldObj.Name())
+ case *ssa.FieldAddr:
+ structType, ok := resolveStruct(x.X.Type())
+ if !ok {
+ // This should not happen.
+ panic(fmt.Sprintf("structType not available for struct: %#v", x.X))
+ }
+ fieldObj := structType.Field(x.Field)
+ return fmt.Sprintf("&(%s.%s)", l.valueAsString(x.X), fieldObj.Name())
+ case *ssa.Index:
+ return fmt.Sprintf("%s[%s]", l.valueAsString(x.X), l.valueAsString(x.Index))
+ case *ssa.IndexAddr:
+ return fmt.Sprintf("&(%s[%s])", l.valueAsString(x.X), l.valueAsString(x.Index))
+ case *ssa.Lookup:
+ return fmt.Sprintf("%s[%s]", l.valueAsString(x.X), l.valueAsString(x.Index))
+ case *ssa.Extract:
+ return fmt.Sprintf("%s[%d]", l.valueAsString(x.Tuple), x.Index)
+ }
+
+ // In the case of any other type (e.g. this may be an alloc, a return
+ // value, etc.), just return the literal pointer value to the Value.
+ // This will be unique within the ssa graph, and so if two values are
+ // equal, they are from the same type.
+ return fmt.Sprintf("{%T:%p}", v, v)
+}
+
+// String returns the full lock state.
+func (l *lockState) String() string {
+ if l.count() == 0 {
+ return "no locks held"
+ }
+ return strings.Join(l.lockedMutexes, ",")
+}
+
+// pushDefer pushes a defer onto the stack.
+func (l *lockState) pushDefer(d *ssa.Defer) {
+ l.modify()
+ l.defers = append(l.defers, d)
+}
+
+// popDefer pops a defer from the stack.
+func (l *lockState) popDefer() *ssa.Defer {
+ // Does not technically modify the underlying slice.
+ count := len(l.defers)
+ if count == 0 {
+ return nil
+ }
+ d := l.defers[count-1]
+ l.defers = l.defers[:count-1]
+ return d
+}
diff --git a/tools/checklocks/test/BUILD b/tools/checklocks/test/BUILD
index b055e71d9..966bbac22 100644
--- a/tools/checklocks/test/BUILD
+++ b/tools/checklocks/test/BUILD
@@ -4,5 +4,17 @@ package(licenses = ["notice"])
go_library(
name = "test",
- srcs = ["test.go"],
+ srcs = [
+ "alignment.go",
+ "atomics.go",
+ "basics.go",
+ "branches.go",
+ "closures.go",
+ "defer.go",
+ "incompat.go",
+ "methods.go",
+ "parameters.go",
+ "return.go",
+ "test.go",
+ ],
)
diff --git a/tools/checklocks/test/alignment.go b/tools/checklocks/test/alignment.go
new file mode 100644
index 000000000..cd857ff73
--- /dev/null
+++ b/tools/checklocks/test/alignment.go
@@ -0,0 +1,51 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+type alignedStruct32 struct {
+ v int32
+}
+
+type alignedStruct64 struct {
+ v int64
+}
+
+type alignedStructGood struct {
+ v0 alignedStruct32
+ v1 alignedStruct32
+ v2 alignedStruct64
+}
+
+type alignedStructGoodArray0 struct {
+ v0 [3]alignedStruct32
+ v1 [3]alignedStruct32
+ v2 alignedStruct64
+}
+
+type alignedStructGoodArray1 [16]alignedStructGood
+
+type alignedStructBad struct {
+ v0 alignedStruct32
+ v1 alignedStruct64
+ v2 alignedStruct32
+}
+
+type alignedStructBadArray0 struct {
+ v0 [3]alignedStruct32
+ v1 [2]alignedStruct64
+ v2 [1]alignedStruct32
+}
+
+type alignedStructBadArray1 [16]alignedStructBad
diff --git a/tools/checklocks/test/atomics.go b/tools/checklocks/test/atomics.go
new file mode 100644
index 000000000..8e060d8a2
--- /dev/null
+++ b/tools/checklocks/test/atomics.go
@@ -0,0 +1,91 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "sync"
+ "sync/atomic"
+)
+
+type atomicStruct struct {
+ accessedNormally int32
+
+ // +checkatomic
+ accessedAtomically int32
+
+ // +checklocksignore
+ ignored int32
+}
+
+func testNormalAccess(tc *atomicStruct, v chan int32, p chan *int32) {
+ v <- tc.accessedNormally
+ p <- &tc.accessedNormally
+}
+
+func testAtomicAccess(tc *atomicStruct, v chan int32) {
+ v <- atomic.LoadInt32(&tc.accessedAtomically)
+}
+
+func testAtomicAccessInvalid(tc *atomicStruct, v chan int32) {
+ v <- atomic.LoadInt32(&tc.accessedNormally) // +checklocksfail
+}
+
+func testNormalAccessInvalid(tc *atomicStruct, v chan int32, p chan *int32) {
+ v <- tc.accessedAtomically // +checklocksfail
+ p <- &tc.accessedAtomically // +checklocksfail
+}
+
+func testIgnored(tc *atomicStruct, v chan int32, p chan *int32) {
+ v <- atomic.LoadInt32(&tc.ignored)
+ v <- tc.ignored
+ p <- &tc.ignored
+}
+
+type atomicMixedStruct struct {
+ mu sync.Mutex
+
+ // +checkatomic
+ // +checklocks:mu
+ accessedMixed int32
+}
+
+func testAtomicMixedValidRead(tc *atomicMixedStruct, v chan int32) {
+ v <- atomic.LoadInt32(&tc.accessedMixed)
+}
+
+func testAtomicMixedInvalidRead(tc *atomicMixedStruct, v chan int32, p chan *int32) {
+ v <- tc.accessedMixed // +checklocksfail
+ p <- &tc.accessedMixed // +checklocksfail
+}
+
+func testAtomicMixedValidLockedWrite(tc *atomicMixedStruct, v chan int32, p chan *int32) {
+ tc.mu.Lock()
+ atomic.StoreInt32(&tc.accessedMixed, 1)
+ tc.mu.Unlock()
+}
+
+func testAtomicMixedInvalidLockedWrite(tc *atomicMixedStruct, v chan int32, p chan *int32) {
+ tc.mu.Lock()
+ tc.accessedMixed = 1 // +checklocksfail:2
+ tc.mu.Unlock()
+}
+
+func testAtomicMixedInvalidAtomicWrite(tc *atomicMixedStruct, v chan int32, p chan *int32) {
+ atomic.StoreInt32(&tc.accessedMixed, 1) // +checklocksfail
+}
+
+func testAtomicMixedInvalidWrite(tc *atomicMixedStruct, v chan int32, p chan *int32) {
+ tc.accessedMixed = 1 // +checklocksfail:2
+}
diff --git a/tools/checklocks/test/basics.go b/tools/checklocks/test/basics.go
new file mode 100644
index 000000000..7a773171f
--- /dev/null
+++ b/tools/checklocks/test/basics.go
@@ -0,0 +1,145 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "sync"
+)
+
+func testLockedAccessValid(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ tc.guardedField = 1
+ tc.mu.Unlock()
+}
+
+func testLockedAccessIgnore(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ tc.unguardedField = 1
+ tc.mu.Unlock()
+}
+
+func testUnlockedAccessInvalidWrite(tc *oneGuardStruct) {
+ tc.guardedField = 2 // +checklocksfail
+}
+
+func testUnlockedAccessInvalidRead(tc *oneGuardStruct) {
+ x := tc.guardedField // +checklocksfail
+ _ = x
+}
+
+func testUnlockedAccessValid(tc *oneGuardStruct) {
+ tc.unguardedField = 2
+}
+
+func testCallValidAccess(tc *oneGuardStruct) {
+ callValidAccess(tc)
+}
+
+func callValidAccess(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ tc.guardedField = 1
+ tc.mu.Unlock()
+}
+
+func testCallValueMixup(tc *oneGuardStruct) {
+ callValueMixup(tc, tc)
+}
+
+func callValueMixup(tc1, tc2 *oneGuardStruct) {
+ tc1.mu.Lock()
+ tc2.guardedField = 2 // +checklocksfail
+ tc1.mu.Unlock()
+}
+
+func testCallPreconditionsInvalid(tc *oneGuardStruct) {
+ callPreconditions(tc) // +checklocksfail
+}
+
+func testCallPreconditionsValid(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ callPreconditions(tc)
+ tc.mu.Unlock()
+}
+
+// +checklocks:tc.mu
+func callPreconditions(tc *oneGuardStruct) {
+ tc.guardedField = 1
+}
+
+type nestedFieldsStruct struct {
+ mu sync.Mutex
+
+ // +checklocks:mu
+ nestedStruct struct {
+ nested1 int
+ nested2 int
+ }
+}
+
+func testNestedGuardValid(tc *nestedFieldsStruct) {
+ tc.mu.Lock()
+ tc.nestedStruct.nested1 = 1
+ tc.nestedStruct.nested2 = 2
+ tc.mu.Unlock()
+}
+
+func testNestedGuardInvalid(tc *nestedFieldsStruct) {
+ tc.nestedStruct.nested1 = 1 // +checklocksfail
+}
+
+type rwGuardStruct struct {
+ rwMu sync.RWMutex
+
+ // +checklocks:rwMu
+ guardedField int
+}
+
+func testRWValidRead(tc *rwGuardStruct) {
+ tc.rwMu.Lock()
+ tc.guardedField = 1
+ tc.rwMu.Unlock()
+}
+
+func testRWValidWrite(tc *rwGuardStruct) {
+ tc.rwMu.RLock()
+ tc.guardedField = 2
+ tc.rwMu.RUnlock()
+}
+
+func testRWInvalidWrite(tc *rwGuardStruct) {
+ tc.guardedField = 3 // +checklocksfail
+}
+
+func testRWInvalidRead(tc *rwGuardStruct) {
+ x := tc.guardedField + 3 // +checklocksfail
+ _ = x
+}
+
+func testTwoLocksDoubleGuardStructValid(tc *twoLocksDoubleGuardStruct) {
+ tc.mu.Lock()
+ tc.secondMu.Lock()
+ tc.doubleGuardedField = 1
+ tc.secondMu.Unlock()
+}
+
+func testTwoLocksDoubleGuardStructOnlyOne(tc *twoLocksDoubleGuardStruct) {
+ tc.mu.Lock()
+ tc.doubleGuardedField = 2 // +checklocksfail
+ tc.mu.Unlock()
+}
+
+func testTwoLocksDoubleGuardStructInvalid(tc *twoLocksDoubleGuardStruct) {
+ tc.doubleGuardedField = 3 // +checklocksfail:2
+}
diff --git a/tools/checklocks/test/branches.go b/tools/checklocks/test/branches.go
new file mode 100644
index 000000000..81fec29e5
--- /dev/null
+++ b/tools/checklocks/test/branches.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "math/rand"
+)
+
+func testInconsistentReturn(tc *oneGuardStruct) { // +checklocksfail
+ if x := rand.Intn(10); x%2 == 1 {
+ tc.mu.Lock()
+ }
+}
+
+func testConsistentBranching(tc *oneGuardStruct) {
+ x := rand.Intn(10)
+ if x%2 == 1 {
+ tc.mu.Lock()
+ } else {
+ tc.mu.Lock()
+ }
+ tc.guardedField = 1
+ if x%2 == 1 {
+ tc.mu.Unlock()
+ } else {
+ tc.mu.Unlock()
+ }
+}
+
+func testInconsistentBranching(tc *oneGuardStruct) { // +checklocksfail:2
+ // We traverse the control flow graph in all consistent ways. We cannot
+ // determine however, that the first if block and second if block will
+ // evaluate to the same condition. Therefore, there are two consistent
+ // paths through this code, and two inconsistent paths. Either way, the
+ // guardedField should be also marked as an invalid access.
+ x := rand.Intn(10)
+ if x%2 == 1 {
+ tc.mu.Lock()
+ }
+ tc.guardedField = 1 // +checklocksfail
+ if x%2 == 1 {
+ tc.mu.Unlock() // +checklocksforce
+ }
+}
diff --git a/tools/checklocks/test/closures.go b/tools/checklocks/test/closures.go
new file mode 100644
index 000000000..7da87540a
--- /dev/null
+++ b/tools/checklocks/test/closures.go
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+func testClosureInvalid(tc *oneGuardStruct) {
+ // This is expected to fail.
+ callClosure(func() {
+ tc.guardedField = 1 // +checklocksfail
+ })
+}
+
+func testClosureUnsupported(tc *oneGuardStruct) {
+ // Locked outside the closure, so may or may not be valid. This cannot
+ // be handled and we should explicitly fail. This can't be handled
+ // because of the call through callClosure, below, which means the
+ // closure will actually be passed as a value somewhere.
+ tc.mu.Lock()
+ callClosure(func() {
+ tc.guardedField = 1 // +checklocksfail
+ })
+ tc.mu.Unlock()
+}
+
+func testClosureValid(tc *oneGuardStruct) {
+ // All locking happens within the closure. This should not present a
+ // problem for analysis.
+ callClosure(func() {
+ tc.mu.Lock()
+ tc.guardedField = 1
+ tc.mu.Unlock()
+ })
+}
+
+func testClosureInline(tc *oneGuardStruct) {
+ // If the closure is being dispatching inline only, then we should be
+ // able to analyze this call and give it a thumbs up.
+ tc.mu.Lock()
+ func() {
+ tc.guardedField = 1
+ }()
+ tc.mu.Unlock()
+}
+
+func testAnonymousInvalid(tc *oneGuardStruct) {
+ // Invalid, as per testClosureInvalid above.
+ callAnonymous(func(tc *oneGuardStruct) {
+ tc.guardedField = 1 // +checklocksfail
+ }, tc)
+}
+
+func testAnonymousUnsupported(tc *oneGuardStruct) {
+ // Not supportable, as per testClosureUnsupported above.
+ tc.mu.Lock()
+ callAnonymous(func(tc *oneGuardStruct) {
+ tc.guardedField = 1 // +checklocksfail
+ }, tc)
+ tc.mu.Unlock()
+}
+
+func testAnonymousValid(tc *oneGuardStruct) {
+ // Valid, as per testClosureValid above.
+ callAnonymous(func(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ tc.guardedField = 1
+ tc.mu.Unlock()
+ }, tc)
+}
+
+func testAnonymousInline(tc *oneGuardStruct) {
+ // Unlike the closure case, we are able to dynamically infer the set of
+ // preconditions for the function dispatch and assert that this is
+ // a valid call.
+ tc.mu.Lock()
+ func(tc *oneGuardStruct) {
+ tc.guardedField = 1
+ }(tc)
+ tc.mu.Unlock()
+}
+
+//go:noinline
+func callClosure(fn func()) {
+ fn()
+}
+
+//go:noinline
+func callAnonymous(fn func(*oneGuardStruct), tc *oneGuardStruct) {
+ fn(tc)
+}
diff --git a/tools/checklocks/test/defer.go b/tools/checklocks/test/defer.go
new file mode 100644
index 000000000..6e574e5eb
--- /dev/null
+++ b/tools/checklocks/test/defer.go
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+func testDeferValidUnlock(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ tc.guardedField = 1
+ defer tc.mu.Unlock()
+}
+
+func testDeferValidAccess(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ defer func() {
+ tc.guardedField = 1
+ tc.mu.Unlock()
+ }()
+}
+
+func testDeferInvalidAccess(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ defer func() {
+ // N.B. Executed after tc.mu.Unlock().
+ tc.guardedField = 1 // +checklocksfail
+ }()
+ tc.mu.Unlock()
+}
diff --git a/tools/checklocks/test/incompat.go b/tools/checklocks/test/incompat.go
new file mode 100644
index 000000000..b39bc66c1
--- /dev/null
+++ b/tools/checklocks/test/incompat.go
@@ -0,0 +1,54 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "sync"
+)
+
+// unsupportedLockerStruct verifies that trying to annotate a field that is not a
+// sync.Mutex or sync.RWMutex results in a failure.
+type unsupportedLockerStruct struct {
+ mu sync.Locker
+
+ // +checklocks:mu
+ x int // +checklocksfail
+}
+
+// badFieldsStruct verifies that refering invalid fields fails.
+type badFieldsStruct struct {
+ // +checklocks:mu
+ x int // +checklocksfail
+}
+
+// redundantStruct verifies that redundant annotations fail.
+type redundantStruct struct {
+ mu sync.Mutex
+
+ // +checklocks:mu
+ // +checklocks:mu
+ x int // +checklocksfail
+}
+
+// conflictsStruct verifies that conflicting annotations fail.
+type conflictsStruct struct {
+ // +checkatomicignore
+ // +checkatomic
+ x int // +checklocksfail
+
+ // +checkatomic
+ // +checkatomicignore
+ y int // +checklocksfail
+}
diff --git a/tools/checklocks/test/methods.go b/tools/checklocks/test/methods.go
new file mode 100644
index 000000000..72e26fca6
--- /dev/null
+++ b/tools/checklocks/test/methods.go
@@ -0,0 +1,117 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "sync"
+)
+
+type testMethods struct {
+ mu sync.Mutex
+
+ // +checklocks:mu
+ guardedField int
+}
+
+func (t *testMethods) methodValid() {
+ t.mu.Lock()
+ t.guardedField = 1
+ t.mu.Unlock()
+}
+
+func (t *testMethods) methodInvalid() {
+ t.guardedField = 2 // +checklocksfail
+}
+
+// +checklocks:t.mu
+func (t *testMethods) MethodLocked(a, b, c int) {
+ t.guardedField = 3
+}
+
+// +checklocksignore
+func (t *testMethods) methodIgnore() {
+ t.guardedField = 2
+}
+
+func testMethodCallsValid(tc *testMethods) {
+ tc.methodValid()
+}
+
+func testMethodCallsValidPreconditions(tc *testMethods) {
+ tc.mu.Lock()
+ tc.MethodLocked(1, 2, 3)
+ tc.mu.Unlock()
+}
+
+func testMethodCallsInvalid(tc *testMethods) {
+ tc.MethodLocked(4, 5, 6) // +checklocksfail
+}
+
+func testMultipleParameters(tc1, tc2, tc3 *testMethods) {
+ tc1.mu.Lock()
+ tc1.guardedField = 1
+ tc2.guardedField = 2 // +checklocksfail
+ tc3.guardedField = 3 // +checklocksfail
+ tc1.mu.Unlock()
+}
+
+type testMethodsWithParameters struct {
+ mu sync.Mutex
+
+ // +checklocks:mu
+ guardedField int
+}
+
+type ptrToTestMethodsWithParameters *testMethodsWithParameters
+
+// +checklocks:t.mu
+// +checklocks:a.mu
+func (t *testMethodsWithParameters) methodLockedWithParameters(a *testMethodsWithParameters, b *testMethodsWithParameters) {
+ t.guardedField = a.guardedField
+ b.guardedField = a.guardedField // +checklocksfail
+}
+
+// +checklocks:t.mu
+// +checklocks:a.mu
+// +checklocks:b.mu
+func (t *testMethodsWithParameters) methodLockedWithPtrType(a *testMethodsWithParameters, b ptrToTestMethodsWithParameters) {
+ t.guardedField = a.guardedField
+ b.guardedField = a.guardedField
+}
+
+// +checklocks:a.mu
+func standaloneFunctionWithGuard(a *testMethodsWithParameters) {
+ a.guardedField = 1
+ a.mu.Unlock()
+ a.guardedField = 1 // +checklocksfail
+}
+
+type testMethodsWithEmbedded struct {
+ mu sync.Mutex
+
+ // +checklocks:mu
+ guardedField int
+ p *testMethodsWithParameters
+}
+
+// +checklocks:t.mu
+func (t *testMethodsWithEmbedded) DoLocked(a, b *testMethodsWithParameters) {
+ t.guardedField = 1
+ a.mu.Lock()
+ b.mu.Lock()
+ t.p.methodLockedWithParameters(a, b) // +checklocksfail
+ a.mu.Unlock()
+ b.mu.Unlock()
+}
diff --git a/tools/checklocks/test/parameters.go b/tools/checklocks/test/parameters.go
new file mode 100644
index 000000000..5b9e664b6
--- /dev/null
+++ b/tools/checklocks/test/parameters.go
@@ -0,0 +1,48 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+func testParameterPassingbyAddrValid(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ nestedWithGuardByAddr(&tc.guardedField, &tc.unguardedField)
+ tc.mu.Unlock()
+}
+
+func testParameterPassingByAddrInalid(tc *oneGuardStruct) {
+ nestedWithGuardByAddr(&tc.guardedField, &tc.unguardedField) // +checklocksfail
+}
+
+func testParameterPassingByValueValid(tc *oneGuardStruct) {
+ tc.mu.Lock()
+ nestedWithGuardByValue(tc.guardedField, tc.unguardedField)
+ tc.mu.Unlock()
+}
+
+func testParameterPassingByValueInalid(tc *oneGuardStruct) {
+ nestedWithGuardByValue(tc.guardedField, tc.unguardedField) // +checklocksfail
+}
+
+func nestedWithGuardByAddr(guardedField, unguardedField *int) {
+ *guardedField = 4
+ *unguardedField = 5
+}
+
+func nestedWithGuardByValue(guardedField, unguardedField int) {
+ // read the fields to keep SA4009 static analyzer happy.
+ _ = guardedField
+ _ = unguardedField
+ guardedField = 4
+ unguardedField = 5
+}
diff --git a/tools/checklocks/test/return.go b/tools/checklocks/test/return.go
new file mode 100644
index 000000000..47c7b6773
--- /dev/null
+++ b/tools/checklocks/test/return.go
@@ -0,0 +1,61 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+// +checklocks:tc.mu
+func testReturnInvalidGuard() (tc *oneGuardStruct) { // +checklocksfail
+ return new(oneGuardStruct)
+}
+
+// +checklocksrelease:tc.mu
+func testReturnInvalidRelease() (tc *oneGuardStruct) { // +checklocksfail
+ return new(oneGuardStruct)
+}
+
+// +checklocksacquire:tc.mu
+func testReturnInvalidAcquire() (tc *oneGuardStruct) {
+ return new(oneGuardStruct) // +checklocksfail
+}
+
+// +checklocksacquire:tc.mu
+func testReturnValidAcquire() (tc *oneGuardStruct) {
+ tc = new(oneGuardStruct)
+ tc.mu.Lock()
+ return tc
+}
+
+func testReturnAcquireCall() {
+ tc := testReturnValidAcquire()
+ tc.guardedField = 1
+ tc.mu.Unlock()
+}
+
+// +checklocksacquire:tc.val.mu
+// +checklocksacquire:tc.ptr.mu
+func testReturnValidNestedAcquire() (tc *nestedGuardStruct) {
+ tc = new(nestedGuardStruct)
+ tc.ptr = new(oneGuardStruct)
+ tc.val.mu.Lock()
+ tc.ptr.mu.Lock()
+ return tc
+}
+
+func testReturnNestedAcquireCall() {
+ tc := testReturnValidNestedAcquire()
+ tc.val.guardedField = 1
+ tc.ptr.guardedField = 1
+ tc.val.mu.Unlock()
+ tc.ptr.mu.Unlock()
+}
diff --git a/tools/checklocks/test/test.go b/tools/checklocks/test/test.go
index 05693c183..cbf6b1635 100644
--- a/tools/checklocks/test/test.go
+++ b/tools/checklocks/test/test.go
@@ -13,99 +13,24 @@
// limitations under the License.
// Package test is a test package.
+//
+// Tests are all compilation tests in separate files.
package test
import (
- "math/rand"
"sync"
)
-type oneGuarded struct {
+// oneGuardStruct has one guarded field.
+type oneGuardStruct struct {
mu sync.Mutex
// +checklocks:mu
- guardedField int
-
+ guardedField int
unguardedField int
}
-func testAccessOne() {
- var tc oneGuarded
- // Valid access
- tc.mu.Lock()
- tc.guardedField = 1
- tc.unguardedField = 1
- tc.mu.Unlock()
-
- // Valid access as unguarded field is not protected by mu.
- tc.unguardedField = 2
-
- // Invalid access
- tc.guardedField = 2 // +checklocksfail
-
- // Invalid read of a guarded field.
- x := tc.guardedField // +checklocksfail
- _ = x
-}
-
-func testFunctionCallsNoParameters() {
- // Couple of regular function calls with no parameters.
- funcCallWithValidAccess()
- funcCallWithInvalidAccess()
-}
-
-func funcCallWithValidAccess() {
- var tc2 oneGuarded
- // Valid tc2 access
- tc2.mu.Lock()
- tc2.guardedField = 1
- tc2.mu.Unlock()
-}
-
-func funcCallWithInvalidAccess() {
- var tc oneGuarded
- var tc2 oneGuarded
- // Invalid access, wrong mutex is held.
- tc.mu.Lock()
- tc2.guardedField = 2 // +checklocksfail
- tc.mu.Unlock()
-}
-
-func testParameterPassing() {
- var tc oneGuarded
-
- // Valid call where a guardedField is passed to a function as a parameter.
- tc.mu.Lock()
- nestedWithGuardByAddr(&tc.guardedField, &tc.unguardedField)
- tc.mu.Unlock()
-
- // Invalid call where a guardedField is passed to a function as a parameter
- // without holding locks.
- nestedWithGuardByAddr(&tc.guardedField, &tc.unguardedField) // +checklocksfail
-
- // Valid call where a guardedField is passed to a function as a parameter.
- tc.mu.Lock()
- nestedWithGuardByValue(tc.guardedField, tc.unguardedField)
- tc.mu.Unlock()
-
- // Invalid call where a guardedField is passed to a function as a parameter
- // without holding locks.
- nestedWithGuardByValue(tc.guardedField, tc.unguardedField) // +checklocksfail
-}
-
-func nestedWithGuardByAddr(guardedField, unguardedField *int) {
- *guardedField = 4
- *unguardedField = 5
-}
-
-func nestedWithGuardByValue(guardedField, unguardedField int) {
- // read the fields to keep SA4009 static analyzer happy.
- _ = guardedField
- _ = unguardedField
- guardedField = 4
- unguardedField = 5
-}
-
-type twoGuarded struct {
+// twoGuardStruct has two guarded fields.
+type twoGuardStruct struct {
mu sync.Mutex
// +checklocks:mu
guardedField1 int
@@ -113,250 +38,27 @@ type twoGuarded struct {
guardedField2 int
}
-type twoLocks struct {
+// twoLocksStruct has two locks and two fields.
+type twoLocksStruct struct {
mu sync.Mutex
secondMu sync.Mutex
-
// +checklocks:mu
guardedField1 int
// +checklocks:secondMu
guardedField2 int
}
-type twoLocksDoubleGuard struct {
+// twoLocksDoubleGuardStruct has two locks and a single field with two guards.
+type twoLocksDoubleGuardStruct struct {
mu sync.Mutex
secondMu sync.Mutex
-
// +checklocks:mu
// +checklocks:secondMu
doubleGuardedField int
}
-func testTwoLocksDoubleGuard() {
- var tc twoLocksDoubleGuard
-
- // Double guarded field
- tc.mu.Lock()
- tc.secondMu.Lock()
- tc.doubleGuardedField = 1
- tc.secondMu.Unlock()
-
- // This should fail as we released the secondMu.
- tc.doubleGuardedField = 2 // +checklocksfail
- tc.mu.Unlock()
-
- // This should fail as well as now we are not holding any locks.
- //
- // This line triggers two failures one for each mutex, hence the 2 after
- // fail.
- tc.doubleGuardedField = 3 // +checklocksfail:2
-}
-
-type rwGuarded struct {
- rwMu sync.RWMutex
-
- // +checklocks:rwMu
- rwGuardedField int
-}
-
-func testRWGuarded() {
- var tc rwGuarded
-
- // Assignment w/ exclusive lock should pass.
- tc.rwMu.Lock()
- tc.rwGuardedField = 1
- tc.rwMu.Unlock()
-
- // Assignment w/ RWLock should pass as we don't differentiate between
- // Lock/RLock.
- tc.rwMu.RLock()
- tc.rwGuardedField = 2
- tc.rwMu.RUnlock()
-
- // Assignment w/o hold Lock() should fail.
- tc.rwGuardedField = 3 // +checklocksfail
-
- // Reading w/o holding lock should fail.
- x := tc.rwGuardedField + 3 // +checklocksfail
- _ = x
-}
-
-type nestedFields struct {
- mu sync.Mutex
-
- // +checklocks:mu
- nestedStruct struct {
- nested1 int
- nested2 int
- }
-}
-
-func testNestedStructGuards() {
- var tc nestedFields
- // Valid access with mu held.
- tc.mu.Lock()
- tc.nestedStruct.nested1 = 1
- tc.nestedStruct.nested2 = 2
- tc.mu.Unlock()
-
- // Invalid access to nested1 wihout holding mu.
- tc.nestedStruct.nested1 = 1 // +checklocksfail
-}
-
-type testCaseMethods struct {
- mu sync.Mutex
-
- // +checklocks:mu
- guardedField int
-}
-
-func (t *testCaseMethods) Method() {
- // Valid access
- t.mu.Lock()
- t.guardedField = 1
- t.mu.Unlock()
-
- // invalid access
- t.guardedField = 2 // +checklocksfail
-}
-
-// +checklocks:t.mu
-func (t *testCaseMethods) MethodLocked(a, b, c int) {
- t.guardedField = 3
-}
-
-// +checklocksignore
-func (t *testCaseMethods) IgnoredMethod() {
- // Invalid access but should not fail as the function is annotated
- // with "// +checklocksignore"
- t.guardedField = 2
-}
-
-func testMethodCalls() {
- var tc2 testCaseMethods
-
- // Valid use, tc2.Method acquires lock.
- tc2.Method()
-
- // Valid access tc2.mu is held before calling tc2.MethodLocked.
- tc2.mu.Lock()
- tc2.MethodLocked(1, 2, 3)
- tc2.mu.Unlock()
-
- // Invalid access no locks are being held.
- tc2.MethodLocked(4, 5, 6) // +checklocksfail
-}
-
-type noMutex struct {
- f int
- g int
-}
-
-func (n noMutex) method() {
- n.f = 1
- n.f = n.g
-}
-
-func testNoMutex() {
- var n noMutex
- n.method()
-}
-
-func testMultiple() {
- var tc1, tc2, tc3 testCaseMethods
-
- tc1.mu.Lock()
-
- // Valid access we are holding tc1's lock.
- tc1.guardedField = 1
-
- // Invalid access we are not holding tc2 or tc3's lock.
- tc2.guardedField = 2 // +checklocksfail
- tc3.guardedField = 3 // +checklocksfail
- tc1.mu.Unlock()
-}
-
-func testConditionalBranchingLocks() {
- var tc2 testCaseMethods
- x := rand.Intn(10)
- if x%2 == 1 {
- tc2.mu.Lock()
- }
- // This is invalid access as tc2.mu is not held if we never entered
- // the if block.
- tc2.guardedField = 1 // +checklocksfail
-
- var tc3 testCaseMethods
- if x%2 == 1 {
- tc3.mu.Lock()
- } else {
- tc3.mu.Lock()
- }
- // This is valid as tc3.mu is held in if and else blocks.
- tc3.guardedField = 1
-}
-
-type testMethodWithParams struct {
- mu sync.Mutex
-
- // +checklocks:mu
- guardedField int
-}
-
-type ptrToTestMethodWithParams *testMethodWithParams
-
-// +checklocks:t.mu
-// +checklocks:a.mu
-func (t *testMethodWithParams) methodLockedWithParams(a *testMethodWithParams, b *testMethodWithParams) {
- t.guardedField = a.guardedField
- b.guardedField = a.guardedField // +checklocksfail
-}
-
-// +checklocks:t.mu
-// +checklocks:a.mu
-// +checklocks:b.mu
-func (t *testMethodWithParams) methodLockedWithPtrType(a *testMethodWithParams, b ptrToTestMethodWithParams) {
- t.guardedField = a.guardedField
- b.guardedField = a.guardedField
-}
-
-// +checklocks:a.mu
-func standaloneFunctionWithGuard(a *testMethodWithParams) {
- a.guardedField = 1
- a.mu.Unlock()
- a.guardedField = 1 // +checklocksfail
-}
-
-type testMethodWithEmbedded struct {
- mu sync.Mutex
-
- // +checklocks:mu
- guardedField int
- p *testMethodWithParams
-}
-
-// +checklocks:t.mu
-func (t *testMethodWithEmbedded) DoLocked() {
- var a, b testMethodWithParams
- t.guardedField = 1
- a.mu.Lock()
- b.mu.Lock()
- t.p.methodLockedWithParams(&a, &b) // +checklocksfail
- a.mu.Unlock()
- b.mu.Unlock()
-}
-
-// UnsupportedLockerExample is a test that verifies that trying to annotate a
-// field that is not a sync.Mutex/RWMutex results in a failure.
-type UnsupportedLockerExample struct {
- mu sync.Locker
-
- // +checklocks:mu
- x int // +checklocksfail
-}
-
-func abc() {
- var mu sync.Mutex
- a := UnsupportedLockerExample{mu: &mu}
- a.x = 1
+// nestedGuardStruct nests oneGuardStruct fields.
+type nestedGuardStruct struct {
+ val oneGuardStruct
+ ptr *oneGuardStruct
}
diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD
index 6c6f604b5..a7e280b32 100644
--- a/tools/nogo/BUILD
+++ b/tools/nogo/BUILD
@@ -37,6 +37,7 @@ go_library(
"//tools/checkescape",
"//tools/checklocks",
"//tools/checkunsafe",
+ "//tools/nogo/objdump",
"//tools/worker",
"@co_honnef_go_tools//staticcheck:go_default_library",
"@co_honnef_go_tools//stylecheck:go_default_library",
@@ -68,6 +69,7 @@ go_library(
"@org_golang_x_tools//go/analysis/passes/unsafeptr:go_default_library",
"@org_golang_x_tools//go/analysis/passes/unusedresult:go_default_library",
"@org_golang_x_tools//go/gcexportdata:go_default_library",
+ "@org_golang_x_tools//go/types/objectpath:go_default_library",
],
)
diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go
index 3a6c3fb08..0e7e92965 100644
--- a/tools/nogo/check/main.go
+++ b/tools/nogo/check/main.go
@@ -62,7 +62,8 @@ func run([]string) int {
// Check & load the configuration.
if *packageFile != "" && *stdlibFile != "" {
- log.Fatalf("unable to perform stdlib and package analysis; provide only one!")
+ fmt.Fprintf(os.Stderr, "unable to perform stdlib and package analysis; provide only one!")
+ return 1
}
// Run the configuration.
@@ -75,18 +76,21 @@ func run([]string) int {
c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig)
findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil)
} else {
- log.Fatalf("please provide at least one of package or stdlib!")
+ fmt.Fprintf(os.Stderr, "please provide at least one of package or stdlib!")
+ return 1
}
// Check that analysis was successful.
if err != nil {
- log.Fatalf("error performing analysis: %v", err)
+ fmt.Fprintf(os.Stderr, "error performing analysis: %v", err)
+ return 1
}
// Save facts.
if *factsOutput != "" {
if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil {
- log.Fatalf("error saving findings to %q: %v", *factsOutput, err)
+ fmt.Fprintf(os.Stderr, "error saving findings to %q: %v", *factsOutput, err)
+ return 1
}
}
@@ -94,10 +98,12 @@ func run([]string) int {
if *findingsOutput != "" {
w, err := os.OpenFile(*findingsOutput, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
- log.Fatalf("error opening output file %q: %v", *findingsOutput, err)
+ fmt.Fprintf(os.Stderr, "error opening output file %q: %v", *findingsOutput, err)
+ return 1
}
if err := nogo.WriteFindingsTo(w, findings, false /* json */); err != nil {
- log.Fatalf("error writing findings to %q: %v", *findingsOutput, err)
+ fmt.Fprintf(os.Stderr, "error writing findings to %q: %v", *findingsOutput, err)
+ return 1
}
} else {
for _, finding := range findings {
diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl
index ddf5816a6..80182ff6c 100644
--- a/tools/nogo/defs.bzl
+++ b/tools/nogo/defs.bzl
@@ -198,6 +198,22 @@ NogoInfo = provider(
},
)
+def _select_objfile(files):
+ """Returns (.a file, .x file, is_archive).
+
+ If no .a file is available, then the first .x file will be returned
+ instead, and vice versa. If neither are available, then the first provided
+ file will be returned."""
+ a_files = [f for f in files if f.path.endswith(".a")]
+ x_files = [f for f in files if f.path.endswith(".x")]
+ if not len(x_files) and not len(a_files):
+ return (files[0], files[0], False)
+ if not len(x_files):
+ x_files = a_files
+ if not len(a_files):
+ a_files = x_files
+ return a_files[0], x_files[0], True
+
def _nogo_aspect_impl(target, ctx):
# If this is a nogo rule itself (and not the shadow of a go_library or
# go_binary rule created by such a rule), then we simply return nothing.
@@ -232,20 +248,14 @@ def _nogo_aspect_impl(target, ctx):
deps = deps + info.deps
# Start with all target files and srcs as input.
- inputs = target.files.to_list() + srcs
+ binaries = target.files.to_list()
+ inputs = binaries + srcs
# Generate a shell script that dumps the binary. Annoyingly, this seems
# necessary as the context in which a run_shell command runs does not seem
# to cleanly allow us redirect stdout to the actual output file. Perhaps
# I'm missing something here, but the intermediate script does work.
- binaries = target.files.to_list()
- objfiles = [f for f in binaries if f.path.endswith(".a")]
- if len(objfiles) > 0:
- # Prefer the .a files for go_library targets.
- target_objfile = objfiles[0]
- else:
- # Use the raw binary for go_binary and go_test targets.
- target_objfile = binaries[0]
+ target_objfile, target_xfile, has_objfile = _select_objfile(binaries)
inputs.append(target_objfile)
# Extract the importpath for this package.
@@ -274,10 +284,8 @@ def _nogo_aspect_impl(target, ctx):
# Configure where to find the binary & fact files. Note that this will
# use .x and .a regardless of whether this is a go_binary rule, since
# these dependencies must be go_library rules.
- x_files = [f.path for f in info.binaries if f.path.endswith(".x")]
- if not len(x_files):
- x_files = [f.path for f in info.binaries if f.path.endswith(".a")]
- import_map[info.importpath] = x_files[0]
+ _, x_file, _ = _select_objfile(info.binaries)
+ import_map[info.importpath] = x_file.path
fact_map[info.importpath] = info.facts.path
# Collect all findings; duplicates are resolved at the end.
@@ -287,6 +295,11 @@ def _nogo_aspect_impl(target, ctx):
inputs.append(info.facts)
inputs += info.binaries
+ # Add the module itself, for the type sanity check. This applies only to
+ # the libraries, and not binaries or tests.
+ if has_objfile:
+ import_map[importpath] = target_xfile.path
+
# Add the standard library facts.
stdlib_info = ctx.attr._nogo_stdlib[NogoStdlibInfo]
stdlib_facts = stdlib_info.facts
diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go
index acee7c8bc..d95d7652f 100644
--- a/tools/nogo/nogo.go
+++ b/tools/nogo/nogo.go
@@ -41,9 +41,10 @@ import (
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/internal/facts"
"golang.org/x/tools/go/gcexportdata"
+ "golang.org/x/tools/go/types/objectpath"
// Special case: flags live here and change overall behavior.
- "gvisor.dev/gvisor/tools/checkescape"
+ "gvisor.dev/gvisor/tools/nogo/objdump"
"gvisor.dev/gvisor/tools/worker"
)
@@ -216,6 +217,11 @@ func (i *importer) Import(path string) (*types.Package, error) {
}
}
+ // Check the cache.
+ if pkg, ok := i.cache[path]; ok && pkg.Complete() {
+ return pkg, nil
+ }
+
// Actually load the data.
realPath, ok := i.ImportMap[path]
var (
@@ -327,6 +333,9 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi
// Closure to check a single package.
localStdlibFacts := make(stdlibFacts)
localStdlibErrs := make(map[string]error)
+ stdlibCachedFacts.Lookup([]string{""}, func() worker.Sizer {
+ return localStdlibFacts
+ })
var checkOne func(pkg string) error // Recursive.
checkOne = func(pkg string) error {
// Is this already done?
@@ -355,11 +364,11 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi
}
// Provide the input.
- oldReader := checkescape.Reader
- checkescape.Reader = rc // For analysis.
+ oldReader := objdump.Reader
+ objdump.Reader = rc // For analysis.
defer func() {
rc.Close()
- checkescape.Reader = oldReader // Restore.
+ objdump.Reader = oldReader // Restore.
}()
// Run the analysis.
@@ -406,6 +415,56 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi
return allFindings, buf.Bytes(), nil
}
+// sanityCheckScope checks that all object in astTypes map to the correct
+// objects in binaryTypes. Note that we don't check whether the sets are the
+// same, we only care about the fidelity of objects in astTypes.
+//
+// When an inconsistency is identified, we record it in the astToBinaryMap.
+// This allows us to dynamically replace facts and correct for the issue. The
+// total number of mismatches is returned.
+func sanityCheckScope(astScope *types.Scope, binaryTypes *types.Package, binaryScope *types.Scope, astToBinary map[types.Object]types.Object) error {
+ for _, x := range astScope.Names() {
+ fe := astScope.Lookup(x)
+ path, err := objectpath.For(fe)
+ if err != nil {
+ continue // Not an encoded object.
+ }
+ se, err := objectpath.Object(binaryTypes, path)
+ if err != nil {
+ continue // May be unused, see below.
+ }
+ if fe.Id() != se.Id() {
+ // These types are incompatible. This means that when
+ // this objectpath is loading from the binaryTypes (for
+ // dependencies) it will resolve to a fact for that
+ // type. We don't actually care about this error since
+ // we do the rewritten, but may as well alert.
+ log.Printf("WARNING: Object %s is a victim of go/issues/44195.", fe.Id())
+ }
+ se = binaryScope.Lookup(x)
+ if se == nil {
+ // The fact may not be exported in the objectdata, if
+ // it is package internal. This is fine, as nothing out
+ // of this package can use these symbols.
+ continue
+ }
+ // Save the translation.
+ astToBinary[fe] = se
+ }
+ for i := 0; i < astScope.NumChildren(); i++ {
+ if err := sanityCheckScope(astScope.Child(i), binaryTypes, binaryScope, astToBinary); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// sanityCheckTypes checks that two types are sane. The total number of
+// mismatches is returned.
+func sanityCheckTypes(astTypes, binaryTypes *types.Package, astToBinary map[types.Object]types.Object) error {
+ return sanityCheckScope(astTypes.Scope(), binaryTypes, binaryTypes.Scope(), astToBinary)
+}
+
// CheckPackage runs all given analyzers.
//
// The implementation was adapted from [1], which was in turn adpated from [2].
@@ -450,17 +509,46 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC
Scopes: make(map[ast.Node]*types.Scope),
Selections: make(map[*ast.SelectorExpr]*types.Selection),
}
- types, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo)
+ astTypes, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo)
if err != nil && imp.lastErr != ErrSkip {
return nil, nil, fmt.Errorf("error checking types: %w", err)
}
- // Load all package facts.
- facts, err := facts.Decode(types, config.factLoader)
+ // Load all facts using the astTypes, although it may need reconciling
+ // later on. See the fact functions below.
+ astFacts, err := facts.Decode(astTypes, config.factLoader)
if err != nil {
return nil, nil, fmt.Errorf("error decoding facts: %w", err)
}
+ // Sanity check all types and record metadata to prevent
+ // https://github.com/golang/go/issues/44195.
+ //
+ // This block loads the binary types, whose encoding will be well
+ // defined and aligned with any downstream consumers. Below in the fact
+ // functions for the analysis, we serialize types to both the astFacts
+ // and the binaryFacts if available. The binaryFacts are the final
+ // encoded facts in order to ensure compatibility. We keep the
+ // intermediate astTypes in order to allow exporting and importing
+ // within the local package under analysis.
+ var (
+ astToBinary = make(map[types.Object]types.Object)
+ binaryFacts *facts.Set
+ )
+ if _, ok := config.ImportMap[config.ImportPath]; ok {
+ binaryTypes, err := imp.Import(config.ImportPath)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error loading self: %w", err)
+ }
+ if err := sanityCheckTypes(astTypes, binaryTypes, astToBinary); err != nil {
+ return nil, nil, fmt.Errorf("error sanity checking types: %w", err)
+ }
+ binaryFacts, err = facts.Decode(binaryTypes, config.factLoader)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error decoding facts: %w", err)
+ }
+ }
+
// Register fact types and establish dependencies between analyzers.
// The visit closure will execute recursively, and populate results
// will all required analysis results.
@@ -479,15 +567,15 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC
}
// Run the analysis.
- factFilter := make(map[reflect.Type]bool)
+ localFactsFilter := make(map[reflect.Type]bool)
for _, f := range a.FactTypes {
- factFilter[reflect.TypeOf(f)] = true
+ localFactsFilter[reflect.TypeOf(f)] = true
}
p := &analysis.Pass{
Analyzer: a,
Fset: imp.fset,
Files: syntax,
- Pkg: types,
+ Pkg: astTypes,
TypesInfo: typesInfo,
ResultOf: results, // All results.
Report: func(d analysis.Diagnostic) {
@@ -497,13 +585,29 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC
Message: d.Message,
})
},
- ImportPackageFact: facts.ImportPackageFact,
- ExportPackageFact: facts.ExportPackageFact,
- ImportObjectFact: facts.ImportObjectFact,
- ExportObjectFact: facts.ExportObjectFact,
- AllPackageFacts: func() []analysis.PackageFact { return facts.AllPackageFacts(factFilter) },
- AllObjectFacts: func() []analysis.ObjectFact { return facts.AllObjectFacts(factFilter) },
- TypesSizes: typesSizes,
+ ImportPackageFact: astFacts.ImportPackageFact,
+ ExportPackageFact: func(fact analysis.Fact) {
+ astFacts.ExportPackageFact(fact)
+ if binaryFacts != nil {
+ binaryFacts.ExportPackageFact(fact)
+ }
+ },
+ ImportObjectFact: astFacts.ImportObjectFact,
+ ExportObjectFact: func(obj types.Object, fact analysis.Fact) {
+ astFacts.ExportObjectFact(obj, fact)
+ // Note that if no object is recorded in
+ // astToBinary and binaryFacts != nil, then the
+ // object doesn't appear in the exported data.
+ // It was likely an internal object to the
+ // package, and there is no meaningful
+ // downstream consumer of the fact.
+ if binaryObj, ok := astToBinary[obj]; ok && binaryFacts != nil {
+ binaryFacts.ExportObjectFact(binaryObj, fact)
+ }
+ },
+ AllPackageFacts: func() []analysis.PackageFact { return astFacts.AllPackageFacts(localFactsFilter) },
+ AllObjectFacts: func() []analysis.ObjectFact { return astFacts.AllObjectFacts(localFactsFilter) },
+ TypesSizes: typesSizes,
}
result, err := a.Run(p)
if err != nil {
@@ -528,8 +632,14 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC
}
}
- // Return all findings.
- return findings, facts.Encode(), nil
+ // Return all findings. Note that we have a preference to returning the
+ // binary facts if available, so that downstream consumers of these
+ // facts will find the export aligns with the internal type details.
+ // See the block above with the call to sanityCheckTypes.
+ if binaryFacts != nil {
+ return findings, binaryFacts.Encode(), nil
+ }
+ return findings, astFacts.Encode(), nil
}
func init() {
diff --git a/tools/nogo/objdump/BUILD b/tools/nogo/objdump/BUILD
new file mode 100644
index 000000000..da56efdf7
--- /dev/null
+++ b/tools/nogo/objdump/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "objdump",
+ srcs = ["objdump.go"],
+ nogo = False,
+ visibility = ["//tools:__subpackages__"],
+)
diff --git a/tools/nogo/objdump/objdump.go b/tools/nogo/objdump/objdump.go
new file mode 100644
index 000000000..48484abf3
--- /dev/null
+++ b/tools/nogo/objdump/objdump.go
@@ -0,0 +1,96 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package objdump is a wrapper around relevant objdump flags.
+package objdump
+
+import (
+ "flag"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+)
+
+var (
+ // Binary is the binary under analysis.
+ //
+ // See Reader, below.
+ binary = flag.String("binary", "", "binary under analysis")
+
+ // Reader is the input stream.
+ //
+ // This may be set instead of Binary.
+ Reader io.Reader
+
+ // objdumpTool is the tool used to dump a binary.
+ objdumpTool = flag.String("objdump_tool", "", "tool used to dump a binary")
+)
+
+// LoadRaw reads the raw object output.
+func LoadRaw(fn func(r io.Reader) error) error {
+ var r io.Reader
+ if *binary != "" {
+ f, err := os.Open(*binary)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ r = f
+ } else if Reader != nil {
+ r = Reader
+ } else {
+ // We have no input stream.
+ return fmt.Errorf("no binary or reader provided")
+ }
+ return fn(r)
+}
+
+// Load reads the objdump output.
+func Load(fn func(r io.Reader) error) error {
+ var (
+ args []string
+ stdin io.Reader
+ )
+ if *binary != "" {
+ args = append(args, *binary)
+ } else if Reader != nil {
+ stdin = Reader
+ } else {
+ // We have no input stream or binary.
+ return fmt.Errorf("no binary or reader provided")
+ }
+
+ // Construct our command.
+ cmd := exec.Command(*objdumpTool, args...)
+ cmd.Stdin = stdin
+ cmd.Stderr = os.Stderr
+ out, err := cmd.StdoutPipe()
+ if err != nil {
+ return err
+ }
+ if err := cmd.Start(); err != nil {
+ return err
+ }
+
+ // Call the user hook.
+ userErr := fn(out)
+
+ // Wait for the dump to finish.
+ if err := cmd.Wait(); userErr == nil && err != nil {
+ return err
+ }
+
+ return userErr
+}