diff options
43 files changed, 1110 insertions, 418 deletions
diff --git a/g3doc/README.md b/g3doc/README.md index dc4179037..5e23aa5ec 100644 --- a/g3doc/README.md +++ b/g3doc/README.md @@ -23,7 +23,7 @@ links below to see detailed instructions for each of them: gVisor provides a virtualized environment in order to sandbox containers. The system interfaces normally implemented by the host kernel are moved into a -distinct, per-sandbox application kernel in order to minimize the risk of an +distinct, per-sandbox application kernel in order to minimize the risk of a container escape exploit. gVisor does not introduce large fixed overheads however, and still retains a process-like model with respect to resource utilization. diff --git a/pkg/atomicbitops/aligned_32bit_unsafe.go b/pkg/atomicbitops/aligned_32bit_unsafe.go index 0e4765c48..a143c027d 100644 --- a/pkg/atomicbitops/aligned_32bit_unsafe.go +++ b/pkg/atomicbitops/aligned_32bit_unsafe.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build arm || mips || 386 -// +build arm mips 386 +//go:build arm || mips || mipsle || 386 +// +build arm mips mipsle 386 package atomicbitops diff --git a/pkg/atomicbitops/aligned_64bit.go b/pkg/atomicbitops/aligned_64bit.go index 2c421d920..634f0ed2c 100644 --- a/pkg/atomicbitops/aligned_64bit.go +++ b/pkg/atomicbitops/aligned_64bit.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build !arm && !mips && !386 -// +build !arm,!mips,!386 +//go:build !arm && !mips && !mipsle && !386 +// +build !arm,!mips,!mipsle,!386 package atomicbitops diff --git a/pkg/context/context.go b/pkg/context/context.go index f3031fc60..e86c14195 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -29,26 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/log" ) -type contextID int - -// Globally accessible values from a context. These keys are defined in the -// context package to resolve dependency cycles by not requiring the caller to -// import packages usually required to get these information. -const ( - // CtxThreadGroupID is the current thread group ID when a context represents - // a task context. The value is represented as an int32. - CtxThreadGroupID contextID = iota -) - -// ThreadGroupIDFromContext returns the current thread group ID when ctx -// represents a task context. -func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { - if tgid := ctx.Value(CtxThreadGroupID); tgid != nil { - return tgid.(int32), true - } - return 0, false -} - // A Context represents a thread of execution (hereafter "goroutine" to reflect // Go idiosyncrasy). It carries state associated with the goroutine across API // boundaries. diff --git a/pkg/errors/linuxerr/BUILD b/pkg/errors/linuxerr/BUILD index e73b0e28a..5b59ebd6e 100644 --- a/pkg/errors/linuxerr/BUILD +++ b/pkg/errors/linuxerr/BUILD @@ -21,7 +21,6 @@ go_test( srcs = ["linuxerr_test.go"], deps = [ ":linuxerr", - "//pkg/abi/linux/errno", "//pkg/errors", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/errors/linuxerr/linuxerr.go b/pkg/errors/linuxerr/linuxerr.go index 5905ef593..e44a55afd 100644 --- a/pkg/errors/linuxerr/linuxerr.go +++ b/pkg/errors/linuxerr/linuxerr.go @@ -34,41 +34,41 @@ const maxErrno uint32 = errno.EHWPOISON + 1 // (e.g. unix.Errno(EPERM.Errno()) == unix.EPERM is true). Converting unix/syscall.Errno // to the errors should be done via the lookup methods provided. var ( - NOERROR = errors.New(errno.NOERRNO, "not an error") - EPERM = errors.New(errno.EPERM, "operation not permitted") - ENOENT = errors.New(errno.ENOENT, "no such file or directory") - ESRCH = errors.New(errno.ESRCH, "no such process") - EINTR = errors.New(errno.EINTR, "interrupted system call") - EIO = errors.New(errno.EIO, "I/O error") - ENXIO = errors.New(errno.ENXIO, "no such device or address") - E2BIG = errors.New(errno.E2BIG, "argument list too long") - ENOEXEC = errors.New(errno.ENOEXEC, "exec format error") - EBADF = errors.New(errno.EBADF, "bad file number") - ECHILD = errors.New(errno.ECHILD, "no child processes") - EAGAIN = errors.New(errno.EAGAIN, "try again") - ENOMEM = errors.New(errno.ENOMEM, "out of memory") - EACCES = errors.New(errno.EACCES, "permission denied") - EFAULT = errors.New(errno.EFAULT, "bad address") - ENOTBLK = errors.New(errno.ENOTBLK, "block device required") - EBUSY = errors.New(errno.EBUSY, "device or resource busy") - EEXIST = errors.New(errno.EEXIST, "file exists") - EXDEV = errors.New(errno.EXDEV, "cross-device link") - ENODEV = errors.New(errno.ENODEV, "no such device") - ENOTDIR = errors.New(errno.ENOTDIR, "not a directory") - EISDIR = errors.New(errno.EISDIR, "is a directory") - EINVAL = errors.New(errno.EINVAL, "invalid argument") - ENFILE = errors.New(errno.ENFILE, "file table overflow") - EMFILE = errors.New(errno.EMFILE, "too many open files") - ENOTTY = errors.New(errno.ENOTTY, "not a typewriter") - ETXTBSY = errors.New(errno.ETXTBSY, "text file busy") - EFBIG = errors.New(errno.EFBIG, "file too large") - ENOSPC = errors.New(errno.ENOSPC, "no space left on device") - ESPIPE = errors.New(errno.ESPIPE, "illegal seek") - EROFS = errors.New(errno.EROFS, "read-only file system") - EMLINK = errors.New(errno.EMLINK, "too many links") - EPIPE = errors.New(errno.EPIPE, "broken pipe") - EDOM = errors.New(errno.EDOM, "math argument out of domain of func") - ERANGE = errors.New(errno.ERANGE, "math result not representable") + noError *errors.Error = nil + EPERM = errors.New(errno.EPERM, "operation not permitted") + ENOENT = errors.New(errno.ENOENT, "no such file or directory") + ESRCH = errors.New(errno.ESRCH, "no such process") + EINTR = errors.New(errno.EINTR, "interrupted system call") + EIO = errors.New(errno.EIO, "I/O error") + ENXIO = errors.New(errno.ENXIO, "no such device or address") + E2BIG = errors.New(errno.E2BIG, "argument list too long") + ENOEXEC = errors.New(errno.ENOEXEC, "exec format error") + EBADF = errors.New(errno.EBADF, "bad file number") + ECHILD = errors.New(errno.ECHILD, "no child processes") + EAGAIN = errors.New(errno.EAGAIN, "try again") + ENOMEM = errors.New(errno.ENOMEM, "out of memory") + EACCES = errors.New(errno.EACCES, "permission denied") + EFAULT = errors.New(errno.EFAULT, "bad address") + ENOTBLK = errors.New(errno.ENOTBLK, "block device required") + EBUSY = errors.New(errno.EBUSY, "device or resource busy") + EEXIST = errors.New(errno.EEXIST, "file exists") + EXDEV = errors.New(errno.EXDEV, "cross-device link") + ENODEV = errors.New(errno.ENODEV, "no such device") + ENOTDIR = errors.New(errno.ENOTDIR, "not a directory") + EISDIR = errors.New(errno.EISDIR, "is a directory") + EINVAL = errors.New(errno.EINVAL, "invalid argument") + ENFILE = errors.New(errno.ENFILE, "file table overflow") + EMFILE = errors.New(errno.EMFILE, "too many open files") + ENOTTY = errors.New(errno.ENOTTY, "not a typewriter") + ETXTBSY = errors.New(errno.ETXTBSY, "text file busy") + EFBIG = errors.New(errno.EFBIG, "file too large") + ENOSPC = errors.New(errno.ENOSPC, "no space left on device") + ESPIPE = errors.New(errno.ESPIPE, "illegal seek") + EROFS = errors.New(errno.EROFS, "read-only file system") + EMLINK = errors.New(errno.EMLINK, "too many links") + EPIPE = errors.New(errno.EPIPE, "broken pipe") + EDOM = errors.New(errno.EDOM, "math argument out of domain of func") + ERANGE = errors.New(errno.ERANGE, "math result not representable") // Errno values from include/uapi/asm-generic/errno.h. EDEADLK = errors.New(errno.EDEADLK, "resource deadlock would occur") @@ -186,7 +186,7 @@ var errNotValidError = errors.New(errno.Errno(maxErrno), "not a valid error") // errnos (especially uint32(sycall.Errno)) and *errors.Error. var errorSlice = []*errors.Error{ // Errno values from include/uapi/asm-generic/errno-base.h. - errno.NOERRNO: NOERROR, + errno.NOERRNO: noError, errno.EPERM: EPERM, errno.ENOENT: ENOENT, errno.ESRCH: ESRCH, @@ -324,32 +324,45 @@ var errorSlice = []*errors.Error{ errno.EHWPOISON: EHWPOISON, } -// ErrorFromErrno gets an error from the list and panics if an invalid entry is requested. -func ErrorFromErrno(e errno.Errno) *errors.Error { - err := errorSlice[e] +// ErrorFromUnix returns a linuxerr from a unix.Errno. +func ErrorFromUnix(err unix.Errno) error { + if err == unix.Errno(0) { + return nil + } + e := errorSlice[errno.Errno(err)] // Done this way because a single comparison in benchmarks is 2-3 faster // than something like ( if err == nil && err > 0 ). - if err != errNotValidError { - return err + if e == errNotValidError { + panic(fmt.Sprintf("invalid error requested with errno: %v", e)) } - panic(fmt.Sprintf("invalid error requested with errno: %d", e)) + return e } -// Equals compars a linuxerr to a given error -// TODO(b/34162363): Remove when syserror is removed. -func Equals(e *errors.Error, err error) bool { - if err == nil { - return e == NOERROR || e == nil +// ToError converts a linuxerr to an error type. +func ToError(err *errors.Error) error { + if err == noError { + return nil } - if e == nil { - return err == NOERROR || err == unix.Errno(0) + return err +} + +// ToUnix converts a linuxerr to a unix.Errno. +func ToUnix(e *errors.Error) unix.Errno { + var unixErr unix.Errno + if e != noError { + unixErr = unix.Errno(e.Errno()) } + return unixErr +} - switch err.(type) { - case *errors.Error: - return e == err - case unix.Errno, error: - return unix.Errno(e.Errno()) == err +// Equals compars a linuxerr to a given error. +func Equals(e *errors.Error, err error) bool { + var unixErr unix.Errno + if e != noError { + unixErr = unix.Errno(e.Errno()) + } + if err == nil { + err = noError } - return false + return e == err || unixErr == err } diff --git a/pkg/errors/linuxerr/linuxerr_test.go b/pkg/errors/linuxerr/linuxerr_test.go index df7cd1c5a..b99884b22 100644 --- a/pkg/errors/linuxerr/linuxerr_test.go +++ b/pkg/errors/linuxerr/linuxerr_test.go @@ -23,7 +23,6 @@ import ( "testing" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux/errno" gErrors "gvisor.dev/gvisor/pkg/errors" "gvisor.dev/gvisor/pkg/errors/linuxerr" ) @@ -121,7 +120,7 @@ func BenchmarkReturnLinuxerr(b *testing.B) { func BenchmarkConvertUnixLinuxerr(b *testing.B) { var localError error for i := b.N; i > 0; i-- { - localError = linuxerr.ErrorFromErrno(errno.Errno(unix.EINVAL)) + localError = linuxerr.ErrorFromUnix(unix.EINVAL) } if localError != nil { return @@ -131,7 +130,7 @@ func BenchmarkConvertUnixLinuxerr(b *testing.B) { func BenchmarkConvertUnixLinuxerrZero(b *testing.B) { var localError error for i := b.N; i > 0; i-- { - localError = linuxerr.ErrorFromErrno(errno.Errno(0)) + localError = linuxerr.ErrorFromUnix(unix.Errno(0)) } if localError != nil { return @@ -179,7 +178,7 @@ func TestErrorTranslation(t *testing.T) { func TestSyscallErrnoToErrors(t *testing.T) { for _, tc := range []struct { errno syscall.Errno - err *gErrors.Error + err error }{ {errno: syscall.EACCES, err: linuxerr.EACCES}, {errno: syscall.EAGAIN, err: linuxerr.EAGAIN}, @@ -200,9 +199,9 @@ func TestSyscallErrnoToErrors(t *testing.T) { {errno: syscall.EWOULDBLOCK, err: linuxerr.EAGAIN}, } { t.Run(tc.errno.Error(), func(t *testing.T) { - e := linuxerr.ErrorFromErrno(errno.Errno(tc.errno)) + e := linuxerr.ErrorFromUnix(unix.Errno(tc.errno)) if e != tc.err { - t.Fatalf("Mismatch errors: want: %+v (%d) got: %+v %d", tc.err, tc.err.Errno(), e, e.Errno()) + t.Fatalf("Mismatch errors: want: %+v %T got: %+v %T", tc.err, tc.err, e, e) } }) } @@ -212,6 +211,7 @@ func TestSyscallErrnoToErrors(t *testing.T) { // unix.Errno and linuxerr. // TODO (b/34162363): Remove this. func TestEqualsMethod(t *testing.T) { + noError := linuxerr.ErrorFromUnix(unix.Errno(0)) for _, tc := range []struct { name string linuxErr []*gErrors.Error @@ -220,20 +220,20 @@ func TestEqualsMethod(t *testing.T) { }{ { name: "compare nil", - linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, - err: []error{nil, linuxerr.NOERROR, unix.Errno(0)}, + linuxErr: []*gErrors.Error{nil}, + err: []error{nil, noError, unix.Errno(0)}, equal: true, }, { name: "linuxerr nil error not", - linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, + linuxErr: []*gErrors.Error{nil}, err: []error{unix.Errno(1), linuxerr.EPERM, linuxerr.EACCES}, equal: false, }, { name: "linuxerr not nil error nil", linuxErr: []*gErrors.Error{linuxerr.ENOENT}, - err: []error{nil, unix.Errno(0), linuxerr.NOERROR}, + err: []error{nil, unix.Errno(0)}, equal: false, }, { @@ -250,7 +250,7 @@ func TestEqualsMethod(t *testing.T) { }, { name: "other error", - linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR, linuxerr.E2BIG, linuxerr.EINVAL}, + linuxErr: []*gErrors.Error{nil, linuxerr.E2BIG, linuxerr.EINVAL}, err: []error{fs.ErrInvalid, io.EOF}, equal: false, }, diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index 2b22b2203..2a307df38 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -22,7 +22,6 @@ go_library( "version.go", ], deps = [ - "//pkg/abi/linux/errno", "//pkg/errors", "//pkg/errors/linuxerr", "//pkg/fd", diff --git a/pkg/p9/file.go b/pkg/p9/file.go index 8d6af2d6b..b4b556cb9 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -21,13 +21,37 @@ import ( "gvisor.dev/gvisor/pkg/fd" ) +// AttacherOptions contains Attacher configuration. +type AttacherOptions struct { + // SetAttrOnDeleted is set to true if it's safe to call File.SetAttr for + // deleted files. + SetAttrOnDeleted bool + + // AllocateOnDeleted is set to true if it's safe to call File.Allocate for + // deleted files. + AllocateOnDeleted bool +} + +// NoServerOptions partially implements Attacher with empty AttacherOptions. +type NoServerOptions struct{} + +// ServerOptions implements Attacher. +func (*NoServerOptions) ServerOptions() AttacherOptions { + return AttacherOptions{} +} + // Attacher is provided by the server. type Attacher interface { // Attach returns a new File. // - // The client-side attach will be translate to a series of walks from + // The client-side attach will be translated to a series of walks from // the file returned by this Attach call. Attach() (File, error) + + // ServerOptions returns configuration options for this attach point. + // + // This is never caller in the client-side. + ServerOptions() AttacherOptions } // File is a set of operations corresponding to a single node. @@ -301,7 +325,7 @@ type File interface { type DefaultWalkGetAttr struct{} // WalkGetAttr implements File.WalkGetAttr. -func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) { +func (*DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) { return nil, nil, AttrMask{}, Attr{}, unix.ENOSYS } @@ -309,7 +333,7 @@ func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, er type DisallowClientCalls struct{} // SetAttrClose implements File.SetAttrClose. -func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { +func (*DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { panic("SetAttrClose should not be called on the server") } @@ -321,6 +345,11 @@ func (*DisallowServerCalls) Renamed(File, string) { panic("Renamed should not be called on the client") } +// ServerOptions implements Attacher. +func (*DisallowServerCalls) ServerOptions() AttacherOptions { + panic("ServerOptions should not be called on the client") +} + // DefaultMultiGetAttr implements File.MultiGetAttr() on top of File. func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) { stats := make([]FullStat, 0, len(names)) diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index a8f8a9d03..c85af5e9e 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -23,7 +23,6 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/errors" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" @@ -46,7 +45,7 @@ func ExtractErrno(err error) unix.Errno { // Attempt to unwrap. switch e := err.(type) { case *errors.Error: - return unix.Errno(e.Errno()) + return linuxerr.ToUnix(e) case unix.Errno: return e case *os.PathError: @@ -69,7 +68,7 @@ func newErr(err error) *Rlerror { // ExtractLinuxerrErrno extracts a *errors.Error from a error, best effort. // TODO(b/34162363): Merge this with ExtractErrno. -func ExtractLinuxerrErrno(err error) *errors.Error { +func ExtractLinuxerrErrno(err error) error { switch err { case os.ErrNotExist: return linuxerr.ENOENT @@ -84,9 +83,9 @@ func ExtractLinuxerrErrno(err error) *errors.Error { // Attempt to unwrap. switch e := err.(type) { case *errors.Error: - return e + return linuxerr.ToError(e) case unix.Errno: - return linuxerr.ErrorFromErrno(errno.Errno(e)) + return linuxerr.ErrorFromUnix(e) case *os.PathError: return ExtractLinuxerrErrno(e.Err) case *os.SyscallError: @@ -103,7 +102,7 @@ func ExtractLinuxerrErrno(err error) *errors.Error { // newErrFromLinuxerr returns an Rlerror from the linuxerr list. // TODO(b/34162363): Merge this with newErr. func newErrFromLinuxerr(err error) *Rlerror { - return &Rlerror{Error: uint32(ExtractLinuxerrErrno(err).Errno())} + return &Rlerror{Error: uint32(ExtractErrno(err))} } // handler is implemented for server-handled messages. @@ -179,7 +178,7 @@ func (t *Tsetattrclunk) handle(cs *connState) message { // This might be technically incorrect, as it's possible that // there were multiple links and you can still change the // corresponding inode information. - if ref.isDeleted() { + if !cs.server.options.SetAttrOnDeleted && ref.isDeleted() { return unix.EINVAL } @@ -914,7 +913,7 @@ func (t *Tsetattr) handle(cs *connState) message { // This might be technically incorrect, as it's possible that // there were multiple links and you can still change the // corresponding inode information. - if ref.isDeleted() { + if !cs.server.options.SetAttrOnDeleted && ref.isDeleted() { return unix.EINVAL } @@ -947,7 +946,7 @@ func (t *Tallocate) handle(cs *connState) message { } // We don't allow allocate on files that have been deleted. - if ref.isDeleted() { + if !cs.server.options.AllocateOnDeleted && ref.isDeleted() { return unix.EINVAL } diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 9c1ada0cb..f3eb8468b 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -12,7 +12,7 @@ MOCK_SRC_PACKAGE = "gvisor.dev/gvisor/pkg/p9" # mockgen_reflect is a source file that contains mock generation code that # imports the p9 package and generates a specification via reflection. The # usual generation path must be split into two distinct parts because the full -# source tree is not available to all build targets. Only declared depencies +# source tree is not available to all build targets. Only declared dependencies # are available (and even then, not the Go source files). genrule( name = "mockgen_reflect", diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index fd5ac3dbe..56939d100 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -307,6 +307,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) { } // Start the server, synchronized on exit. + h.Attacher.EXPECT().ServerOptions().Return(p9.AttacherOptions{}).Times(1) server := p9.NewServer(h.Attacher) h.wg.Add(1) go func() { diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 241ab44ef..e7d129f9d 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -19,7 +19,7 @@ import ( "runtime/debug" "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux/errno" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fdchannel" @@ -34,6 +34,8 @@ type Server struct { // attacher provides the attach function. attacher Attacher + options AttacherOptions + // pathTree is the full set of paths opened on this server. // // These may be across different connections, but rename operations @@ -48,10 +50,15 @@ type Server struct { renameMu sync.RWMutex } -// NewServer returns a new server. +// NewServer returns a new server. attacher may be nil. func NewServer(attacher Attacher) *Server { + opts := AttacherOptions{} + if attacher != nil { + opts = attacher.ServerOptions() + } return &Server{ attacher: attacher, + options: opts, pathTree: newPathNode(), } } @@ -510,7 +517,7 @@ func (cs *connState) handle(m message) (r message) { // It will be removed a followup, when all the unix.Errno errors are // replaced with linuxerr. if rlError, ok := r.(*Rlerror); ok { - e := linuxerr.ErrorFromErrno(errno.Errno(rlError.Error)) + e := linuxerr.ErrorFromUnix(unix.Errno(rlError.Error)) r = newErrFromLinuxerr(e) } } else { diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go index c08d47787..2039a96ad 100644 --- a/pkg/sentry/kernel/auth/context.go +++ b/pkg/sentry/kernel/auth/context.go @@ -24,6 +24,10 @@ type contextID int const ( // CtxCredentials is a Context.Value key for Credentials. CtxCredentials contextID = iota + + // CtxThreadGroupID is the current thread group ID when a context represents + // a task context. The value is represented as an int32. + CtxThreadGroupID contextID = iota ) // CredentialsFromContext returns a copy of the Credentials used by ctx, or a @@ -35,6 +39,15 @@ func CredentialsFromContext(ctx context.Context) *Credentials { return NewAnonymousCredentials() } +// ThreadGroupIDFromContext returns the current thread group ID when ctx +// represents a task context. +func ThreadGroupIDFromContext(ctx context.Context) (tgid int32, ok bool) { + if tgid := ctx.Value(CtxThreadGroupID); tgid != nil { + return tgid.(int32), true + } + return 0, false +} + // ContextWithCredentials returns a copy of ctx carrying creds. func ContextWithCredentials(ctx context.Context, creds *Credentials) context.Context { return &authContext{ctx, creds} diff --git a/pkg/sentry/kernel/mq/mq.go b/pkg/sentry/kernel/mq/mq.go index 07482decf..7515a2772 100644 --- a/pkg/sentry/kernel/mq/mq.go +++ b/pkg/sentry/kernel/mq/mq.go @@ -399,7 +399,7 @@ func (q *Queue) Flush(ctx context.Context) { q.mu.Lock() defer q.mu.Unlock() - pid, ok := context.ThreadGroupIDFromContext(ctx) + pid, ok := auth.ThreadGroupIDFromContext(ctx) if ok { if q.subscriber != nil && pid == q.subscriber.pid { q.subscriber = nil diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index ab938fa3c..bb9a129ab 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -444,7 +444,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ hostarch. s.mu.Lock() defer s.mu.Unlock() s.attachTime = ktime.NowFromContext(ctx) - if pid, ok := context.ThreadGroupIDFromContext(ctx); ok { + if pid, ok := auth.ThreadGroupIDFromContext(ctx); ok { s.lastAttachDetachPID = pid } else { // AddMapping is called during a syscall, so ctx should always be a task @@ -468,7 +468,7 @@ func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ hostar // If called from a non-task context we also won't have a threadgroup // id. Silently skip updating the lastAttachDetachPid in that case. - if pid, ok := context.ThreadGroupIDFromContext(ctx); ok { + if pid, ok := auth.ThreadGroupIDFromContext(ctx); ok { s.lastAttachDetachPID = pid } else { log.Debugf("Couldn't obtain pid when removing mapping to %s, not updating the last detach pid.", s.debugLocked()) diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index cb9bcd7c0..ce38d9342 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -86,7 +86,7 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { return t case auth.CtxCredentials: return t.creds.Load() - case context.CtxThreadGroupID: + case auth.CtxThreadGroupID: return int32(t.tg.ID()) case fs.CtxRoot: if !isTaskGoroutine { diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 2b1d7e114..2b85fe1ca 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -381,7 +381,7 @@ func ExtractErrno(err error, sysno int) int { case unix.Errno: return int(err) case *errors.Error: - return int(err.Errno()) + return int(linuxerr.ToUnix(err)) case *memmap.BusError: // Bus errors may generate SIGBUS, but for syscalls they still // return EFAULT. See case in task_run.go where the fault is @@ -395,7 +395,7 @@ func ExtractErrno(err error, sysno int) int { return ExtractErrno(err.Err, sysno) default: if errno, ok := linuxerr.TranslateError(err); ok { - return int(errno.Errno()) + return int(linuxerr.ToUnix(errno)) } } panic(fmt.Sprintf("Unknown syscall %d error: %v", sysno, err)) diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index 4d73d46ef..fd0ab4c76 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -136,14 +136,14 @@ func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if err != nil { return 0, nil, err } - tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink) + tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.UMOUNT_NOFOLLOW == 0)) if err != nil { return 0, nil, err } defer tpop.Release(t) opts := vfs.UmountOptions{ - Flags: uint32(flags), + Flags: uint32(flags &^ linux.UMOUNT_NOFOLLOW), } return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index ead36880f..5d76adac1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -134,6 +134,7 @@ go_test( srcs = [ "conntrack_test.go", "forwarding_test.go", + "iptables_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index c489506bb..1c6060b70 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -119,22 +119,24 @@ type conn struct { // // +checklocks:mu destinationManip bool + + stateMu sync.RWMutex `state:"nosave"` // tcb is TCB control block. It is used to keep track of states // of tcp connection. // - // +checklocks:mu + // +checklocks:stateMu tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. // - // +checklocks:mu + // +checklocks:stateMu lastUsed tcpip.MonotonicTime } // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { - cn.mu.RLock() - defer cn.mu.RUnlock() + cn.stateMu.RLock() + defer cn.stateMu.RUnlock() if cn.tcb.State() == tcpconntrack.ResultAlive { // Use the same default as Linux, which doesn't delete // established connections for 5(!) days. @@ -147,7 +149,7 @@ func (cn *conn) timedOut(now tcpip.MonotonicTime) bool { // update the connection tracking state. // -// +checklocks:cn.mu +// +checklocks:cn.stateMu func (cn *conn) updateLocked(pkt *PacketBuffer, reply bool) { if pkt.TransportProtocolNumber != header.TCPProtocolNumber { return @@ -209,17 +211,41 @@ type bucket struct { tuples tupleList } -func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, netHdrFunc func([]byte) header.Network) (header.Network, header.ChecksummableTransport, bool) { - switch pkt.tuple.id().transProto { +// A netAndTransHeadersFunc returns the network and transport headers found +// in an ICMP payload. The transport layer's payload will not be returned. +// +// May panic if the packet does not hold the transport header. +type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte) + +func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv4(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the buffer is smaller than + // the total length specified in the IPv4 header. + transHdr := icmpPayload[netHdr.HeaderLength():] + return netHdr, transHdr[:minTransHdrLen] +} + +func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) { + netHdr := header.IPv6(icmpPayload) + // Do not use netHdr.Payload() as we might not hold the full packet + // in the ICMP error; Payload() panics if the IP payload is smaller than + // the payload length specified in the IPv6 header. + transHdr := icmpPayload[header.IPv6MinimumSize:] + return netHdr, transHdr[:minTransHdrLen] +} + +func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) { + switch transProto { case header.TCPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.TCP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize) + return netHeader, header.TCP(transHeaderBytes), true } case header.UDPProtocolNumber: if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok { - netHeader := netHdrFunc(netAndTransHeader) - return netHeader, header.UDP(netHeader.Payload()), true + netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize) + return netHeader, header.UDP(transHeaderBytes), true } } return nil, nil, false @@ -246,7 +272,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic("should have dropped packets with IPv4 options") } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, func(b []byte) header.Network { return header.IPv4(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.id().transProto); ok { return netHdr, transHdr, true, true } case header.ICMPv6ProtocolNumber: @@ -264,7 +290,7 @@ func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Check panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto)) } - if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, func(b []byte) header.Network { return header.IPv6(b) }); ok { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok { return netHdr, transHdr, true, true } } @@ -283,34 +309,16 @@ func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkPro } } -func getTupleIDForPacketInICMPError(pkt *PacketBuffer, netHdrFunc func([]byte) header.Network, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { - switch transProto { - case header.TCPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.TCPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.TCP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } - case header.UDPProtocolNumber: - if netAndTransHeader, ok := pkt.Data().PullUp(netLen + header.UDPMinimumSize); ok { - netHdr := netHdrFunc(netAndTransHeader) - transHdr := header.UDP(netHdr.Payload()) - return tupleID{ - srcAddr: netHdr.DestinationAddress(), - srcPort: transHdr.DestinationPort(), - dstAddr: netHdr.SourceAddress(), - dstPort: transHdr.SourcePort(), - transProto: transProto, - netProto: netProto, - }, true - } +func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) { + if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok { + return tupleID{ + srcAddr: netHdr.DestinationAddress(), + srcPort: transHdr.DestinationPort(), + dstAddr: netHdr.SourceAddress(), + dstPort: transHdr.SourcePort(), + transProto: transProto, + netProto: netProto, + }, true } return tupleID{}, false @@ -349,7 +357,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { return tupleID{}, false, false } - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv4(b) }, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok { return tid, true, true } case header.ICMPv6ProtocolNumber: @@ -370,7 +378,7 @@ func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) { } // TODO(https://gvisor.dev/issue/6789): Handle extension headers. - if tid, ok := getTupleIDForPacketInICMPError(pkt, func(b []byte) header.Network { return header.IPv6(b) }, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { + if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok { return tid, true, true } } @@ -601,14 +609,17 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool { // packets are fragmented. reply := pkt.tuple.reply - tid, performManip := func() (tupleID, bool) { - cn.mu.Lock() - defer cn.mu.Unlock() - // Mark the connection as having been used recently so it isn't reaped. - cn.lastUsed = cn.ct.clock.NowMonotonic() - // Update connection state. - cn.updateLocked(pkt, reply) + cn.stateMu.Lock() + // Mark the connection as having been used recently so it isn't reaped. + cn.lastUsed = cn.ct.clock.NowMonotonic() + // Update connection state. + cn.updateLocked(pkt, reply) + cn.stateMu.Unlock() + + tid, performManip := func() (tupleID, bool) { + cn.mu.RLock() + defer cn.mu.RUnlock() var tuple *tuple if reply { @@ -730,9 +741,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused deletes timed out entries from the conntrack map. The rules for // reaping are: -// - Most reaping occurs in connFor, which is called on each packet. connFor -// cleans up the bucket the packet's connection maps to. Thus calls to -// reapUnused should be fast. // - Each call to reapUnused traverses a fraction of the conntrack table. // Specifically, it traverses len(ct.buckets)/fractionPerReaping. // - After reaping, reapUnused decides when it should next run based on the @@ -799,45 +807,48 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // Precondition: ct.mu is read locked and bkt.mu is write locked. // +checklocksread:ct.mu // +checklocks:bkt.mu -func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { - if !tuple.conn.timedOut(now) { +func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool { + if !reapingTuple.conn.timedOut(now) { return false } - // To maintain lock order, we can only reap both tuples if the reply appears - // later in the table. - replyBktID := ct.bucket(tuple.id().reply()) - tuple.conn.mu.RLock() - replyTupleInserted := tuple.conn.finalized - tuple.conn.mu.RUnlock() - if bktID > replyBktID && replyTupleInserted { - return true + var otherTuple *tuple + if reapingTuple.reply { + otherTuple = &reapingTuple.conn.original + } else { + otherTuple = &reapingTuple.conn.reply } - // Reap the reply. - if replyTupleInserted { - // Don't re-lock if both tuples are in the same bucket. - if bktID != replyBktID { - replyBkt := &ct.buckets[replyBktID] - replyBkt.mu.Lock() - removeConnFromBucket(replyBkt, tuple) - replyBkt.mu.Unlock() - } else { - removeConnFromBucket(bkt, tuple) - } + otherTupleBktID := ct.bucket(otherTuple.id()) + reapingTuple.conn.mu.RLock() + replyTupleInserted := reapingTuple.conn.finalized + reapingTuple.conn.mu.RUnlock() + + // To maintain lock order, we can only reap both tuples if the tuple for the + // other direction appears later in the table. + if bktID > otherTupleBktID && replyTupleInserted { + return true } - bkt.tuples.Remove(tuple) - return true -} + bkt.tuples.Remove(reapingTuple) + + if !replyTupleInserted { + // The other tuple is the reply which has not yet been inserted. + return true + } -// +checklocks:b.mu -func removeConnFromBucket(b *bucket, tuple *tuple) { - if tuple.reply { - b.tuples.Remove(&tuple.conn.original) + // Reap the other connection. + if bktID == otherTupleBktID { + // Don't re-lock if both tuples are in the same bucket. + bkt.tuples.Remove(otherTuple) } else { - b.tuples.Remove(&tuple.conn.reply) + otherTupleBkt := &ct.buckets[otherTupleBktID] + otherTupleBkt.mu.Lock() + otherTupleBkt.tuples.Remove(otherTuple) + otherTupleBkt.mu.Unlock() } + + return true } func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { diff --git a/pkg/tcpip/stack/iptables_test.go b/pkg/tcpip/stack/iptables_test.go new file mode 100644 index 000000000..1788e98c9 --- /dev/null +++ b/pkg/tcpip/stack/iptables_test.go @@ -0,0 +1,220 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// TestNATedConnectionReap tests that NATed connections are properly reaped. +func TestNATedConnectionReap(t *testing.T) { + // Note that the network protocol used for this test doesn't matter as this + // test focuses on reaping, not anything related to a specific network + // protocol. + + const ( + nattedDstPort = 1 + srcPort = 2 + dstPort = 3 + + nattedDstAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + srcAddr = tcpip.Address("\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + dstAddr = tcpip.Address("\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ) + + clock := faketime.NewManualClock() + iptables := DefaultTables(0 /* seed */, clock) + + table := Table{ + Rules: []Rule{ + // Prerouting + { + Target: &DNATTarget{NetworkProtocol: header.IPv6ProtocolNumber, Addr: nattedDstAddr, Port: nattedDstPort}, + }, + { + Target: &AcceptTarget{}, + }, + + // Input + { + Target: &AcceptTarget{}, + }, + + // Forward + { + Target: &AcceptTarget{}, + }, + + // Output + { + Target: &AcceptTarget{}, + }, + + // Postrouting + { + Target: &AcceptTarget{}, + }, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 2, + Forward: 3, + Output: 4, + Postrouting: 5, + }, + } + if err := iptables.ReplaceTable(NATID, table, true /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, true): %s", NATID, err) + } + + // Stop the reaper if it is running so we can reap manually as it is started + // on the first change to IPTables. + iptables.reaperDone <- struct{}{} + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + header.UDPMinimumSize, + }) + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udp.SetSourcePort(srcPort) + udp.SetDestinationPort(dstPort) + udp.SetChecksum(0) + udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum( + header.UDPProtocolNumber, + srcAddr, + dstAddr, + uint16(len(udp)), + ))) + pkt.TransportProtocolNumber = header.UDPProtocolNumber + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(udp)), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 64, + SrcAddr: srcAddr, + DstAddr: dstAddr, + }) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + + originalTID, _, ok := getTupleID(pkt) + if !ok { + t.Fatal("failed to get original tuple ID") + } + + if !iptables.CheckPrerouting(pkt, nil /* addressEP */, "" /* inNicName */) { + t.Fatal("got ipt.CheckPrerouting(...) = false, want = true") + } + if !iptables.CheckInput(pkt, "" /* inNicName */) { + t.Fatal("got ipt.CheckInput(...) = false, want = true") + } + + invertedReplyTID, _, ok := getTupleID(pkt) + if !ok { + t.Fatal("failed to get NATed packet's tuple ID") + } + if invertedReplyTID == originalTID { + t.Fatalf("NAT not performed; got invertedReplyTID = %#v", invertedReplyTID) + } + replyTID := invertedReplyTID.reply() + + originalBktID := iptables.connections.bucket(originalTID) + replyBktID := iptables.connections.bucket(replyTID) + + // This test depends on the original and reply tuples mapping to different + // buckets. + if originalBktID == replyBktID { + t.Fatalf("expected bucket IDs to be different; got = %d", originalBktID) + } + + lowerBktID := originalBktID + if lowerBktID > replyBktID { + lowerBktID = replyBktID + } + + runReaper := func() { + // Reaping the bucket with the lower ID should reap both tuples of the + // connection if it has timed out. + // + // We will manually pick the next start bucket ID and don't use the + // interval so we ignore the return values. + _, _ = iptables.connections.reapUnused(lowerBktID, 0 /* prevInterval */) + } + + iptables.connections.mu.RLock() + buckets := iptables.connections.buckets + iptables.connections.mu.RUnlock() + + originalBkt := &buckets[originalBktID] + replyBkt := &buckets[replyBktID] + + // Run the reaper and make sure the tuples were not reaped. + reapAndCheckForConnections := func() { + t.Helper() + + runReaper() + + now := clock.NowMonotonic() + if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple == nil { + t.Error("expected to get original tuple") + } + + if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple == nil { + t.Error("expected to get reply tuple") + } + + if t.Failed() { + t.FailNow() + } + } + + // Connection was just added and no time has passed - it should not be reaped. + reapAndCheckForConnections() + + // Time must advance past the unestablished timeout for a connection to be + // reaped. + clock.Advance(unestablishedTimeout) + reapAndCheckForConnections() + + // Connection should now be reaped. + clock.Advance(1) + runReaper() + now := clock.NowMonotonic() + if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple != nil { + t.Errorf("got originalBkt.connForTID(%#v, %#v) = %#v, want = nil", originalTID, now, originalTuple) + } + if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple != nil { + t.Errorf("got replyBkt.connForTID(%#v, %#v) = %#v, want = nil", replyTID, now, replyTuple) + } + // Make sure we don't have stale tuples just lying around. + // + // We manually check the buckets as connForTID will skip over tuples that + // have timed out. + checkNoTupleInBucket := func(bkt *bucket, tid tupleID, reply bool) { + t.Helper() + + bkt.mu.RLock() + defer bkt.mu.RUnlock() + for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() { + if tuple.id() == originalTID { + t.Errorf("unexpectedly found tuple with ID = %#v; reply = %t", tid, reply) + } + } + } + checkNoTupleInBucket(originalBkt, originalTID, false /* reply */) + checkNoTupleInBucket(replyBkt, replyTID, true /* reply */) +} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 7fe3b29d9..b2383576c 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -1781,8 +1781,11 @@ func TestNAT(t *testing.T) { } func TestNATICMPError(t *testing.T) { - const srcPort = 1234 - const dstPort = 5432 + const ( + srcPort = 1234 + dstPort = 5432 + dataSize = 4 + ) type icmpTypeTest struct { name string @@ -1836,8 +1839,7 @@ func TestNATICMPError(t *testing.T) { netProto: ipv4.ProtocolNumber, host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View { - totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original) - hdr := buffer.NewPrependable(totalLen) + hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)) if n := copy(hdr.Prepend(len(original)), original); n != len(original) { t.Fatalf("got copy(...) = %d, want = %d", n, len(original)) } @@ -1845,8 +1847,9 @@ func TestNATICMPError(t *testing.T) { icmp.SetType(header.ICMPv4Type(icmpType)) icmp.SetChecksum(0) icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0)) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.ICMPv4ProtocolNumber, utils.Host1IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, @@ -1875,9 +1878,9 @@ func TestNATICMPError(t *testing.T) { name: "UDP", proto: header.UDPProtocolNumber, buf: func() buffer.View { - totalLen := header.IPv4MinimumSize + header.UDPMinimumSize - hdr := buffer.NewPrependable(totalLen) - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpSize := header.UDPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv4MinimumSize + udpSize) + udp := header.UDP(hdr.Prepend(udpSize)) udp.SetSourcePort(srcPort) udp.SetDestinationPort(dstPort) udp.SetChecksum(0) @@ -1887,8 +1890,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, uint16(len(udp)), ))) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.UDPProtocolNumber, utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, @@ -1910,9 +1914,9 @@ func TestNATICMPError(t *testing.T) { name: "TCP", proto: header.TCPProtocolNumber, buf: func() buffer.View { - totalLen := header.IPv4MinimumSize + header.TCPMinimumSize - hdr := buffer.NewPrependable(totalLen) - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcpSize := header.TCPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv4MinimumSize + tcpSize) + tcp := header.TCP(hdr.Prepend(tcpSize)) tcp.SetSourcePort(srcPort) tcp.SetDestinationPort(dstPort) tcp.SetDataOffset(header.TCPMinimumSize) @@ -1923,8 +1927,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, uint16(len(tcp)), ))) - ipHdr(hdr.Prepend(header.IPv4MinimumSize), - totalLen, + ipHdr( + hdr.Prepend(header.IPv4MinimumSize), + hdr.UsedLength(), header.TCPProtocolNumber, utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, @@ -1989,7 +1994,8 @@ func TestNATICMPError(t *testing.T) { Src: utils.Host1IPv6Addr.AddressWithPrefix.Address, Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, })) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), payloadLen, header.ICMPv6ProtocolNumber, utils.Host1IPv6Addr.AddressWithPrefix.Address, @@ -2016,8 +2022,9 @@ func TestNATICMPError(t *testing.T) { name: "UDP", proto: header.UDPProtocolNumber, buf: func() buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpSize := header.UDPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + udpSize) + udp := header.UDP(hdr.Prepend(udpSize)) udp.SetSourcePort(srcPort) udp.SetDestinationPort(dstPort) udp.SetChecksum(0) @@ -2027,8 +2034,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, uint16(len(udp)), ))) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), - header.UDPMinimumSize, + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), + len(udp), header.UDPProtocolNumber, utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, @@ -2050,8 +2058,9 @@ func TestNATICMPError(t *testing.T) { name: "TCP", proto: header.TCPProtocolNumber, buf: func() buffer.View { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.TCPMinimumSize) - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcpSize := header.TCPMinimumSize + dataSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + tcpSize) + tcp := header.TCP(hdr.Prepend(tcpSize)) tcp.SetSourcePort(srcPort) tcp.SetDestinationPort(dstPort) tcp.SetDataOffset(header.TCPMinimumSize) @@ -2062,8 +2071,9 @@ func TestNATICMPError(t *testing.T) { utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, uint16(len(tcp)), ))) - ip6Hdr(hdr.Prepend(header.IPv6MinimumSize), - header.TCPMinimumSize, + ip6Hdr( + hdr.Prepend(header.IPv6MinimumSize), + len(tcp), header.TCPProtocolNumber, utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, @@ -2117,109 +2127,141 @@ func TestNATICMPError(t *testing.T) { }, } + trimTests := []struct { + name string + trimLen int + expectNATedICMP bool + }{ + { + name: "Trim nothing", + trimLen: 0, + expectNATedICMP: true, + }, + { + name: "Trim data", + trimLen: dataSize, + expectNATedICMP: true, + }, + { + name: "Trim data and transport header", + trimLen: dataSize + 1, + expectNATedICMP: false, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { for _, transportType := range test.transportTypes { t.Run(transportType.name, func(t *testing.T) { for _, icmpType := range test.icmpTypes { t.Run(icmpType.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, - }) - - ep1 := channel.New(1, header.IPv6MinimumMTU, "") - ep2 := channel.New(1, header.IPv6MinimumMTU, "") - utils.SetupRouterStack(t, s, ep1, ep2) - - ipv6 := test.netProto == ipv6.ProtocolNumber - ipt := s.IPTables() - - table := stack.Table{ - Rules: []stack.Rule{ - // Prerouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - InputInterface: utils.RouterNIC2Name, + for _, trimTest := range trimTests { + t.Run(trimTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + + ep1 := channel.New(1, header.IPv6MinimumMTU, "") + ep2 := channel.New(1, header.IPv6MinimumMTU, "") + utils.SetupRouterStack(t, s, ep1, ep2) + + ipv6 := test.netProto == ipv6.ProtocolNumber + ipt := s.IPTables() + + table := stack.Table{ + Rules: []stack.Rule{ + // Prerouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + InputInterface: utils.RouterNIC2Name, + }, + Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, + }, + { + Target: &stack.AcceptTarget{}, + }, + + // Input + { + Target: &stack.AcceptTarget{}, + }, + + // Forward + { + Target: &stack.AcceptTarget{}, + }, + + // Output + { + Target: &stack.AcceptTarget{}, + }, + + // Postrouting + { + Filter: stack.IPHeaderFilter{ + Protocol: transportType.proto, + CheckProtocol: true, + OutputInterface: utils.RouterNIC1Name, + }, + Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, + }, + { + Target: &stack.AcceptTarget{}, + }, }, - Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort}, - }, - { - Target: &stack.AcceptTarget{}, - }, - - // Input - { - Target: &stack.AcceptTarget{}, - }, - - // Forward - { - Target: &stack.AcceptTarget{}, - }, - - // Output - { - Target: &stack.AcceptTarget{}, - }, - - // Postrouting - { - Filter: stack.IPHeaderFilter{ - Protocol: transportType.proto, - CheckProtocol: true, - OutputInterface: utils.RouterNIC1Name, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 2, + stack.Forward: 3, + stack.Output: 4, + stack.Postrouting: 5, }, - Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, - }, - { - Target: &stack.AcceptTarget{}, - }, - }, - BuiltinChains: [stack.NumHooks]int{ - stack.Prerouting: 0, - stack.Input: 2, - stack.Forward: 3, - stack.Output: 4, - stack.Postrouting: 5, - }, - } + } - if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { - t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) - } + if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err) + } - ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(), - })) + buf := transportType.buf - { - pkt, ok := ep1.Read() - if !ok { - t.Fatal("expected to read a packet on ep1") - } - pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) - transportType.checkNATed(t, pktView) - if t.Failed() { - t.FailNow() - } + ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: append(buffer.View(nil), buf...).ToVectorisedView(), + })) - ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), - })) - } + { + pkt, ok := ep1.Read() + if !ok { + t.Fatal("expected to read a packet on ep1") + } + pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader()) + transportType.checkNATed(t, pktView) + if t.Failed() { + t.FailNow() + } + + pktView = pktView[:len(pktView)-trimTest.trimLen] + buf = buf[:len(buf)-trimTest.trimLen] + + ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(), + })) + } - pkt, ok := ep2.Read() - if ok != icmpType.expectResponse { - t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse) - } - if !icmpType.expectResponse { - return + pkt, ok := ep2.Read() + expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP + if ok != expectResponse { + t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, expectResponse) + } + if !expectResponse { + return + } + test.decrementTTL(buf) + test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), buf, icmpType.val) + }) } - test.decrementTTL(transportType.buf) - test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val) }) } }) diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 9fb3ebd95..f819cf8fb 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -78,6 +78,11 @@ type DefaultRoute struct { Name string } +type Neighbor struct { + IP net.IP + HardwareAddr net.HardwareAddr +} + // FDBasedLink configures an fd-based link. type FDBasedLink struct { Name string @@ -90,6 +95,7 @@ type FDBasedLink struct { RXChecksumOffload bool LinkAddress net.HardwareAddr QDisc config.QueueingDiscipline + Neighbors []Neighbor // NumChannels controls how many underlying FD's are to be used to // create this endpoint. @@ -241,6 +247,11 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } routes = append(routes, route) } + + for _, neigh := range link.Neighbors { + proto, tcpipAddr := ipToAddressAndProto(neigh.IP) + n.Stack.AddStaticNeighbor(nicID, proto, tcpipAddr, tcpip.LinkAddress(neigh.HardwareAddr)) + } } if !args.Defaultv4Gateway.Route.Empty() { diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index 600b21189..3d610199c 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -140,6 +140,17 @@ func (a *attachPoint) Attach() (p9.File, error) { return lf, nil } +// ServerOptions implements p9.Attacher. It's safe to call SetAttr and Allocate +// on deleted files because fsgofer either uses an existing FD or opens a new +// one using the magic symlink in `/proc/[pid]/fd` and cannot mistakely open +// a file that was created in the same path as the delete file. +func (a *attachPoint) ServerOptions() p9.AttacherOptions { + return p9.AttacherOptions{ + SetAttrOnDeleted: true, + AllocateOnDeleted: true, + } +} + // makeQID returns a unique QID for the given stat buffer. func (a *attachPoint) makeQID(stat *unix.Stat_t) p9.QID { a.deviceMu.Lock() diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 3451d1037..03c5de2c6 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -173,6 +173,23 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG continue } + // Collect data from the ARP table. + dump, err := netlink.NeighList(iface.Index, 0) + if err != nil { + return fmt.Errorf("fetching ARP table for %q: %w", iface.Name, err) + } + + var neighbors []boot.Neighbor + for _, n := range dump { + // There are only two "good" states NUD_PERMANENT and NUD_REACHABLE, + // but NUD_REACHABLE is fully dynamic and will be re-probed anyway. + if n.State == netlink.NUD_PERMANENT { + log.Debugf("Copying a static ARP entry: %+v %+v", n.IP, n.HardwareAddr) + // No flags are copied because Stack.AddStaticNeighbor does not support flags right now. + neighbors = append(neighbors, boot.Neighbor{IP: n.IP, HardwareAddr: n.HardwareAddr}) + } + } + // Scrape the routes before removing the address, since that // will remove the routes as well. routes, defv4, defv6, err := routesForIface(iface) @@ -203,6 +220,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG RXChecksumOffload: rxChecksumOffload, NumChannels: numNetworkChannels, QDisc: qDisc, + Neighbors: neighbors, } // Get the link for the interface. diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index f748d685a..7952fd969 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -1053,3 +1053,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:verity_mount_test", ) + +syscall_test( + test = "//test/syscalls/linux:deleted_test", +) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 6217ff4dc..020c4673a 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -4432,3 +4432,18 @@ cc_binary( "@com_google_absl//absl/container:flat_hash_set", ], ) + +cc_binary( + name = "deleted_test", + testonly = 1, + srcs = ["deleted.cc"], + linkstatic = 1, + deps = [ + "//test/util:file_descriptor", + "//test/util:fs_util", + gtest, + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) diff --git a/test/syscalls/linux/deleted.cc b/test/syscalls/linux/deleted.cc new file mode 100644 index 000000000..695ceafd3 --- /dev/null +++ b/test/syscalls/linux/deleted.cc @@ -0,0 +1,116 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <time.h> +#include <unistd.h> + +#include <string> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +constexpr mode_t mode = 1; + +namespace gvisor { +namespace testing { + +namespace { + +PosixErrorOr<FileDescriptor> createdDeleted() { + auto path = NewTempAbsPath(); + PosixErrorOr<FileDescriptor> fd = Open(path, O_RDWR | O_CREAT, mode); + if (!fd.ok()) { + return fd.error(); + } + + auto err = Unlink(path); + if (!err.ok()) { + return err; + } + return fd; +} + +TEST(DeletedTest, Utime) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted()); + + const struct timespec times[2] = {{10, 0}, {20, 0}}; + EXPECT_THAT(futimens(fd.get(), times), SyscallSucceeds()); + + struct stat stat; + ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds()); + EXPECT_EQ(10, stat.st_atime); + EXPECT_EQ(20, stat.st_mtime); +} + +TEST(DeletedTest, Chmod) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted()); + + ASSERT_THAT(fchmod(fd.get(), mode + 1), SyscallSucceeds()); + + struct stat stat; + ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds()); + EXPECT_EQ(mode + 1, stat.st_mode & ~S_IFMT); +} + +TEST(DeletedTest, Truncate) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted()); + const std::string data = "foobar"; + ASSERT_THAT(write(fd.get(), data.c_str(), data.size()), SyscallSucceeds()); + + ASSERT_THAT(ftruncate(fd.get(), 0), SyscallSucceeds()); + + struct stat stat; + ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds()); + ASSERT_EQ(stat.st_size, 0); +} + +TEST(DeletedTest, Fallocate) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(createdDeleted()); + + ASSERT_THAT(fallocate(fd.get(), 0, 0, 123), SyscallSucceeds()); + + struct stat stat; + ASSERT_THAT(fstat(fd.get(), &stat), SyscallSucceeds()); + EXPECT_EQ(123, stat.st_size); +} + +// Tests that a file can be created with the same path as a deleted file that +// still have an open FD to it. +TEST(DeletedTest, Replace) { + auto path = NewTempAbsPath(); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR | O_CREAT, mode)); + ASSERT_NO_ERRNO(Unlink(path)); + + auto other = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR | O_CREAT | O_EXCL, mode)); + + auto stat = ASSERT_NO_ERRNO_AND_VALUE(Fstat(fd.get())); + auto stat_other = ASSERT_NO_ERRNO_AND_VALUE(Fstat(other.get())); + ASSERT_NE(stat.st_ino, stat_other.st_ino); + + // Check that the path points to the new file. + stat = ASSERT_NO_ERRNO_AND_VALUE(Stat(path)); + ASSERT_EQ(stat.st_ino, stat_other.st_ino); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index 3c7311782..e2a41d172 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -115,6 +115,40 @@ TEST(MountTest, OpenFileBusy) { EXPECT_THAT(umount(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); } +TEST(MountTest, UmountNoFollow) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + auto const mountPoint = NewTempAbsPathInDir(dir.path()); + ASSERT_THAT(mkdir(mountPoint.c_str(), 0777), SyscallSucceeds()); + + // Create a symlink in dir which will point to the actual mountpoint. + const std::string symlinkInDir = NewTempAbsPathInDir(dir.path()); + EXPECT_THAT(symlink(mountPoint.c_str(), symlinkInDir.c_str()), + SyscallSucceeds()); + + // Create a symlink to the dir. + const std::string symlinkToDir = NewTempAbsPath(); + EXPECT_THAT(symlink(dir.path().c_str(), symlinkToDir.c_str()), + SyscallSucceeds()); + + // Should fail with ELOOP when UMOUNT_NOFOLLOW is specified and the last + // component is a symlink. + auto mount = ASSERT_NO_ERRNO_AND_VALUE( + Mount("", mountPoint, "tmpfs", 0, "mode=0700", 0)); + EXPECT_THAT(umount2(symlinkInDir.c_str(), UMOUNT_NOFOLLOW), + SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(unlink(symlinkInDir.c_str()), SyscallSucceeds()); + + // UMOUNT_NOFOLLOW should only apply to the last path component. A symlink in + // non-last path component should be just fine. + EXPECT_THAT(umount2(JoinPath(symlinkToDir, Basename(mountPoint)).c_str(), + UMOUNT_NOFOLLOW), + SyscallSucceeds()); + mount.Release(); +} + TEST(MountTest, UmountDetach) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 253411858..1c24d9ffc 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -188,6 +188,14 @@ PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, return NoError(); } +PosixError Unlink(absl::string_view path) { + int res = unlink(std::string(path).c_str()); + if (res < 0) { + return PosixError(errno, absl::StrCat("unlink ", path)); + } + return NoError(); +} + PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, int flags) { int res = unlinkat(dfd.get(), std::string(path).c_str(), flags); diff --git a/test/util/fs_util.h b/test/util/fs_util.h index bb2d1d3c8..3ae0a725a 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -71,6 +71,7 @@ PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, dev_t dev); // Unlink the file. +PosixError Unlink(absl::string_view path); PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, int flags); diff --git a/tools/checklocks/README.md b/tools/checklocks/README.md index bd4beb649..eaad69399 100644 --- a/tools/checklocks/README.md +++ b/tools/checklocks/README.md @@ -1,6 +1,6 @@ # CheckLocks Analyzer -<!--* freshness: { owner: 'gvisor-eng' reviewed: '2021-03-21' } *--> +<!--* freshness: { owner: 'gvisor-eng' reviewed: '2021-10-15' } *--> Checklocks is an analyzer for lock and atomic constraints. The analyzer relies on explicit annotations to identify fields that should be checked for access. @@ -100,29 +100,6 @@ func abc() { ### Explicitly Not Supported -1. Checking for embedded mutexes as sync.Locker rather than directly as - 'sync.Mutex'. In other words, the checker will not track mutex Lock and - Unlock() methods where the mutex is behind an interface dispatch. - -An example that we won't handle is shown below (this in fact will fail to -build): - -```go -type A struct { - mu sync.Locker - - // +checklocks:mu - x int -} - -func abc() { - mu sync.Mutex - a := A{mu: &mu} - a.x = 1 // This won't be flagged by copylocks checker. -} - -``` - 1. The checker will not support guards on anything other than the cases described above. For example, global mutexes cannot be referred to by checklocks. Only struct members can be used. diff --git a/tools/checklocks/analysis.go b/tools/checklocks/analysis.go index ec0cba7f9..2def09744 100644 --- a/tools/checklocks/analysis.go +++ b/tools/checklocks/analysis.go @@ -183,8 +183,11 @@ type instructionWithReferrers interface { // checkFieldAccess checks the validity of a field access. // // This also enforces atomicity constraints for fields that must be accessed -// atomically. The parameter isWrite indicates whether this field is used for -// a write operation. +// atomically. The parameter isWrite indicates whether this field is used +// downstream for a write operation. +// +// Note that this function is not called if lff.Ignore is true, since it cannot +// discover any local anonymous functions or closures. func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj ssa.Value, field int, ls *lockState, isWrite bool) { var ( lff lockFieldFacts @@ -200,7 +203,8 @@ func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj for guardName, fl := range lgf.GuardedBy { guardsFound++ r := fl.resolve(structObj) - if _, ok := ls.isHeld(r, isWrite); ok { + s, ok := ls.isHeld(r, isWrite) + if ok { guardsHeld++ continue } @@ -218,7 +222,7 @@ func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj // true for this case, so we require it to be read-only. if lgf.AtomicDisposition != atomicRequired { // There is no force key, no atomic access and no lock held. - pc.maybeFail(inst.Pos(), "invalid field access, %s must be locked when accessing %s (locks: %s)", guardName, fieldObj.Name(), ls.String()) + pc.maybeFail(inst.Pos(), "invalid field access, must hold %s (%s) when accessing %s (locks: %s)", guardName, s, fieldObj.Name(), ls.String()) } } @@ -247,10 +251,19 @@ func (pc *passContext) checkFieldAccess(inst instructionWithReferrers, structObj } } -func (pc *passContext) checkCall(call callCommon, ls *lockState) { +func (pc *passContext) checkCall(call callCommon, lff *lockFunctionFacts, ls *lockState) { // See: https://godoc.org/golang.org/x/tools/go/ssa#CallCommon // - // 1. "call" mode: when Method is nil (!IsInvoke), a CallCommon represents an ordinary + // "invoke" mode: Method is non-nil, and Value is the underlying value. + if fn := call.Common().Method; fn != nil { + var nlff lockFunctionFacts + pc.pass.ImportObjectFact(fn, &nlff) + nlff.Ignore = nlff.Ignore || lff.Ignore // Inherit ignore. + pc.checkFunctionCall(call, fn, &nlff, ls) + return + } + + // "call" mode: when Method is nil (!IsInvoke), a CallCommon represents an ordinary // function call of the value in Value, which may be a *Builtin, a *Function or any // other value of kind 'func'. // @@ -269,10 +282,13 @@ func (pc *passContext) checkCall(call callCommon, ls *lockState) { // function call. switch fn := call.Common().Value.(type) { case *ssa.Function: - var lff lockFunctionFacts - if fn.Object() != nil { - pc.pass.ImportObjectFact(fn.Object(), &lff) - pc.checkFunctionCall(call, fn, &lff, ls) + nlff := lockFunctionFacts{ + Ignore: lff.Ignore, // Inherit ignore. + } + if obj := fn.Object(); obj != nil { + pc.pass.ImportObjectFact(obj, &nlff) + nlff.Ignore = nlff.Ignore || lff.Ignore // See above. + pc.checkFunctionCall(call, obj.(*types.Func), &nlff, ls) } else { // Anonymous functions have no facts, and cannot be // annotated. We don't check for violations using the @@ -282,28 +298,31 @@ func (pc *passContext) checkCall(call callCommon, ls *lockState) { for i, arg := range call.Common().Args { fnls.store(fn.Params[i], arg) } - pc.checkFunction(call, fn, &lff, fnls, true /* force */) + pc.checkFunction(call, fn, &nlff, fnls, true /* force */) } case *ssa.MakeClosure: // Note that creating and then invoking closures locally is // allowed, but analysis of passing closures is done when // checking individual instructions. - pc.checkClosure(call, fn, ls) + pc.checkClosure(call, fn, lff, ls) default: return } } // postFunctionCallUpdate updates all conditions. -func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunctionFacts, ls *lockState) { +func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunctionFacts, ls *lockState, aliases bool) { // Release all locks not still held. for fieldName, fg := range lff.HeldOnEntry { if _, ok := lff.HeldOnExit[fieldName]; ok { continue } + if fg.IsAlias && !aliases { + continue + } r := fg.resolveCall(call.Common().Args, call.Value()) if s, ok := ls.unlockField(r, fg.Exclusive); !ok { - if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { pc.maybeFail(call.Pos(), "attempt to release %s (%s), but not held (locks: %s)", fieldName, s, ls.String()) } } @@ -314,10 +333,13 @@ func (pc *passContext) postFunctionCallUpdate(call callCommon, lff *lockFunction if _, ok := lff.HeldOnEntry[fieldName]; ok { continue } + if fg.IsAlias && !aliases { + continue + } // Acquire the lock per the annotation. r := fg.resolveCall(call.Common().Args, call.Value()) if s, ok := ls.lockField(r, fg.Exclusive); !ok { - if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { pc.maybeFail(call.Pos(), "attempt to acquire %s (%s), but already held (locks: %s)", fieldName, s, ls.String()) } } @@ -337,12 +359,29 @@ func exclusiveStr(exclusive bool) string { // atomic functions are tracked by checkFieldAccess by looking directly at the // referrers (because ordering doesn't matter there, so we need not scan in // instruction order). -func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff *lockFunctionFacts, ls *lockState) { - // Check all guards required are held. +func (pc *passContext) checkFunctionCall(call callCommon, fn *types.Func, lff *lockFunctionFacts, ls *lockState) { + // Extract the "receiver" properly. + var rcvr ssa.Value + if call.Common().Method != nil { + // This is an interface dispatch for sync.Locker. + rcvr = call.Common().Value + } else if args := call.Common().Args; len(args) > 0 && fn.Type().(*types.Signature).Recv() != nil { + // This matches the signature for the relevant + // sync.Lock/sync.Unlock functions below. + rcvr = args[0] + } + // Note that at this point, rcvr may be nil, but it should not match any + // of the function signatures below where rcvr may be used. + + // Check all guards required are held. Note that this explicitly does + // not include aliases, hence false being passed below. for fieldName, fg := range lff.HeldOnEntry { + if fg.IsAlias { + continue + } r := fg.resolveCall(call.Common().Args, call.Value()) if s, ok := ls.isHeld(r, fg.Exclusive); !ok { - if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { pc.maybeFail(call.Pos(), "must hold %s %s (%s) to call %s, but not held (locks: %s)", fieldName, exclusiveStr(fg.Exclusive), s, fn.Name(), ls.String()) } else { // Force the lock to be acquired. @@ -352,19 +391,19 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff } // Update all lock state accordingly. - pc.postFunctionCallUpdate(call, lff, ls) + pc.postFunctionCallUpdate(call, lff, ls, false /* aliases */) // Check if it's a method dispatch for something in the sync package. // See: https://godoc.org/golang.org/x/tools/go/ssa#Function - if fn.Package() != nil && fn.Package().Pkg.Name() == "sync" && fn.Signature.Recv() != nil { + if fn.Pkg() != nil && fn.Pkg().Name() == "sync" { isExclusive := false switch fn.Name() { case "Lock": isExclusive = true fallthrough case "RLock": - if s, ok := ls.lockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok { - if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + if s, ok := ls.lockField(resolvedValue{value: rcvr, valid: true}, isExclusive); !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { // Double locking a mutex that is already locked. pc.maybeFail(call.Pos(), "%s already locked (locks: %s)", s, ls.String()) } @@ -373,15 +412,15 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff isExclusive = true fallthrough case "RUnlock": - if s, ok := ls.unlockField(resolvedValue{value: call.Common().Args[0], valid: true}, isExclusive); !ok { - if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + if s, ok := ls.unlockField(resolvedValue{value: rcvr, valid: true}, isExclusive); !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { // Unlocking something that is already unlocked. pc.maybeFail(call.Pos(), "%s already unlocked or locked differently (locks: %s)", s, ls.String()) } } case "DowngradeLock": if s, ok := ls.downgradeField(resolvedValue{value: call.Common().Args[0], valid: true}); !ok { - if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok { + if _, ok := pc.forced[pc.positionKey(call.Pos())]; !ok && !lff.Ignore { // Downgrading something that may not be downgraded. pc.maybeFail(call.Pos(), "%s already unlocked or not exclusive (locks: %s)", s, ls.String()) } @@ -392,7 +431,7 @@ func (pc *passContext) checkFunctionCall(call callCommon, fn *ssa.Function, lff // checkClosure forks the lock state, and creates a binding for the FreeVars of // the closure. This allows the analysis to resolve the closure. -func (pc *passContext) checkClosure(call callCommon, fn *ssa.MakeClosure, ls *lockState) { +func (pc *passContext) checkClosure(call callCommon, fn *ssa.MakeClosure, lff *lockFunctionFacts, ls *lockState) { clls := ls.fork() clfn := fn.Fn.(*ssa.Function) for i, fv := range clfn.FreeVars { @@ -402,8 +441,10 @@ func (pc *passContext) checkClosure(call callCommon, fn *ssa.MakeClosure, ls *lo // Note that this is *not* a call to check function call, which checks // against the function preconditions. Instead, this does a fresh // analysis of the function from source code with a different state. - var nolff lockFunctionFacts - pc.checkFunction(call, clfn, &nolff, clls, true /* force */) + nlff := lockFunctionFacts{ + Ignore: lff.Ignore, // Inherit ignore. + } + pc.checkFunction(call, clfn, &nlff, clls, true /* force */) } // freshAlloc indicates that v has been allocated within the local scope. There @@ -455,7 +496,7 @@ type callCommon interface { // checkInstruction checks the legality the single instruction based on the // current lockState. -func (pc *passContext) checkInstruction(inst ssa.Instruction, ls *lockState) (*ssa.Return, *lockState) { +func (pc *passContext) checkInstruction(inst ssa.Instruction, lff *lockFunctionFacts, ls *lockState) (*ssa.Return, *lockState) { switch x := inst.(type) { case *ssa.Store: // Record that this value is holding this other value. This is @@ -468,52 +509,55 @@ func (pc *passContext) checkInstruction(inst ssa.Instruction, ls *lockState) (*s // state, but this is intentional. ls.store(x.Addr, x.Val) case *ssa.Field: - if !freshAlloc(x.X) { + if !freshAlloc(x.X) && !lff.Ignore { pc.checkFieldAccess(x, x.X, x.Field, ls, false) } case *ssa.FieldAddr: - if !freshAlloc(x.X) { + if !freshAlloc(x.X) && !lff.Ignore { pc.checkFieldAccess(x, x.X, x.Field, ls, isWrite(x)) } case *ssa.Call: - pc.checkCall(x, ls) + pc.checkCall(x, lff, ls) case *ssa.Defer: ls.pushDefer(x) case *ssa.RunDefers: for d := ls.popDefer(); d != nil; d = ls.popDefer() { - pc.checkCall(d, ls) + pc.checkCall(d, lff, ls) } case *ssa.MakeClosure: - refs := x.Referrers() - if refs == nil { - // This is strange, it's not used? Ignore this case, - // since it will probably be optimized away. - return nil, nil - } - hasNonCall := false - for _, ref := range *refs { - switch ref.(type) { - case *ssa.Call, *ssa.Defer: - // Analysis will be done on the call itself - // subsequently, including the lock state at - // the time of the call. - default: - // We need to analyze separately. Per below, - // this means that we'll analyze at closure - // construction time no zero assumptions about - // when it will be called. - hasNonCall = true + if refs := x.Referrers(); refs != nil { + var ( + calls int + nonCalls int + ) + for _, ref := range *refs { + switch ref.(type) { + case *ssa.Call, *ssa.Defer: + // Analysis will be done on the call + // itself subsequently, including the + // lock state at the time of the call. + calls++ + default: + // We need to analyze separately. Per + // below, this means that we'll analyze + // at closure construction time no zero + // assumptions about when it will be + // called. + nonCalls++ + } + } + if calls > 0 && nonCalls == 0 { + return nil, nil } - } - if !hasNonCall { - return nil, nil } // Analyze the closure without bindings. This means that we // assume no lock facts or have any existing lock state. Only // trivial closures are acceptable in this case. clfn := x.Fn.(*ssa.Function) - var nolff lockFunctionFacts - pc.checkFunction(nil, clfn, &nolff, nil, false /* force */) + nlff := lockFunctionFacts{ + Ignore: lff.Ignore, // Inherit ignore. + } + pc.checkFunction(nil, clfn, &nlff, nil, false /* force */) case *ssa.Return: return x, ls // Valid return state. } @@ -522,11 +566,25 @@ func (pc *passContext) checkInstruction(inst ssa.Instruction, ls *lockState) (*s // checkBasicBlock traverses the control flow graph starting at a set of given // block and checks each instruction for allowed operations. -func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, lff *lockFunctionFacts, parent *lockState, seen map[*ssa.BasicBlock]*lockState) *lockState { +func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, lff *lockFunctionFacts, parent *lockState, seen map[*ssa.BasicBlock]*lockState, rg map[*ssa.BasicBlock]struct{}) *lockState { + // Check for cached results from entering this block from a *different* + // execution path. Note that this is not the same path, which is + // checked with the recursion guard below. if oldLS, ok := seen[block]; ok && oldLS.isCompatible(parent) { return nil } + // Prevent recursion. If the lock state is constantly changing and we + // are a recursive path, then there will never be a return block. + if rg == nil { + rg = make(map[*ssa.BasicBlock]struct{}) + } + if _, ok := rg[block]; ok { + return nil + } + rg[block] = struct{}{} + defer func() { delete(rg, block) }() + // If the lock state is not compatible, then we need to do the // recursive analysis to ensure that it is still sane. For example, the // following is guaranteed to generate incompatible locking states: @@ -548,14 +606,14 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, seen[block] = parent ls := parent.fork() for _, inst := range block.Instrs { - rv, rls = pc.checkInstruction(inst, ls) + rv, rls = pc.checkInstruction(inst, lff, ls) if rls != nil { failed := false // Validate held locks. for fieldName, fg := range lff.HeldOnExit { r := fg.resolveStatic(fn, rv) if s, ok := rls.isHeld(r, fg.Exclusive); !ok { - if _, ok := pc.forced[pc.positionKey(rv.Pos())]; !ok { + if _, ok := pc.forced[pc.positionKey(rv.Pos())]; !ok && !lff.Ignore { pc.maybeFail(rv.Pos(), "lock %s (%s) not held %s (locks: %s)", fieldName, s, exclusiveStr(fg.Exclusive), rls.String()) failed = true } else { @@ -565,7 +623,7 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, } } // Check for other locks, but only if the above didn't trip. - if !failed && rls.count() != len(lff.HeldOnExit) { + if !failed && rls.count() != len(lff.HeldOnExit) && !lff.Ignore { pc.maybeFail(rv.Pos(), "return with unexpected locks held (locks: %s)", rls.String()) } } @@ -578,9 +636,9 @@ func (pc *passContext) checkBasicBlock(fn *ssa.Function, block *ssa.BasicBlock, // above. Note that checkBasicBlock will recursively analyze // the lock state to ensure that Releases and Acquires are // respected. - if pls := pc.checkBasicBlock(fn, succ, lff, ls, seen); pls != nil { + if pls := pc.checkBasicBlock(fn, succ, lff, ls, seen, rg); pls != nil { if rls != nil && !rls.isCompatible(pls) { - if _, ok := pc.forced[pc.positionKey(fn.Pos())]; !ok { + if _, ok := pc.forced[pc.positionKey(fn.Pos())]; !ok && !lff.Ignore { pc.maybeFail(fn.Pos(), "incompatible return states (first: %s, second: %s)", rls.String(), pls.String()) } } @@ -619,12 +677,15 @@ func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *loc // for the method to be invoked. Note that in the overwhleming majority // of cases, parent will be nil. However, in the case of closures and // anonymous functions, we may start with a non-nil lock state. + // + // Note that this will include all aliases, which are also released + // appropriately below. ls := parent.fork() for fieldName, fg := range lff.HeldOnEntry { // The first is the method object itself so we skip that when looking // for receiver/function parameters. r := fg.resolveStatic(fn, call.Value()) - if s, ok := ls.lockField(r, fg.Exclusive); !ok { + if s, ok := ls.lockField(r, fg.Exclusive); !ok && !lff.Ignore { // This can only happen if the same value is declared // multiple times, and should be caught by the earlier // fact scanning. Keep it here as a sanity check. @@ -635,17 +696,17 @@ func (pc *passContext) checkFunction(call callCommon, fn *ssa.Function, lff *loc // Scan the blocks. seen := make(map[*ssa.BasicBlock]*lockState) if len(fn.Blocks) > 0 { - pc.checkBasicBlock(fn, fn.Blocks[0], lff, ls, seen) + pc.checkBasicBlock(fn, fn.Blocks[0], lff, ls, seen, nil) } // Scan the recover block. if fn.Recover != nil { - pc.checkBasicBlock(fn, fn.Recover, lff, ls, seen) + pc.checkBasicBlock(fn, fn.Recover, lff, ls, seen, nil) } // Update all lock state accordingly. This will be called only if we // are doing inline analysis for e.g. an anonymous function. if call != nil && parent != nil { - pc.postFunctionCallUpdate(call, lff, parent) + pc.postFunctionCallUpdate(call, lff, parent, true /* aliases */) } } diff --git a/tools/checklocks/annotations.go b/tools/checklocks/annotations.go index 1f679e5be..950168ee1 100644 --- a/tools/checklocks/annotations.go +++ b/tools/checklocks/annotations.go @@ -32,6 +32,7 @@ const ( checkLocksIgnore = "// +checklocksignore" checkLocksForce = "// +checklocksforce" checkLocksFail = "// +checklocksfail" + checkLocksAlias = "// +checklocksalias:" checkAtomicAnnotation = "// +checkatomic" ) diff --git a/tools/checklocks/checklocks.go b/tools/checklocks/checklocks.go index 401fb55ec..ae8db1a36 100644 --- a/tools/checklocks/checklocks.go +++ b/tools/checklocks/checklocks.go @@ -131,11 +131,6 @@ func run(pass *analysis.Pass) (interface{}, error) { var lff lockFunctionFacts pc.pass.ImportObjectFact(fn.Object(), &lff) - // Do we ignore this? - if lff.Ignore { - continue - } - // Check the basic blocks in the function. pc.checkFunction(nil, fn, &lff, nil, false /* force */) } diff --git a/tools/checklocks/facts.go b/tools/checklocks/facts.go index fd681adc3..17aef5790 100644 --- a/tools/checklocks/facts.go +++ b/tools/checklocks/facts.go @@ -164,6 +164,9 @@ type functionGuard struct { // that the field must be extracted from a tuple. NeedsExtract bool + // IsAlias indicates that this guard is an alias. + IsAlias bool + // FieldList is the traversal path to the object. FieldList fieldList @@ -312,6 +315,36 @@ func (lff *lockFunctionFacts) addReleases(pc *passContext, d *ast.FuncDecl, guar } } +// addAlias adds an alias. +func (lff *lockFunctionFacts) addAlias(pc *passContext, d *ast.FuncDecl, guardName string) { + // Parse the alias. + parts := strings.Split(guardName, "=") + if len(parts) != 2 { + pc.maybeFail(d.Pos(), "invalid annotation %s for alias", guardName) + return + } + + // Parse the actual guard. + fg, ok := lff.checkGuard(pc, d, parts[0], true /* exclusive */, true /* allowReturn */) + if !ok { + return + } + fg.IsAlias = true + + // Find the existing specification. + _, entryOk := lff.HeldOnEntry[parts[1]] + if entryOk { + lff.HeldOnEntry[guardName] = fg + } + _, exitOk := lff.HeldOnExit[parts[1]] + if exitOk { + lff.HeldOnExit[guardName] = fg + } + if !entryOk && !exitOk { + pc.maybeFail(d.Pos(), "alias annotation %s does not refer to an existing guard", guardName) + } +} + // fieldListFor returns the fieldList for the given object. func (pc *passContext) fieldListFor(pos token.Pos, fieldObj types.Object, index int, fieldName string, checkMutex bool, exclusive bool) (int, bool) { var lff lockFieldFacts @@ -403,6 +436,7 @@ func (pc *passContext) resolveField(pos token.Pos, structType *types.Struct, par var ( mutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineMutex|Mutex)") rwMutexRE = regexp.MustCompile("((.*/)|^)sync.(CrossGoroutineRWMutex|RWMutex)") + lockerRE = regexp.MustCompile("((.*/)|^)sync.Locker") ) // exportLockFieldFacts finds all struct fields that are mutexes, and ensures @@ -426,9 +460,14 @@ func (pc *passContext) exportLockFieldFacts(structType *types.Struct, ss *ast.St lff.IsMutex = true case rwMutexRE.MatchString(s): lff.IsRWMutex = true + case lockerRE.MatchString(s): + lff.IsMutex = true } // Save whether this is a pointer. _, lff.IsPointer = fieldObj.Type().Underlying().(*types.Pointer) + if !lff.IsPointer { + _, lff.IsPointer = fieldObj.Type().Underlying().(*types.Interface) + } // We must always export the lockFieldFacts, since traversal // can take place along any object in the struct. pc.pass.ExportObjectFact(fieldObj, lff) @@ -630,6 +669,7 @@ func (pc *passContext) exportFunctionFacts(d *ast.FuncDecl) { checkLocksAcquiresRead: func(guardName string) { lff.addAcquires(pc, d, guardName, false /* exclusive */) }, checkLocksReleases: func(guardName string) { lff.addReleases(pc, d, guardName, true /* exclusive */) }, checkLocksReleasesRead: func(guardName string) { lff.addReleases(pc, d, guardName, false /* exclusive */) }, + checkLocksAlias: func(guardName string) { lff.addAlias(pc, d, guardName) }, }) } diff --git a/tools/checklocks/test/BUILD b/tools/checklocks/test/BUILD index f2ea6c7c6..4b90731f5 100644 --- a/tools/checklocks/test/BUILD +++ b/tools/checklocks/test/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_library( name = "test", srcs = [ + "aliases.go", "alignment.go", "anon.go", "atomics.go", @@ -13,6 +14,7 @@ go_library( "closures.go", "defer.go", "incompat.go", + "locker.go", "methods.go", "parameters.go", "return.go", diff --git a/tools/checklocks/test/aliases.go b/tools/checklocks/test/aliases.go new file mode 100644 index 000000000..e28027fe5 --- /dev/null +++ b/tools/checklocks/test/aliases.go @@ -0,0 +1,26 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +// +checklocks:tc.mu +// +checklocksalias:tc2.mu=tc.mu +func testAliasValid(tc *oneGuardStruct, tc2 *oneGuardStruct) { + tc2.guardedField = 1 +} + +// +checklocks:tc.mu +func testAliasInvalid(tc *oneGuardStruct, tc2 *oneGuardStruct) { + tc2.guardedField = 1 // +checklocksfail +} diff --git a/tools/checklocks/test/branches.go b/tools/checklocks/test/branches.go index 81fec29e5..247885a49 100644 --- a/tools/checklocks/test/branches.go +++ b/tools/checklocks/test/branches.go @@ -54,3 +54,19 @@ func testInconsistentBranching(tc *oneGuardStruct) { // +checklocksfail:2 tc.mu.Unlock() // +checklocksforce } } + +func testUnboundedLocks(tc []*oneGuardStruct) { + for _, l := range tc { + l.mu.Lock() + } + // This test should have the above *not fail*, though the exact + // lock state cannot be tracked through the below. Therefore, we + // expect the next loop to actually fail, and we force the unlock + // loop to succeed in exactly the same way. + for _, l := range tc { + l.guardedField = 1 // +checklocksfail + } + for _, l := range tc { + l.mu.Unlock() // +checklocksforce + } +} diff --git a/tools/checklocks/test/closures.go b/tools/checklocks/test/closures.go index 7da87540a..316d12ce1 100644 --- a/tools/checklocks/test/closures.go +++ b/tools/checklocks/test/closures.go @@ -53,6 +53,15 @@ func testClosureInline(tc *oneGuardStruct) { tc.mu.Unlock() } +// +checklocksignore +func testClosureIgnore(tc *oneGuardStruct) { + // Inherit the checklocksignore. + x := func() { + tc.guardedField = 1 + } + x() +} + func testAnonymousInvalid(tc *oneGuardStruct) { // Invalid, as per testClosureInvalid above. callAnonymous(func(tc *oneGuardStruct) { @@ -89,6 +98,15 @@ func testAnonymousInline(tc *oneGuardStruct) { tc.mu.Unlock() } +// +checklocksignore +func testAnonymousIgnore(tc *oneGuardStruct) { + // Inherit the checklocksignore. + x := func(tc *oneGuardStruct) { + tc.guardedField = 1 + } + x(tc) +} + //go:noinline func callClosure(fn func()) { fn() diff --git a/tools/checklocks/test/incompat.go b/tools/checklocks/test/incompat.go index b39bc66c1..f55fa532d 100644 --- a/tools/checklocks/test/incompat.go +++ b/tools/checklocks/test/incompat.go @@ -18,15 +18,6 @@ import ( "sync" ) -// unsupportedLockerStruct verifies that trying to annotate a field that is not a -// sync.Mutex or sync.RWMutex results in a failure. -type unsupportedLockerStruct struct { - mu sync.Locker - - // +checklocks:mu - x int // +checklocksfail -} - // badFieldsStruct verifies that refering invalid fields fails. type badFieldsStruct struct { // +checklocks:mu diff --git a/tools/checklocks/test/locker.go b/tools/checklocks/test/locker.go new file mode 100644 index 000000000..b0e7d1143 --- /dev/null +++ b/tools/checklocks/test/locker.go @@ -0,0 +1,33 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import "sync" + +type lockerStruct struct { + mu sync.Locker + // +checklocks:mu + guardedField int +} + +func testLockerValid(tc *lockerStruct) { + tc.mu.Lock() + tc.guardedField = 1 + tc.mu.Unlock() +} + +func testLockerInvalid(tc *lockerStruct) { + tc.guardedField = 1 // +checklocksfail +} |