diff options
Diffstat (limited to 'pkg/errors')
-rw-r--r-- | pkg/errors/linuxerr/BUILD | 1 | ||||
-rw-r--r-- | pkg/errors/linuxerr/linuxerr.go | 20 | ||||
-rw-r--r-- | pkg/errors/linuxerr/linuxerr_test.go | 61 |
3 files changed, 82 insertions, 0 deletions
diff --git a/pkg/errors/linuxerr/BUILD b/pkg/errors/linuxerr/BUILD index 8afc9688c..201727780 100644 --- a/pkg/errors/linuxerr/BUILD +++ b/pkg/errors/linuxerr/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/abi/linux/errno", "//pkg/errors", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/errors/linuxerr/linuxerr.go b/pkg/errors/linuxerr/linuxerr.go index bbdcdecd0..9246f2e89 100644 --- a/pkg/errors/linuxerr/linuxerr.go +++ b/pkg/errors/linuxerr/linuxerr.go @@ -20,6 +20,7 @@ package linuxerr import ( "fmt" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux/errno" "gvisor.dev/gvisor/pkg/errors" ) @@ -325,3 +326,22 @@ func ErrorFromErrno(e errno.Errno) *errors.Error { } panic(fmt.Sprintf("invalid error requested with errno: %d", e)) } + +// Equals compars a linuxerr to a given error +// TODO(b/34162363): Remove when syserror is removed. +func Equals(e *errors.Error, err error) bool { + if err == nil { + return e == NOERROR || e == nil + } + if e == nil { + return err == NOERROR || err == unix.Errno(0) + } + + switch err.(type) { + case *errors.Error: + return e == err + case unix.Errno, error: + return unix.Errno(e.Errno()) == err + } + return false +} diff --git a/pkg/errors/linuxerr/linuxerr_test.go b/pkg/errors/linuxerr/linuxerr_test.go index a81dd9560..62743c338 100644 --- a/pkg/errors/linuxerr/linuxerr_test.go +++ b/pkg/errors/linuxerr/linuxerr_test.go @@ -16,6 +16,8 @@ package syserror_test import ( "errors" + "io" + "io/fs" "syscall" "testing" @@ -243,3 +245,62 @@ func TestSyscallErrnoToErrors(t *testing.T) { }) } } + +// TestEqualsMethod tests that the Equals method correctly compares syerror, +// unix.Errno and linuxerr. +// TODO (b/34162363): Remove this. +func TestEqualsMethod(t *testing.T) { + for _, tc := range []struct { + name string + linuxErr []*gErrors.Error + err []error + equal bool + }{ + { + name: "compare nil", + linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, + err: []error{nil, linuxerr.NOERROR, unix.Errno(0)}, + equal: true, + }, + { + name: "linuxerr nil error not", + linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR}, + err: []error{unix.Errno(1), linuxerr.EPERM, syserror.EACCES}, + equal: false, + }, + { + name: "linuxerr not nil error nil", + linuxErr: []*gErrors.Error{linuxerr.ENOENT}, + err: []error{nil, unix.Errno(0), linuxerr.NOERROR}, + equal: false, + }, + { + name: "equal errors", + linuxErr: []*gErrors.Error{linuxerr.ESRCH}, + err: []error{linuxerr.ESRCH, syserror.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())}, + equal: true, + }, + { + name: "unequal errors", + linuxErr: []*gErrors.Error{linuxerr.ENOENT}, + err: []error{linuxerr.ESRCH, syserror.ESRCH, unix.Errno(linuxerr.ESRCH.Errno())}, + equal: false, + }, + { + name: "other error", + linuxErr: []*gErrors.Error{nil, linuxerr.NOERROR, linuxerr.E2BIG, linuxerr.EINVAL}, + err: []error{fs.ErrInvalid, io.EOF}, + equal: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + for _, le := range tc.linuxErr { + for _, e := range tc.err { + if linuxerr.Equals(le, e) != tc.equal { + t.Fatalf("Expected %t from Equals method for linuxerr: %s %T and error: %s %T", tc.equal, le, le, e, e) + } + } + } + }) + } +} |