summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--runsc/cmd/mitigate.go10
-rw-r--r--runsc/mitigate/cpu.go118
-rw-r--r--runsc/mitigate/cpu_test.go121
-rw-r--r--runsc/mitigate/mitigate.go82
-rw-r--r--runsc/mitigate/mitigate_conf.go2
-rw-r--r--runsc/mitigate/mitigate_test.go139
6 files changed, 415 insertions, 57 deletions
diff --git a/runsc/cmd/mitigate.go b/runsc/cmd/mitigate.go
index 9052f091d..822af1917 100644
--- a/runsc/cmd/mitigate.go
+++ b/runsc/cmd/mitigate.go
@@ -16,7 +16,6 @@ package cmd
import (
"context"
- "io/ioutil"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
@@ -56,14 +55,7 @@ func (m *Mitigate) Execute(_ context.Context, f *flag.FlagSet, args ...interface
return subcommands.ExitUsageError
}
- const path = "/proc/cpuinfo"
- data, err := ioutil.ReadFile(path)
- if err != nil {
- log.Warningf("Failed to read %s: %v", path, err)
- return subcommands.ExitFailure
- }
-
- if err := m.mitigate.Execute(data); err != nil {
+ if err := m.mitigate.Execute(); err != nil {
log.Warningf("Execute failed: %v", err)
return subcommands.ExitFailure
}
diff --git a/runsc/mitigate/cpu.go b/runsc/mitigate/cpu.go
index 38f9b787a..4b2aa351f 100644
--- a/runsc/mitigate/cpu.go
+++ b/runsc/mitigate/cpu.go
@@ -45,7 +45,7 @@ const (
type cpuSet map[cpuID]*threadGroup
// newCPUSet creates a CPUSet from data read from /proc/cpuinfo.
-func newCPUSet(data []byte, vulnerable func(*thread) bool) (cpuSet, error) {
+func newCPUSet(data []byte, vulnerable func(thread) bool) (cpuSet, error) {
processors, err := getThreads(string(data))
if err != nil {
return nil, err
@@ -68,6 +68,26 @@ func newCPUSet(data []byte, vulnerable func(*thread) bool) (cpuSet, error) {
return set, nil
}
+// newCPUSetFromPossible makes a cpuSet data read from
+// /sys/devices/system/cpu/possible. This is used in enable operations
+// where the caller simply wants to enable all CPUS.
+func newCPUSetFromPossible(data []byte) (cpuSet, error) {
+ threads, err := getThreadsFromPossible(data)
+ if err != nil {
+ return nil, err
+ }
+
+ // We don't care if a CPU is vulnerable or not, we just
+ // want to return a list of all CPUs on the host.
+ set := cpuSet{
+ threads[0].id: &threadGroup{
+ threads: threads,
+ isVulnerable: false,
+ },
+ }
+ return set, nil
+}
+
// String implements the String method for CPUSet.
func (c cpuSet) String() string {
ret := ""
@@ -79,8 +99,8 @@ func (c cpuSet) String() string {
// getRemainingList returns the list of threads that will remain active
// after mitigation.
-func (c cpuSet) getRemainingList() []*thread {
- threads := make([]*thread, 0, len(c))
+func (c cpuSet) getRemainingList() []thread {
+ threads := make([]thread, 0, len(c))
for _, core := range c {
// If we're vulnerable, take only one thread from the pair.
if core.isVulnerable {
@@ -95,8 +115,8 @@ func (c cpuSet) getRemainingList() []*thread {
// getShutdownList returns the list of threads that will be shutdown on
// mitigation.
-func (c cpuSet) getShutdownList() []*thread {
- threads := make([]*thread, 0)
+func (c cpuSet) getShutdownList() []thread {
+ threads := make([]thread, 0)
for _, core := range c {
// Only if we're vulnerable do shutdown anything. In this case,
// shutdown all but the first entry.
@@ -109,12 +129,12 @@ func (c cpuSet) getShutdownList() []*thread {
// threadGroup represents Hyperthread pairs on the same physical/core ID.
type threadGroup struct {
- threads []*thread
+ threads []thread
isVulnerable bool
}
// String implements the String method for threadGroup.
-func (c *threadGroup) String() string {
+func (c threadGroup) String() string {
ret := fmt.Sprintf("ThreadGroup:\nIsVulnerable: %t\n", c.isVulnerable)
for _, processor := range c.threads {
ret += fmt.Sprintf("%s\n", processor)
@@ -123,13 +143,13 @@ func (c *threadGroup) String() string {
}
// getThreads returns threads structs from reading /proc/cpuinfo.
-func getThreads(data string) ([]*thread, error) {
+func getThreads(data string) ([]thread, error) {
// Each processor entry should start with the
// processor key. Find the beginings of each.
r := buildRegex(processorKey, `\d+`)
indices := r.FindAllStringIndex(data, -1)
if len(indices) < 1 {
- return nil, fmt.Errorf("no cpus found for: %s", data)
+ return nil, fmt.Errorf("no cpus found for: %q", data)
}
// Add the ending index for last entry.
@@ -139,7 +159,7 @@ func getThreads(data string) ([]*thread, error) {
// indexes (e.g. data[index[i], index[i+1]]).
// There should be len(indicies) - 1 CPUs
// since the last index is the end of the string.
- var cpus = make([]*thread, 0, len(indices)-1)
+ cpus := make([]thread, 0, len(indices))
// Find each string that represents a CPU. These begin "processor".
for i := 1; i < len(indices); i++ {
start := indices[i-1][0]
@@ -154,6 +174,45 @@ func getThreads(data string) ([]*thread, error) {
return cpus, nil
}
+// getThreadsFromPossible makes threads from data read from /sys/devices/system/cpu/possible.
+func getThreadsFromPossible(data []byte) ([]thread, error) {
+ possibleRegex := regexp.MustCompile(`(?m)^(\d+)(-(\d+))?$`)
+ matches := possibleRegex.FindStringSubmatch(string(data))
+ if len(matches) != 4 {
+ return nil, fmt.Errorf("mismatch regex from %s: %q", allPossibleCPUs, string(data))
+ }
+
+ // If matches[3] is empty, we only have one cpu entry.
+ if matches[3] == "" {
+ matches[3] = matches[1]
+ }
+
+ begin, err := strconv.ParseInt(matches[1], 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse begin: %v", err)
+ }
+ end, err := strconv.ParseInt(matches[3], 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse end: %v", err)
+ }
+ if begin > end || begin < 0 || end < 0 {
+ return nil, fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", begin, end)
+ }
+
+ ret := make([]thread, 0, end-begin)
+ for i := begin; i <= end; i++ {
+ ret = append(ret, thread{
+ processorNumber: i,
+ id: cpuID{
+ physicalID: 0, // we don't care about id for enable ops.
+ coreID: 0,
+ },
+ })
+ }
+
+ return ret, nil
+}
+
// cpuID for each thread is defined by the physical and
// core IDs. If equal, two threads are Hyperthread pairs.
type cpuID struct {
@@ -172,43 +231,44 @@ type thread struct {
}
// newThread parses a CPU from a single cpu entry from /proc/cpuinfo.
-func newThread(data string) (*thread, error) {
+func newThread(data string) (thread, error) {
+ empty := thread{}
processor, err := parseProcessor(data)
if err != nil {
- return nil, err
+ return empty, err
}
vendorID, err := parseVendorID(data)
if err != nil {
- return nil, err
+ return empty, err
}
cpuFamily, err := parseCPUFamily(data)
if err != nil {
- return nil, err
+ return empty, err
}
model, err := parseModel(data)
if err != nil {
- return nil, err
+ return empty, err
}
physicalID, err := parsePhysicalID(data)
if err != nil {
- return nil, err
+ return empty, err
}
coreID, err := parseCoreID(data)
if err != nil {
- return nil, err
+ return empty, err
}
bugs, err := parseBugs(data)
if err != nil {
- return nil, err
+ return empty, err
}
- return &thread{
+ return thread{
processorNumber: processor,
vendorID: vendorID,
cpuFamily: cpuFamily,
@@ -222,7 +282,7 @@ func newThread(data string) (*thread, error) {
}
// String implements the String method for thread.
-func (t *thread) String() string {
+func (t thread) String() string {
template := `CPU: %d
CPU ID: %+v
Vendor: %s
@@ -237,21 +297,27 @@ Bugs: %s
return fmt.Sprintf(template, t.processorNumber, t.id, t.vendorID, t.cpuFamily, t.model, strings.Join(bugs, ","))
}
-// shutdown turns off the CPU by writing 0 to /sys/devices/cpu/cpu{N}/online.
-func (t *thread) shutdown() error {
+// enable turns on the CPU by writing 1 to /sys/devices/cpu/cpu{N}/online.
+func (t thread) enable() error {
+ cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber)
+ return ioutil.WriteFile(cpuPath, []byte{'1'}, 0644)
+}
+
+// disable turns off the CPU by writing 0 to /sys/devices/cpu/cpu{N}/online.
+func (t thread) disable() error {
cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber)
return ioutil.WriteFile(cpuPath, []byte{'0'}, 0644)
}
// isVulnerable checks if a CPU is vulnerable to mds.
-func (t *thread) isVulnerable() bool {
+func (t thread) isVulnerable() bool {
_, ok := t.bugs[mds]
return ok
}
// isActive checks if a CPU is active from /sys/devices/system/cpu/cpu{N}/online
// If the file does not exist (ioutil returns in error), we assume the CPU is on.
-func (t *thread) isActive() bool {
+func (t thread) isActive() bool {
cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber)
data, err := ioutil.ReadFile(cpuPath)
if err != nil {
@@ -262,7 +328,7 @@ func (t *thread) isActive() bool {
// similarTo checks family/model/bugs fields for equality of two
// processors.
-func (t *thread) similarTo(other *thread) bool {
+func (t thread) similarTo(other thread) bool {
if t.vendorID != other.vendorID {
return false
}
@@ -351,7 +417,7 @@ func parseRegex(data, key, match string) (string, error) {
r := buildRegex(key, match)
matches := r.FindStringSubmatch(data)
if len(matches) < 2 {
- return "", fmt.Errorf("failed to match key %s: %s", key, data)
+ return "", fmt.Errorf("failed to match key %q: %q", key, data)
}
return matches[1], nil
}
diff --git a/runsc/mitigate/cpu_test.go b/runsc/mitigate/cpu_test.go
index 21c12f586..374333465 100644
--- a/runsc/mitigate/cpu_test.go
+++ b/runsc/mitigate/cpu_test.go
@@ -21,8 +21,8 @@ import (
"testing"
)
-// cpuTestCase represents data from CPUs that will be mitigated.
-type cpuTestCase struct {
+// mockCPU represents data from CPUs that will be mitigated.
+type mockCPU struct {
name string
vendorID string
family int
@@ -34,7 +34,7 @@ type cpuTestCase struct {
threadsPerCore int
}
-var cascadeLake4 = cpuTestCase{
+var cascadeLake4 = mockCPU{
name: "CascadeLake",
vendorID: "GenuineIntel",
family: 6,
@@ -46,7 +46,7 @@ var cascadeLake4 = cpuTestCase{
threadsPerCore: 2,
}
-var haswell2 = cpuTestCase{
+var haswell2 = mockCPU{
name: "Haswell",
vendorID: "GenuineIntel",
family: 6,
@@ -58,7 +58,7 @@ var haswell2 = cpuTestCase{
threadsPerCore: 2,
}
-var haswell2core = cpuTestCase{
+var haswell2core = mockCPU{
name: "Haswell2Physical",
vendorID: "GenuineIntel",
family: 6,
@@ -70,7 +70,7 @@ var haswell2core = cpuTestCase{
threadsPerCore: 1,
}
-var amd8 = cpuTestCase{
+var amd8 = mockCPU{
name: "AMD",
vendorID: "AuthenticAMD",
family: 23,
@@ -83,7 +83,7 @@ var amd8 = cpuTestCase{
}
// makeCPUString makes a string formated like /proc/cpuinfo for each cpuTestCase
-func (tc cpuTestCase) makeCPUString() string {
+func (tc mockCPU) makeCPUString() string {
template := `processor : %d
vendor_id : %s
cpu family : %d
@@ -115,10 +115,18 @@ bugs : %s
return ret
}
+func (tc mockCPU) makeSysPossibleString() string {
+ max := tc.physicalCores * tc.cores * tc.threadsPerCore
+ if max == 1 {
+ return "0"
+ }
+ return fmt.Sprintf("0-%d", max-1)
+}
+
// TestMockCPUSet tests mock cpu test cases against the cpuSet functions.
func TestMockCPUSet(t *testing.T) {
for _, tc := range []struct {
- testCase cpuTestCase
+ testCase mockCPU
isVulnerable bool
}{
{
@@ -141,7 +149,7 @@ func TestMockCPUSet(t *testing.T) {
} {
t.Run(tc.testCase.name, func(t *testing.T) {
data := tc.testCase.makeCPUString()
- vulnerable := func(t *thread) bool {
+ vulnerable := func(t thread) bool {
return t.isVulnerable()
}
set, err := newCPUSet([]byte(data), vulnerable)
@@ -170,6 +178,18 @@ func TestMockCPUSet(t *testing.T) {
}
delete(set, r.id)
}
+
+ possible := tc.testCase.makeSysPossibleString()
+ set, err = newCPUSetFromPossible([]byte(possible))
+ if err != nil {
+ t.Fatalf("Failed to make cpuSet: %v", err)
+ }
+
+ want = tc.testCase.physicalCores * tc.testCase.cores * tc.testCase.threadsPerCore
+ got := len(set.getRemainingList())
+ if got != want {
+ t.Fatalf("Returned the wrong number of CPUs want: %d got: %d", want, got)
+ }
})
}
}
@@ -328,7 +348,7 @@ func TestReadFile(t *testing.T) {
t.Fatalf("Failed to read cpuinfo: %v", err)
}
- vulnerable := func(t *thread) bool {
+ vulnerable := func(t thread) bool {
return t.isVulnerable()
}
@@ -502,3 +522,84 @@ power management:`
})
}
}
+
+func TestReverse(t *testing.T) {
+ const noParse = "-1-"
+ for _, tc := range []struct {
+ name string
+ output string
+ wantErr error
+ wantCount int
+ }{
+ {
+ name: "base",
+ output: "0-7",
+ wantErr: nil,
+ wantCount: 8,
+ },
+ {
+ name: "huge",
+ output: "0-111",
+ wantErr: nil,
+ wantCount: 112,
+ },
+ {
+ name: "not zero",
+ output: "50-53",
+ wantErr: nil,
+ wantCount: 4,
+ },
+ {
+ name: "small",
+ output: "0",
+ wantErr: nil,
+ wantCount: 1,
+ },
+ {
+ name: "invalid order",
+ output: "10-6",
+ wantErr: fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", 10, 6),
+ },
+ {
+ name: "no parse",
+ output: noParse,
+ wantErr: fmt.Errorf(`mismatch regex from /sys/devices/system/cpu/possible: %q`, noParse),
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ threads, err := getThreadsFromPossible([]byte(tc.output))
+
+ switch {
+ case tc.wantErr == nil:
+ if err != nil {
+ t.Fatalf("Wanted nil err, got: %v", err)
+ }
+ case err == nil:
+ t.Fatalf("Want error: %v got: %v", tc.wantErr, err)
+ default:
+ if tc.wantErr.Error() != err.Error() {
+ t.Fatalf("Want error: %v got error: %v", tc.wantErr, err)
+ }
+ }
+
+ if len(threads) != tc.wantCount {
+ t.Fatalf("Want count: %d got: %d", tc.wantCount, len(threads))
+ }
+ })
+ }
+}
+
+func TestReverseSmoke(t *testing.T) {
+ data, err := ioutil.ReadFile(allPossibleCPUs)
+ if err != nil {
+ t.Fatalf("Failed to read from possible: %v", err)
+ }
+ threads, err := getThreadsFromPossible(data)
+ if err != nil {
+ t.Fatalf("Could not parse possible output: %v", err)
+ }
+
+ if len(threads) <= 0 {
+ t.Fatalf("Didn't get any CPU cores: %d", len(threads))
+ }
+}
diff --git a/runsc/mitigate/mitigate.go b/runsc/mitigate/mitigate.go
index 3ea58454f..91de623e3 100644
--- a/runsc/mitigate/mitigate.go
+++ b/runsc/mitigate/mitigate.go
@@ -21,15 +21,23 @@ package mitigate
import (
"fmt"
+ "io/ioutil"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/flag"
)
+const (
+ cpuInfo = "/proc/cpuinfo"
+ allPossibleCPUs = "/sys/devices/system/cpu/possible"
+)
+
// Mitigate handles high level mitigate operations provided to runsc.
type Mitigate struct {
- dryRun bool // Run the command without changing the underlying system.
- other mitigate // Struct holds extra mitigate logic.
+ dryRun bool // Run the command without changing the underlying system.
+ reverse bool // Reverse mitigate by turning on all CPU cores.
+ other mitigate // Struct holds extra mitigate logic.
+ path string // path to read for each operation (e.g. /proc/cpuinfo).
}
// Usage implments Usage for cmd.Mitigate.
@@ -37,6 +45,8 @@ func (m Mitigate) Usage() string {
usageString := `mitigate [flags]
Mitigate mitigates a system to the "MDS" vulnerability by implementing a manual shutdown of SMT. The command checks /proc/cpuinfo for cpus having the MDS vulnerability, and if found, shutdown all but one CPU per hyperthread pair via /sys/devices/system/cpu/cpu{N}/online. CPUs can be restored by writing "2" to each file in /sys/devices/system/cpu/cpu{N}/online or performing a system reboot.
+
+The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online.
`
return usageString + m.other.usage()
}
@@ -44,31 +54,81 @@ Mitigate mitigates a system to the "MDS" vulnerability by implementing a manual
// SetFlags sets flags for the command Mitigate.
func (m Mitigate) SetFlags(f *flag.FlagSet) {
f.BoolVar(&m.dryRun, "dryrun", false, "run the command without changing system")
+ f.BoolVar(&m.reverse, "reverse", false, "reverse mitigate by enabling all CPUs")
m.other.setFlags(f)
+ m.path = cpuInfo
+ if m.reverse {
+ m.path = allPossibleCPUs
+ }
}
// Execute executes the Mitigate command.
-func (m Mitigate) Execute(data []byte) error {
+func (m Mitigate) Execute() error {
+ data, err := ioutil.ReadFile(m.path)
+ if err != nil {
+ return fmt.Errorf("failed to read %s: %v", m.path, err)
+ }
+
+ if m.reverse {
+ err := m.doReverse(data)
+ if err != nil {
+ return fmt.Errorf("reverse operation failed: %v", err)
+ }
+ return nil
+ }
+
+ set, err := m.doMitigate(data)
+ if err != nil {
+ return fmt.Errorf("mitigate operation failed: %v", err)
+ }
+ return m.other.execute(set, m.dryRun)
+}
+
+func (m Mitigate) doMitigate(data []byte) (cpuSet, error) {
set, err := newCPUSet(data, m.other.vulnerable)
if err != nil {
- return err
+ return nil, err
}
log.Infof("Mitigate found the following CPUs...")
log.Infof("%s", set)
- shutdownList := set.getShutdownList()
- log.Infof("Shutting down threads on thread pairs.")
- for _, t := range shutdownList {
- log.Infof("Shutting down thread: %s", t)
+ disableList := set.getShutdownList()
+ log.Infof("Disabling threads on thread pairs.")
+ for _, t := range disableList {
+ log.Infof("Disable thread: %s", t)
if m.dryRun {
continue
}
- if err := t.shutdown(); err != nil {
- return fmt.Errorf("error shutting down thread: %s err: %v", t, err)
+ if err := t.disable(); err != nil {
+ return nil, fmt.Errorf("error disabling thread: %s err: %v", t, err)
}
}
log.Infof("Shutdown successful.")
- m.other.execute(set, m.dryRun)
+ return set, nil
+}
+
+func (m Mitigate) doReverse(data []byte) error {
+ set, err := newCPUSetFromPossible(data)
+ if err != nil {
+ return err
+ }
+
+ log.Infof("Reverse mitigate found the following CPUs...")
+ log.Infof("%s", set)
+
+ enableList := set.getRemainingList()
+
+ log.Infof("Enabling all CPUs...")
+ for _, t := range enableList {
+ log.Infof("Enabling thread: %s", t)
+ if m.dryRun {
+ continue
+ }
+ if err := t.enable(); err != nil {
+ return fmt.Errorf("error enabling thread: %s err: %v", t, err)
+ }
+ }
+ log.Infof("Enable successful.")
return nil
}
diff --git a/runsc/mitigate/mitigate_conf.go b/runsc/mitigate/mitigate_conf.go
index 1e74f5891..ee326324b 100644
--- a/runsc/mitigate/mitigate_conf.go
+++ b/runsc/mitigate/mitigate_conf.go
@@ -32,6 +32,6 @@ func (m mitigate) execute(set cpuSet, dryrun bool) error {
return nil
}
-func (m mitigate) vulnerable(other *thread) bool {
+func (m mitigate) vulnerable(other thread) bool {
return other.isVulnerable()
}
diff --git a/runsc/mitigate/mitigate_test.go b/runsc/mitigate/mitigate_test.go
index c6c825b72..b3a9a9b18 100644
--- a/runsc/mitigate/mitigate_test.go
+++ b/runsc/mitigate/mitigate_test.go
@@ -13,3 +13,142 @@
// limitations under the License.
package mitigate
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+ "testing"
+)
+
+type executeTestCase struct {
+ name string
+ mitigateData string
+ mitigateError error
+ reverseData string
+ reverseError error
+}
+
+func TestExecute(t *testing.T) {
+
+ partial := `processor : 1
+vendor_id : AuthenticAMD
+cpu family : 23
+model : 49
+model name : AMD EPYC 7B12
+physical id : 0
+bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass
+power management:
+`
+
+ for _, tc := range []executeTestCase{
+ {
+ name: "CascadeLake4",
+ mitigateData: cascadeLake4.makeCPUString(),
+ reverseData: cascadeLake4.makeSysPossibleString(),
+ },
+ {
+ name: "Empty",
+ mitigateData: "",
+ mitigateError: fmt.Errorf(`mitigate operation failed: no cpus found for: ""`),
+ reverseData: "",
+ reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from %s: ""`, allPossibleCPUs),
+ },
+ {
+ name: "Partial",
+ mitigateData: `processor : 0
+vendor_id : AuthenticAMD
+cpu family : 23
+model : 49
+model name : AMD EPYC 7B12
+physical id : 0
+core id : 0
+cpu cores : 1
+bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass
+power management:
+
+` + partial,
+ mitigateError: fmt.Errorf(`mitigate operation failed: failed to match key "core id": %q`, partial),
+ reverseData: "1-",
+ reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from %s: %q`, allPossibleCPUs, "1-"),
+ },
+ } {
+ doExecuteTest(t, Mitigate{}, tc)
+ }
+}
+
+func TestExecuteSmoke(t *testing.T) {
+ smokeMitigate, err := ioutil.ReadFile(cpuInfo)
+ if err != nil {
+ t.Fatalf("Failed to read %s: %v", cpuInfo, err)
+ }
+ smokeReverse, err := ioutil.ReadFile(allPossibleCPUs)
+ if err != nil {
+ t.Fatalf("Failed to read %s: %v", allPossibleCPUs, err)
+ }
+ doExecuteTest(t, Mitigate{}, executeTestCase{
+ name: "SmokeTest",
+ mitigateData: string(smokeMitigate),
+ reverseData: string(smokeReverse),
+ })
+
+}
+
+// doExecuteTest runs Execute with the mitigate operation and reverse operation.
+func doExecuteTest(t *testing.T, m Mitigate, tc executeTestCase) {
+ t.Run("Mitigate"+tc.name, func(t *testing.T) {
+ m.dryRun = true
+ file, err := ioutil.TempFile("", "outfile.txt")
+ if err != nil {
+ t.Fatalf("Failed to create tmpfile: %v", err)
+ }
+ defer os.Remove(file.Name())
+
+ if _, err := file.WriteString(tc.mitigateData); err != nil {
+ t.Fatalf("Failed to write to file: %v", err)
+ }
+
+ m.path = file.Name()
+
+ got := m.Execute()
+ if err = checkErr(tc.mitigateError, got); err != nil {
+ t.Fatalf("Mitigate error mismatch: %v", err)
+ }
+ })
+ t.Run("Reverse"+tc.name, func(t *testing.T) {
+ m.dryRun = true
+ m.reverse = true
+
+ file, err := ioutil.TempFile("", "outfile.txt")
+ if err != nil {
+ t.Fatalf("Failed to create tmpfile: %v", err)
+ }
+ defer os.Remove(file.Name())
+
+ if _, err := file.WriteString(tc.reverseData); err != nil {
+ t.Fatalf("Failed to write to file: %v", err)
+ }
+
+ m.path = file.Name()
+ got := m.Execute()
+ if err = checkErr(tc.reverseError, got); err != nil {
+ t.Fatalf("Mitigate error mismatch: %v", err)
+ }
+ })
+
+}
+
+// checkErr checks error for equality.
+func checkErr(want, got error) error {
+ switch {
+ case want == nil && got == nil:
+ case want != nil && got == nil:
+ fallthrough
+ case want == nil && got != nil:
+ fallthrough
+ case want.Error() != strings.Trim(got.Error(), " "):
+ return fmt.Errorf("got: %v want: %v", got, want)
+ }
+ return nil
+}