diff options
author | Zach Koopmans <zkoopmans@google.com> | 2021-03-02 14:08:33 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-03-02 14:10:51 -0800 |
commit | b8a5420f49a2afd622ec08b5019e1bf537f7da82 (patch) | |
tree | ef1e6167acf886322ea42c7056515cec0c24adcd /runsc | |
parent | a317174673562996a98f5a735771955d6651e233 (diff) |
Add reverse flag to mitigate.
Add reverse operation to mitigate that just enables
all CPUs.
PiperOrigin-RevId: 360511215
Diffstat (limited to 'runsc')
-rw-r--r-- | runsc/cmd/mitigate.go | 10 | ||||
-rw-r--r-- | runsc/mitigate/cpu.go | 118 | ||||
-rw-r--r-- | runsc/mitigate/cpu_test.go | 121 | ||||
-rw-r--r-- | runsc/mitigate/mitigate.go | 82 | ||||
-rw-r--r-- | runsc/mitigate/mitigate_conf.go | 2 | ||||
-rw-r--r-- | runsc/mitigate/mitigate_test.go | 139 |
6 files changed, 415 insertions, 57 deletions
diff --git a/runsc/cmd/mitigate.go b/runsc/cmd/mitigate.go index 9052f091d..822af1917 100644 --- a/runsc/cmd/mitigate.go +++ b/runsc/cmd/mitigate.go @@ -16,7 +16,6 @@ package cmd import ( "context" - "io/ioutil" "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" @@ -56,14 +55,7 @@ func (m *Mitigate) Execute(_ context.Context, f *flag.FlagSet, args ...interface return subcommands.ExitUsageError } - const path = "/proc/cpuinfo" - data, err := ioutil.ReadFile(path) - if err != nil { - log.Warningf("Failed to read %s: %v", path, err) - return subcommands.ExitFailure - } - - if err := m.mitigate.Execute(data); err != nil { + if err := m.mitigate.Execute(); err != nil { log.Warningf("Execute failed: %v", err) return subcommands.ExitFailure } diff --git a/runsc/mitigate/cpu.go b/runsc/mitigate/cpu.go index 38f9b787a..4b2aa351f 100644 --- a/runsc/mitigate/cpu.go +++ b/runsc/mitigate/cpu.go @@ -45,7 +45,7 @@ const ( type cpuSet map[cpuID]*threadGroup // newCPUSet creates a CPUSet from data read from /proc/cpuinfo. -func newCPUSet(data []byte, vulnerable func(*thread) bool) (cpuSet, error) { +func newCPUSet(data []byte, vulnerable func(thread) bool) (cpuSet, error) { processors, err := getThreads(string(data)) if err != nil { return nil, err @@ -68,6 +68,26 @@ func newCPUSet(data []byte, vulnerable func(*thread) bool) (cpuSet, error) { return set, nil } +// newCPUSetFromPossible makes a cpuSet data read from +// /sys/devices/system/cpu/possible. This is used in enable operations +// where the caller simply wants to enable all CPUS. +func newCPUSetFromPossible(data []byte) (cpuSet, error) { + threads, err := getThreadsFromPossible(data) + if err != nil { + return nil, err + } + + // We don't care if a CPU is vulnerable or not, we just + // want to return a list of all CPUs on the host. + set := cpuSet{ + threads[0].id: &threadGroup{ + threads: threads, + isVulnerable: false, + }, + } + return set, nil +} + // String implements the String method for CPUSet. func (c cpuSet) String() string { ret := "" @@ -79,8 +99,8 @@ func (c cpuSet) String() string { // getRemainingList returns the list of threads that will remain active // after mitigation. -func (c cpuSet) getRemainingList() []*thread { - threads := make([]*thread, 0, len(c)) +func (c cpuSet) getRemainingList() []thread { + threads := make([]thread, 0, len(c)) for _, core := range c { // If we're vulnerable, take only one thread from the pair. if core.isVulnerable { @@ -95,8 +115,8 @@ func (c cpuSet) getRemainingList() []*thread { // getShutdownList returns the list of threads that will be shutdown on // mitigation. -func (c cpuSet) getShutdownList() []*thread { - threads := make([]*thread, 0) +func (c cpuSet) getShutdownList() []thread { + threads := make([]thread, 0) for _, core := range c { // Only if we're vulnerable do shutdown anything. In this case, // shutdown all but the first entry. @@ -109,12 +129,12 @@ func (c cpuSet) getShutdownList() []*thread { // threadGroup represents Hyperthread pairs on the same physical/core ID. type threadGroup struct { - threads []*thread + threads []thread isVulnerable bool } // String implements the String method for threadGroup. -func (c *threadGroup) String() string { +func (c threadGroup) String() string { ret := fmt.Sprintf("ThreadGroup:\nIsVulnerable: %t\n", c.isVulnerable) for _, processor := range c.threads { ret += fmt.Sprintf("%s\n", processor) @@ -123,13 +143,13 @@ func (c *threadGroup) String() string { } // getThreads returns threads structs from reading /proc/cpuinfo. -func getThreads(data string) ([]*thread, error) { +func getThreads(data string) ([]thread, error) { // Each processor entry should start with the // processor key. Find the beginings of each. r := buildRegex(processorKey, `\d+`) indices := r.FindAllStringIndex(data, -1) if len(indices) < 1 { - return nil, fmt.Errorf("no cpus found for: %s", data) + return nil, fmt.Errorf("no cpus found for: %q", data) } // Add the ending index for last entry. @@ -139,7 +159,7 @@ func getThreads(data string) ([]*thread, error) { // indexes (e.g. data[index[i], index[i+1]]). // There should be len(indicies) - 1 CPUs // since the last index is the end of the string. - var cpus = make([]*thread, 0, len(indices)-1) + cpus := make([]thread, 0, len(indices)) // Find each string that represents a CPU. These begin "processor". for i := 1; i < len(indices); i++ { start := indices[i-1][0] @@ -154,6 +174,45 @@ func getThreads(data string) ([]*thread, error) { return cpus, nil } +// getThreadsFromPossible makes threads from data read from /sys/devices/system/cpu/possible. +func getThreadsFromPossible(data []byte) ([]thread, error) { + possibleRegex := regexp.MustCompile(`(?m)^(\d+)(-(\d+))?$`) + matches := possibleRegex.FindStringSubmatch(string(data)) + if len(matches) != 4 { + return nil, fmt.Errorf("mismatch regex from %s: %q", allPossibleCPUs, string(data)) + } + + // If matches[3] is empty, we only have one cpu entry. + if matches[3] == "" { + matches[3] = matches[1] + } + + begin, err := strconv.ParseInt(matches[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse begin: %v", err) + } + end, err := strconv.ParseInt(matches[3], 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse end: %v", err) + } + if begin > end || begin < 0 || end < 0 { + return nil, fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", begin, end) + } + + ret := make([]thread, 0, end-begin) + for i := begin; i <= end; i++ { + ret = append(ret, thread{ + processorNumber: i, + id: cpuID{ + physicalID: 0, // we don't care about id for enable ops. + coreID: 0, + }, + }) + } + + return ret, nil +} + // cpuID for each thread is defined by the physical and // core IDs. If equal, two threads are Hyperthread pairs. type cpuID struct { @@ -172,43 +231,44 @@ type thread struct { } // newThread parses a CPU from a single cpu entry from /proc/cpuinfo. -func newThread(data string) (*thread, error) { +func newThread(data string) (thread, error) { + empty := thread{} processor, err := parseProcessor(data) if err != nil { - return nil, err + return empty, err } vendorID, err := parseVendorID(data) if err != nil { - return nil, err + return empty, err } cpuFamily, err := parseCPUFamily(data) if err != nil { - return nil, err + return empty, err } model, err := parseModel(data) if err != nil { - return nil, err + return empty, err } physicalID, err := parsePhysicalID(data) if err != nil { - return nil, err + return empty, err } coreID, err := parseCoreID(data) if err != nil { - return nil, err + return empty, err } bugs, err := parseBugs(data) if err != nil { - return nil, err + return empty, err } - return &thread{ + return thread{ processorNumber: processor, vendorID: vendorID, cpuFamily: cpuFamily, @@ -222,7 +282,7 @@ func newThread(data string) (*thread, error) { } // String implements the String method for thread. -func (t *thread) String() string { +func (t thread) String() string { template := `CPU: %d CPU ID: %+v Vendor: %s @@ -237,21 +297,27 @@ Bugs: %s return fmt.Sprintf(template, t.processorNumber, t.id, t.vendorID, t.cpuFamily, t.model, strings.Join(bugs, ",")) } -// shutdown turns off the CPU by writing 0 to /sys/devices/cpu/cpu{N}/online. -func (t *thread) shutdown() error { +// enable turns on the CPU by writing 1 to /sys/devices/cpu/cpu{N}/online. +func (t thread) enable() error { + cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) + return ioutil.WriteFile(cpuPath, []byte{'1'}, 0644) +} + +// disable turns off the CPU by writing 0 to /sys/devices/cpu/cpu{N}/online. +func (t thread) disable() error { cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) return ioutil.WriteFile(cpuPath, []byte{'0'}, 0644) } // isVulnerable checks if a CPU is vulnerable to mds. -func (t *thread) isVulnerable() bool { +func (t thread) isVulnerable() bool { _, ok := t.bugs[mds] return ok } // isActive checks if a CPU is active from /sys/devices/system/cpu/cpu{N}/online // If the file does not exist (ioutil returns in error), we assume the CPU is on. -func (t *thread) isActive() bool { +func (t thread) isActive() bool { cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) data, err := ioutil.ReadFile(cpuPath) if err != nil { @@ -262,7 +328,7 @@ func (t *thread) isActive() bool { // similarTo checks family/model/bugs fields for equality of two // processors. -func (t *thread) similarTo(other *thread) bool { +func (t thread) similarTo(other thread) bool { if t.vendorID != other.vendorID { return false } @@ -351,7 +417,7 @@ func parseRegex(data, key, match string) (string, error) { r := buildRegex(key, match) matches := r.FindStringSubmatch(data) if len(matches) < 2 { - return "", fmt.Errorf("failed to match key %s: %s", key, data) + return "", fmt.Errorf("failed to match key %q: %q", key, data) } return matches[1], nil } diff --git a/runsc/mitigate/cpu_test.go b/runsc/mitigate/cpu_test.go index 21c12f586..374333465 100644 --- a/runsc/mitigate/cpu_test.go +++ b/runsc/mitigate/cpu_test.go @@ -21,8 +21,8 @@ import ( "testing" ) -// cpuTestCase represents data from CPUs that will be mitigated. -type cpuTestCase struct { +// mockCPU represents data from CPUs that will be mitigated. +type mockCPU struct { name string vendorID string family int @@ -34,7 +34,7 @@ type cpuTestCase struct { threadsPerCore int } -var cascadeLake4 = cpuTestCase{ +var cascadeLake4 = mockCPU{ name: "CascadeLake", vendorID: "GenuineIntel", family: 6, @@ -46,7 +46,7 @@ var cascadeLake4 = cpuTestCase{ threadsPerCore: 2, } -var haswell2 = cpuTestCase{ +var haswell2 = mockCPU{ name: "Haswell", vendorID: "GenuineIntel", family: 6, @@ -58,7 +58,7 @@ var haswell2 = cpuTestCase{ threadsPerCore: 2, } -var haswell2core = cpuTestCase{ +var haswell2core = mockCPU{ name: "Haswell2Physical", vendorID: "GenuineIntel", family: 6, @@ -70,7 +70,7 @@ var haswell2core = cpuTestCase{ threadsPerCore: 1, } -var amd8 = cpuTestCase{ +var amd8 = mockCPU{ name: "AMD", vendorID: "AuthenticAMD", family: 23, @@ -83,7 +83,7 @@ var amd8 = cpuTestCase{ } // makeCPUString makes a string formated like /proc/cpuinfo for each cpuTestCase -func (tc cpuTestCase) makeCPUString() string { +func (tc mockCPU) makeCPUString() string { template := `processor : %d vendor_id : %s cpu family : %d @@ -115,10 +115,18 @@ bugs : %s return ret } +func (tc mockCPU) makeSysPossibleString() string { + max := tc.physicalCores * tc.cores * tc.threadsPerCore + if max == 1 { + return "0" + } + return fmt.Sprintf("0-%d", max-1) +} + // TestMockCPUSet tests mock cpu test cases against the cpuSet functions. func TestMockCPUSet(t *testing.T) { for _, tc := range []struct { - testCase cpuTestCase + testCase mockCPU isVulnerable bool }{ { @@ -141,7 +149,7 @@ func TestMockCPUSet(t *testing.T) { } { t.Run(tc.testCase.name, func(t *testing.T) { data := tc.testCase.makeCPUString() - vulnerable := func(t *thread) bool { + vulnerable := func(t thread) bool { return t.isVulnerable() } set, err := newCPUSet([]byte(data), vulnerable) @@ -170,6 +178,18 @@ func TestMockCPUSet(t *testing.T) { } delete(set, r.id) } + + possible := tc.testCase.makeSysPossibleString() + set, err = newCPUSetFromPossible([]byte(possible)) + if err != nil { + t.Fatalf("Failed to make cpuSet: %v", err) + } + + want = tc.testCase.physicalCores * tc.testCase.cores * tc.testCase.threadsPerCore + got := len(set.getRemainingList()) + if got != want { + t.Fatalf("Returned the wrong number of CPUs want: %d got: %d", want, got) + } }) } } @@ -328,7 +348,7 @@ func TestReadFile(t *testing.T) { t.Fatalf("Failed to read cpuinfo: %v", err) } - vulnerable := func(t *thread) bool { + vulnerable := func(t thread) bool { return t.isVulnerable() } @@ -502,3 +522,84 @@ power management:` }) } } + +func TestReverse(t *testing.T) { + const noParse = "-1-" + for _, tc := range []struct { + name string + output string + wantErr error + wantCount int + }{ + { + name: "base", + output: "0-7", + wantErr: nil, + wantCount: 8, + }, + { + name: "huge", + output: "0-111", + wantErr: nil, + wantCount: 112, + }, + { + name: "not zero", + output: "50-53", + wantErr: nil, + wantCount: 4, + }, + { + name: "small", + output: "0", + wantErr: nil, + wantCount: 1, + }, + { + name: "invalid order", + output: "10-6", + wantErr: fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", 10, 6), + }, + { + name: "no parse", + output: noParse, + wantErr: fmt.Errorf(`mismatch regex from /sys/devices/system/cpu/possible: %q`, noParse), + }, + } { + t.Run(tc.name, func(t *testing.T) { + threads, err := getThreadsFromPossible([]byte(tc.output)) + + switch { + case tc.wantErr == nil: + if err != nil { + t.Fatalf("Wanted nil err, got: %v", err) + } + case err == nil: + t.Fatalf("Want error: %v got: %v", tc.wantErr, err) + default: + if tc.wantErr.Error() != err.Error() { + t.Fatalf("Want error: %v got error: %v", tc.wantErr, err) + } + } + + if len(threads) != tc.wantCount { + t.Fatalf("Want count: %d got: %d", tc.wantCount, len(threads)) + } + }) + } +} + +func TestReverseSmoke(t *testing.T) { + data, err := ioutil.ReadFile(allPossibleCPUs) + if err != nil { + t.Fatalf("Failed to read from possible: %v", err) + } + threads, err := getThreadsFromPossible(data) + if err != nil { + t.Fatalf("Could not parse possible output: %v", err) + } + + if len(threads) <= 0 { + t.Fatalf("Didn't get any CPU cores: %d", len(threads)) + } +} diff --git a/runsc/mitigate/mitigate.go b/runsc/mitigate/mitigate.go index 3ea58454f..91de623e3 100644 --- a/runsc/mitigate/mitigate.go +++ b/runsc/mitigate/mitigate.go @@ -21,15 +21,23 @@ package mitigate import ( "fmt" + "io/ioutil" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/flag" ) +const ( + cpuInfo = "/proc/cpuinfo" + allPossibleCPUs = "/sys/devices/system/cpu/possible" +) + // Mitigate handles high level mitigate operations provided to runsc. type Mitigate struct { - dryRun bool // Run the command without changing the underlying system. - other mitigate // Struct holds extra mitigate logic. + dryRun bool // Run the command without changing the underlying system. + reverse bool // Reverse mitigate by turning on all CPU cores. + other mitigate // Struct holds extra mitigate logic. + path string // path to read for each operation (e.g. /proc/cpuinfo). } // Usage implments Usage for cmd.Mitigate. @@ -37,6 +45,8 @@ func (m Mitigate) Usage() string { usageString := `mitigate [flags] Mitigate mitigates a system to the "MDS" vulnerability by implementing a manual shutdown of SMT. The command checks /proc/cpuinfo for cpus having the MDS vulnerability, and if found, shutdown all but one CPU per hyperthread pair via /sys/devices/system/cpu/cpu{N}/online. CPUs can be restored by writing "2" to each file in /sys/devices/system/cpu/cpu{N}/online or performing a system reboot. + +The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online. ` return usageString + m.other.usage() } @@ -44,31 +54,81 @@ Mitigate mitigates a system to the "MDS" vulnerability by implementing a manual // SetFlags sets flags for the command Mitigate. func (m Mitigate) SetFlags(f *flag.FlagSet) { f.BoolVar(&m.dryRun, "dryrun", false, "run the command without changing system") + f.BoolVar(&m.reverse, "reverse", false, "reverse mitigate by enabling all CPUs") m.other.setFlags(f) + m.path = cpuInfo + if m.reverse { + m.path = allPossibleCPUs + } } // Execute executes the Mitigate command. -func (m Mitigate) Execute(data []byte) error { +func (m Mitigate) Execute() error { + data, err := ioutil.ReadFile(m.path) + if err != nil { + return fmt.Errorf("failed to read %s: %v", m.path, err) + } + + if m.reverse { + err := m.doReverse(data) + if err != nil { + return fmt.Errorf("reverse operation failed: %v", err) + } + return nil + } + + set, err := m.doMitigate(data) + if err != nil { + return fmt.Errorf("mitigate operation failed: %v", err) + } + return m.other.execute(set, m.dryRun) +} + +func (m Mitigate) doMitigate(data []byte) (cpuSet, error) { set, err := newCPUSet(data, m.other.vulnerable) if err != nil { - return err + return nil, err } log.Infof("Mitigate found the following CPUs...") log.Infof("%s", set) - shutdownList := set.getShutdownList() - log.Infof("Shutting down threads on thread pairs.") - for _, t := range shutdownList { - log.Infof("Shutting down thread: %s", t) + disableList := set.getShutdownList() + log.Infof("Disabling threads on thread pairs.") + for _, t := range disableList { + log.Infof("Disable thread: %s", t) if m.dryRun { continue } - if err := t.shutdown(); err != nil { - return fmt.Errorf("error shutting down thread: %s err: %v", t, err) + if err := t.disable(); err != nil { + return nil, fmt.Errorf("error disabling thread: %s err: %v", t, err) } } log.Infof("Shutdown successful.") - m.other.execute(set, m.dryRun) + return set, nil +} + +func (m Mitigate) doReverse(data []byte) error { + set, err := newCPUSetFromPossible(data) + if err != nil { + return err + } + + log.Infof("Reverse mitigate found the following CPUs...") + log.Infof("%s", set) + + enableList := set.getRemainingList() + + log.Infof("Enabling all CPUs...") + for _, t := range enableList { + log.Infof("Enabling thread: %s", t) + if m.dryRun { + continue + } + if err := t.enable(); err != nil { + return fmt.Errorf("error enabling thread: %s err: %v", t, err) + } + } + log.Infof("Enable successful.") return nil } diff --git a/runsc/mitigate/mitigate_conf.go b/runsc/mitigate/mitigate_conf.go index 1e74f5891..ee326324b 100644 --- a/runsc/mitigate/mitigate_conf.go +++ b/runsc/mitigate/mitigate_conf.go @@ -32,6 +32,6 @@ func (m mitigate) execute(set cpuSet, dryrun bool) error { return nil } -func (m mitigate) vulnerable(other *thread) bool { +func (m mitigate) vulnerable(other thread) bool { return other.isVulnerable() } diff --git a/runsc/mitigate/mitigate_test.go b/runsc/mitigate/mitigate_test.go index c6c825b72..b3a9a9b18 100644 --- a/runsc/mitigate/mitigate_test.go +++ b/runsc/mitigate/mitigate_test.go @@ -13,3 +13,142 @@ // limitations under the License. package mitigate + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +type executeTestCase struct { + name string + mitigateData string + mitigateError error + reverseData string + reverseError error +} + +func TestExecute(t *testing.T) { + + partial := `processor : 1 +vendor_id : AuthenticAMD +cpu family : 23 +model : 49 +model name : AMD EPYC 7B12 +physical id : 0 +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass +power management: +` + + for _, tc := range []executeTestCase{ + { + name: "CascadeLake4", + mitigateData: cascadeLake4.makeCPUString(), + reverseData: cascadeLake4.makeSysPossibleString(), + }, + { + name: "Empty", + mitigateData: "", + mitigateError: fmt.Errorf(`mitigate operation failed: no cpus found for: ""`), + reverseData: "", + reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from %s: ""`, allPossibleCPUs), + }, + { + name: "Partial", + mitigateData: `processor : 0 +vendor_id : AuthenticAMD +cpu family : 23 +model : 49 +model name : AMD EPYC 7B12 +physical id : 0 +core id : 0 +cpu cores : 1 +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass +power management: + +` + partial, + mitigateError: fmt.Errorf(`mitigate operation failed: failed to match key "core id": %q`, partial), + reverseData: "1-", + reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from %s: %q`, allPossibleCPUs, "1-"), + }, + } { + doExecuteTest(t, Mitigate{}, tc) + } +} + +func TestExecuteSmoke(t *testing.T) { + smokeMitigate, err := ioutil.ReadFile(cpuInfo) + if err != nil { + t.Fatalf("Failed to read %s: %v", cpuInfo, err) + } + smokeReverse, err := ioutil.ReadFile(allPossibleCPUs) + if err != nil { + t.Fatalf("Failed to read %s: %v", allPossibleCPUs, err) + } + doExecuteTest(t, Mitigate{}, executeTestCase{ + name: "SmokeTest", + mitigateData: string(smokeMitigate), + reverseData: string(smokeReverse), + }) + +} + +// doExecuteTest runs Execute with the mitigate operation and reverse operation. +func doExecuteTest(t *testing.T, m Mitigate, tc executeTestCase) { + t.Run("Mitigate"+tc.name, func(t *testing.T) { + m.dryRun = true + file, err := ioutil.TempFile("", "outfile.txt") + if err != nil { + t.Fatalf("Failed to create tmpfile: %v", err) + } + defer os.Remove(file.Name()) + + if _, err := file.WriteString(tc.mitigateData); err != nil { + t.Fatalf("Failed to write to file: %v", err) + } + + m.path = file.Name() + + got := m.Execute() + if err = checkErr(tc.mitigateError, got); err != nil { + t.Fatalf("Mitigate error mismatch: %v", err) + } + }) + t.Run("Reverse"+tc.name, func(t *testing.T) { + m.dryRun = true + m.reverse = true + + file, err := ioutil.TempFile("", "outfile.txt") + if err != nil { + t.Fatalf("Failed to create tmpfile: %v", err) + } + defer os.Remove(file.Name()) + + if _, err := file.WriteString(tc.reverseData); err != nil { + t.Fatalf("Failed to write to file: %v", err) + } + + m.path = file.Name() + got := m.Execute() + if err = checkErr(tc.reverseError, got); err != nil { + t.Fatalf("Mitigate error mismatch: %v", err) + } + }) + +} + +// checkErr checks error for equality. +func checkErr(want, got error) error { + switch { + case want == nil && got == nil: + case want != nil && got == nil: + fallthrough + case want == nil && got != nil: + fallthrough + case want.Error() != strings.Trim(got.Error(), " "): + return fmt.Errorf("got: %v want: %v", got, want) + } + return nil +} |