diff options
Diffstat (limited to 'test')
200 files changed, 12766 insertions, 1948 deletions
diff --git a/test/cmd/test_app/BUILD b/test/cmd/test_app/BUILD new file mode 100644 index 000000000..98ba5a3d9 --- /dev/null +++ b/test/cmd/test_app/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "test_app", + testonly = 1, + srcs = [ + "fds.go", + "test_app.go", + ], + pure = True, + visibility = ["//runsc/container:__pkg__"], + deps = [ + "//pkg/test/testutil", + "//pkg/unet", + "//runsc/flag", + "@com_github_google_subcommands//:go_default_library", + "@com_github_kr_pty//:go_default_library", + ], +) diff --git a/test/cmd/test_app/fds.go b/test/cmd/test_app/fds.go new file mode 100644 index 000000000..a7658eefd --- /dev/null +++ b/test/cmd/test_app/fds.go @@ -0,0 +1,185 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "io/ioutil" + "log" + "os" + "time" + + "github.com/google/subcommands" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/pkg/unet" + "gvisor.dev/gvisor/runsc/flag" +) + +const fileContents = "foobarbaz" + +// fdSender will open a file and send the FD over a unix domain socket. +type fdSender struct { + socketPath string +} + +// Name implements subcommands.Command.Name. +func (*fdSender) Name() string { + return "fd_sender" +} + +// Synopsis implements subcommands.Command.Synopsys. +func (*fdSender) Synopsis() string { + return "creates a file and sends the FD over the socket" +} + +// Usage implements subcommands.Command.Usage. +func (*fdSender) Usage() string { + return "fd_sender <flags>" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (fds *fdSender) SetFlags(f *flag.FlagSet) { + f.StringVar(&fds.socketPath, "socket", "", "path to socket") +} + +// Execute implements subcommands.Command.Execute. +func (fds *fdSender) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if fds.socketPath == "" { + log.Fatalf("socket flag must be set") + } + + dir, err := ioutil.TempDir("", "") + if err != nil { + log.Fatalf("TempDir failed: %v", err) + } + + fileToSend, err := ioutil.TempFile(dir, "") + if err != nil { + log.Fatalf("TempFile failed: %v", err) + } + defer fileToSend.Close() + + if _, err := fileToSend.WriteString(fileContents); err != nil { + log.Fatalf("Write(%q) failed: %v", fileContents, err) + } + + // Receiver may not be started yet, so try connecting in a poll loop. + var s *unet.Socket + if err := testutil.Poll(func() error { + var err error + s, err = unet.Connect(fds.socketPath, true /* SEQPACKET, so we can send empty message with FD */) + return err + }, 10*time.Second); err != nil { + log.Fatalf("Error connecting to socket %q: %v", fds.socketPath, err) + } + defer s.Close() + + w := s.Writer(true) + w.ControlMessage.PackFDs(int(fileToSend.Fd())) + if _, err := w.WriteVec([][]byte{[]byte{'a'}}); err != nil { + log.Fatalf("Error sending FD %q over socket %q: %v", fileToSend.Fd(), fds.socketPath, err) + } + + log.Print("FD SENDER exiting successfully") + return subcommands.ExitSuccess +} + +// fdReceiver receives an FD from a unix domain socket and does things to it. +type fdReceiver struct { + socketPath string +} + +// Name implements subcommands.Command.Name. +func (*fdReceiver) Name() string { + return "fd_receiver" +} + +// Synopsis implements subcommands.Command.Synopsys. +func (*fdReceiver) Synopsis() string { + return "reads an FD from a unix socket, and then does things to it" +} + +// Usage implements subcommands.Command.Usage. +func (*fdReceiver) Usage() string { + return "fd_receiver <flags>" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (fdr *fdReceiver) SetFlags(f *flag.FlagSet) { + f.StringVar(&fdr.socketPath, "socket", "", "path to socket") +} + +// Execute implements subcommands.Command.Execute. +func (fdr *fdReceiver) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if fdr.socketPath == "" { + log.Fatalf("Flags cannot be empty, given: socket: %q", fdr.socketPath) + } + + ss, err := unet.BindAndListen(fdr.socketPath, true /* packet */) + if err != nil { + log.Fatalf("BindAndListen(%q) failed: %v", fdr.socketPath, err) + } + defer ss.Close() + + var s *unet.Socket + c := make(chan error, 1) + go func() { + var err error + s, err = ss.Accept() + c <- err + }() + + select { + case err := <-c: + if err != nil { + log.Fatalf("Accept() failed: %v", err) + } + case <-time.After(10 * time.Second): + log.Fatalf("Timeout waiting for accept") + } + + r := s.Reader(true) + r.EnableFDs(1) + b := [][]byte{{'a'}} + if n, err := r.ReadVec(b); n != 1 || err != nil { + log.Fatalf("ReadVec got n=%d err %v (wanted 0, nil)", n, err) + } + + fds, err := r.ExtractFDs() + if err != nil { + log.Fatalf("ExtractFD() got err %v", err) + } + if len(fds) != 1 { + log.Fatalf("ExtractFD() got %d FDs, wanted 1", len(fds)) + } + fd := fds[0] + + file := os.NewFile(uintptr(fd), "received file") + defer file.Close() + if _, err := file.Seek(0, os.SEEK_SET); err != nil { + log.Fatalf("Seek(0, 0) failed: %v", err) + } + + got, err := ioutil.ReadAll(file) + if err != nil { + log.Fatalf("ReadAll failed: %v", err) + } + if string(got) != fileContents { + log.Fatalf("ReadAll got %q want %q", string(got), fileContents) + } + + log.Print("FD RECEIVER exiting successfully") + return subcommands.ExitSuccess +} diff --git a/test/cmd/test_app/test_app.go b/test/cmd/test_app/test_app.go new file mode 100644 index 000000000..3ba4f38f8 --- /dev/null +++ b/test/cmd/test_app/test_app.go @@ -0,0 +1,394 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary test_app is like a swiss knife for tests that need to run anything +// inside the sandbox. New functionality can be added with new commands. +package main + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "regexp" + "strconv" + sys "syscall" + "time" + + "github.com/google/subcommands" + "github.com/kr/pty" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/runsc/flag" +) + +func main() { + subcommands.Register(subcommands.HelpCommand(), "") + subcommands.Register(subcommands.FlagsCommand(), "") + subcommands.Register(new(capability), "") + subcommands.Register(new(fdReceiver), "") + subcommands.Register(new(fdSender), "") + subcommands.Register(new(forkBomb), "") + subcommands.Register(new(ptyRunner), "") + subcommands.Register(new(reaper), "") + subcommands.Register(new(syscall), "") + subcommands.Register(new(taskTree), "") + subcommands.Register(new(uds), "") + + flag.Parse() + + exitCode := subcommands.Execute(context.Background()) + os.Exit(int(exitCode)) +} + +type uds struct { + fileName string + socketPath string +} + +// Name implements subcommands.Command.Name. +func (*uds) Name() string { + return "uds" +} + +// Synopsis implements subcommands.Command.Synopsys. +func (*uds) Synopsis() string { + return "creates unix domain socket client and server. Client sends a contant flow of sequential numbers. Server prints them to --file" +} + +// Usage implements subcommands.Command.Usage. +func (*uds) Usage() string { + return "uds <flags>" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (c *uds) SetFlags(f *flag.FlagSet) { + f.StringVar(&c.fileName, "file", "", "name of output file") + f.StringVar(&c.socketPath, "socket", "", "path to socket") +} + +// Execute implements subcommands.Command.Execute. +func (c *uds) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if c.fileName == "" || c.socketPath == "" { + log.Fatalf("Flags cannot be empty, given: fileName: %q, socketPath: %q", c.fileName, c.socketPath) + return subcommands.ExitFailure + } + outputFile, err := os.OpenFile(c.fileName, os.O_WRONLY|os.O_CREATE, 0666) + if err != nil { + log.Fatal("error opening output file:", err) + } + + defer os.Remove(c.socketPath) + + listener, err := net.Listen("unix", c.socketPath) + if err != nil { + log.Fatalf("error listening on socket %q: %v", c.socketPath, err) + } + + go server(listener, outputFile) + for i := 0; ; i++ { + conn, err := net.Dial("unix", c.socketPath) + if err != nil { + log.Fatal("error dialing:", err) + } + if _, err := conn.Write([]byte(strconv.Itoa(i))); err != nil { + log.Fatal("error writing:", err) + } + conn.Close() + time.Sleep(100 * time.Millisecond) + } +} + +func server(listener net.Listener, out *os.File) { + buf := make([]byte, 16) + + for { + c, err := listener.Accept() + if err != nil { + log.Fatal("error accepting connection:", err) + } + nr, err := c.Read(buf) + if err != nil { + log.Fatal("error reading from buf:", err) + } + data := buf[0:nr] + fmt.Fprint(out, string(data)+"\n") + } +} + +type taskTree struct { + depth int + width int + pause bool +} + +// Name implements subcommands.Command. +func (*taskTree) Name() string { + return "task-tree" +} + +// Synopsis implements subcommands.Command. +func (*taskTree) Synopsis() string { + return "creates a tree of tasks" +} + +// Usage implements subcommands.Command. +func (*taskTree) Usage() string { + return "task-tree <flags>" +} + +// SetFlags implements subcommands.Command. +func (c *taskTree) SetFlags(f *flag.FlagSet) { + f.IntVar(&c.depth, "depth", 1, "number of levels to create") + f.IntVar(&c.width, "width", 1, "number of tasks at each level") + f.BoolVar(&c.pause, "pause", false, "whether the tasks should pause perpetually") +} + +// Execute implements subcommands.Command. +func (c *taskTree) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + stop := testutil.StartReaper() + defer stop() + + if c.depth == 0 { + log.Printf("Child sleeping, PID: %d\n", os.Getpid()) + select {} + } + log.Printf("Parent %d sleeping, PID: %d\n", c.depth, os.Getpid()) + + var cmds []*exec.Cmd + for i := 0; i < c.width; i++ { + cmd := exec.Command( + "/proc/self/exe", c.Name(), + "--depth", strconv.Itoa(c.depth-1), + "--width", strconv.Itoa(c.width), + "--pause", strconv.FormatBool(c.pause)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + log.Fatal("failed to call self:", err) + } + cmds = append(cmds, cmd) + } + + for _, c := range cmds { + c.Wait() + } + + if c.pause { + select {} + } + + return subcommands.ExitSuccess +} + +type forkBomb struct { + delay time.Duration +} + +// Name implements subcommands.Command. +func (*forkBomb) Name() string { + return "fork-bomb" +} + +// Synopsis implements subcommands.Command. +func (*forkBomb) Synopsis() string { + return "creates child process until the end of times" +} + +// Usage implements subcommands.Command. +func (*forkBomb) Usage() string { + return "fork-bomb <flags>" +} + +// SetFlags implements subcommands.Command. +func (c *forkBomb) SetFlags(f *flag.FlagSet) { + f.DurationVar(&c.delay, "delay", 100*time.Millisecond, "amount of time to delay creation of child") +} + +// Execute implements subcommands.Command. +func (c *forkBomb) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + time.Sleep(c.delay) + + cmd := exec.Command("/proc/self/exe", c.Name()) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + log.Fatal("failed to call self:", err) + } + return subcommands.ExitSuccess +} + +type reaper struct{} + +// Name implements subcommands.Command. +func (*reaper) Name() string { + return "reaper" +} + +// Synopsis implements subcommands.Command. +func (*reaper) Synopsis() string { + return "reaps all children in a loop" +} + +// Usage implements subcommands.Command. +func (*reaper) Usage() string { + return "reaper <flags>" +} + +// SetFlags implements subcommands.Command. +func (*reaper) SetFlags(*flag.FlagSet) {} + +// Execute implements subcommands.Command. +func (c *reaper) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + stop := testutil.StartReaper() + defer stop() + select {} +} + +type syscall struct { + sysno uint64 +} + +// Name implements subcommands.Command. +func (*syscall) Name() string { + return "syscall" +} + +// Synopsis implements subcommands.Command. +func (*syscall) Synopsis() string { + return "syscall makes a syscall" +} + +// Usage implements subcommands.Command. +func (*syscall) Usage() string { + return "syscall <flags>" +} + +// SetFlags implements subcommands.Command. +func (s *syscall) SetFlags(f *flag.FlagSet) { + f.Uint64Var(&s.sysno, "syscall", 0, "syscall to call") +} + +// Execute implements subcommands.Command. +func (s *syscall) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if _, _, errno := sys.Syscall(uintptr(s.sysno), 0, 0, 0); errno != 0 { + fmt.Printf("syscall(%d, 0, 0...) failed: %v\n", s.sysno, errno) + } else { + fmt.Printf("syscall(%d, 0, 0...) success\n", s.sysno) + } + return subcommands.ExitSuccess +} + +type capability struct { + enabled uint64 + disabled uint64 +} + +// Name implements subcommands.Command. +func (*capability) Name() string { + return "capability" +} + +// Synopsis implements subcommands.Command. +func (*capability) Synopsis() string { + return "checks if effective capabilities are set/unset" +} + +// Usage implements subcommands.Command. +func (*capability) Usage() string { + return "capability [--enabled=number] [--disabled=number]" +} + +// SetFlags implements subcommands.Command. +func (c *capability) SetFlags(f *flag.FlagSet) { + f.Uint64Var(&c.enabled, "enabled", 0, "") + f.Uint64Var(&c.disabled, "disabled", 0, "") +} + +// Execute implements subcommands.Command. +func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if c.enabled == 0 && c.disabled == 0 { + fmt.Println("One of the flags must be set") + return subcommands.ExitUsageError + } + + status, err := ioutil.ReadFile("/proc/self/status") + if err != nil { + fmt.Printf("Error reading %q: %v\n", "proc/self/status", err) + return subcommands.ExitFailure + } + re := regexp.MustCompile("CapEff:\t([0-9a-f]+)\n") + matches := re.FindStringSubmatch(string(status)) + if matches == nil || len(matches) != 2 { + fmt.Printf("Effective capabilities not found in\n%s\n", status) + return subcommands.ExitFailure + } + caps, err := strconv.ParseUint(matches[1], 16, 64) + if err != nil { + fmt.Printf("failed to convert capabilities %q: %v\n", matches[1], err) + return subcommands.ExitFailure + } + + if c.enabled != 0 && (caps&c.enabled) != c.enabled { + fmt.Printf("Missing capabilities, want: %#x: got: %#x\n", c.enabled, caps) + return subcommands.ExitFailure + } + if c.disabled != 0 && (caps&c.disabled) != 0 { + fmt.Printf("Extra capabilities found, dont_want: %#x: got: %#x\n", c.disabled, caps) + return subcommands.ExitFailure + } + + return subcommands.ExitSuccess +} + +type ptyRunner struct{} + +// Name implements subcommands.Command. +func (*ptyRunner) Name() string { + return "pty-runner" +} + +// Synopsis implements subcommands.Command. +func (*ptyRunner) Synopsis() string { + return "runs the given command with an open pty terminal" +} + +// Usage implements subcommands.Command. +func (*ptyRunner) Usage() string { + return "pty-runner [command]" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (*ptyRunner) SetFlags(f *flag.FlagSet) {} + +// Execute implements subcommands.Command. +func (*ptyRunner) Execute(_ context.Context, fs *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + c := exec.Command(fs.Args()[0], fs.Args()[1:]...) + f, err := pty.Start(c) + if err != nil { + fmt.Printf("pty.Start failed: %v", err) + return subcommands.ExitFailure + } + defer f.Close() + + // Copy stdout from the command to keep this process alive until the + // subprocess exits. + io.Copy(os.Stdout, f) + + return subcommands.ExitSuccess +} diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 76e04f878..44cce0e3b 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -20,9 +20,9 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/bits", - "//runsc/dockerutil", + "//pkg/test/dockerutil", + "//pkg/test/testutil", "//runsc/specutils", - "//runsc/testutil", ], ) diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go index 4074d2285..6a63b1232 100644 --- a/test/e2e/exec_test.go +++ b/test/e2e/exec_test.go @@ -23,6 +23,8 @@ package integration import ( "fmt" + "os" + "os/exec" "strconv" "strings" "syscall" @@ -31,23 +33,23 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" - "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/runsc/specutils" ) // Test that exec uses the exact same capability set as the container. func TestExecCapabilities(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-capabilities-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Start the container. - if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() + // Check that capability. matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) if err != nil { t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) @@ -59,7 +61,7 @@ func TestExecCapabilities(t *testing.T) { t.Log("Root capabilities:", want) // Now check that exec'd process capabilities match the root. - got, err := d.Exec("grep", "CapEff:", "/proc/self/status") + got, err := d.Exec(dockerutil.RunOpts{}, "grep", "CapEff:", "/proc/self/status") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -72,16 +74,16 @@ func TestExecCapabilities(t *testing.T) { // Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW // which is removed from the container when --net-raw=false. func TestExecPrivileged(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-privileged-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Start the container with all capabilities dropped. - if err := d.Run("--cap-drop=all", "alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + CapDrop: []string{"all"}, + }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Check that all capabilities where dropped from container. matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) @@ -100,9 +102,11 @@ func TestExecPrivileged(t *testing.T) { t.Fatalf("Container should have no capabilities: %x", containerCaps) } - // Check that 'exec --privileged' adds all capabilities, except - // for CAP_NET_RAW. - got, err := d.ExecWithFlags([]string{"--privileged"}, "grep", "CapEff:", "/proc/self/status") + // Check that 'exec --privileged' adds all capabilities, except for + // CAP_NET_RAW. + got, err := d.Exec(dockerutil.RunOpts{ + Privileged: true, + }, "grep", "CapEff:", "/proc/self/status") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -114,97 +118,99 @@ func TestExecPrivileged(t *testing.T) { } func TestExecJobControl(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-job-control-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Start the container. - if err := d.Run("alpine", "sleep", "1000"); err != nil { + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Exec 'sh' with an attached pty. - cmd, ptmx, err := d.ExecWithTerminal("sh") - if err != nil { + if _, err := d.Exec(dockerutil.RunOpts{ + Pty: func(cmd *exec.Cmd, ptmx *os.File) { + // Call "sleep 100 | cat" in the shell. We pipe to cat + // so that there will be two processes in the + // foreground process group. + if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil { + t.Fatalf("error writing to pty: %v", err) + } + + // Give shell a few seconds to start executing the sleep. + time.Sleep(2 * time.Second) + + // Send a ^C to the pty, which should kill sleep and + // cat, but not the shell. \x03 is ASCII "end of + // text", which is the same as ^C. + if _, err := ptmx.Write([]byte{'\x03'}); err != nil { + t.Fatalf("error writing to pty: %v", err) + } + + // The shell should still be alive at this point. Sleep + // should have exited with code 2+128=130. We'll exit + // with 10 plus that number, so that we can be sure + // that the shell did not get signalled. + if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { + t.Fatalf("error writing to pty: %v", err) + } + + // Exec process should exit with code 10+130=140. + ps, err := cmd.Process.Wait() + if err != nil { + t.Fatalf("error waiting for exec process: %v", err) + } + ws := ps.Sys().(syscall.WaitStatus) + if !ws.Exited() { + t.Errorf("ws.Exited got false, want true") + } + if got, want := ws.ExitStatus(), 140; got != want { + t.Errorf("ws.ExitedStatus got %d, want %d", got, want) + } + }, + }, "sh"); err != nil { t.Fatalf("docker exec failed: %v", err) } - defer ptmx.Close() - - // Call "sleep 100 | cat" in the shell. We pipe to cat so that there - // will be two processes in the foreground process group. - if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Give shell a few seconds to start executing the sleep. - time.Sleep(2 * time.Second) - - // Send a ^C to the pty, which should kill sleep and cat, but not the - // shell. \x03 is ASCII "end of text", which is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // The shell should still be alive at this point. Sleep should have - // exited with code 2+128=130. We'll exit with 10 plus that number, so - // that we can be sure that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Exec process should exit with code 10+130=140. - ps, err := cmd.Process.Wait() - if err != nil { - t.Fatalf("error waiting for exec process: %v", err) - } - ws := ps.Sys().(syscall.WaitStatus) - if !ws.Exited() { - t.Errorf("ws.Exited got false, want true") - } - if got, want := ws.ExitStatus(), 140; got != want { - t.Errorf("ws.ExitedStatus got %d, want %d", got, want) - } } // Test that failure to exec returns proper error message. func TestExecError(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-error-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Start the container. - if err := d.Run("alpine", "sleep", "1000"); err != nil { + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() - _, err := d.Exec("no_can_find") + // Attempt to exec a binary that doesn't exist. + out, err := d.Exec(dockerutil.RunOpts{}, "no_can_find") if err == nil { t.Fatalf("docker exec didn't fail") } - if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(err.Error(), want) { - t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want) + if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(out, want) { + t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", out, want) } } // Test that exec inherits environment from run. func TestExecEnv(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-env-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Start the container with env FOO=BAR. - if err := d.Run("-e", "FOO=BAR", "alpine", "sleep", "1000"); err != nil { + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + Env: []string{"FOO=BAR"}, + }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Exec "echo $FOO". - got, err := d.Exec("/bin/sh", "-c", "echo $FOO") + got, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", "echo $FOO") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -216,17 +222,19 @@ func TestExecEnv(t *testing.T) { // TestRunEnvHasHome tests that run always has HOME environment set. func TestRunEnvHasHome(t *testing.T) { // Base alpine image does not have any environment variables set. - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("run-env-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Exec "echo $HOME". The 'bin' user's home dir is '/bin'. - got, err := d.RunFg("--user", "bin", "alpine", "/bin/sh", "-c", "echo $HOME") + got, err := d.Run(dockerutil.RunOpts{ + Image: "basic/alpine", + User: "bin", + }, "/bin/sh", "-c", "echo $HOME") if err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() + + // Check that the directory matches. if got, want := strings.TrimSpace(got), "/bin"; got != want { t.Errorf("bad output from 'docker run'. Got %q; Want %q.", got, want) } @@ -235,28 +243,17 @@ func TestRunEnvHasHome(t *testing.T) { // Test that exec always has HOME environment set, even when not set in run. func TestExecEnvHasHome(t *testing.T) { // Base alpine image does not have any environment variables set. - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("exec-env-home-test") - - // We will check that HOME is set for root user, and also for a new - // non-root user we will create. - newUID := 1234 - newHome := "/foo/bar" + d := dockerutil.MakeDocker(t) + defer d.CleanUp() - // Create a new user with a home directory, and then sleep. - script := fmt.Sprintf(` - mkdir -p -m 777 %s && \ - adduser foo -D -u %d -h %s && \ - sleep 1000`, newHome, newUID, newHome) - if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil { + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Exec "echo $HOME", and expect to see "/root". - got, err := d.Exec("/bin/sh", "-c", "echo $HOME") + got, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", "echo $HOME") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -264,8 +261,18 @@ func TestExecEnvHasHome(t *testing.T) { t.Errorf("wanted exec output to contain %q, got %q", want, got) } - // Execute the same as uid 123 and expect newHome. - got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME") + // Create a new user with a home directory. + newUID := 1234 + newHome := "/foo/bar" + cmd := fmt.Sprintf("mkdir -p -m 777 %q && adduser foo -D -u %d -h %q", newHome, newUID, newHome) + if _, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", cmd); err != nil { + t.Fatalf("docker exec failed: %v", err) + } + + // Execute the same as the new user and expect newHome. + got, err = d.Exec(dockerutil.RunOpts{ + User: strconv.Itoa(newUID), + }, "/bin/sh", "-c", "echo $HOME") if err != nil { t.Fatalf("docker exec failed: %v", err) } diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 28064e557..404e37689 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -27,14 +27,15 @@ import ( "net" "net/http" "os" + "os/exec" "strconv" "strings" "syscall" "testing" "time" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) // httpRequestSucceeds sends a request to a given url and checks that the status is OK. @@ -53,65 +54,66 @@ func httpRequestSucceeds(client http.Client, server string, port int) error { // TestLifeCycle tests a basic Create/Start/Stop docker container life cycle. func TestLifeCycle(t *testing.T) { - if err := dockerutil.Pull("nginx"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("lifecycle-test") - if err := d.Create("-p", "80", "nginx"); err != nil { - t.Fatal("docker create failed:", err) + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + + // Start the container. + if err := d.Create(dockerutil.RunOpts{ + Image: "basic/nginx", + Ports: []int{80}, + }); err != nil { + t.Fatalf("docker create failed: %v", err) } if err := d.Start(); err != nil { - d.CleanUp() - t.Fatal("docker start failed:", err) + t.Fatalf("docker start failed: %v", err) } - // Test that container is working + // Test that container is working. port, err := d.FindPort(80) if err != nil { - t.Fatal("docker.FindPort(80) failed: ", err) + t.Fatalf("docker.FindPort(80) failed: %v", err) } if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) + t.Fatalf("WaitForHTTP() timeout: %v", err) } client := http.Client{Timeout: time.Duration(2 * time.Second)} if err := httpRequestSucceeds(client, "localhost", port); err != nil { - t.Error("http request failed:", err) + t.Errorf("http request failed: %v", err) } if err := d.Stop(); err != nil { - d.CleanUp() - t.Fatal("docker stop failed:", err) + t.Fatalf("docker stop failed: %v", err) } if err := d.Remove(); err != nil { - t.Fatal("docker rm failed:", err) + t.Fatalf("docker rm failed: %v", err) } } func TestPauseResume(t *testing.T) { - const img = "gcr.io/gvisor-presubmit/python-hello" if !testutil.IsCheckpointSupported() { - t.Log("Checkpoint is not supported, skipping test.") - return + t.Skip("Checkpoint is not supported.") } - if err := dockerutil.Pull(img); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("pause-resume-test") - if err := d.Run("-p", "8080", img); err != nil { + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + + // Start the container. + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/python", + Ports: []int{8080}, // See Dockerfile. + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 8080 is mapped to. port, err := d.FindPort(8080) if err != nil { - t.Fatal("docker.FindPort(8080) failed:", err) + t.Fatalf("docker.FindPort(8080) failed: %v", err) } // Wait until it's up and running. if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) + t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check that container is working. @@ -121,7 +123,7 @@ func TestPauseResume(t *testing.T) { } if err := d.Pause(); err != nil { - t.Fatal("docker pause failed:", err) + t.Fatalf("docker pause failed: %v", err) } // Check if container is paused. @@ -137,12 +139,12 @@ func TestPauseResume(t *testing.T) { } if err := d.Unpause(); err != nil { - t.Fatal("docker unpause failed:", err) + t.Fatalf("docker unpause failed: %v", err) } // Wait until it's up and running. if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) + t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. @@ -152,45 +154,43 @@ func TestPauseResume(t *testing.T) { } func TestCheckpointRestore(t *testing.T) { - const img = "gcr.io/gvisor-presubmit/python-hello" if !testutil.IsCheckpointSupported() { - t.Log("Pause/resume is not supported, skipping test.") - return + t.Skip("Pause/resume is not supported.") } - if err := dockerutil.Pull(img); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("save-restore-test") - if err := d.Run("-p", "8080", img); err != nil { + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + + // Start the container. + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/python", + Ports: []int{8080}, // See Dockerfile. + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() + // Create a snapshot. if err := d.Checkpoint("test"); err != nil { - t.Fatal("docker checkpoint failed:", err) + t.Fatalf("docker checkpoint failed: %v", err) } - if _, err := d.Wait(30 * time.Second); err != nil { - t.Fatal(err) + t.Fatalf("wait failed: %v", err) } - // TODO(b/143498576): Remove after github.com/moby/moby/issues/38963 is fixed. - time.Sleep(1 * time.Second) - - if err := d.Restore("test"); err != nil { - t.Fatal("docker restore failed:", err) + // TODO(b/143498576): Remove Poll after github.com/moby/moby/issues/38963 is fixed. + if err := testutil.Poll(func() error { return d.Restore("test") }, 15*time.Second); err != nil { + t.Fatalf("docker restore failed: %v", err) } // Find where port 8080 is mapped to. port, err := d.FindPort(8080) if err != nil { - t.Fatal("docker.FindPort(8080) failed:", err) + t.Fatalf("docker.FindPort(8080) failed: %v", err) } // Wait until it's up and running. if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatal("WaitForHTTP() timeout:", err) + t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. @@ -202,26 +202,28 @@ func TestCheckpointRestore(t *testing.T) { // Create client and server that talk to each other using the local IP. func TestConnectToSelf(t *testing.T) { - d := dockerutil.MakeDocker("connect-to-self-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Creates server that replies "server" and exists. Sleeps at the end because // 'docker exec' gets killed if the init process exists before it can finish. - if err := d.Run("ubuntu:trusty", "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil { - t.Fatal("docker run failed:", err) + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/ubuntu", + }, "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil { + t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Finds IP address for host. - ip, err := d.Exec("/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'") + ip, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'") if err != nil { - t.Fatal("docker exec failed:", err) + t.Fatalf("docker exec failed: %v", err) } ip = strings.TrimRight(ip, "\n") // Runs client that sends "client" to the server and exits. - reply, err := d.Exec("/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip)) + reply, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip)) if err != nil { - t.Fatal("docker exec failed:", err) + t.Fatalf("docker exec failed: %v", err) } // Ensure both client and server got the message from each other. @@ -229,21 +231,22 @@ func TestConnectToSelf(t *testing.T) { t.Errorf("Error on server, want: %q, got: %q", want, reply) } if _, err := d.WaitForOutput("^client\n$", 1*time.Second); err != nil { - t.Fatal("docker.WaitForOutput(client) timeout:", err) + t.Fatalf("docker.WaitForOutput(client) timeout: %v", err) } } func TestMemLimit(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") - cmd := "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'" - out, err := d.RunFg("--memory=500MB", "alpine", "sh", "-c", cmd) + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + + allocMemory := 500 * 1024 + out, err := d.Run(dockerutil.RunOpts{ + Image: "basic/alpine", + Memory: allocMemory, // In kB. + }, "sh", "-c", "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'") if err != nil { - t.Fatal("docker run failed:", err) + t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Remove warning message that swap isn't present. if strings.HasPrefix(out, "WARNING") { @@ -254,27 +257,30 @@ func TestMemLimit(t *testing.T) { out = lines[1] } + // Ensure the memory matches what we want. got, err := strconv.ParseUint(strings.TrimSpace(out), 10, 64) if err != nil { t.Fatalf("failed to parse %q: %v", out, err) } - if want := uint64(500 * 1024); got != want { + if want := uint64(allocMemory); got != want { t.Errorf("MemTotal got: %d, want: %d", got, want) } } func TestNumCPU(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") - cmd := "cat /proc/cpuinfo | grep 'processor.*:' | wc -l" - out, err := d.RunFg("--cpuset-cpus=0", "alpine", "sh", "-c", cmd) + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + + // Read how many cores are in the container. + out, err := d.Run(dockerutil.RunOpts{ + Image: "basic/alpine", + Extra: []string{"--cpuset-cpus=0"}, + }, "sh", "-c", "cat /proc/cpuinfo | grep 'processor.*:' | wc -l") if err != nil { - t.Fatal("docker run failed:", err) + t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() + // Ensure it matches what we want. got, err := strconv.Atoi(strings.TrimSpace(out)) if err != nil { t.Fatalf("failed to parse %q: %v", out, err) @@ -286,39 +292,39 @@ func TestNumCPU(t *testing.T) { // TestJobControl tests that job control characters are handled properly. func TestJobControl(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("job-control-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Start the container with an attached PTY. - _, ptmx, err := d.RunWithPty("alpine", "sh") - if err != nil { + if _, err := d.Run(dockerutil.RunOpts{ + Image: "basic/alpine", + Pty: func(_ *exec.Cmd, ptmx *os.File) { + // Call "sleep 100" in the shell. + if _, err := ptmx.Write([]byte("sleep 100\n")); err != nil { + t.Fatalf("error writing to pty: %v", err) + } + + // Give shell a few seconds to start executing the sleep. + time.Sleep(2 * time.Second) + + // Send a ^C to the pty, which should kill sleep, but + // not the shell. \x03 is ASCII "end of text", which + // is the same as ^C. + if _, err := ptmx.Write([]byte{'\x03'}); err != nil { + t.Fatalf("error writing to pty: %v", err) + } + + // The shell should still be alive at this point. Sleep + // should have exited with code 2+128=130. We'll exit + // with 10 plus that number, so that we can be sure + // that the shell did not get signalled. + if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { + t.Fatalf("error writing to pty: %v", err) + } + }, + }, "sh"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer ptmx.Close() - defer d.CleanUp() - - // Call "sleep 100" in the shell. - if _, err := ptmx.Write([]byte("sleep 100\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Give shell a few seconds to start executing the sleep. - time.Sleep(2 * time.Second) - - // Send a ^C to the pty, which should kill sleep, but not the shell. - // \x03 is ASCII "end of text", which is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // The shell should still be alive at this point. Sleep should have - // exited with code 2+128=130. We'll exit with 10 plus that number, so - // that we can be sure that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } // Wait for the container to exit. got, err := d.Wait(5 * time.Second) @@ -334,14 +340,25 @@ func TestJobControl(t *testing.T) { // TestTmpFile checks that files inside '/tmp' are not overridden. In addition, // it checks that working dir is created if it doesn't exit. func TestTmpFile(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + + // Should work without ReadOnly + if _, err := d.Run(dockerutil.RunOpts{ + Image: "basic/alpine", + WorkDir: "/tmp/foo/bar", + }, "touch", "/tmp/foo/bar/file"); err != nil { + t.Fatalf("docker run failed: %v", err) } - d := dockerutil.MakeDocker("tmp-file-test") - if err := d.Run("-w=/tmp/foo/bar", "--read-only", "alpine", "touch", "/tmp/foo/bar/file"); err != nil { - t.Fatal("docker run failed:", err) + + // Expect failure. + if _, err := d.Run(dockerutil.RunOpts{ + Image: "basic/alpine", + WorkDir: "/tmp/foo/bar", + ReadOnly: true, + }, "touch", "/tmp/foo/bar/file"); err == nil { + t.Fatalf("docker run expected failure, but succeeded") } - defer d.CleanUp() } func TestMain(m *testing.M) { diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go index 2488be383..327a2174c 100644 --- a/test/e2e/regression_test.go +++ b/test/e2e/regression_test.go @@ -18,7 +18,7 @@ import ( "strings" "testing" - "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" ) // Test that UDS can be created using overlay when parent directory is in lower @@ -27,19 +27,19 @@ import ( // Prerequisite: the directory where the socket file is created must not have // been open for write before bind(2) is called. func TestBindOverlay(t *testing.T) { - if err := dockerutil.Pull("ubuntu:trusty"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("bind-overlay-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() - cmd := "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p" - got, err := d.RunFg("ubuntu:trusty", "bash", "-c", cmd) + // Run the container. + got, err := d.Run(dockerutil.RunOpts{ + Image: "basic/ubuntu", + }, "bash", "-c", "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p") if err != nil { - t.Fatal("docker run failed:", err) + t.Fatalf("docker run failed: %v", err) } + // Check the output contains what we want. if want := "foobar-asdf"; !strings.Contains(got, want) { t.Fatalf("docker run output is missing %q: %s", want, got) } - defer d.CleanUp() } diff --git a/test/image/BUILD b/test/image/BUILD index 7392ac54e..e749e47d4 100644 --- a/test/image/BUILD +++ b/test/image/BUILD @@ -22,8 +22,8 @@ go_test( ], visibility = ["//:sandbox"], deps = [ - "//runsc/dockerutil", - "//runsc/testutil", + "//pkg/test/dockerutil", + "//pkg/test/testutil", ], ) diff --git a/test/image/image_test.go b/test/image/image_test.go index 0a1e19d6f..2e3543109 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -28,24 +28,29 @@ import ( "log" "net/http" "os" - "path/filepath" "strings" "testing" "time" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) func TestHelloWorld(t *testing.T) { - d := dockerutil.MakeDocker("hello-test") - if err := d.Run("hello-world"); err != nil { + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + + // Run the basic container. + out, err := d.Run(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "echo", "Hello world!") + if err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() - if _, err := d.WaitForOutput("Hello from Docker!", 5*time.Second); err != nil { - t.Fatalf("docker didn't say hello: %v", err) + // Check the output. + if !strings.Contains(out, "Hello world!") { + t.Fatalf("docker didn't say hello: got %s", out) } } @@ -102,27 +107,22 @@ func testHTTPServer(t *testing.T, port int) { } func TestHttpd(t *testing.T) { - if err := dockerutil.Pull("httpd"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("http-test") - - dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Start the container. - mountArg := dockerutil.MountArg(dir, "/usr/local/apache2/htdocs", dockerutil.ReadOnly) - if err := d.Run("-p", "80", mountArg, "httpd"); err != nil { + d.CopyFiles("/usr/local/apache2/htdocs", "test/image/latin10k.txt") + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/httpd", + Ports: []int{80}, + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 80 is mapped to. port, err := d.FindPort(80) if err != nil { - t.Fatalf("docker.FindPort(80) failed: %v", err) + t.Fatalf("FindPort(80) failed: %v", err) } // Wait until it's up and running. @@ -134,27 +134,22 @@ func TestHttpd(t *testing.T) { } func TestNginx(t *testing.T) { - if err := dockerutil.Pull("nginx"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("net-test") - - dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Start the container. - mountArg := dockerutil.MountArg(dir, "/usr/share/nginx/html", dockerutil.ReadOnly) - if err := d.Run("-p", "80", mountArg, "nginx"); err != nil { + d.CopyFiles("/usr/share/nginx/html", "test/image/latin10k.txt") + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/nginx", + Ports: []int{80}, + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 80 is mapped to. port, err := d.FindPort(80) if err != nil { - t.Fatalf("docker.FindPort(80) failed: %v", err) + t.Fatalf("FindPort(80) failed: %v", err) } // Wait until it's up and running. @@ -166,99 +161,58 @@ func TestNginx(t *testing.T) { } func TestMysql(t *testing.T) { - if err := dockerutil.Pull("mysql"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("mysql-test") + server := dockerutil.MakeDocker(t) + defer server.CleanUp() // Start the container. - if err := d.Run("-e", "MYSQL_ROOT_PASSWORD=foobar123", "mysql"); err != nil { + if err := server.Spawn(dockerutil.RunOpts{ + Image: "basic/mysql", + Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"}, + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Wait until it's up and running. - if _, err := d.WaitForOutput("port: 3306 MySQL Community Server", 3*time.Minute); err != nil { - t.Fatalf("docker.WaitForOutput() timeout: %v", err) + if _, err := server.WaitForOutput("port: 3306 MySQL Community Server", 3*time.Minute); err != nil { + t.Fatalf("WaitForOutput() timeout: %v", err) } - client := dockerutil.MakeDocker("mysql-client-test") - dir, err := dockerutil.PrepareFiles("test/image/mysql.sql") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } + // Generate the client and copy in the SQL payload. + client := dockerutil.MakeDocker(t) + defer client.CleanUp() - // Tell mysql client to connect to the server and execute the file in verbose - // mode to verify the output. - args := []string{ - dockerutil.LinkArg(&d, "mysql"), - dockerutil.MountArg(dir, "/sql", dockerutil.ReadWrite), - "mysql", - "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql", - } - if err := client.Run(args...); err != nil { + // Tell mysql client to connect to the server and execute the file in + // verbose mode to verify the output. + client.CopyFiles("/sql", "test/image/mysql.sql") + client.Link("mysql", server) + if _, err := client.Run(dockerutil.RunOpts{ + Image: "basic/mysql", + }, "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer client.CleanUp() // Ensure file executed to the end and shutdown mysql. - if _, err := client.WaitForOutput("--------------\nshutdown\n--------------", 15*time.Second); err != nil { - t.Fatalf("docker.WaitForOutput() timeout: %v", err) - } - if _, err := d.WaitForOutput("mysqld: Shutdown complete", 30*time.Second); err != nil { - t.Fatalf("docker.WaitForOutput() timeout: %v", err) + if _, err := server.WaitForOutput("mysqld: Shutdown complete", 30*time.Second); err != nil { + t.Fatalf("WaitForOutput() timeout: %v", err) } } -func TestPythonHello(t *testing.T) { - // TODO(b/136503277): Once we have more complete python runtime tests, - // we can drop this one. - const img = "gcr.io/gvisor-presubmit/python-hello" - if err := dockerutil.Pull(img); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("python-hello-test") - if err := d.Run("-p", "8080", img); err != nil { - t.Fatalf("docker run failed: %v", err) - } +func TestTomcat(t *testing.T) { + d := dockerutil.MakeDocker(t) defer d.CleanUp() - // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) - if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) - } - - // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { - t.Fatalf("WaitForHTTP() timeout: %v", err) - } - - // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) - resp, err := http.Get(url) - if err != nil { - t.Errorf("Error reaching http server: %v", err) - } - if want := http.StatusOK; resp.StatusCode != want { - t.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want) - } -} - -func TestTomcat(t *testing.T) { - if err := dockerutil.Pull("tomcat:8.0"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("tomcat-test") - if err := d.Run("-p", "8080", "tomcat:8.0"); err != nil { + // Start the server. + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/tomcat", + Ports: []int{8080}, + }); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 8080 is mapped to. port, err := d.FindPort(8080) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("FindPort(8080) failed: %v", err) } // Wait until it's up and running. @@ -278,28 +232,22 @@ func TestTomcat(t *testing.T) { } func TestRuby(t *testing.T) { - if err := dockerutil.Pull("ruby"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("ruby-test") - - dir, err := dockerutil.PrepareFiles("test/image/ruby.rb", "test/image/ruby.sh") - if err != nil { - t.Fatalf("PrepareFiles() failed: %v", err) - } - if err := os.Chmod(filepath.Join(dir, "ruby.sh"), 0333); err != nil { - t.Fatalf("os.Chmod(%q, 0333) failed: %v", dir, err) - } + d := dockerutil.MakeDocker(t) + defer d.CleanUp() - if err := d.Run("-p", "8080", dockerutil.MountArg(dir, "/src", dockerutil.ReadOnly), "ruby", "/src/ruby.sh"); err != nil { + // Execute the ruby workload. + d.CopyFiles("/src", "test/image/ruby.rb", "test/image/ruby.sh") + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/ruby", + Ports: []int{8080}, + }, "/src/ruby.sh"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // Find where port 8080 is mapped to. port, err := d.FindPort(8080) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("FindPort(8080) failed: %v", err) } // Wait until it's up and running, 'gem install' can take some time. @@ -326,18 +274,17 @@ func TestRuby(t *testing.T) { } func TestStdio(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - d := dockerutil.MakeDocker("stdio-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() wantStdout := "hello stdout" wantStderr := "bonjour stderr" cmd := fmt.Sprintf("echo %q; echo %q 1>&2;", wantStdout, wantStderr) - if err := d.Run("alpine", "/bin/sh", "-c", cmd); err != nil { + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "/bin/sh", "-c", cmd); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() for _, want := range []string{wantStdout, wantStderr} { if _, err := d.WaitForOutput(want, 5*time.Second); err != nil { diff --git a/test/image/ruby.sh b/test/image/ruby.sh index ebe8d5b0e..ebe8d5b0e 100644..100755 --- a/test/image/ruby.sh +++ b/test/image/ruby.sh diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 6bb3b82b5..3e29ca90d 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -14,7 +14,7 @@ go_library( ], visibility = ["//test/iptables:__subpackages__"], deps = [ - "//runsc/testutil", + "//pkg/test/testutil", ], ) @@ -23,14 +23,14 @@ go_test( srcs = [ "iptables_test.go", ], + data = ["//test/iptables/runner"], library = ":iptables", tags = [ "local", "manual", ], deps = [ - "//pkg/log", - "//runsc/dockerutil", - "//runsc/testutil", + "//pkg/test/dockerutil", + "//pkg/test/testutil", ], ) diff --git a/test/iptables/README.md b/test/iptables/README.md index cc8a2fcac..b9f44bd40 100644 --- a/test/iptables/README.md +++ b/test/iptables/README.md @@ -38,7 +38,7 @@ Build the testing Docker container. Re-run this when you modify the test code in this directory: ```bash -$ bazel run //test/iptables/runner:runner-image -- --norun +$ make load-iptables ``` Run an individual test via: diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index bd6059921..41e0cfa8d 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -26,6 +26,7 @@ const ( acceptPort = 2402 sendloopDuration = 2 * time.Second network = "udp4" + chainName = "foochain" ) func init() { @@ -36,6 +37,18 @@ func init() { RegisterTestCase(FilterInputDropTCPSrcPort{}) RegisterTestCase(FilterInputDropUDPPort{}) RegisterTestCase(FilterInputDropUDP{}) + RegisterTestCase(FilterInputCreateUserChain{}) + RegisterTestCase(FilterInputDefaultPolicyAccept{}) + RegisterTestCase(FilterInputDefaultPolicyDrop{}) + RegisterTestCase(FilterInputReturnUnderflow{}) + RegisterTestCase(FilterInputSerializeJump{}) + RegisterTestCase(FilterInputJumpBasic{}) + RegisterTestCase(FilterInputJumpReturn{}) + RegisterTestCase(FilterInputJumpReturnDrop{}) + RegisterTestCase(FilterInputJumpBuiltin{}) + RegisterTestCase(FilterInputJumpTwice{}) + RegisterTestCase(FilterInputDestination{}) + RegisterTestCase(FilterInputInvertDestination{}) } // FilterInputDropUDP tests that we can drop UDP traffic. @@ -95,7 +108,7 @@ func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP) error { func (FilterInputDropOnlyUDP) LocalAction(ip net.IP) error { // Try to establish a TCP connection with the container, which should // succeed. - return connectTCP(ip, acceptPort, dropPort, sendloopDuration) + return connectTCP(ip, acceptPort, sendloopDuration) } // FilterInputDropUDPPort tests that we can drop UDP traffic by port. @@ -181,8 +194,11 @@ func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error { // LocalAction implements TestCase.LocalAction. func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil { - return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) + // Ensure we cannot connect to the container. + for start := time.Now(); time.Since(start) < sendloopDuration; { + if err := connectTCP(ip, dropPort, sendloopDuration-time.Since(start)); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort) + } } return nil @@ -198,13 +214,14 @@ func (FilterInputDropTCPSrcPort) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { + // Drop anything from an ephemeral port. + if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. if err := listenTCP(acceptPort, sendloopDuration); err == nil { - return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) + return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort) } return nil @@ -212,8 +229,11 @@ func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error { // LocalAction implements TestCase.LocalAction. func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil { - return fmt.Errorf("connection on port %d should not be acceptedi, but got accepted", dropPort) + // Ensure we cannot connect to the container. + for start := time.Now(); time.Since(start) < sendloopDuration; { + if err := connectTCP(ip, acceptPort, sendloopDuration-time.Since(start)); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort) + } } return nil @@ -263,13 +283,12 @@ func (FilterInputMultiUDPRules) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { - return err - } - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"); err != nil { - return err + rules := [][]string{ + {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"}, + {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"}, + {"-L"}, } - return filterTable("-L") + return filterTableRules(rules) } // LocalAction implements TestCase.LocalAction. @@ -295,8 +314,356 @@ func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP) error { return nil } -// LocalAction implements TestCase.LocalAction. func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP) error { // No-op. return nil } + +// FilterInputCreateUserChain tests chain creation. +type FilterInputCreateUserChain struct{} + +// Name implements TestCase.Name. +func (FilterInputCreateUserChain) Name() string { + return "FilterInputCreateUserChain" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error { + rules := [][]string{ + // Create a chain. + {"-N", chainName}, + // Add a simple rule to the chain. + {"-A", chainName, "-j", "DROP"}, + } + return filterTableRules(rules) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputCreateUserChain) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// FilterInputDefaultPolicyAccept tests the default ACCEPT policy. +type FilterInputDefaultPolicyAccept struct{} + +// Name implements TestCase.Name. +func (FilterInputDefaultPolicyAccept) Name() string { + return "FilterInputDefaultPolicyAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP) error { + // Set the default policy to accept, then receive a packet. + if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil { + return err + } + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputDefaultPolicyDrop tests the default DROP policy. +type FilterInputDefaultPolicyDrop struct{} + +// Name implements TestCase.Name. +func (FilterInputDefaultPolicyDrop) Name() string { + return "FilterInputDefaultPolicyDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error { + if err := filterTable("-P", "INPUT", "DROP"); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes +// the underflow rule (i.e. default policy) to be executed. +type FilterInputReturnUnderflow struct{} + +// Name implements TestCase.Name. +func (FilterInputReturnUnderflow) Name() string { + return "FilterInputReturnUnderflow" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { + // Add a RETURN rule followed by an unconditional accept, and set the + // default policy to DROP. + rules := [][]string{ + {"-A", "INPUT", "-j", "RETURN"}, + {"-A", "INPUT", "-j", "DROP"}, + {"-P", "INPUT", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // We should receive packets, as the RETURN rule will trigger the default + // ACCEPT policy. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputSerializeJump verifies that we can serialize jumps. +type FilterInputSerializeJump struct{} + +// Name implements TestCase.Name. +func (FilterInputSerializeJump) Name() string { + return "FilterInputSerializeJump" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputSerializeJump) ContainerAction(ip net.IP) error { + // Write a JUMP rule, the serialize it with `-L`. + rules := [][]string{ + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-L"}, + } + return filterTableRules(rules) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputSerializeJump) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// FilterInputJumpBasic jumps to a chain and executes a rule there. +type FilterInputJumpBasic struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpBasic) Name() string { + return "FilterInputJumpBasic" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpBasic) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpBasic) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputJumpReturn jumps, returns, and executes a rule. +type FilterInputJumpReturn struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpReturn) Name() string { + return "FilterInputJumpReturn" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpReturn) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-N", chainName}, + {"-P", "INPUT", "ACCEPT"}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", "RETURN"}, + {"-A", chainName, "-j", "DROP"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpReturn) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets. +type FilterInputJumpReturnDrop struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpReturnDrop) Name() string { + return "FilterInputJumpReturnDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-A", "INPUT", "-j", "DROP"}, + {"-A", chainName, "-j", "RETURN"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal. +type FilterInputJumpBuiltin struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpBuiltin) Name() string { + return "FilterInputJumpBuiltin" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpBuiltin) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "INPUT", "-j", "OUTPUT"); err == nil { + return fmt.Errorf("iptables should be unable to jump to a built-in chain") + } + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpBuiltin) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// FilterInputJumpTwice jumps twice, then returns twice and executes a rule. +type FilterInputJumpTwice struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpTwice) Name() string { + return "FilterInputJumpTwice" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpTwice) ContainerAction(ip net.IP) error { + const chainName2 = chainName + "2" + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-N", chainName}, + {"-N", chainName2}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", chainName2}, + {"-A", "INPUT", "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // UDP packets should jump and return twice, eventually hitting the + // ACCEPT rule. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpTwice) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputDestination verifies that we can filter packets via `-d +// <ipaddr>`. +type FilterInputDestination struct{} + +// Name implements TestCase.Name. +func (FilterInputDestination) Name() string { + return "FilterInputDestination" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDestination) ContainerAction(ip net.IP) error { + addrs, err := localAddrs() + if err != nil { + return err + } + + // Make INPUT's default action DROP, then ACCEPT all packets bound for + // this machine. + rules := [][]string{{"-P", "INPUT", "DROP"}} + for _, addr := range addrs { + rules = append(rules, []string{"-A", "INPUT", "-d", addr, "-j", "ACCEPT"}) + } + if err := filterTableRules(rules); err != nil { + return err + } + + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDestination) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputInvertDestination verifies that we can filter packets via `! -d +// <ipaddr>`. +type FilterInputInvertDestination struct{} + +// Name implements TestCase.Name. +func (FilterInputInvertDestination) Name() string { + return "FilterInputInvertDestination" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInvertDestination) ContainerAction(ip net.IP) error { + // Make INPUT's default action DROP, then ACCEPT all packets not bound + // for 127.0.0.1. + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-A", "INPUT", "!", "-d", localIP, "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInvertDestination) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index ee2c49f9a..f6d974b85 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -22,9 +22,17 @@ import ( func init() { RegisterTestCase(FilterOutputDropTCPDestPort{}) RegisterTestCase(FilterOutputDropTCPSrcPort{}) + RegisterTestCase(FilterOutputDestination{}) + RegisterTestCase(FilterOutputInvertDestination{}) + RegisterTestCase(FilterOutputAcceptTCPOwner{}) + RegisterTestCase(FilterOutputDropTCPOwner{}) + RegisterTestCase(FilterOutputAcceptUDPOwner{}) + RegisterTestCase(FilterOutputDropUDPOwner{}) + RegisterTestCase(FilterOutputOwnerFail{}) } -// FilterOutputDropTCPDestPort tests that connections are not accepted on specified source ports. +// FilterOutputDropTCPDestPort tests that connections are not accepted on +// specified source ports. type FilterOutputDropTCPDestPort struct{} // Name implements TestCase.Name. @@ -48,14 +56,15 @@ func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error { // LocalAction implements TestCase.LocalAction. func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil { + if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) } return nil } -// FilterOutputDropTCPSrcPort tests that connections are not accepted on specified source ports. +// FilterOutputDropTCPSrcPort tests that connections are not accepted on +// specified source ports. type FilterOutputDropTCPSrcPort struct{} // Name implements TestCase.Name. @@ -79,9 +88,201 @@ func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error { // LocalAction implements TestCase.LocalAction. func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil { + if err := connectTCP(ip, dropPort, sendloopDuration); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) } return nil } + +// FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted. +type FilterOutputAcceptTCPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputAcceptTCPOwner) Name() string { + return "FilterOutputAcceptTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("connection on port %d should be accepted, but got dropped", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("connection destined to port %d should be accepted, but got dropped", acceptPort) + } + + return nil +} + +// FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped. +type FilterOutputDropTCPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputDropTCPOwner) Name() string { + return "FilterOutputDropTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted. +type FilterOutputAcceptUDPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputAcceptUDPOwner) Name() string { + return "FilterOutputAcceptUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Send UDP packets on acceptPort. + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP) error { + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped. +type FilterOutputDropUDPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputDropUDPOwner) Name() string { + return "FilterOutputDropUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Send UDP packets on dropPort. + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error { + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets should not be received") + } + + return nil +} + +// FilterOutputOwnerFail tests that without uid/gid option, owner rule +// will fail. +type FilterOutputOwnerFail struct{} + +// Name implements TestCase.Name. +func (FilterOutputOwnerFail) Name() string { + return "FilterOutputOwnerFail" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { + return fmt.Errorf("Invalid argument") + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputOwnerFail) LocalAction(ip net.IP) error { + // no-op. + return nil +} + +// FilterOutputDestination tests that we can selectively allow packets to +// certain destinations. +type FilterOutputDestination struct{} + +// Name implements TestCase.Name. +func (FilterOutputDestination) Name() string { + return "FilterOutputDestination" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDestination) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, + {"-P", "OUTPUT", "DROP"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDestination) LocalAction(ip net.IP) error { + return listenUDP(acceptPort, sendloopDuration) +} + +// FilterOutputInvertDestination tests that we can selectively allow packets +// not headed for a particular destination. +type FilterOutputInvertDestination struct{} + +// Name implements TestCase.Name. +func (FilterOutputInvertDestination) Name() string { + return "FilterOutputInvertDestination" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-A", "OUTPUT", "!", "-d", localIP, "-j", "ACCEPT"}, + {"-P", "OUTPUT", "DROP"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInvertDestination) LocalAction(ip net.IP) error { + return listenUDP(acceptPort, sendloopDuration) +} diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go index 2e565d988..16cb4f4da 100644 --- a/test/iptables/iptables.go +++ b/test/iptables/iptables.go @@ -18,12 +18,19 @@ package iptables import ( "fmt" "net" + "time" ) // IPExchangePort is the port the container listens on to receive the IP // address of the local process. const IPExchangePort = 2349 +// TerminalStatement is the last statement in the test runner. +const TerminalStatement = "Finished!" + +// TestTimeout is the timeout used for all tests. +const TestTimeout = 10 * time.Minute + // A TestCase contains one action to run in the container and one to run // locally. The actions run concurrently and each must succeed for the test // pass. diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 41909582a..334d8e676 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -15,28 +15,14 @@ package iptables import ( - "flag" "fmt" "net" - "os" - "path" "testing" - "time" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) -const timeout = 18 * time.Second - -var image = flag.String("image", "bazel/test/iptables/runner:runner-image", "image to run tests in") - -type result struct { - output string - err error -} - // singleTest runs a TestCase. Each test follows a pattern: // - Create a container. // - Get the container's IP. @@ -46,77 +32,45 @@ type result struct { // // Container output is logged to $TEST_UNDECLARED_OUTPUTS_DIR if it exists, or // to stderr. -func singleTest(test TestCase) error { +func singleTest(t *testing.T, test TestCase) { if _, ok := Tests[test.Name()]; !ok { - return fmt.Errorf("no test found with name %q. Has it been registered?", test.Name()) + t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) } + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + // Create and start the container. - cont := dockerutil.MakeDocker("gvisor-iptables") - defer cont.CleanUp() - resultChan := make(chan *result) - go func() { - output, err := cont.RunFg("--cap-add=NET_ADMIN", *image, "-name", test.Name()) - logContainer(output, err) - resultChan <- &result{output, err} - }() + d.CopyFiles("/runner", "test/iptables/runner/runner") + if err := d.Spawn(dockerutil.RunOpts{ + Image: "iptables", + CapAdd: []string{"NET_ADMIN"}, + }, "/runner/runner", "-name", test.Name()); err != nil { + t.Fatalf("docker run failed: %v", err) + } // Get the container IP. - ip, err := getIP(cont) + ip, err := d.FindIP() if err != nil { - return fmt.Errorf("failed to get container IP: %v", err) + t.Fatalf("failed to get container IP: %v", err) } // Give the container our IP. if err := sendIP(ip); err != nil { - return fmt.Errorf("failed to send IP to container: %v", err) + t.Fatalf("failed to send IP to container: %v", err) } // Run our side of the test. - errChan := make(chan error) - go func() { - errChan <- test.LocalAction(ip) - }() - - // Wait for both the container and local tests to finish. - var res *result - to := time.After(timeout) - for localDone := false; res == nil || !localDone; { - select { - case res = <-resultChan: - log.Infof("Container finished.") - case err, localDone = <-errChan: - log.Infof("Local finished.") - if err != nil { - return fmt.Errorf("local test failed: %v", err) - } - case <-to: - return fmt.Errorf("timed out after %f seconds", timeout.Seconds()) - } + if err := test.LocalAction(ip); err != nil { + t.Fatalf("LocalAction failed: %v", err) } - return res.err -} - -func getIP(cont dockerutil.Docker) (net.IP, error) { - // The container might not have started yet, so retry a few times. - var ipStr string - to := time.After(timeout) - for ipStr == "" { - ipStr, _ = cont.FindIP() - select { - case <-to: - return net.IP{}, fmt.Errorf("timed out getting IP after %f seconds", timeout.Seconds()) - default: - time.Sleep(250 * time.Millisecond) - } - } - ip := net.ParseIP(ipStr) - if ip == nil { - return net.IP{}, fmt.Errorf("invalid IP: %q", ipStr) + // Wait for the final statement. This structure has the side effect + // that all container logs will appear within the individual test + // context. + if _, err := d.WaitForOutput(TerminalStatement, TestTimeout); err != nil { + t.Fatalf("test failed: %v", err) } - log.Infof("Container has IP of %s", ipStr) - return ip, nil } func sendIP(ip net.IP) error { @@ -132,7 +86,7 @@ func sendIP(ip net.IP) error { conn = c return err } - if err := testutil.Poll(cb, timeout); err != nil { + if err := testutil.Poll(cb, TestTimeout); err != nil { return fmt.Errorf("timed out waiting to send IP, most recent error: %v", err) } if _, err := conn.Write([]byte{0}); err != nil { @@ -141,87 +95,184 @@ func sendIP(ip net.IP) error { return nil } -func logContainer(output string, err error) { - msg := fmt.Sprintf("Container error: %v\nContainer output:\n%v", err, output) - if artifactsDir := os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); artifactsDir != "" { - fpath := path.Join(artifactsDir, "container.log") - if file, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE, 0644); err != nil { - log.Warningf("Failed to open log file %q: %v", fpath, err) - } else { - defer file.Close() - if _, err := file.Write([]byte(msg)); err == nil { - return - } - log.Warningf("Failed to write to log file %s: %v", fpath, err) - } - } - - // We couldn't write to the output directory -- just log to stderr. - log.Infof(msg) -} - func TestFilterInputDropUDP(t *testing.T) { - if err := singleTest(FilterInputDropUDP{}); err != nil { - t.Fatal(err) - } + singleTest(t, FilterInputDropUDP{}) } func TestFilterInputDropUDPPort(t *testing.T) { - if err := singleTest(FilterInputDropUDPPort{}); err != nil { - t.Fatal(err) - } + singleTest(t, FilterInputDropUDPPort{}) } func TestFilterInputDropDifferentUDPPort(t *testing.T) { - if err := singleTest(FilterInputDropDifferentUDPPort{}); err != nil { - t.Fatal(err) - } + singleTest(t, FilterInputDropDifferentUDPPort{}) } func TestFilterInputDropAll(t *testing.T) { - if err := singleTest(FilterInputDropAll{}); err != nil { - t.Fatal(err) - } + singleTest(t, FilterInputDropAll{}) } func TestFilterInputDropOnlyUDP(t *testing.T) { - if err := singleTest(FilterInputDropOnlyUDP{}); err != nil { - t.Fatal(err) - } + singleTest(t, FilterInputDropOnlyUDP{}) } func TestNATRedirectUDPPort(t *testing.T) { - if err := singleTest(NATRedirectUDPPort{}); err != nil { - t.Fatal(err) - } + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATRedirectUDPPort{}) +} + +func TestNATRedirectTCPPort(t *testing.T) { + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATRedirectTCPPort{}) } func TestNATDropUDP(t *testing.T) { - if err := singleTest(NATDropUDP{}); err != nil { - t.Fatal(err) - } + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATDropUDP{}) +} + +func TestNATAcceptAll(t *testing.T) { + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATAcceptAll{}) } func TestFilterInputDropTCPDestPort(t *testing.T) { - if err := singleTest(FilterInputDropTCPDestPort{}); err != nil { - t.Fatal(err) - } + singleTest(t, FilterInputDropTCPDestPort{}) } func TestFilterInputDropTCPSrcPort(t *testing.T) { - if err := singleTest(FilterInputDropTCPSrcPort{}); err != nil { - t.Fatal(err) - } + singleTest(t, FilterInputDropTCPSrcPort{}) +} + +func TestFilterInputCreateUserChain(t *testing.T) { + singleTest(t, FilterInputCreateUserChain{}) +} + +func TestFilterInputDefaultPolicyAccept(t *testing.T) { + singleTest(t, FilterInputDefaultPolicyAccept{}) +} + +func TestFilterInputDefaultPolicyDrop(t *testing.T) { + singleTest(t, FilterInputDefaultPolicyDrop{}) +} + +func TestFilterInputReturnUnderflow(t *testing.T) { + singleTest(t, FilterInputReturnUnderflow{}) } func TestFilterOutputDropTCPDestPort(t *testing.T) { - if err := singleTest(FilterOutputDropTCPDestPort{}); err != nil { - t.Fatal(err) - } + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, FilterOutputDropTCPDestPort{}) } func TestFilterOutputDropTCPSrcPort(t *testing.T) { - if err := singleTest(FilterOutputDropTCPSrcPort{}); err != nil { - t.Fatal(err) - } + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, FilterOutputDropTCPSrcPort{}) +} + +func TestFilterOutputAcceptTCPOwner(t *testing.T) { + singleTest(t, FilterOutputAcceptTCPOwner{}) +} + +func TestFilterOutputDropTCPOwner(t *testing.T) { + singleTest(t, FilterOutputDropTCPOwner{}) +} + +func TestFilterOutputAcceptUDPOwner(t *testing.T) { + singleTest(t, FilterOutputAcceptUDPOwner{}) +} + +func TestFilterOutputDropUDPOwner(t *testing.T) { + singleTest(t, FilterOutputDropUDPOwner{}) +} + +func TestFilterOutputOwnerFail(t *testing.T) { + singleTest(t, FilterOutputOwnerFail{}) +} + +func TestJumpSerialize(t *testing.T) { + singleTest(t, FilterInputSerializeJump{}) +} + +func TestJumpBasic(t *testing.T) { + singleTest(t, FilterInputJumpBasic{}) +} + +func TestJumpReturn(t *testing.T) { + singleTest(t, FilterInputJumpReturn{}) +} + +func TestJumpReturnDrop(t *testing.T) { + singleTest(t, FilterInputJumpReturnDrop{}) +} + +func TestJumpBuiltin(t *testing.T) { + singleTest(t, FilterInputJumpBuiltin{}) +} + +func TestJumpTwice(t *testing.T) { + singleTest(t, FilterInputJumpTwice{}) +} + +func TestInputDestination(t *testing.T) { + singleTest(t, FilterInputDestination{}) +} + +func TestInputInvertDestination(t *testing.T) { + singleTest(t, FilterInputInvertDestination{}) +} + +func TestOutputDestination(t *testing.T) { + singleTest(t, FilterOutputDestination{}) +} + +func TestOutputInvertDestination(t *testing.T) { + singleTest(t, FilterOutputInvertDestination{}) +} + +func TestNATOutRedirectIP(t *testing.T) { + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATOutRedirectIP{}) +} + +func TestNATOutDontRedirectIP(t *testing.T) { + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATOutDontRedirectIP{}) +} + +func TestNATOutRedirectInvert(t *testing.T) { + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATOutRedirectInvert{}) +} + +func TestNATPreRedirectIP(t *testing.T) { + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATPreRedirectIP{}) +} + +func TestNATPreDontRedirectIP(t *testing.T) { + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATPreDontRedirectIP{}) +} + +func TestNATPreRedirectInvert(t *testing.T) { + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATPreRedirectInvert{}) +} + +func TestNATRedirectRequiresProtocol(t *testing.T) { + // TODO(gvisor.dev/issue/170): Enable when supported. + t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") + singleTest(t, NATRedirectRequiresProtocol{}) } diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 043114c78..2a00677be 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -20,14 +20,24 @@ import ( "os/exec" "time" - "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) const iptablesBinary = "iptables" +const localIP = "127.0.0.1" // filterTable calls `iptables -t filter` with the given args. func filterTable(args ...string) error { - args = append([]string{"-t", "filter"}, args...) + return tableCmd("filter", args) +} + +// natTable calls `iptables -t nat` with the given args. +func natTable(args ...string) error { + return tableCmd("nat", args) +} + +func tableCmd(table string, args []string) error { + args = append([]string{"-t", table}, args...) cmd := exec.Command(iptablesBinary, args...) if out, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) @@ -35,6 +45,25 @@ func filterTable(args ...string) error { return nil } +// filterTableRules is like filterTable, but runs multiple iptables commands. +func filterTableRules(argsList [][]string) error { + return tableRules("filter", argsList) +} + +// natTableRules is like natTable, but runs multiple iptables commands. +func natTableRules(argsList [][]string) error { + return tableRules("nat", argsList) +} + +func tableRules(table string, argsList [][]string) error { + for _, args := range argsList { + if err := tableCmd(table, args); err != nil { + return err + } + } + return nil +} + // listenUDP listens on a UDP port and returns the value of net.Conn.Read() for // the first read on that port. func listenUDP(port int, timeout time.Duration) error { @@ -106,27 +135,37 @@ func listenTCP(port int, timeout time.Duration) error { return nil } -// connectTCP connects the TCP server over specified local port, server IP and remote/server port. -func connectTCP(ip net.IP, remotePort, localPort int, timeout time.Duration) error { +// connectTCP connects to the given IP and port from an ephemeral local address. +func connectTCP(ip net.IP, port int, timeout time.Duration) error { contAddr := net.TCPAddr{ IP: ip, - Port: remotePort, + Port: port, } // The container may not be listening when we first connect, so retry // upon error. callback := func() error { - localAddr := net.TCPAddr{ - Port: localPort, - } - conn, err := net.DialTCP("tcp4", &localAddr, &contAddr) + conn, err := net.DialTimeout("tcp", contAddr.String(), timeout) if conn != nil { conn.Close() } return err } if err := testutil.Poll(callback, timeout); err != nil { - return fmt.Errorf("timed out waiting to send IP, most recent error: %v", err) + return fmt.Errorf("timed out waiting to connect IP, most recent error: %v", err) } return nil } + +// localAddrs returns a list of local network interface addresses. +func localAddrs() ([]string, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + addrStrs := make([]string, 0, len(addrs)) + for _, addr := range addrs { + addrStrs = append(addrStrs, addr.String()) + } + return addrStrs, nil +} diff --git a/test/iptables/nat.go b/test/iptables/nat.go index b5c6f927e..40096901c 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -15,8 +15,10 @@ package iptables import ( + "errors" "fmt" "net" + "time" ) const ( @@ -25,7 +27,16 @@ const ( func init() { RegisterTestCase(NATRedirectUDPPort{}) + RegisterTestCase(NATRedirectTCPPort{}) RegisterTestCase(NATDropUDP{}) + RegisterTestCase(NATAcceptAll{}) + RegisterTestCase(NATPreRedirectIP{}) + RegisterTestCase(NATPreDontRedirectIP{}) + RegisterTestCase(NATPreRedirectInvert{}) + RegisterTestCase(NATOutRedirectIP{}) + RegisterTestCase(NATOutDontRedirectIP{}) + RegisterTestCase(NATOutRedirectInvert{}) + RegisterTestCase(NATRedirectRequiresProtocol{}) } // NATRedirectUDPPort tests that packets are redirected to different port. @@ -38,13 +49,14 @@ func (NATRedirectUDPPort) Name() string { // ContainerAction implements TestCase.ContainerAction. func (NATRedirectUDPPort) ContainerAction(ip net.IP) error { - if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } if err := listenUDP(redirectPort, sendloopDuration); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err) } + return nil } @@ -53,7 +65,31 @@ func (NATRedirectUDPPort) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } -// NATDropUDP tests that packets are not received in ports other than redirect port. +// NATRedirectTCPPort tests that connections are redirected on specified ports. +type NATRedirectTCPPort struct{} + +// Name implements TestCase.Name. +func (NATRedirectTCPPort) Name() string { + return "NATRedirectTCPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATRedirectTCPPort) ContainerAction(ip net.IP) error { + if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + return err + } + + // Listen for TCP packets on redirect port. + return listenTCP(redirectPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATRedirectTCPPort) LocalAction(ip net.IP) error { + return connectTCP(ip, dropPort, sendloopDuration) +} + +// NATDropUDP tests that packets are not received in ports other than redirect +// port. type NATDropUDP struct{} // Name implements TestCase.Name. @@ -63,7 +99,7 @@ func (NATDropUDP) Name() string { // ContainerAction implements TestCase.ContainerAction. func (NATDropUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } @@ -78,3 +114,218 @@ func (NATDropUDP) ContainerAction(ip net.IP) error { func (NATDropUDP) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } + +// NATAcceptAll tests that all UDP packets are accepted. +type NATAcceptAll struct{} + +// Name implements TestCase.Name. +func (NATAcceptAll) Name() string { + return "NATAcceptAll" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATAcceptAll) ContainerAction(ip net.IP) error { + if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { + return err + } + + if err := listenUDP(acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATAcceptAll) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// NATOutRedirectIP uses iptables to select packets based on destination IP and +// redirects them. +type NATOutRedirectIP struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectIP) Name() string { + return "NATOutRedirectIP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectIP) ContainerAction(ip net.IP) error { + // Redirect OUTPUT packets to a listening localhost port. + dest := net.IP([]byte{200, 0, 0, 2}) + return loopbackTest(dest, "-A", "OUTPUT", "-d", dest.String(), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectIP) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// NATOutDontRedirectIP tests that iptables matching with "-d" does not match +// packets it shouldn't. +type NATOutDontRedirectIP struct{} + +// Name implements TestCase.Name. +func (NATOutDontRedirectIP) Name() string { + return "NATOutDontRedirectIP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutDontRedirectIP) ContainerAction(ip net.IP) error { + if err := natTable("-A", "OUTPUT", "-d", localIP, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutDontRedirectIP) LocalAction(ip net.IP) error { + return listenUDP(acceptPort, sendloopDuration) +} + +// NATOutRedirectInvert tests that iptables can match with "! -d". +type NATOutRedirectInvert struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectInvert) Name() string { + return "NATOutRedirectInvert" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectInvert) ContainerAction(ip net.IP) error { + // Redirect OUTPUT packets to a listening localhost port. + dest := []byte{200, 0, 0, 3} + destStr := "200.0.0.2" + return loopbackTest(dest, "-A", "OUTPUT", "!", "-d", destStr, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectInvert) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// NATPreRedirectIP tests that we can use iptables to select packets based on +// destination IP and redirect them. +type NATPreRedirectIP struct{} + +// Name implements TestCase.Name. +func (NATPreRedirectIP) Name() string { + return "NATPreRedirectIP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectIP) ContainerAction(ip net.IP) error { + addrs, err := localAddrs() + if err != nil { + return err + } + + var rules [][]string + for _, addr := range addrs { + rules = append(rules, []string{"-A", "PREROUTING", "-p", "udp", "-d", addr, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)}) + } + if err := natTableRules(rules); err != nil { + return err + } + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectIP) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// NATPreDontRedirectIP tests that iptables matching with "-d" does not match +// packets it shouldn't. +type NATPreDontRedirectIP struct{} + +// Name implements TestCase.Name. +func (NATPreDontRedirectIP) Name() string { + return "NATPreDontRedirectIP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreDontRedirectIP) ContainerAction(ip net.IP) error { + if err := natTable("-A", "PREROUTING", "-p", "udp", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreDontRedirectIP) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// NATPreRedirectInvert tests that iptables can match with "! -d". +type NATPreRedirectInvert struct{} + +// Name implements TestCase.Name. +func (NATPreRedirectInvert) Name() string { + return "NATPreRedirectInvert" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectInvert) ContainerAction(ip net.IP) error { + if err := natTable("-A", "PREROUTING", "-p", "udp", "!", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectInvert) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a +// protocol to be specified with -p. +type NATRedirectRequiresProtocol struct{} + +// Name implements TestCase.Name. +func (NATRedirectRequiresProtocol) Name() string { + return "NATRedirectRequiresProtocol" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP) error { + if err := natTable("-A", "PREROUTING", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil { + return errors.New("expected an error using REDIRECT --to-ports without a protocol") + } + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATRedirectRequiresProtocol) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// loopbackTests runs an iptables rule and ensures that packets sent to +// dest:dropPort are received by localhost:acceptPort. +func loopbackTest(dest net.IP, args ...string) error { + if err := natTable(args...); err != nil { + return err + } + sendCh := make(chan error) + listenCh := make(chan error) + go func() { + sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration) + }() + go func() { + listenCh <- listenUDP(acceptPort, sendloopDuration) + }() + select { + case err := <-listenCh: + if err != nil { + return err + } + case <-time.After(sendloopDuration): + return errors.New("timed out") + } + // sendCh will always take the full sendloop time. + return <-sendCh +} diff --git a/test/iptables/runner/BUILD b/test/iptables/runner/BUILD index b9199387a..24504a1b9 100644 --- a/test/iptables/runner/BUILD +++ b/test/iptables/runner/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "container_image", "go_binary", "go_image") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) @@ -6,18 +6,7 @@ go_binary( name = "runner", testonly = 1, srcs = ["main.go"], - deps = ["//test/iptables"], -) - -container_image( - name = "iptables-base", - base = "@iptables-test//image", -) - -go_image( - name = "runner-image", - testonly = 1, - srcs = ["main.go"], - base = ":iptables-base", + pure = True, + visibility = ["//test/iptables:__subpackages__"], deps = ["//test/iptables"], ) diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go index 3c794114e..6f77c0684 100644 --- a/test/iptables/runner/main.go +++ b/test/iptables/runner/main.go @@ -46,6 +46,9 @@ func main() { if err := test.ContainerAction(ip); err != nil { log.Fatalf("Failed running test %q: %v", *name, err) } + + // Emit the final line. + log.Printf("%s", iptables.TerminalStatement) } // getIP listens for a connection from the local process and returns the source diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD index d113555b1..dfcd55f60 100644 --- a/test/packetdrill/BUILD +++ b/test/packetdrill/BUILD @@ -3,6 +3,36 @@ load("defs.bzl", "packetdrill_test") package(licenses = ["notice"]) packetdrill_test( - name = "fin_wait2_timeout", + name = "packetdrill_sanity_test", + scripts = ["sanity_test.pkt"], +) + +packetdrill_test( + name = "accept_ack_drop_test", + scripts = ["accept_ack_drop.pkt"], +) + +packetdrill_test( + name = "fin_wait2_timeout_test", scripts = ["fin_wait2_timeout.pkt"], ) + +packetdrill_test( + name = "listen_close_before_handshake_complete_test", + scripts = ["listen_close_before_handshake_complete.pkt"], +) + +packetdrill_test( + name = "no_rst_to_rst_test", + scripts = ["no_rst_to_rst.pkt"], +) + +packetdrill_test( + name = "tcp_defer_accept_test", + scripts = ["tcp_defer_accept.pkt"], +) + +packetdrill_test( + name = "tcp_defer_accept_timeout_test", + scripts = ["tcp_defer_accept_timeout.pkt"], +) diff --git a/test/packetdrill/Dockerfile b/test/packetdrill/Dockerfile index bd4451355..4b75e9527 100644 --- a/test/packetdrill/Dockerfile +++ b/test/packetdrill/Dockerfile @@ -1,9 +1,9 @@ FROM ubuntu:bionic -RUN apt-get update -RUN apt-get install -y net-tools git iptables iputils-ping netcat tcpdump jq tar +RUN apt-get update && apt-get install -y net-tools git iptables iputils-ping \ + netcat tcpdump jq tar bison flex make RUN hash -r RUN git clone --branch packetdrill-v2.0 \ https://github.com/google/packetdrill.git -RUN cd packetdrill/gtests/net/packetdrill && ./configure && \ - apt-get install -y bison flex make && make +RUN cd packetdrill/gtests/net/packetdrill && ./configure && make +CMD /bin/bash diff --git a/test/packetdrill/accept_ack_drop.pkt b/test/packetdrill/accept_ack_drop.pkt new file mode 100644 index 000000000..76e638fd4 --- /dev/null +++ b/test/packetdrill/accept_ack_drop.pkt @@ -0,0 +1,27 @@ +// Test that the accept works if the final ACK is dropped and an ack with data +// follows the dropped ack. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0.0 > S. 0:0(0) ack 1 <...> + ++0.0 < . 1:5(4) ack 1 win 257 ++0.0 > . 1:1(0) ack 5 <...> + +// This should cause connection to transition to connected state. ++0.000 accept(3, ..., ...) = 4 ++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 + +// Now read the data and we should get 4 bytes. ++0.000 read(4,..., 4) = 4 ++0.000 close(4) = 0 + ++0.0 > F. 1:1(0) ack 5 <...> ++0.0 < F. 5:5(0) ack 2 win 257 ++0.01 > . 2:2(0) ack 6 <...>
\ No newline at end of file diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl index 8623ce7b1..f499c177b 100644 --- a/test/packetdrill/defs.bzl +++ b/test/packetdrill/defs.bzl @@ -66,7 +66,7 @@ def packetdrill_linux_test(name, **kwargs): if "tags" not in kwargs: kwargs["tags"] = _PACKETDRILL_TAGS _packetdrill_test( - name = name + "_linux_test", + name = name, flags = ["--dut_platform", "linux"], **kwargs ) @@ -75,13 +75,13 @@ def packetdrill_netstack_test(name, **kwargs): if "tags" not in kwargs: kwargs["tags"] = _PACKETDRILL_TAGS _packetdrill_test( - name = name + "_netstack_test", + name = name, # This is the default runtime unless # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value. flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"], **kwargs ) -def packetdrill_test(**kwargs): - packetdrill_linux_test(**kwargs) - packetdrill_netstack_test(**kwargs) +def packetdrill_test(name, **kwargs): + packetdrill_linux_test(name + "_linux_test", **kwargs) + packetdrill_netstack_test(name + "_netstack_test", **kwargs) diff --git a/test/packetdrill/fin_wait2_timeout.pkt b/test/packetdrill/fin_wait2_timeout.pkt index 613f0bec9..93ab08575 100644 --- a/test/packetdrill/fin_wait2_timeout.pkt +++ b/test/packetdrill/fin_wait2_timeout.pkt @@ -19,5 +19,5 @@ +0 > F. 1:1(0) ack 1 <...> +0 < . 1:1(0) ack 2 win 257 -+1.1 < . 1:1(0) ack 2 win 257 ++2 < . 1:1(0) ack 2 win 257 +0 > R 2:2(0) win 0 diff --git a/test/packetdrill/listen_close_before_handshake_complete.pkt b/test/packetdrill/listen_close_before_handshake_complete.pkt new file mode 100644 index 000000000..51c3f1a32 --- /dev/null +++ b/test/packetdrill/listen_close_before_handshake_complete.pkt @@ -0,0 +1,31 @@ +// Test that closing a listening socket closes any connections in SYN-RCVD +// state and any packets bound for these connections generate a RESET. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0 > S. 0:0(0) ack 1 <...> + ++0.100 close(3) = 0 ++0.1 < P. 1:1(0) ack 1 win 257 + +// Linux generates a reset with no ack number/bit set. This is contradictory to +// what is specified in Rule 1 under Reset Generation in +// https://tools.ietf.org/html/rfc793#section-3.4. +// "1. If the connection does not exist (CLOSED) then a reset is sent +// in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected +// by this means. +// +// If the incoming segment has an ACK field, the reset takes its +// sequence number from the ACK field of the segment, otherwise the +// reset has sequence number zero and the ACK field is set to the sum +// of the sequence number and segment length of the incoming segment. +// The connection remains in the CLOSED state." + ++0.0 > R 1:1(0) win 0
\ No newline at end of file diff --git a/test/packetdrill/no_rst_to_rst.pkt b/test/packetdrill/no_rst_to_rst.pkt new file mode 100644 index 000000000..612747827 --- /dev/null +++ b/test/packetdrill/no_rst_to_rst.pkt @@ -0,0 +1,36 @@ +// Test a RST is not generated in response to a RST and a RST is correctly +// generated when an accepted endpoint is RST due to an incoming RST. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0 > S. 0:0(0) ack 1 <...> ++0 < P. 1:1(0) ack 1 win 257 + ++0.100 accept(3, ..., ...) = 4 + ++0.200 < R 1:1(0) win 0 + ++0.300 read(4,..., 4) = -1 ECONNRESET (Connection Reset by Peer) + ++0.00 < . 1:1(0) ack 1 win 257 + +// Linux generates a reset with no ack number/bit set. This is contradictory to +// what is specified in Rule 1 under Reset Generation in +// https://tools.ietf.org/html/rfc793#section-3.4. +// "1. If the connection does not exist (CLOSED) then a reset is sent +// in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected +// by this means. +// +// If the incoming segment has an ACK field, the reset takes its +// sequence number from the ACK field of the segment, otherwise the +// reset has sequence number zero and the ACK field is set to the sum +// of the sequence number and segment length of the incoming segment. +// The connection remains in the CLOSED state." + ++0.00 > R 1:1(0) win 0
\ No newline at end of file diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh index 0b22dfd5c..922547d65 100755 --- a/test/packetdrill/packetdrill_test.sh +++ b/test/packetdrill/packetdrill_test.sh @@ -85,23 +85,26 @@ if [[ ! -x "${INIT_SCRIPT-}" ]]; then exit 2 fi +function new_net_prefix() { + # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. + echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)" +} + # Variables specific to the control network and interface start with CTRL_. # Variables specific to the test network and interface start with TEST_. # Variables specific to the DUT start with DUT_. # Variables specific to the test runner start with TEST_RUNNER_. declare -r PACKETDRILL="/packetdrill/gtests/net/packetdrill/packetdrill" # Use random numbers so that test networks don't collide. -declare -r CTRL_NET="ctrl_net-${RANDOM}${RANDOM}" -declare -r TEST_NET="test_net-${RANDOM}${RANDOM}" +declare CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)" +declare CTRL_NET_PREFIX=$(new_net_prefix) +declare TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)" +declare TEST_NET_PREFIX=$(new_net_prefix) declare -r tolerance_usecs=100000 # On both DUT and test runner, testing packets are on the eth2 interface. declare -r TEST_DEVICE="eth2" # Number of bits in the *_NET_PREFIX variables. declare -r NET_MASK="24" -function new_net_prefix() { - # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. - echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)" -} # Last bits of the DUT's IP address. declare -r DUT_NET_SUFFIX=".10" # Control port. @@ -137,23 +140,21 @@ function finish { trap finish EXIT # Subnet for control packets between test runner and DUT. -declare CTRL_NET_PREFIX=$(new_net_prefix) while ! docker network create \ "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do sleep 0.1 - declare CTRL_NET_PREFIX=$(new_net_prefix) + CTRL_NET_PREFIX=$(new_net_prefix) + CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)" done # Subnet for the packets that are part of the test. -declare TEST_NET_PREFIX=$(new_net_prefix) while ! docker network create \ "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do sleep 0.1 - declare TEST_NET_PREFIX=$(new_net_prefix) + TEST_NET_PREFIX=$(new_net_prefix) + TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)" done -docker pull "${IMAGE_TAG}" - # Create the DUT container and connect to network. DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG}) diff --git a/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt new file mode 100644 index 000000000..a86b90ce6 --- /dev/null +++ b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt @@ -0,0 +1,9 @@ +// Test that a listening socket generates a RST when it receives an +// ACK and syn cookies are not in use. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + ++0 listen(3, 1) = 0 ++0.1 < . 1:1(0) ack 1 win 32792 ++0 > R 1:1(0) ack 0 win 0
\ No newline at end of file diff --git a/test/packetdrill/sanity_test.pkt b/test/packetdrill/sanity_test.pkt new file mode 100644 index 000000000..b3b58c366 --- /dev/null +++ b/test/packetdrill/sanity_test.pkt @@ -0,0 +1,7 @@ +// Basic sanity test. One system call. +// +// All of the plumbing has to be working however, and the packetdrill wire +// client needs to be able to connect to the wire server and send the script, +// probe local interfaces, run through the test w/ timings, etc. + +0.000 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 diff --git a/test/packetdrill/tcp_defer_accept.pkt b/test/packetdrill/tcp_defer_accept.pkt new file mode 100644 index 000000000..a17f946db --- /dev/null +++ b/test/packetdrill/tcp_defer_accept.pkt @@ -0,0 +1,48 @@ +// Test that a bare ACK does not complete a connection when TCP_DEFER_ACCEPT +// timeout is not hit but an ACK w/ data does complete and deliver the +// connection to the accept queue. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [5], 4) = 0 ++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0.0 > S. 0:0(0) ack 1 <...> + +// Send a bare ACK this should not complete the connection as we +// set the TCP_DEFER_ACCEPT above. ++0.0 < . 1:1(0) ack 1 win 257 + +// The bare ACK should be dropped and no connection should be delivered +// to the accept queue. ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT +// to 5 seconds above. ++2.5 < . 1:1(0) ack 1 win 257 ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// set accept socket back to blocking. ++0.000 fcntl(3, F_SETFL, O_RDWR) = 0 + +// Now send an ACK w/ data. This should complete the connection +// and deliver the socket to the accept queue. ++0.1 < . 1:5(4) ack 1 win 257 ++0.0 > . 1:1(0) ack 5 <...> + +// This should cause connection to transition to connected state. ++0.000 accept(3, ..., ...) = 4 ++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 + +// Now read the data and we should get 4 bytes. ++0.000 read(4,..., 4) = 4 ++0.000 close(4) = 0 + ++0.0 > F. 1:1(0) ack 5 <...> ++0.0 < F. 5:5(0) ack 2 win 257 ++0.01 > . 2:2(0) ack 6 <...>
\ No newline at end of file diff --git a/test/packetdrill/tcp_defer_accept_timeout.pkt b/test/packetdrill/tcp_defer_accept_timeout.pkt new file mode 100644 index 000000000..201fdeb14 --- /dev/null +++ b/test/packetdrill/tcp_defer_accept_timeout.pkt @@ -0,0 +1,48 @@ +// Test that a bare ACK is accepted after TCP_DEFER_ACCEPT timeout +// is hit and a connection is delivered. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [3], 4) = 0 ++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0.0 > S. 0:0(0) ack 1 <...> + +// Send a bare ACK this should not complete the connection as we +// set the TCP_DEFER_ACCEPT above. ++0.0 < . 1:1(0) ack 1 win 257 + +// The bare ACK should be dropped and no connection should be delivered +// to the accept queue. ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT +// to 5 seconds above. ++2.5 < . 1:1(0) ack 1 win 257 ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// set accept socket back to blocking. ++0.000 fcntl(3, F_SETFL, O_RDWR) = 0 + +// We should see one more retransmit of the SYN-ACK as a last ditch +// attempt when TCP_DEFER_ACCEPT timeout is hit to trigger another +// ACK or a packet with data. ++.35~+2.35 > S. 0:0(0) ack 1 <...> + +// Now send another bare ACK after TCP_DEFER_ACCEPT time has been passed. ++0.0 < . 1:1(0) ack 1 win 257 + +// The ACK above should cause connection to transition to connected state. ++0.000 accept(3, ..., ...) = 4 ++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 + ++0.000 close(4) = 0 + ++0.0 > F. 1:1(0) ack 1 <...> ++0.0 < F. 1:1(0) ack 2 win 257 ++0.01 > . 2:2(0) ack 2 <...> diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md new file mode 100644 index 000000000..ece4dedc6 --- /dev/null +++ b/test/packetimpact/README.md @@ -0,0 +1,531 @@ +# Packetimpact + +## What is packetimpact? + +Packetimpact is a tool for platform-independent network testing. It is heavily +inspired by [packetdrill](https://github.com/google/packetdrill). It creates two +docker containers connected by a network. One is for the test bench, which +operates the test. The other is for the device-under-test (DUT), which is the +software being tested. The test bench communicates over the network with the DUT +to check correctness of the network. + +### Goals + +Packetimpact aims to provide: + +* A **multi-platform** solution that can test both Linux and gVisor. +* **Conciseness** on par with packetdrill scripts. +* **Control-flow** like for loops, conditionals, and variables. +* **Flexibilty** to specify every byte in a packet or use multiple sockets. + +## When to use packetimpact? + +There are a few ways to write networking tests for gVisor currently: + +* [Go unit tests](https://github.com/google/gvisor/tree/master/pkg/tcpip) +* [syscall tests](https://github.com/google/gvisor/tree/master/test/syscalls/linux) +* [packetdrill tests](https://github.com/google/gvisor/tree/master/test/packetdrill) +* packetimpact tests + +The right choice depends on the needs of the test. + +Feature | Go unit test | syscall test | packetdrill | packetimpact +------------- | ------------ | ------------ | ----------- | ------------ +Multiplatform | no | **YES** | **YES** | **YES** +Concise | no | somewhat | somewhat | **VERY** +Control-flow | **YES** | **YES** | no | **YES** +Flexible | **VERY** | no | somewhat | **VERY** + +### Go unit tests + +If the test depends on the internals of gVisor and doesn't need to run on Linux +or other platforms for comparison purposes, a Go unit test can be appropriate. +They can observe internals of gVisor networking. The downside is that they are +**not concise** and **not multiplatform**. If you require insight on gVisor +internals, this is the right choice. + +### Syscall tests + +Syscall tests are **multiplatform** but cannot examine the internals of gVisor +networking. They are **concise**. They can use **control-flow** structures like +conditionals, for loops, and variables. However, they are limited to only what +the POSIX interface provides so they are **not flexible**. For example, you +would have difficulty writing a syscall test that intentionally sends a bad IP +checksum. Or if you did write that test with raw sockets, it would be very +**verbose** to write a test that intentionally send wrong checksums, wrong +protocols, wrong sequence numbers, etc. + +### Packetdrill tests + +Packetdrill tests are **multiplatform** and can run against both Linux and +gVisor. They are **concise** and use a special packetdrill scripting language. +They are **more flexible** than a syscall test in that they can send packets +that a syscall test would have difficulty sending, like a packet with a +calcuated ACK number. But they are also somewhat limimted in flexibiilty in that +they can't do tests with multiple sockets. They have **no control-flow** ability +like variables or conditionals. For example, it isn't possible to send a packet +that depends on the window size of a previous packet because the packetdrill +language can't express that. Nor could you branch based on whether or not the +other side supports window scaling, for example. + +### Packetimpact tests + +Packetimpact tests are similar to Packetdrill tests except that they are written +in Go instead of the packetdrill scripting language. That gives them all the +**control-flow** abilities of Go (loops, functions, variables, etc). They are +**multiplatform** in the same way as packetdrill tests but even more +**flexible** because Go is more expressive than the scripting language of +packetdrill. However, Go is **not as concise** as the packetdrill language. Many +design decisions below are made to mitigate that. + +## How it works + +``` + +--------------+ +--------------+ + | | TEST NET | | + | | <===========> | Device | + | Test | | Under | + | Bench | | Test | + | | <===========> | (DUT) | + | | CONTROL NET | | + +--------------+ +--------------+ +``` + +Two docker containers are created by a script, one for the test bench and the +other for the device under test (DUT). The script connects the two containers +with a control network and test network. It also does some other tasks like +waiting until the DUT is ready before starting the test and disabling Linux +networking that would interfere with the test bench. + +### DUT + +The DUT container runs a program called the "posix_server". The posix_server is +written in c++ for maximum portability. It is compiled on the host. The script +that starts the containers copies it into the DUT's container and runs it. It's +job is to receive directions from the test bench on what actions to take. For +this, the posix_server does three steps in a loop: + +1. Listen for a request from the test bench. +2. Execute a command. +3. Send the response back to the test bench. + +The requests and responses are +[protobufs](https://developers.google.com/protocol-buffers) and the +communication is done with [gRPC](https://grpc.io/). The commands run are +[POSIX socket commands](https://en.wikipedia.org/wiki/Berkeley_sockets#Socket_API_functions), +with the inputs and outputs converted into protobuf requests and responses. All +communication is on the control network, so that the test network is unaffected +by extra packets. + +For example, this is the request and response pair to call +[`socket()`](http://man7.org/linux/man-pages/man2/socket.2.html): + +```protocol-buffer +message SocketRequest { + int32 domain = 1; + int32 type = 2; + int32 protocol = 3; +} + +message SocketResponse { + int32 fd = 1; + int32 errno_ = 2; +} +``` + +##### Alternatives considered + +* We could have use JSON for communication instead. It would have been a + lighter-touch than protobuf but protobuf handles all the data type and has + strict typing to prevent a class of errors. The test bench could be written + in other languages, too. +* Instead of mimicking the POSIX interfaces, arguments could have had a more + natural form, like the `bind()` getting a string IP address instead of bytes + in a `sockaddr_t`. However, conforming to the existing structures keeps more + of the complexity in Go and keeps the posix_server simpler and thus more + likely to compile everywhere. + +### Test Bench + +The test bench does most of the work in a test. It is a Go program that compiles +on the host and is copied by the script into test bench's container. It is a +regular [go unit test](https://golang.org/pkg/testing/) that imports the test +bench framework. The test bench framwork is based on three basic utilities: + +* Commanding the DUT to run POSIX commands and return responses. +* Sending raw packets to the DUT on the test network. +* Listening for raw packets from the DUT on the test network. + +#### DUT commands + +To keep the interface to the DUT consistent and easy-to-use, each POSIX command +supported by the posix_server is wrapped in functions with signatures similar to +the ones in the [Go unix package](https://godoc.org/golang.org/x/sys/unix). This +way all the details of endianess and (un)marshalling of go structs such as +[unix.Timeval](https://godoc.org/golang.org/x/sys/unix#Timeval) is handled in +one place. This also makes it straight-forward to convert tests that use `unix.` +or `syscall.` calls to `dut.` calls. + +For example, creating a connection to the DUT and commanding it to make a socket +looks like this: + +```go +dut := testbench.NewDut(t) +fd, err := dut.SocketWithErrno(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP) +if fd < 0 { + t.Fatalf(...) +} +``` + +Because the usual case is to fail the test when the DUT fails to create a +socket, there is a concise version of each of the `...WithErrno` functions that +does that: + +```go +dut := testbench.NewDut(t) +fd := dut.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP) +``` + +The DUT and other structs in the code store a `*testing.T` so that they can +provide versions of functions that call `t.Fatalf(...)`. This helps keep tests +concise. + +##### Alternatives considered + +* Instead of mimicking the `unix.` go interface, we could have invented a more + natural one, like using `float64` instead of `Timeval`. However, using the + same function signatures that `unix.` has makes it easier to convert code to + `dut.`. Also, using an existing interface ensures that we don't invent an + interface that isn't extensible. For example, if we invented a function for + `bind()` that didn't support IPv6 and later we had to add a second `bind6()` + function. + +#### Sending/Receiving Raw Packets + +The framework wraps POSIX sockets for sending and receiving raw frames. Both +send and receive are synchronous commands. +[SO_RCVTIMEO](http://man7.org/linux/man-pages/man7/socket.7.html) is used to set +a timeout on the receive commands. For ease of use, these are wrapped in an +`Injector` and a `Sniffer`. They have functions: + +```go +func (s *Sniffer) Recv(timeout time.Duration) []byte {...} +func (i *Injector) Send(b []byte) {...} +``` + +##### Alternatives considered + +* [gopacket](https://github.com/google/gopacket) pcap has raw socket support + but requires cgo. cgo is not guaranteed to be portable from the host to the + container and in practice, the container doesn't recognize binaries built on + the host if they use cgo. +* Both gVisor and gopacket have the ability to read and write pcap files + without cgo but that is insufficient here. +* The sniffer and injector can't share a socket because they need to be bound + differently. +* Sniffing could have been done asynchronously with channels, obviating the + need for `SO_RCVTIMEO`. But that would introduce asynchronous complication. + `SO_RCVTIMEO` is well supported on the test bench. + +#### `Layer` struct + +A large part of packetimpact tests is creating packets to send and comparing +received packets against expectations. To keep tests concise, it is useful to be +able to specify just the important parts of packets that need to be set. For +example, sending a packet with default values except for TCP Flags. And for +packets received, it's useful to be able to compare just the necessary parts of +received packets and ignore the rest. + +To aid in both of those, Go structs with optional fields are created for each +encapsulation type, such as IPv4, TCP, and Ethernet. This is inspired by +[scapy](https://scapy.readthedocs.io/en/latest/). For example, here is the +struct for Ethernet: + +```go +type Ether struct { + LayerBase + SrcAddr *tcpip.LinkAddress + DstAddr *tcpip.LinkAddress + Type *tcpip.NetworkProtocolNumber +} +``` + +Each struct has the same fields as those in the +[gVisor headers](https://github.com/google/gvisor/tree/master/pkg/tcpip/header) +but with a pointer for each field that may be `nil`. + +##### Alternatives considered + +* Just use []byte like gVisor headers do. The drawback is that it makes the + tests more verbose. + * For example, there would be no way to call `Send(myBytes)` concisely and + indicate if the checksum should be calculated automatically versus + overridden. The only way would be to add lines to the test to calculate + it before each Send, which is wordy. Or make multiple versions of Send: + one that checksums IP, one that doesn't, one that checksums TCP, one + that does both, etc. That would be many combinations. + * Filtering inputs would become verbose. Either: + * large conditionals that need to be repeated many places: + `h[FlagOffset] == SYN && h[LengthOffset:LengthOffset+2] == ...` or + * Many functions, one per field, like: `filterByFlag(myBytes, SYN)`, + `filterByLength(myBytes, 20)`, `filterByNextProto(myBytes, 0x8000)`, + etc. + * Using pointers allows us to combine `Layer`s with a one-line call to + `mergo.Merge(...)`. So the default `Layers` can be overridden by a + `Layers` with just the TCP conection's src/dst which can be overridden + by one with just a test specific TCP window size. Each override is + specified as just one call to `mergo.Merge`. + * It's a proven way to separate the details of a packet from the byte + format as shown by scapy's success. +* Use packetgo. It's more general than parsing packets with gVisor. However: + * packetgo doesn't have optional fields so many of the above problems + still apply. + * It would be yet another dependency. + * It's not as well known to engineers that are already writing gVisor + code. + * It might be a good candidate for replacing the parsing of packets into + `Layer`s if all that parsing turns out to be more work than parsing by + packetgo and converting *that* to `Layer`. packetgo has easier to use + getters for the layers. This could be done later in a way that doesn't + break tests. + +#### `Layer` methods + +The `Layer` structs provide a way to partially specify an encapsulation. They +also need methods for using those partially specified encapsulation, for example +to marshal them to bytes or compare them. For those, each encapsulation +implements the `Layer` interface: + +```go +// Layer is the interface that all encapsulations must implement. +// +// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A +// Layer contains all the fields of the encapsulation. Each field is a pointer +// and may be nil. +type Layer interface { + // toBytes converts the Layer into bytes. In places where the Layer's field + // isn't nil, the value that is pointed to is used. When the field is nil, a + // reasonable default for the Layer is used. For example, "64" for IPv4 TTL + // and a calculated checksum for TCP or IP. Some layers require information + // from the previous or next layers in order to compute a default, such as + // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list + // to the layer's neighbors. + toBytes() ([]byte, error) + + // match checks if the current Layer matches the provided Layer. If either + // Layer has a nil in a given field, that field is considered matching. + // Otherwise, the values pointed to by the fields must match. + match(Layer) bool + + // length in bytes of the current encapsulation + length() int + + // next gets a pointer to the encapsulated Layer. + next() Layer + + // prev gets a pointer to the Layer encapsulating this one. + prev() Layer + + // setNext sets the pointer to the encapsulated Layer. + setNext(Layer) + + // setPrev sets the pointer to the Layer encapsulating this one. + setPrev(Layer) +} +``` + +For each `Layer` there is also a parsing function. For example, this one is for +Ethernet: + +``` +func ParseEther(b []byte) (Layers, error) +``` + +The parsing function converts bytes received on the wire into a `Layer` +(actually `Layers`, see below) which has no `nil`s in it. By using +`match(Layer)` to compare against another `Layer` that *does* have `nil`s in it, +the received bytes can be partially compared. The `nil`s behave as +"don't-cares". + +##### Alternatives considered + +* Matching against `[]byte` instead of converting to `Layer` first. + * The downside is that it precludes the use of a `cmp.Equal` one-liner to + do comparisons. + * It creates confusion in the code to deal with both representations at + different times. For example, is the checksum calculated on `[]byte` or + `Layer` when sending? What about when checking received packets? + +#### `Layers` + +``` +type Layers []Layer + +func (ls *Layers) match(other Layers) bool {...} +func (ls *Layers) toBytes() ([]byte, error) {...} +``` + +`Layers` is an array of `Layer`. It represents a stack of encapsulations, such +as `Layers{Ether{},IPv4{},TCP{},Payload{}}`. It also has `toBytes()` and +`match(Layers)`, like `Layer`. The parse functions above actually return +`Layers` and not `Layer` because they know about the headers below and +sequentially call each parser on the remaining, encapsulated bytes. + +All this leads to the ability to write concise packet processing. For example: + +```go +etherType := 0x8000 +flags = uint8(header.TCPFlagSyn|header.TCPFlagAck) +toMatch := Layers{Ether{Type: ðerType}, IPv4{}, TCP{Flags: &flags}} +for { + recvBytes := sniffer.Recv(time.Second) + if recvBytes == nil { + println("Got no packet for 1 second") + } + gotPacket, err := ParseEther(recvBytes) + if err == nil && toMatch.match(gotPacket) { + println("Got a TCP/IPv4/Eth packet with SYNACK") + } +} +``` + +##### Alternatives considered + +* Don't use previous and next pointers. + * Each layer may need to be able to interrogate the layers aroung it, like + for computing the next protocol number or total length. So *some* + mechanism is needed for a `Layer` to see neighboring layers. + * We could pass the entire array `Layers` to the `toBytes()` function. + Passing an array to a method that includes in the array the function + receiver itself seems wrong. + +#### Connections + +Using `Layers` above, we can create connection structures to maintain state +about connections. For example, here is the `TCPIPv4` struct: + +``` +type TCPIPv4 struct { + outgoing Layers + incoming Layers + localSeqNum uint32 + remoteSeqNum uint32 + sniffer Sniffer + injector Injector + t *testing.T +} +``` + +`TCPIPv4` contains an `outgoing Layers` which holds the defaults for the +connection, such as the source and destination MACs, IPs, and ports. When +`outgoing.toBytes()` is called a valid packet for this TCPIPv4 flow is built. + +It also contains `incoming Layers` which holds filter for incoming packets that +belong to this flow. `incoming.match(Layers)` is used on received bytes to check +if they are part of the flow. + +The `sniffer` and `injector` are for receiving and sending raw packet bytes. The +`localSeqNum` and `remoteSeqNum` are updated by `Send` and `Recv` so that +outgoing packets will have, by default, the correct sequence number and ack +number. + +TCPIPv4 provides some functions: + +``` +func (conn *TCPIPv4) Send(tcp TCP) {...} +func (conn *TCPIPv4) Recv(timeout time.Duration) *TCP {...} +``` + +`Send(tcp TCP)` uses [mergo](https://github.com/imdario/mergo) to merge the +provided `TCP` (a `Layer`) into `outgoing`. This way the user can specify +concisely just which fields of `outgoing` to modify. The packet is sent using +the `injector`. + +`Recv(timeout time.Duration)` reads packets from the sniffer until either the +timeout has elapsed or a packet that matches `incoming` arrives. + +Using those, we can perform a TCP 3-way handshake without too much code: + +```go +func (conn *TCPIPv4) Handshake() { + syn := uint8(header.TCPFlagSyn) + synack := uint8(header.TCPFlagSyn) + ack := uint8(header.TCPFlagAck) + conn.Send(TCP{Flags: &syn}) // Send a packet with all defaults but set TCP-SYN. + + // Wait for the SYN-ACK response. + for { + newTCP := conn.Recv(time.Second) // This already filters by MAC, IP, and ports. + if TCP{Flags: &synack}.match(newTCP) { + break // Only if it's a SYN-ACK proceed. + } + } + + conn.Send(TCP{Flags: &ack}) // Send an ACK. The seq and ack numbers are set correctly. +} +``` + +The handshake code is part of the testbench utilities so tests can share this +common sequence, making tests even more concise. + +##### Alternatives considered + +* Instead of storing `outgoing` and `incoming`, store values. + * There would be many more things to store instead, like `localMac`, + `remoteMac`, `localIP`, `remoteIP`, `localPort`, and `remotePort`. + * Construction of a packet would be many lines to copy each of these + values into a `[]byte`. And there would be slight variations needed for + each encapsulation stack, like TCPIPv6 and ARP. + * Filtering incoming packets would be a long sequence: + * Compare the MACs, then + * Parse the next header, then + * Compare the IPs, then + * Parse the next header, then + * Compare the TCP ports. Instead it's all just one call to + `cmp.Equal(...)`, for all sequences. + * A TCPIPv6 connection could share most of the code. Only the type of the + IP addresses are different. The types of `outgoing` and `incoming` would + be remain `Layers`. + * An ARP connection could share all the Ethernet parts. The IP `Layer` + could be factored out of `outgoing`. After that, the IPv4 and IPv6 + connections could implement one interface and a single TCP struct could + have either network protocol through composition. + +## Putting it all together + +Here's what te start of a packetimpact unit test looks like. This test creates a +TCP connection with the DUT. There are added comments for explanation in this +document but a real test might not include them in order to stay even more +concise. + +```go +func TestMyTcpTest(t *testing.T) { + // Prepare a DUT for communication. + dut := testbench.NewDUT(t) + + // This does: + // dut.Socket() + // dut.Bind() + // dut.Getsockname() to learn the new port number + // dut.Listen() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) // Tell the DUT to close the socket at the end of the test. + + // Monitor a new TCP connection with sniffer, injector, sequence number tracking, + // and reasonable outgoing and incoming packet field default IPs, MACs, and port numbers. + conn := testbench.NewTCPIPv4(t, dut, remotePort) + + // Perform a 3-way handshake: send SYN, expect SYNACK, send ACK. + conn.Handshake() + + // Tell the DUT to accept the new connection. + acceptFD := dut.Accept(acceptFd) +} +``` + +## Other notes + +* The time between receiving a SYN-ACK and replying with an ACK in `Handshake` + is about 3ms. This is much slower than the native unix response, which is + about 0.3ms. Packetdrill gets closer to 0.3ms. For tests where timing is + crucial, packetdrill is faster and more precise. diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD new file mode 100644 index 000000000..3ce63c2c6 --- /dev/null +++ b/test/packetimpact/dut/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "cc_binary", "grpcpp") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +cc_binary( + name = "posix_server", + srcs = ["posix_server.cc"], + linkstatic = 1, + static = True, # This is needed for running in a docker container. + deps = [ + grpcpp, + "//test/packetimpact/proto:posix_server_cc_grpc_proto", + "//test/packetimpact/proto:posix_server_cc_proto", + ], +) diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc new file mode 100644 index 000000000..86e580c6f --- /dev/null +++ b/test/packetimpact/dut/posix_server.cc @@ -0,0 +1,260 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <getopt.h> +#include <netdb.h> +#include <netinet/in.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <iostream> +#include <unordered_map> + +#include "arpa/inet.h" +#include "include/grpcpp/security/server_credentials.h" +#include "include/grpcpp/server_builder.h" +#include "test/packetimpact/proto/posix_server.grpc.pb.h" +#include "test/packetimpact/proto/posix_server.pb.h" + +// Converts a sockaddr_storage to a Sockaddr message. +::grpc::Status sockaddr_to_proto(const sockaddr_storage &addr, + socklen_t addrlen, + posix_server::Sockaddr *sockaddr_proto) { + switch (addr.ss_family) { + case AF_INET: { + auto addr_in = reinterpret_cast<const sockaddr_in *>(&addr); + auto response_in = sockaddr_proto->mutable_in(); + response_in->set_family(addr_in->sin_family); + response_in->set_port(ntohs(addr_in->sin_port)); + response_in->mutable_addr()->assign( + reinterpret_cast<const char *>(&addr_in->sin_addr.s_addr), 4); + return ::grpc::Status::OK; + } + case AF_INET6: { + auto addr_in6 = reinterpret_cast<const sockaddr_in6 *>(&addr); + auto response_in6 = sockaddr_proto->mutable_in6(); + response_in6->set_family(addr_in6->sin6_family); + response_in6->set_port(ntohs(addr_in6->sin6_port)); + response_in6->set_flowinfo(ntohl(addr_in6->sin6_flowinfo)); + response_in6->mutable_addr()->assign( + reinterpret_cast<const char *>(&addr_in6->sin6_addr.s6_addr), 16); + response_in6->set_scope_id(ntohl(addr_in6->sin6_scope_id)); + return ::grpc::Status::OK; + } + } + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown Sockaddr"); +} + +class PosixImpl final : public posix_server::Posix::Service { + ::grpc::Status Accept(grpc_impl::ServerContext *context, + const ::posix_server::AcceptRequest *request, + ::posix_server::AcceptResponse *response) override { + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + response->set_fd(accept(request->sockfd(), + reinterpret_cast<sockaddr *>(&addr), &addrlen)); + response->set_errno_(errno); + return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); + } + + ::grpc::Status Bind(grpc_impl::ServerContext *context, + const ::posix_server::BindRequest *request, + ::posix_server::BindResponse *response) override { + if (!request->has_addr()) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Missing address"); + } + sockaddr_storage addr; + + switch (request->addr().sockaddr_case()) { + case posix_server::Sockaddr::SockaddrCase::kIn: { + auto request_in = request->addr().in(); + if (request_in.addr().size() != 4) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "IPv4 address must be 4 bytes"); + } + auto addr_in = reinterpret_cast<sockaddr_in *>(&addr); + addr_in->sin_family = request_in.family(); + addr_in->sin_port = htons(request_in.port()); + request_in.addr().copy( + reinterpret_cast<char *>(&addr_in->sin_addr.s_addr), 4); + break; + } + case posix_server::Sockaddr::SockaddrCase::kIn6: { + auto request_in6 = request->addr().in6(); + if (request_in6.addr().size() != 16) { + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "IPv6 address must be 16 bytes"); + } + auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(&addr); + addr_in6->sin6_family = request_in6.family(); + addr_in6->sin6_port = htons(request_in6.port()); + addr_in6->sin6_flowinfo = htonl(request_in6.flowinfo()); + request_in6.addr().copy( + reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16); + addr_in6->sin6_scope_id = htonl(request_in6.scope_id()); + break; + } + case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET: + default: + return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, + "Unknown Sockaddr"); + } + response->set_ret(bind(request->sockfd(), + reinterpret_cast<sockaddr *>(&addr), sizeof(addr))); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Close(grpc_impl::ServerContext *context, + const ::posix_server::CloseRequest *request, + ::posix_server::CloseResponse *response) override { + response->set_ret(close(request->fd())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status GetSockName( + grpc_impl::ServerContext *context, + const ::posix_server::GetSockNameRequest *request, + ::posix_server::GetSockNameResponse *response) override { + sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + response->set_ret(getsockname( + request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen)); + response->set_errno_(errno); + return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); + } + + ::grpc::Status Listen(grpc_impl::ServerContext *context, + const ::posix_server::ListenRequest *request, + ::posix_server::ListenResponse *response) override { + response->set_ret(listen(request->sockfd(), request->backlog())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Send(::grpc::ServerContext *context, + const ::posix_server::SendRequest *request, + ::posix_server::SendResponse *response) override { + response->set_ret(::send(request->sockfd(), request->buf().data(), + request->buf().size(), request->flags())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status SetSockOpt( + grpc_impl::ServerContext *context, + const ::posix_server::SetSockOptRequest *request, + ::posix_server::SetSockOptResponse *response) override { + response->set_ret(setsockopt(request->sockfd(), request->level(), + request->optname(), request->optval().c_str(), + request->optval().size())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status SetSockOptInt( + ::grpc::ServerContext *context, + const ::posix_server::SetSockOptIntRequest *request, + ::posix_server::SetSockOptIntResponse *response) override { + int opt = request->intval(); + response->set_ret(::setsockopt(request->sockfd(), request->level(), + request->optname(), &opt, sizeof(opt))); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status SetSockOptTimeval( + ::grpc::ServerContext *context, + const ::posix_server::SetSockOptTimevalRequest *request, + ::posix_server::SetSockOptTimevalResponse *response) override { + timeval tv = {.tv_sec = static_cast<__time_t>(request->timeval().seconds()), + .tv_usec = static_cast<__suseconds_t>( + request->timeval().microseconds())}; + response->set_ret(setsockopt(request->sockfd(), request->level(), + request->optname(), &tv, sizeof(tv))); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Socket(grpc_impl::ServerContext *context, + const ::posix_server::SocketRequest *request, + ::posix_server::SocketResponse *response) override { + response->set_fd( + socket(request->domain(), request->type(), request->protocol())); + response->set_errno_(errno); + return ::grpc::Status::OK; + } + + ::grpc::Status Recv(::grpc::ServerContext *context, + const ::posix_server::RecvRequest *request, + ::posix_server::RecvResponse *response) override { + std::vector<char> buf(request->len()); + response->set_ret( + recv(request->sockfd(), buf.data(), buf.size(), request->flags())); + response->set_errno_(errno); + response->set_buf(buf.data(), response->ret()); + return ::grpc::Status::OK; + } +}; + +// Parse command line options. Returns a pointer to the first argument beyond +// the options. +void parse_command_line_options(int argc, char *argv[], std::string *ip, + int *port) { + static struct option options[] = {{"ip", required_argument, NULL, 1}, + {"port", required_argument, NULL, 2}, + {0, 0, 0, 0}}; + + // Parse the arguments. + int c; + while ((c = getopt_long(argc, argv, "", options, NULL)) > 0) { + if (c == 1) { + *ip = optarg; + } else if (c == 2) { + *port = std::stoi(std::string(optarg)); + } + } +} + +void run_server(const std::string &ip, int port) { + PosixImpl posix_service; + grpc::ServerBuilder builder; + std::string server_address = ip + ":" + std::to_string(port); + // Set the authentication mechanism. + std::shared_ptr<grpc::ServerCredentials> creds = + grpc::InsecureServerCredentials(); + builder.AddListeningPort(server_address, creds); + builder.RegisterService(&posix_service); + + std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); + std::cerr << "Server listening on " << server_address << std::endl; + server->Wait(); + std::cerr << "posix_server is finished." << std::endl; +} + +int main(int argc, char *argv[]) { + std::cerr << "posix_server is starting." << std::endl; + std::string ip; + int port; + parse_command_line_options(argc, argv, &ip, &port); + + std::cerr << "Got IP " << ip << " and port " << port << "." << std::endl; + run_server(ip, port); +} diff --git a/test/packetimpact/proto/BUILD b/test/packetimpact/proto/BUILD new file mode 100644 index 000000000..4a4370f42 --- /dev/null +++ b/test/packetimpact/proto/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "proto_library") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +proto_library( + name = "posix_server", + srcs = ["posix_server.proto"], + has_services = 1, +) diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto new file mode 100644 index 000000000..4035e1ee6 --- /dev/null +++ b/test/packetimpact/proto/posix_server.proto @@ -0,0 +1,193 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package posix_server; + +message SockaddrIn { + int32 family = 1; + uint32 port = 2; + bytes addr = 3; +} + +message SockaddrIn6 { + uint32 family = 1; + uint32 port = 2; + uint32 flowinfo = 3; + bytes addr = 4; + uint32 scope_id = 5; +} + +message Sockaddr { + oneof sockaddr { + SockaddrIn in = 1; + SockaddrIn6 in6 = 2; + } +} + +message Timeval { + int64 seconds = 1; + int64 microseconds = 2; +} + +// Request and Response pairs for each Posix service RPC call, sorted. + +message AcceptRequest { + int32 sockfd = 1; +} + +message AcceptResponse { + int32 fd = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + Sockaddr addr = 3; +} + +message BindRequest { + int32 sockfd = 1; + Sockaddr addr = 2; +} + +message BindResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message CloseRequest { + int32 fd = 1; +} + +message CloseResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message GetSockNameRequest { + int32 sockfd = 1; +} + +message GetSockNameResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + Sockaddr addr = 3; +} + +message ListenRequest { + int32 sockfd = 1; + int32 backlog = 2; +} + +message ListenResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SendRequest { + int32 sockfd = 1; + bytes buf = 2; + int32 flags = 3; +} + +message SendResponse { + int32 ret = 1; + int32 errno_ = 2; +} + +message SetSockOptRequest { + int32 sockfd = 1; + int32 level = 2; + int32 optname = 3; + bytes optval = 4; +} + +message SetSockOptResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SetSockOptIntRequest { + int32 sockfd = 1; + int32 level = 2; + int32 optname = 3; + int32 intval = 4; +} + +message SetSockOptIntResponse { + int32 ret = 1; + int32 errno_ = 2; +} + +message SetSockOptTimevalRequest { + int32 sockfd = 1; + int32 level = 2; + int32 optname = 3; + Timeval timeval = 4; +} + +message SetSockOptTimevalResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message SocketRequest { + int32 domain = 1; + int32 type = 2; + int32 protocol = 3; +} + +message SocketResponse { + int32 fd = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. +} + +message RecvRequest { + int32 sockfd = 1; + int32 len = 2; + int32 flags = 3; +} + +message RecvResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + bytes buf = 3; +} + +service Posix { + // Call accept() on the DUT. + rpc Accept(AcceptRequest) returns (AcceptResponse); + // Call bind() on the DUT. + rpc Bind(BindRequest) returns (BindResponse); + // Call close() on the DUT. + rpc Close(CloseRequest) returns (CloseResponse); + // Call getsockname() on the DUT. + rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse); + // Call listen() on the DUT. + rpc Listen(ListenRequest) returns (ListenResponse); + // Call send() on the DUT. + rpc Send(SendRequest) returns (SendResponse); + // Call setsockopt() on the DUT. You should prefer one of the other + // SetSockOpt* functions with a more structured optval or else you may get the + // encoding wrong, such as making a bad assumption about the server's word + // sizes or endianness. + rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse); + // Call setsockopt() on the DUT with an int optval. + rpc SetSockOptInt(SetSockOptIntRequest) returns (SetSockOptIntResponse); + // Call setsockopt() on the DUT with a Timeval optval. + rpc SetSockOptTimeval(SetSockOptTimevalRequest) + returns (SetSockOptTimevalResponse); + // Call socket() on the DUT. + rpc Socket(SocketRequest) returns (SocketResponse); + // Call recv() on the DUT. + rpc Recv(RecvRequest) returns (RecvResponse); +} diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD new file mode 100644 index 000000000..b6a254882 --- /dev/null +++ b/test/packetimpact/testbench/BUILD @@ -0,0 +1,41 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +go_library( + name = "testbench", + srcs = [ + "connections.go", + "dut.go", + "dut_client.go", + "layers.go", + "rawsockets.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//pkg/usermem", + "//test/packetimpact/proto:posix_server_go_proto", + "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", + "@com_github_imdario_mergo//:go_default_library", + "@com_github_mohae_deepcopy//:go_default_library", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//keepalive:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + "@org_uber_go_multierr//:go_default_library", + ], +) + +go_test( + name = "testbench_test", + size = "small", + srcs = ["layers_test.go"], + library = ":testbench", + deps = ["//pkg/tcpip"], +) diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go new file mode 100644 index 000000000..952a717e0 --- /dev/null +++ b/test/packetimpact/testbench/connections.go @@ -0,0 +1,692 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testbench has utilities to send and receive packets and also command +// the DUT to run POSIX functions. +package testbench + +import ( + "flag" + "fmt" + "math/rand" + "net" + "strings" + "testing" + "time" + + "github.com/mohae/deepcopy" + "go.uber.org/multierr" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +var localIPv4 = flag.String("local_ipv4", "", "local IPv4 address for test packets") +var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test packets") +var localMAC = flag.String("local_mac", "", "local mac address for test packets") +var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets") + +// pickPort makes a new socket and returns the socket FD and port. The caller +// must close the FD when done with the port if there is no error. +func pickPort() (int, uint16, error) { + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0) + if err != nil { + return -1, 0, err + } + var sa unix.SockaddrInet4 + copy(sa.Addr[0:4], net.ParseIP(*localIPv4).To4()) + if err := unix.Bind(fd, &sa); err != nil { + unix.Close(fd) + return -1, 0, err + } + newSockAddr, err := unix.Getsockname(fd) + if err != nil { + unix.Close(fd) + return -1, 0, err + } + newSockAddrInet4, ok := newSockAddr.(*unix.SockaddrInet4) + if !ok { + unix.Close(fd) + return -1, 0, fmt.Errorf("can't cast Getsockname result to SockaddrInet4") + } + return fd, uint16(newSockAddrInet4.Port), nil +} + +// layerState stores the state of a layer of a connection. +type layerState interface { + // outgoing returns an outgoing layer to be sent in a frame. + outgoing() Layer + + // incoming creates an expected Layer for comparing against a received Layer. + // Because the expectation can depend on values in the received Layer, it is + // an input to incoming. For example, the ACK number needs to be checked in a + // TCP packet but only if the ACK flag is set in the received packet. The + // calles takes ownership of the returned Layer. + incoming(received Layer) Layer + + // sent updates the layerState based on the Layer that was sent. The input is + // a Layer with all prev and next pointers populated so that the entire frame + // as it was sent is available. + sent(sent Layer) error + + // received updates the layerState based on a Layer that is receieved. The + // input is a Layer with all prev and next pointers populated so that the + // entire frame as it was receieved is available. + received(received Layer) error + + // close frees associated resources held by the LayerState. + close() error +} + +// etherState maintains state about an Ethernet connection. +type etherState struct { + out, in Ether +} + +var _ layerState = (*etherState)(nil) + +// newEtherState creates a new etherState. +func newEtherState(out, in Ether) (*etherState, error) { + lMAC, err := tcpip.ParseMACAddress(*localMAC) + if err != nil { + return nil, err + } + + rMAC, err := tcpip.ParseMACAddress(*remoteMAC) + if err != nil { + return nil, err + } + s := etherState{ + out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC}, + in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *etherState) outgoing() Layer { + return &s.out +} + +// incoming implements layerState.incoming. +func (s *etherState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*etherState) sent(Layer) error { + return nil +} + +func (*etherState) received(Layer) error { + return nil +} + +func (*etherState) close() error { + return nil +} + +// ipv4State maintains state about an IPv4 connection. +type ipv4State struct { + out, in IPv4 +} + +var _ layerState = (*ipv4State)(nil) + +// newIPv4State creates a new ipv4State. +func newIPv4State(out, in IPv4) (*ipv4State, error) { + lIP := tcpip.Address(net.ParseIP(*localIPv4).To4()) + rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4()) + s := ipv4State{ + out: IPv4{SrcAddr: &lIP, DstAddr: &rIP}, + in: IPv4{SrcAddr: &rIP, DstAddr: &lIP}, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *ipv4State) outgoing() Layer { + return &s.out +} + +// incoming implements layerState.incoming. +func (s *ipv4State) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*ipv4State) sent(Layer) error { + return nil +} + +func (*ipv4State) received(Layer) error { + return nil +} + +func (*ipv4State) close() error { + return nil +} + +// tcpState maintains state about a TCP connection. +type tcpState struct { + out, in TCP + localSeqNum, remoteSeqNum *seqnum.Value + synAck *TCP + portPickerFD int + finSent bool +} + +var _ layerState = (*tcpState)(nil) + +// SeqNumValue is a helper routine that allocates a new seqnum.Value value to +// store v and returns a pointer to it. +func SeqNumValue(v seqnum.Value) *seqnum.Value { + return &v +} + +// newTCPState creates a new TCPState. +func newTCPState(out, in TCP) (*tcpState, error) { + portPickerFD, localPort, err := pickPort() + if err != nil { + return nil, err + } + s := tcpState{ + out: TCP{SrcPort: &localPort}, + in: TCP{DstPort: &localPort}, + localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())), + portPickerFD: portPickerFD, + finSent: false, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *tcpState) outgoing() Layer { + newOutgoing := deepcopy.Copy(s.out).(TCP) + if s.localSeqNum != nil { + newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum)) + } + if s.remoteSeqNum != nil { + newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum)) + } + return &newOutgoing +} + +// incoming implements layerState.incoming. +func (s *tcpState) incoming(received Layer) Layer { + tcpReceived, ok := received.(*TCP) + if !ok { + return nil + } + newIn := deepcopy.Copy(s.in).(TCP) + if s.remoteSeqNum != nil { + newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum)) + } + if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 { + // The caller didn't specify an AckNum so we'll expect the calculated one, + // but only if the ACK flag is set because the AckNum is not valid in a + // header if ACK is not set. + newIn.AckNum = Uint32(uint32(*s.localSeqNum)) + } + return &newIn +} + +func (s *tcpState) sent(sent Layer) error { + tcp, ok := sent.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", sent) + } + if !s.finSent { + // update localSeqNum by the payload only when FIN is not yet sent by us + for current := tcp.next(); current != nil; current = current.next() { + s.localSeqNum.UpdateForward(seqnum.Size(current.length())) + } + } + if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.localSeqNum.UpdateForward(1) + } + if *tcp.Flags&(header.TCPFlagFin) != 0 { + s.finSent = true + } + return nil +} + +func (s *tcpState) received(l Layer) error { + tcp, ok := l.(*TCP) + if !ok { + return fmt.Errorf("can't update tcpState with %T Layer", l) + } + s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum)) + if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { + s.remoteSeqNum.UpdateForward(1) + } + for current := tcp.next(); current != nil; current = current.next() { + s.remoteSeqNum.UpdateForward(seqnum.Size(current.length())) + } + return nil +} + +// close frees the port associated with this connection. +func (s *tcpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err + } + s.portPickerFD = -1 + return nil +} + +// udpState maintains state about a UDP connection. +type udpState struct { + out, in UDP + portPickerFD int +} + +var _ layerState = (*udpState)(nil) + +// newUDPState creates a new udpState. +func newUDPState(out, in UDP) (*udpState, error) { + portPickerFD, localPort, err := pickPort() + if err != nil { + return nil, err + } + s := udpState{ + out: UDP{SrcPort: &localPort}, + in: UDP{DstPort: &localPort}, + portPickerFD: portPickerFD, + } + if err := s.out.merge(&out); err != nil { + return nil, err + } + if err := s.in.merge(&in); err != nil { + return nil, err + } + return &s, nil +} + +func (s *udpState) outgoing() Layer { + return &s.out +} + +// incoming implements layerState.incoming. +func (s *udpState) incoming(Layer) Layer { + return deepcopy.Copy(&s.in).(Layer) +} + +func (*udpState) sent(l Layer) error { + return nil +} + +func (*udpState) received(l Layer) error { + return nil +} + +// close frees the port associated with this connection. +func (s *udpState) close() error { + if err := unix.Close(s.portPickerFD); err != nil { + return err + } + s.portPickerFD = -1 + return nil +} + +// Connection holds a collection of layer states for maintaining a connection +// along with sockets for sniffer and injecting packets. +type Connection struct { + layerStates []layerState + injector Injector + sniffer Sniffer + t *testing.T +} + +// match tries to match each Layer in received against the incoming filter. If +// received is longer than layerStates then that may still count as a match. The +// reverse is never a match. override overrides the default matchers for each +// Layer. +func (conn *Connection) match(override, received Layers) bool { + var layersToMatch int + if len(override) < len(conn.layerStates) { + layersToMatch = len(conn.layerStates) + } else { + layersToMatch = len(override) + } + if len(received) < layersToMatch { + return false + } + for i := 0; i < layersToMatch; i++ { + var toMatch Layer + if i < len(conn.layerStates) { + s := conn.layerStates[i] + toMatch = s.incoming(received[i]) + if toMatch == nil { + return false + } + if i < len(override) { + if err := toMatch.merge(override[i]); err != nil { + conn.t.Fatalf("failed to merge: %s", err) + } + } + } else { + toMatch = override[i] + if toMatch == nil { + conn.t.Fatalf("expect the overriding layers to be non-nil") + } + } + if !toMatch.match(received[i]) { + return false + } + } + return true +} + +// Close frees associated resources held by the Connection. +func (conn *Connection) Close() { + errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) + for _, s := range conn.layerStates { + if err := s.close(); err != nil { + errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err)) + } + } + if errs != nil { + conn.t.Fatalf("unable to close %+v: %s", conn, errs) + } +} + +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers { + var layersToSend Layers + for _, s := range conn.layerStates { + layersToSend = append(layersToSend, s.outgoing()) + } + if err := layersToSend[len(layersToSend)-1].merge(layer); err != nil { + conn.t.Fatalf("can't merge %+v into %+v: %s", layer, layersToSend[len(layersToSend)-1], err) + } + layersToSend = append(layersToSend, additionalLayers...) + return layersToSend +} + +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *Connection) SendFrame(frame Layers) { + outBytes, err := frame.toBytes() + if err != nil { + conn.t.Fatalf("can't build outgoing TCP packet: %s", err) + } + conn.injector.Send(outBytes) + + // frame might have nil values where the caller wanted to use default values. + // sentFrame will have no nil values in it because it comes from parsing the + // bytes that were actually sent. + sentFrame := parse(parseEther, outBytes) + // Update the state of each layer based on what was sent. + for i, s := range conn.layerStates { + if err := s.sent(sentFrame[i]); err != nil { + conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) + } + } +} + +// Send a packet with reasonable defaults. Potentially override the final layer +// in the connection with the provided layer and add additionLayers. +func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) { + conn.SendFrame(conn.CreateFrame(layer, additionalLayers...)) +} + +// recvFrame gets the next successfully parsed frame (of type Layers) within the +// timeout provided. If no parsable frame arrives before the timeout, it returns +// nil. +func (conn *Connection) recvFrame(timeout time.Duration) Layers { + if timeout <= 0 { + return nil + } + b := conn.sniffer.Recv(timeout) + if b == nil { + return nil + } + return parse(parseEther, b) +} + +// Expect a frame with the final layerStates layer matching the provided Layer +// within the timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) { + // Make a frame that will ignore all but the final layer. + layers := make([]Layer, len(conn.layerStates)) + layers[len(layers)-1] = layer + + gotFrame, err := conn.ExpectFrame(layers, timeout) + if err != nil { + return nil, err + } + if len(conn.layerStates)-1 < len(gotFrame) { + return gotFrame[len(conn.layerStates)-1], nil + } + conn.t.Fatal("the received frame should be at least as long as the expected layers") + return nil, fmt.Errorf("the received frame should be at least as long as the expected layers") +} + +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) { + deadline := time.Now().Add(timeout) + var allLayers []string + for { + var gotLayers Layers + if timeout = time.Until(deadline); timeout > 0 { + gotLayers = conn.recvFrame(timeout) + } + if gotLayers == nil { + return nil, fmt.Errorf("got %d packets:\n%s", len(allLayers), strings.Join(allLayers, "\n")) + } + if conn.match(layers, gotLayers) { + for i, s := range conn.layerStates { + if err := s.received(gotLayers[i]); err != nil { + conn.t.Fatal(err) + } + } + return gotLayers, nil + } + allLayers = append(allLayers, fmt.Sprintf("%s", gotLayers)) + } +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *Connection) Drain() { + conn.sniffer.Drain() +} + +// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. +type TCPIPv4 Connection + +// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. +func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) + } + tcpState, err := newTCPState(outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv4{ + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, + } +} + +// Handshake performs a TCP 3-way handshake. The input Connection should have a +// final TCP Layer. +func (conn *TCPIPv4) Handshake() { + // Send the SYN. + conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) + + // Wait for the SYN-ACK. + synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if synAck == nil { + conn.t.Fatalf("didn't get synack during handshake: %s", err) + } + conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck + + // Send an ACK. + conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(expected, timeout) +} + +// Send a packet with reasonable defaults. Potentially override the TCP layer in +// the connection with the provided layer and add additionLayers. +func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { + (*Connection)(conn).Send(&tcp, additionalLayers...) +} + +// Close frees associated resources held by the TCPIPv4 connection. +func (conn *TCPIPv4) Close() { + (*Connection)(conn).Close() +} + +// Expect a frame with the TCP layer matching the provided TCP within the +// timeout specified. If it doesn't arrive in time, it returns nil. +func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) { + layer, err := (*Connection)(conn).Expect(&tcp, timeout) + if layer == nil { + return nil, err + } + gotTCP, ok := layer.(*TCP) + if !ok { + conn.t.Fatalf("expected %s to be TCP", layer) + } + return gotTCP, err +} + +func (conn *TCPIPv4) state() *tcpState { + state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState) + if !ok { + conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates) + } + return state +} + +// RemoteSeqNum returns the next expected sequence number from the DUT. +func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { + return conn.state().remoteSeqNum +} + +// LocalSeqNum returns the next sequence number to send from the testbench. +func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { + return conn.state().localSeqNum +} + +// SynAck returns the SynAck that was part of the handshake. +func (conn *TCPIPv4) SynAck() *TCP { + return conn.state().synAck +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *TCPIPv4) Drain() { + conn.sniffer.Drain() +} + +// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. +type UDPIPv4 Connection + +// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. +func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv4State, err := newIPv4State(IPv4{}, IPv4{}) + if err != nil { + t.Fatalf("can't make ipv4State: %s", err) + } + tcpState, err := newUDPState(outgoingUDP, incomingUDP) + if err != nil { + t.Fatalf("can't make udpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return UDPIPv4{ + layerStates: []layerState{etherState, ipv4State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, + } +} + +// CreateFrame builds a frame for the connection with layer overriding defaults +// of the innermost layer and additionalLayers added after it. +func (conn *UDPIPv4) CreateFrame(layer Layer, additionalLayers ...Layer) Layers { + return (*Connection)(conn).CreateFrame(layer, additionalLayers...) +} + +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *UDPIPv4) SendFrame(frame Layers) { + (*Connection)(conn).SendFrame(frame) +} + +// Close frees associated resources held by the UDPIPv4 connection. +func (conn *UDPIPv4) Close() { + (*Connection)(conn).Close() +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv4) Drain() { + conn.sniffer.Drain() +} diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go new file mode 100644 index 000000000..3f340c6bc --- /dev/null +++ b/test/packetimpact/testbench/dut.go @@ -0,0 +1,473 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "context" + "flag" + "net" + "strconv" + "syscall" + "testing" + "time" + + pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" + + "golang.org/x/sys/unix" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +var ( + posixServerIP = flag.String("posix_server_ip", "", "ip address to listen to for UDP commands") + posixServerPort = flag.Int("posix_server_port", 40000, "port to listen to for UDP commands") + rpcTimeout = flag.Duration("rpc_timeout", 100*time.Millisecond, "gRPC timeout") + rpcKeepalive = flag.Duration("rpc_keepalive", 10*time.Second, "gRPC keepalive") +) + +// DUT communicates with the DUT to force it to make POSIX calls. +type DUT struct { + t *testing.T + conn *grpc.ClientConn + posixServer PosixClient +} + +// NewDUT creates a new connection with the DUT over gRPC. +func NewDUT(t *testing.T) DUT { + flag.Parse() + posixServerAddress := *posixServerIP + ":" + strconv.Itoa(*posixServerPort) + conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: *rpcKeepalive})) + if err != nil { + t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err) + } + posixServer := NewPosixClient(conn) + return DUT{ + t: t, + conn: conn, + posixServer: posixServer, + } +} + +// TearDown closes the underlying connection. +func (dut *DUT) TearDown() { + dut.conn.Close() +} + +func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr { + dut.t.Helper() + switch s := sa.(type) { + case *unix.SockaddrInet4: + return &pb.Sockaddr{ + Sockaddr: &pb.Sockaddr_In{ + In: &pb.SockaddrIn{ + Family: unix.AF_INET, + Port: uint32(s.Port), + Addr: s.Addr[:], + }, + }, + } + case *unix.SockaddrInet6: + return &pb.Sockaddr{ + Sockaddr: &pb.Sockaddr_In6{ + In6: &pb.SockaddrIn6{ + Family: unix.AF_INET6, + Port: uint32(s.Port), + Flowinfo: 0, + ScopeId: s.ZoneId, + Addr: s.Addr[:], + }, + }, + } + } + dut.t.Fatalf("can't parse Sockaddr: %+v", sa) + return nil +} + +func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr { + dut.t.Helper() + switch s := sa.Sockaddr.(type) { + case *pb.Sockaddr_In: + ret := unix.SockaddrInet4{ + Port: int(s.In.GetPort()), + } + copy(ret.Addr[:], s.In.GetAddr()) + return &ret + case *pb.Sockaddr_In6: + ret := unix.SockaddrInet6{ + Port: int(s.In6.GetPort()), + ZoneId: s.In6.GetScopeId(), + } + copy(ret.Addr[:], s.In6.GetAddr()) + } + dut.t.Fatalf("can't parse Sockaddr: %+v", sa) + return nil +} + +// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol +// proto, and bound to the IP address addr. Returns the new file descriptor and +// the port that was selected on the DUT. +func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) { + dut.t.Helper() + var fd int32 + if addr.To4() != nil { + fd = dut.Socket(unix.AF_INET, typ, proto) + sa := unix.SockaddrInet4{} + copy(sa.Addr[:], addr.To4()) + dut.Bind(fd, &sa) + } else if addr.To16() != nil { + fd = dut.Socket(unix.AF_INET6, typ, proto) + sa := unix.SockaddrInet6{} + copy(sa.Addr[:], addr.To16()) + dut.Bind(fd, &sa) + } else { + dut.t.Fatalf("unknown ip addr type for remoteIP") + } + sa := dut.GetSockName(fd) + var port int + switch s := sa.(type) { + case *unix.SockaddrInet4: + port = s.Port + case *unix.SockaddrInet6: + port = s.Port + default: + dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa) + } + return fd, uint16(port) +} + +// CreateListener makes a new TCP connection. If it fails, the test ends. +func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { + fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4)) + dut.Listen(fd, backlog) + return fd, remotePort +} + +// All the functions that make gRPC calls to the Posix service are below, sorted +// alphabetically. + +// Accept calls accept on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// AcceptWithErrno. +func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + fd, sa, err := dut.AcceptWithErrno(ctx, sockfd) + if fd < 0 { + dut.t.Fatalf("failed to accept: %s", err) + } + return fd, sa +} + +// AcceptWithErrno calls accept on the DUT. +func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { + dut.t.Helper() + req := pb.AcceptRequest{ + Sockfd: sockfd, + } + resp, err := dut.posixServer.Accept(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Accept: %s", err) + } + return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) +} + +// Bind calls bind on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is +// needed, use BindWithErrno. +func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.BindWithErrno(ctx, fd, sa) + if ret != 0 { + dut.t.Fatalf("failed to bind socket: %s", err) + } +} + +// BindWithErrno calls bind on the DUT. +func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { + dut.t.Helper() + req := pb.BindRequest{ + Sockfd: fd, + Addr: dut.sockaddrToProto(sa), + } + resp, err := dut.posixServer.Bind(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Bind: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Close calls close on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// CloseWithErrno. +func (dut *DUT) Close(fd int32) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.CloseWithErrno(ctx, fd) + if ret != 0 { + dut.t.Fatalf("failed to close: %s", err) + } +} + +// CloseWithErrno calls close on the DUT. +func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { + dut.t.Helper() + req := pb.CloseRequest{ + Fd: fd, + } + resp, err := dut.posixServer.Close(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Close: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// GetSockName calls getsockname on the DUT and causes a fatal test failure if +// it doesn't succeed. If more control over the timeout or error handling is +// needed, use GetSockNameWithErrno. +func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd) + if ret != 0 { + dut.t.Fatalf("failed to getsockname: %s", err) + } + return sa +} + +// GetSockNameWithErrno calls getsockname on the DUT. +func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { + dut.t.Helper() + req := pb.GetSockNameRequest{ + Sockfd: sockfd, + } + resp, err := dut.posixServer.GetSockName(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Bind: %s", err) + } + return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) +} + +// Listen calls listen on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// ListenWithErrno. +func (dut *DUT) Listen(sockfd, backlog int32) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.ListenWithErrno(ctx, sockfd, backlog) + if ret != 0 { + dut.t.Fatalf("failed to listen: %s", err) + } +} + +// ListenWithErrno calls listen on the DUT. +func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) { + dut.t.Helper() + req := pb.ListenRequest{ + Sockfd: sockfd, + Backlog: backlog, + } + resp, err := dut.posixServer.Listen(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Listen: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Send calls send on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// SendWithErrno. +func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags) + if ret == -1 { + dut.t.Fatalf("failed to send: %s", err) + } + return ret +} + +// SendWithErrno calls send on the DUT. +func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) { + dut.t.Helper() + req := pb.SendRequest{ + Sockfd: sockfd, + Buf: buf, + Flags: flags, + } + resp, err := dut.posixServer.Send(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Send: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it +// doesn't succeed. If more control over the timeout or error handling is +// needed, use SetSockOptWithErrno. Because endianess and the width of values +// might differ between the testbench and DUT architectures, prefer to use a +// more specific SetSockOptXxx function. +func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval) + if ret != 0 { + dut.t.Fatalf("failed to SetSockOpt: %s", err) + } +} + +// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the +// width of values might differ between the testbench and DUT architectures, +// prefer to use a more specific SetSockOptXxxWithErrno function. +func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) { + dut.t.Helper() + req := pb.SetSockOptRequest{ + Sockfd: sockfd, + Level: level, + Optname: optname, + Optval: optval, + } + resp, err := dut.posixServer.SetSockOpt(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call SetSockOpt: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the int optval or error handling +// is needed, use SetSockOptIntWithErrno. +func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval) + if ret != 0 { + dut.t.Fatalf("failed to SetSockOptInt: %s", err) + } +} + +// SetSockOptIntWithErrno calls setsockopt with an integer optval. +func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) { + dut.t.Helper() + req := pb.SetSockOptIntRequest{ + Sockfd: sockfd, + Level: level, + Optname: optname, + Intval: optval, + } + resp, err := dut.posixServer.SetSockOptInt(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call SetSockOptInt: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure +// if it doesn't succeed. If more control over the timeout or error handling is +// needed, use SetSockOptTimevalWithErrno. +func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv) + if ret != 0 { + dut.t.Fatalf("failed to SetSockOptTimeval: %s", err) + } +} + +// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to +// bytes. +func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { + dut.t.Helper() + timeval := pb.Timeval{ + Seconds: int64(tv.Sec), + Microseconds: int64(tv.Usec), + } + req := pb.SetSockOptTimevalRequest{ + Sockfd: sockfd, + Level: level, + Optname: optname, + Timeval: &timeval, + } + resp, err := dut.posixServer.SetSockOptTimeval(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call SetSockOptTimeval: %s", err) + } + return resp.GetRet(), syscall.Errno(resp.GetErrno_()) +} + +// Socket calls socket on the DUT and returns the file descriptor. If socket +// fails on the DUT, the test ends. +func (dut *DUT) Socket(domain, typ, proto int32) int32 { + dut.t.Helper() + fd, err := dut.SocketWithErrno(domain, typ, proto) + if fd < 0 { + dut.t.Fatalf("failed to create socket: %s", err) + } + return fd +} + +// SocketWithErrno calls socket on the DUT and returns the fd and errno. +func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { + dut.t.Helper() + req := pb.SocketRequest{ + Domain: domain, + Type: typ, + Protocol: proto, + } + ctx := context.Background() + resp, err := dut.posixServer.Socket(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Socket: %s", err) + } + return resp.GetFd(), syscall.Errno(resp.GetErrno_()) +} + +// Recv calls recv on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// RecvWithErrno. +func (dut *DUT) Recv(sockfd, len, flags int32) []byte { + dut.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + defer cancel() + ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags) + if ret == -1 { + dut.t.Fatalf("failed to recv: %s", err) + } + return buf +} + +// RecvWithErrno calls recv on the DUT. +func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) { + dut.t.Helper() + req := pb.RecvRequest{ + Sockfd: sockfd, + Len: len, + Flags: flags, + } + resp, err := dut.posixServer.Recv(ctx, &req) + if err != nil { + dut.t.Fatalf("failed to call Recv: %s", err) + } + return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) +} diff --git a/test/root/testdata/httpd.go b/test/packetimpact/testbench/dut_client.go index 45d5e33d4..b130a33a2 100644 --- a/test/root/testdata/httpd.go +++ b/test/packetimpact/testbench/dut_client.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testdata +package testbench -// Httpd is a JSON config for an httpd container. -const Httpd = ` -{ - "metadata": { - "name": "httpd" - }, - "image":{ - "image": "httpd" - }, - "mounts": [ - ], - "linux": { - }, - "log_path": "httpd.log" +import ( + "google.golang.org/grpc" + pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" +) + +// PosixClient is a gRPC client for the Posix service. +type PosixClient pb.PosixClient + +// NewPosixClient makes a new gRPC client for the Posix service. +func NewPosixClient(c grpc.ClientConnInterface) PosixClient { + return pb.NewPosixClient(c) } -` diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go new file mode 100644 index 000000000..5ce324f0d --- /dev/null +++ b/test/packetimpact/testbench/layers.go @@ -0,0 +1,710 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "encoding/hex" + "fmt" + "reflect" + "strings" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/imdario/mergo" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// Layer is the interface that all encapsulations must implement. +// +// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A +// Layer contains all the fields of the encapsulation. Each field is a pointer +// and may be nil. +type Layer interface { + fmt.Stringer + + // toBytes converts the Layer into bytes. In places where the Layer's field + // isn't nil, the value that is pointed to is used. When the field is nil, a + // reasonable default for the Layer is used. For example, "64" for IPv4 TTL + // and a calculated checksum for TCP or IP. Some layers require information + // from the previous or next layers in order to compute a default, such as + // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list + // to the layer's neighbors. + toBytes() ([]byte, error) + + // match checks if the current Layer matches the provided Layer. If either + // Layer has a nil in a given field, that field is considered matching. + // Otherwise, the values pointed to by the fields must match. The LayerBase is + // ignored. + match(Layer) bool + + // length in bytes of the current encapsulation + length() int + + // next gets a pointer to the encapsulated Layer. + next() Layer + + // prev gets a pointer to the Layer encapsulating this one. + prev() Layer + + // setNext sets the pointer to the encapsulated Layer. + setNext(Layer) + + // setPrev sets the pointer to the Layer encapsulating this one. + setPrev(Layer) + + // merge overrides the values in the interface with the provided values. + merge(Layer) error +} + +// LayerBase is the common elements of all layers. +type LayerBase struct { + nextLayer Layer + prevLayer Layer +} + +func (lb *LayerBase) next() Layer { + return lb.nextLayer +} + +func (lb *LayerBase) prev() Layer { + return lb.prevLayer +} + +func (lb *LayerBase) setNext(l Layer) { + lb.nextLayer = l +} + +func (lb *LayerBase) setPrev(l Layer) { + lb.prevLayer = l +} + +// equalLayer compares that two Layer structs match while ignoring field in +// which either input has a nil and also ignoring the LayerBase of the inputs. +func equalLayer(x, y Layer) bool { + if x == nil || y == nil { + return true + } + // opt ignores comparison pairs where either of the inputs is a nil. + opt := cmp.FilterValues(func(x, y interface{}) bool { + for _, l := range []interface{}{x, y} { + v := reflect.ValueOf(l) + if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() { + return true + } + } + return false + }, cmp.Ignore()) + return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{})) +} + +// mergeLayer merges other in layer. Any non-nil value in other overrides the +// corresponding value in layer. If other is nil, no action is performed. +func mergeLayer(layer, other Layer) error { + if other == nil { + return nil + } + return mergo.Merge(layer, other, mergo.WithOverride) +} + +func stringLayer(l Layer) string { + v := reflect.ValueOf(l).Elem() + t := v.Type() + var ret []string + for i := 0; i < v.NumField(); i++ { + t := t.Field(i) + if t.Anonymous { + // Ignore the LayerBase in the Layer struct. + continue + } + v := v.Field(i) + if v.IsNil() { + continue + } + v = reflect.Indirect(v) + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { + ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes()))) + } else { + ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v)) + } + } + return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " ")) +} + +// Ether can construct and match an ethernet encapsulation. +type Ether struct { + LayerBase + SrcAddr *tcpip.LinkAddress + DstAddr *tcpip.LinkAddress + Type *tcpip.NetworkProtocolNumber +} + +func (l *Ether) String() string { + return stringLayer(l) +} + +func (l *Ether) toBytes() ([]byte, error) { + b := make([]byte, header.EthernetMinimumSize) + h := header.Ethernet(b) + fields := &header.EthernetFields{} + if l.SrcAddr != nil { + fields.SrcAddr = *l.SrcAddr + } + if l.DstAddr != nil { + fields.DstAddr = *l.DstAddr + } + if l.Type != nil { + fields.Type = *l.Type + } else { + switch n := l.next().(type) { + case *IPv4: + fields.Type = header.IPv4ProtocolNumber + default: + // TODO(b/150301488): Support more protocols, like IPv6. + return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n) + } + } + h.Encode(fields) + return h, nil +} + +// LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value +// to store v and returns a pointer to it. +func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress { + return &v +} + +// NetworkProtocolNumber is a helper routine that allocates a new +// tcpip.NetworkProtocolNumber value to store v and returns a pointer to it. +func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber { + return &v +} + +// layerParser parses the input bytes and returns a Layer along with the next +// layerParser to run. If there is no more parsing to do, the returned +// layerParser is nil. +type layerParser func([]byte) (Layer, layerParser) + +// parse parses bytes starting with the first layerParser and using successive +// layerParsers until all the bytes are parsed. +func parse(parser layerParser, b []byte) Layers { + var layers Layers + for { + var layer Layer + layer, parser = parser(b) + layers = append(layers, layer) + if parser == nil { + break + } + b = b[layer.length():] + } + layers.linkLayers() + return layers +} + +// parseEther parses the bytes assuming that they start with an ethernet header +// and continues parsing further encapsulations. +func parseEther(b []byte) (Layer, layerParser) { + h := header.Ethernet(b) + ether := Ether{ + SrcAddr: LinkAddress(h.SourceAddress()), + DstAddr: LinkAddress(h.DestinationAddress()), + Type: NetworkProtocolNumber(h.Type()), + } + var nextParser layerParser + switch h.Type() { + case header.IPv4ProtocolNumber: + nextParser = parseIPv4 + default: + // Assume that the rest is a payload. + nextParser = parsePayload + } + return ðer, nextParser +} + +func (l *Ether) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *Ether) length() int { + return header.EthernetMinimumSize +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Ether) merge(other Layer) error { + return mergeLayer(l, other) +} + +// IPv4 can construct and match an IPv4 encapsulation. +type IPv4 struct { + LayerBase + IHL *uint8 + TOS *uint8 + TotalLength *uint16 + ID *uint16 + Flags *uint8 + FragmentOffset *uint16 + TTL *uint8 + Protocol *uint8 + Checksum *uint16 + SrcAddr *tcpip.Address + DstAddr *tcpip.Address +} + +func (l *IPv4) String() string { + return stringLayer(l) +} + +func (l *IPv4) toBytes() ([]byte, error) { + b := make([]byte, header.IPv4MinimumSize) + h := header.IPv4(b) + fields := &header.IPv4Fields{ + IHL: 20, + TOS: 0, + TotalLength: 0, + ID: 0, + Flags: 0, + FragmentOffset: 0, + TTL: 64, + Protocol: 0, + Checksum: 0, + SrcAddr: tcpip.Address(""), + DstAddr: tcpip.Address(""), + } + if l.TOS != nil { + fields.TOS = *l.TOS + } + if l.TotalLength != nil { + fields.TotalLength = *l.TotalLength + } else { + fields.TotalLength = uint16(l.length()) + current := l.next() + for current != nil { + fields.TotalLength += uint16(current.length()) + current = current.next() + } + } + if l.ID != nil { + fields.ID = *l.ID + } + if l.Flags != nil { + fields.Flags = *l.Flags + } + if l.FragmentOffset != nil { + fields.FragmentOffset = *l.FragmentOffset + } + if l.TTL != nil { + fields.TTL = *l.TTL + } + if l.Protocol != nil { + fields.Protocol = *l.Protocol + } else { + switch n := l.next().(type) { + case *TCP: + fields.Protocol = uint8(header.TCPProtocolNumber) + case *UDP: + fields.Protocol = uint8(header.UDPProtocolNumber) + default: + // TODO(b/150301488): Support more protocols as needed. + return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n) + } + } + if l.SrcAddr != nil { + fields.SrcAddr = *l.SrcAddr + } + if l.DstAddr != nil { + fields.DstAddr = *l.DstAddr + } + if l.Checksum != nil { + fields.Checksum = *l.Checksum + } + h.Encode(fields) + if l.Checksum == nil { + h.SetChecksum(^h.CalculateChecksum()) + } + return h, nil +} + +// Uint16 is a helper routine that allocates a new +// uint16 value to store v and returns a pointer to it. +func Uint16(v uint16) *uint16 { + return &v +} + +// Uint8 is a helper routine that allocates a new +// uint8 value to store v and returns a pointer to it. +func Uint8(v uint8) *uint8 { + return &v +} + +// Address is a helper routine that allocates a new tcpip.Address value to store +// v and returns a pointer to it. +func Address(v tcpip.Address) *tcpip.Address { + return &v +} + +// parseIPv4 parses the bytes assuming that they start with an ipv4 header and +// continues parsing further encapsulations. +func parseIPv4(b []byte) (Layer, layerParser) { + h := header.IPv4(b) + tos, _ := h.TOS() + ipv4 := IPv4{ + IHL: Uint8(h.HeaderLength()), + TOS: &tos, + TotalLength: Uint16(h.TotalLength()), + ID: Uint16(h.ID()), + Flags: Uint8(h.Flags()), + FragmentOffset: Uint16(h.FragmentOffset()), + TTL: Uint8(h.TTL()), + Protocol: Uint8(h.Protocol()), + Checksum: Uint16(h.Checksum()), + SrcAddr: Address(h.SourceAddress()), + DstAddr: Address(h.DestinationAddress()), + } + var nextParser layerParser + switch h.TransportProtocol() { + case header.TCPProtocolNumber: + nextParser = parseTCP + case header.UDPProtocolNumber: + nextParser = parseUDP + default: + // Assume that the rest is a payload. + nextParser = parsePayload + } + return &ipv4, nextParser +} + +func (l *IPv4) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *IPv4) length() int { + if l.IHL == nil { + return header.IPv4MinimumSize + } + return int(*l.IHL) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv4) merge(other Layer) error { + return mergeLayer(l, other) +} + +// TCP can construct and match a TCP encapsulation. +type TCP struct { + LayerBase + SrcPort *uint16 + DstPort *uint16 + SeqNum *uint32 + AckNum *uint32 + DataOffset *uint8 + Flags *uint8 + WindowSize *uint16 + Checksum *uint16 + UrgentPointer *uint16 +} + +func (l *TCP) String() string { + return stringLayer(l) +} + +func (l *TCP) toBytes() ([]byte, error) { + b := make([]byte, header.TCPMinimumSize) + h := header.TCP(b) + if l.SrcPort != nil { + h.SetSourcePort(*l.SrcPort) + } + if l.DstPort != nil { + h.SetDestinationPort(*l.DstPort) + } + if l.SeqNum != nil { + h.SetSequenceNumber(*l.SeqNum) + } + if l.AckNum != nil { + h.SetAckNumber(*l.AckNum) + } + if l.DataOffset != nil { + h.SetDataOffset(*l.DataOffset) + } else { + h.SetDataOffset(uint8(l.length())) + } + if l.Flags != nil { + h.SetFlags(*l.Flags) + } + if l.WindowSize != nil { + h.SetWindowSize(*l.WindowSize) + } else { + h.SetWindowSize(32768) + } + if l.UrgentPointer != nil { + h.SetUrgentPoiner(*l.UrgentPointer) + } + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + if err := setTCPChecksum(&h, l); err != nil { + return nil, err + } + return h, nil +} + +// totalLength returns the length of the provided layer and all following +// layers. +func totalLength(l Layer) int { + var totalLength int + for ; l != nil; l = l.next() { + totalLength += l.length() + } + return totalLength +} + +// layerChecksum calculates the checksum of the Layer header, including the +// peusdeochecksum of the layer before it and all the bytes after it.. +func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) { + totalLength := uint16(totalLength(l)) + var xsum uint16 + switch s := l.prev().(type) { + case *IPv4: + xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength) + default: + // TODO(b/150301488): Support more protocols, like IPv6. + return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s) + } + var payloadBytes buffer.VectorisedView + for current := l.next(); current != nil; current = current.next() { + payload, err := current.toBytes() + if err != nil { + return 0, fmt.Errorf("can't get bytes for next header: %s", payload) + } + payloadBytes.AppendView(payload) + } + xsum = header.ChecksumVV(payloadBytes, xsum) + return xsum, nil +} + +// setTCPChecksum calculates the checksum of the TCP header and sets it in h. +func setTCPChecksum(h *header.TCP, tcp *TCP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(tcp, header.TCPProtocolNumber) + if err != nil { + return err + } + h.SetChecksum(^h.CalculateChecksum(xsum)) + return nil +} + +// Uint32 is a helper routine that allocates a new +// uint32 value to store v and returns a pointer to it. +func Uint32(v uint32) *uint32 { + return &v +} + +// parseTCP parses the bytes assuming that they start with a tcp header and +// continues parsing further encapsulations. +func parseTCP(b []byte) (Layer, layerParser) { + h := header.TCP(b) + tcp := TCP{ + SrcPort: Uint16(h.SourcePort()), + DstPort: Uint16(h.DestinationPort()), + SeqNum: Uint32(h.SequenceNumber()), + AckNum: Uint32(h.AckNumber()), + DataOffset: Uint8(h.DataOffset()), + Flags: Uint8(h.Flags()), + WindowSize: Uint16(h.WindowSize()), + Checksum: Uint16(h.Checksum()), + UrgentPointer: Uint16(h.UrgentPointer()), + } + return &tcp, parsePayload +} + +func (l *TCP) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *TCP) length() int { + if l.DataOffset == nil { + return header.TCPMinimumSize + } + return int(*l.DataOffset) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *TCP) merge(other Layer) error { + return mergeLayer(l, other) +} + +// UDP can construct and match a UDP encapsulation. +type UDP struct { + LayerBase + SrcPort *uint16 + DstPort *uint16 + Length *uint16 + Checksum *uint16 +} + +func (l *UDP) String() string { + return stringLayer(l) +} + +func (l *UDP) toBytes() ([]byte, error) { + b := make([]byte, header.UDPMinimumSize) + h := header.UDP(b) + if l.SrcPort != nil { + h.SetSourcePort(*l.SrcPort) + } + if l.DstPort != nil { + h.SetDestinationPort(*l.DstPort) + } + if l.Length != nil { + h.SetLength(*l.Length) + } else { + h.SetLength(uint16(totalLength(l))) + } + if l.Checksum != nil { + h.SetChecksum(*l.Checksum) + return h, nil + } + if err := setUDPChecksum(&h, l); err != nil { + return nil, err + } + return h, nil +} + +// setUDPChecksum calculates the checksum of the UDP header and sets it in h. +func setUDPChecksum(h *header.UDP, udp *UDP) error { + h.SetChecksum(0) + xsum, err := layerChecksum(udp, header.UDPProtocolNumber) + if err != nil { + return err + } + h.SetChecksum(^h.CalculateChecksum(xsum)) + return nil +} + +// parseUDP parses the bytes assuming that they start with a udp header and +// returns the parsed layer and the next parser to use. +func parseUDP(b []byte) (Layer, layerParser) { + h := header.UDP(b) + udp := UDP{ + SrcPort: Uint16(h.SourcePort()), + DstPort: Uint16(h.DestinationPort()), + Length: Uint16(h.Length()), + Checksum: Uint16(h.Checksum()), + } + return &udp, parsePayload +} + +func (l *UDP) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *UDP) length() int { + if l.Length == nil { + return header.UDPMinimumSize + } + return int(*l.Length) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *UDP) merge(other Layer) error { + return mergeLayer(l, other) +} + +// Payload has bytes beyond OSI layer 4. +type Payload struct { + LayerBase + Bytes []byte +} + +func (l *Payload) String() string { + return stringLayer(l) +} + +// parsePayload parses the bytes assuming that they start with a payload and +// continue to the end. There can be no further encapsulations. +func parsePayload(b []byte) (Layer, layerParser) { + payload := Payload{ + Bytes: b, + } + return &payload, nil +} + +func (l *Payload) toBytes() ([]byte, error) { + return l.Bytes, nil +} + +func (l *Payload) match(other Layer) bool { + return equalLayer(l, other) +} + +func (l *Payload) length() int { + return len(l.Bytes) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *Payload) merge(other Layer) error { + return mergeLayer(l, other) +} + +// Layers is an array of Layer and supports similar functions to Layer. +type Layers []Layer + +// linkLayers sets the linked-list ponters in ls. +func (ls *Layers) linkLayers() { + for i, l := range *ls { + if i > 0 { + l.setPrev((*ls)[i-1]) + } else { + l.setPrev(nil) + } + if i+1 < len(*ls) { + l.setNext((*ls)[i+1]) + } else { + l.setNext(nil) + } + } +} + +func (ls *Layers) toBytes() ([]byte, error) { + ls.linkLayers() + outBytes := []byte{} + for _, l := range *ls { + layerBytes, err := l.toBytes() + if err != nil { + return nil, err + } + outBytes = append(outBytes, layerBytes...) + } + return outBytes, nil +} + +func (ls *Layers) match(other Layers) bool { + if len(*ls) > len(other) { + return false + } + for i, l := range *ls { + if !equalLayer(l, other[i]) { + return false + } + } + return true +} diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go new file mode 100644 index 000000000..c99cf6312 --- /dev/null +++ b/test/packetimpact/testbench/layers_test.go @@ -0,0 +1,206 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestLayerMatch(t *testing.T) { + var nilPayload *Payload + noPayload := &Payload{} + emptyPayload := &Payload{Bytes: []byte{}} + fullPayload := &Payload{Bytes: []byte{1, 2, 3}} + emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}} + fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}} + for _, tt := range []struct { + a, b Layer + want bool + }{ + {nilPayload, nilPayload, true}, + {nilPayload, noPayload, true}, + {nilPayload, emptyPayload, true}, + {nilPayload, fullPayload, true}, + {noPayload, noPayload, true}, + {noPayload, emptyPayload, true}, + {noPayload, fullPayload, true}, + {emptyPayload, emptyPayload, true}, + {emptyPayload, fullPayload, false}, + {fullPayload, fullPayload, true}, + {emptyTCP, fullTCP, true}, + } { + if got := tt.a.match(tt.b); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want) + } + if got := tt.b.match(tt.a); got != tt.want { + t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want) + } + } +} + +func TestLayerStringFormat(t *testing.T) { + for _, tt := range []struct { + name string + l Layer + want string + }{ + { + name: "TCP", + l: &TCP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + SeqNum: Uint32(3452155723), + AckNum: Uint32(2596996163), + DataOffset: Uint8(5), + Flags: Uint8(20), + WindowSize: Uint16(64240), + Checksum: Uint16(0x2e2b), + }, + want: "&testbench.TCP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "SeqNum:3452155723 " + + "AckNum:2596996163 " + + "DataOffset:5 " + + "Flags:20 " + + "WindowSize:64240 " + + "Checksum:11819" + + "}", + }, + { + name: "UDP", + l: &UDP{ + SrcPort: Uint16(34785), + DstPort: Uint16(47767), + Length: Uint16(12), + }, + want: "&testbench.UDP{" + + "SrcPort:34785 " + + "DstPort:47767 " + + "Length:12" + + "}", + }, + { + name: "IPv4", + l: &IPv4{ + IHL: Uint8(5), + TOS: Uint8(0), + TotalLength: Uint16(44), + ID: Uint16(0), + Flags: Uint8(2), + FragmentOffset: Uint16(0), + TTL: Uint8(64), + Protocol: Uint8(6), + Checksum: Uint16(0x2e2b), + SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})), + DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})), + }, + want: "&testbench.IPv4{" + + "IHL:5 " + + "TOS:0 " + + "TotalLength:44 " + + "ID:0 " + + "Flags:2 " + + "FragmentOffset:0 " + + "TTL:64 " + + "Protocol:6 " + + "Checksum:11819 " + + "SrcAddr:197.34.63.10 " + + "DstAddr:197.34.63.20" + + "}", + }, + { + name: "Ether", + l: &Ether{ + SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})), + DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})), + Type: NetworkProtocolNumber(4), + }, + want: "&testbench.Ether{" + + "SrcAddr:02:42:c5:22:3f:0a " + + "DstAddr:02:42:c5:22:3f:14 " + + "Type:4" + + "}", + }, + { + name: "Payload", + l: &Payload{ + Bytes: []byte("Hooray for packetimpact."), + }, + want: "&testbench.Payload{Bytes:\n" + + "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" + + "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" + + "}", + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.l.String(); got != tt.want { + t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want) + } + }) + } +} + +func TestConnectionMatch(t *testing.T) { + conn := Connection{ + layerStates: []layerState{ðerState{}}, + } + protoNum0 := tcpip.NetworkProtocolNumber(0) + protoNum1 := tcpip.NetworkProtocolNumber(1) + for _, tt := range []struct { + description string + override, received Layers + wantMatch bool + }{ + { + description: "shorter override", + override: []Layer{&Ether{}}, + received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, + wantMatch: true, + }, + { + description: "longer override", + override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}}, + received: []Layer{&Ether{}}, + wantMatch: false, + }, + { + description: "ether layer mismatch", + override: []Layer{&Ether{Type: &protoNum0}}, + received: []Layer{&Ether{Type: &protoNum1}}, + wantMatch: false, + }, + { + description: "both nil", + override: nil, + received: nil, + wantMatch: false, + }, + { + description: "nil override", + override: nil, + received: []Layer{&Ether{}}, + wantMatch: true, + }, + } { + t.Run(tt.description, func(t *testing.T) { + if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch { + t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch) + } + }) + } +} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go new file mode 100644 index 000000000..ff722d4a6 --- /dev/null +++ b/test/packetimpact/testbench/rawsockets.go @@ -0,0 +1,183 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "encoding/binary" + "flag" + "fmt" + "math" + "net" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/usermem" +) + +var device = flag.String("device", "", "local device for test packets") + +// Sniffer can sniff raw packets on the wire. +type Sniffer struct { + t *testing.T + fd int +} + +func htons(x uint16) uint16 { + buf := [2]byte{} + binary.BigEndian.PutUint16(buf[:], x) + return usermem.ByteOrder.Uint16(buf[:]) +} + +// NewSniffer creates a Sniffer connected to *device. +func NewSniffer(t *testing.T) (Sniffer, error) { + flag.Parse() + snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) + if err != nil { + return Sniffer{}, err + } + if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil { + t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err) + } + if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil { + t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err) + } + return Sniffer{ + t: t, + fd: snifferFd, + }, nil +} + +// maxReadSize should be large enough for the maximum frame size in bytes. If a +// packet too large for the buffer arrives, the test will get a fatal error. +const maxReadSize int = 65536 + +// Recv tries to read one frame until the timeout is up. +func (s *Sniffer) Recv(timeout time.Duration) []byte { + deadline := time.Now().Add(timeout) + for { + timeout = deadline.Sub(time.Now()) + if timeout <= 0 { + return nil + } + whole, frac := math.Modf(timeout.Seconds()) + tv := unix.Timeval{ + Sec: int64(whole), + Usec: int64(frac * float64(time.Microsecond/time.Second)), + } + + if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { + s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) + } + + buf := make([]byte, maxReadSize) + nread, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN { + // There was a timeout. + continue + } + if err != nil { + s.t.Fatalf("can't read: %s", err) + } + if nread > maxReadSize { + s.t.Fatalf("received a truncated frame of %d bytes", nread) + } + return buf[:nread] + } +} + +// Drain drains the Sniffer's socket receive buffer by receiving until there's +// nothing else to receive. +func (s *Sniffer) Drain() { + s.t.Helper() + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) + if err != nil { + s.t.Fatalf("failed to get sniffer socket fd flags: %s", err) + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil { + s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err) + } + for { + buf := make([]byte, maxReadSize) + _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC) + if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK { + break + } + } + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { + s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err) + } +} + +// close the socket that Sniffer is using. +func (s *Sniffer) close() error { + if err := unix.Close(s.fd); err != nil { + return fmt.Errorf("can't close sniffer socket: %w", err) + } + s.fd = -1 + return nil +} + +// Injector can inject raw frames. +type Injector struct { + t *testing.T + fd int +} + +// NewInjector creates a new injector on *device. +func NewInjector(t *testing.T) (Injector, error) { + flag.Parse() + ifInfo, err := net.InterfaceByName(*device) + if err != nil { + return Injector{}, err + } + + var haddr [8]byte + copy(haddr[:], ifInfo.HardwareAddr) + sa := unix.SockaddrLinklayer{ + Protocol: unix.ETH_P_IP, + Ifindex: ifInfo.Index, + Halen: uint8(len(ifInfo.HardwareAddr)), + Addr: haddr, + } + + injectFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) + if err != nil { + return Injector{}, err + } + if err := unix.Bind(injectFd, &sa); err != nil { + return Injector{}, err + } + return Injector{ + t: t, + fd: injectFd, + }, nil +} + +// Send a raw frame. +func (i *Injector) Send(b []byte) { + if _, err := unix.Write(i.fd, b); err != nil { + i.t.Fatalf("can't write: %s", err) + } +} + +// close the underlying socket. +func (i *Injector) close() error { + if err := unix.Close(i.fd); err != nil { + return fmt.Errorf("can't close sniffer socket: %w", err) + } + i.fd = -1 + return nil +} diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD new file mode 100644 index 000000000..47c722ccd --- /dev/null +++ b/test/packetimpact/tests/BUILD @@ -0,0 +1,102 @@ +load("defs.bzl", "packetimpact_go_test") + +package( + default_visibility = ["//test/packetimpact:__subpackages__"], + licenses = ["notice"], +) + +packetimpact_go_test( + name = "fin_wait2_timeout", + srcs = ["fin_wait2_timeout_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_recv_multicast", + srcs = ["udp_recv_multicast_test.go"], + # TODO(b/152813495): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_window_shrink", + srcs = ["tcp_window_shrink_test.go"], + # TODO(b/153202472): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_outside_the_window", + srcs = ["tcp_outside_the_window_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_noaccept_close_rst", + srcs = ["tcp_noaccept_close_rst_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_should_piggyback", + srcs = ["tcp_should_piggyback_test.go"], + # TODO(b/153680566): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_close_wait_ack", + srcs = ["tcp_close_wait_ack_test.go"], + # TODO(b/153574037): Fix netstack then remove the line below. + netstack = False, + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_user_timeout", + srcs = ["tcp_user_timeout_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +sh_binary( + name = "test_runner", + srcs = ["test_runner.sh"], +) diff --git a/test/packetimpact/tests/Dockerfile b/test/packetimpact/tests/Dockerfile new file mode 100644 index 000000000..9075bc555 --- /dev/null +++ b/test/packetimpact/tests/Dockerfile @@ -0,0 +1,17 @@ +FROM ubuntu:bionic + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + # iptables to disable OS native packet processing. + iptables \ + # nc to check that the posix_server is running. + netcat \ + # tcpdump to log brief packet sniffing. + tcpdump \ + # ip link show to display MAC addresses. + iproute2 \ + # tshark to log verbose packet sniffing. + tshark \ + # killall for cleanup. + psmisc +RUN hash -r +CMD /bin/bash diff --git a/test/packetimpact/tests/defs.bzl b/test/packetimpact/tests/defs.bzl new file mode 100644 index 000000000..27c5de375 --- /dev/null +++ b/test/packetimpact/tests/defs.bzl @@ -0,0 +1,137 @@ +"""Defines rules for packetimpact test targets.""" + +load("//tools:defs.bzl", "go_test") + +def _packetimpact_test_impl(ctx): + test_runner = ctx.executable._test_runner + bench = ctx.actions.declare_file("%s-bench" % ctx.label.name) + bench_content = "\n".join([ + "#!/bin/bash", + # This test will run part in a distinct user namespace. This can cause + # permission problems, because all runfiles may not be owned by the + # current user, and no other users will be mapped in that namespace. + # Make sure that everything is readable here. + "find . -type f -exec chmod a+rx {} \\;", + "find . -type d -exec chmod a+rx {} \\;", + "%s %s --posix_server_binary %s --testbench_binary %s $@\n" % ( + test_runner.short_path, + " ".join(ctx.attr.flags), + ctx.files._posix_server_binary[0].short_path, + ctx.files.testbench_binary[0].short_path, + ), + ]) + ctx.actions.write(bench, bench_content, is_executable = True) + + transitive_files = depset() + if hasattr(ctx.attr._test_runner, "data_runfiles"): + transitive_files = depset(ctx.attr._test_runner.data_runfiles.files) + runfiles = ctx.runfiles( + files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary, + transitive_files = transitive_files, + collect_default = True, + collect_data = True, + ) + return [DefaultInfo(executable = bench, runfiles = runfiles)] + +_packetimpact_test = rule( + attrs = { + "_test_runner": attr.label( + executable = True, + cfg = "target", + default = ":test_runner", + ), + "_posix_server_binary": attr.label( + cfg = "target", + default = "//test/packetimpact/dut:posix_server", + ), + "testbench_binary": attr.label( + cfg = "target", + mandatory = True, + ), + "flags": attr.string_list( + mandatory = False, + default = [], + ), + }, + test = True, + implementation = _packetimpact_test_impl, +) + +PACKETIMPACT_TAGS = ["local", "manual"] + +def packetimpact_linux_test( + name, + testbench_binary, + expect_failure = False, + **kwargs): + """Add a packetimpact test on linux. + + Args: + name: name of the test + testbench_binary: the testbench binary + **kwargs: all the other args, forwarded to _packetimpact_test + """ + expect_failure_flag = ["--expect_failure"] if expect_failure else [] + _packetimpact_test( + name = name + "_linux_test", + testbench_binary = testbench_binary, + flags = ["--dut_platform", "linux"] + expect_failure_flag, + tags = PACKETIMPACT_TAGS + ["packetimpact"], + **kwargs + ) + +def packetimpact_netstack_test( + name, + testbench_binary, + expect_failure = False, + **kwargs): + """Add a packetimpact test on netstack. + + Args: + name: name of the test + testbench_binary: the testbench binary + expect_failure: the test must fail + **kwargs: all the other args, forwarded to _packetimpact_test + """ + expect_failure_flag = [] + if expect_failure: + expect_failure_flag = ["--expect_failure"] + _packetimpact_test( + name = name + "_netstack_test", + testbench_binary = testbench_binary, + # This is the default runtime unless + # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value. + flags = ["--dut_platform", "netstack", "--runtime=runsc-d"] + expect_failure_flag, + tags = PACKETIMPACT_TAGS + ["packetimpact"], + **kwargs + ) + +def packetimpact_go_test(name, size = "small", pure = True, linux = True, netstack = True, **kwargs): + """Add packetimpact tests written in go. + + Args: + name: name of the test + size: size of the test + pure: make a static go binary + linux: generate a linux test + netstack: generate a netstack test + **kwargs: all the other args, forwarded to go_test + """ + testbench_binary = name + "_test" + go_test( + name = testbench_binary, + size = size, + pure = pure, + tags = PACKETIMPACT_TAGS, + **kwargs + ) + packetimpact_linux_test( + name = name, + expect_failure = not linux, + testbench_binary = testbench_binary, + ) + packetimpact_netstack_test( + name = name, + expect_failure = not netstack, + testbench_binary = testbench_binary, + ) diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go new file mode 100644 index 000000000..b98594f94 --- /dev/null +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -0,0 +1,70 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fin_wait2_timeout_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestFinWait2Timeout(t *testing.T) { + for _, tt := range []struct { + description string + linger2 bool + }{ + {"WithLinger2", true}, + {"WithoutLinger2", false}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Handshake() + + acceptFd, _ := dut.Accept(listenFd) + if tt.linger2 { + tv := unix.Timeval{Sec: 1, Usec: 0} + dut.SetSockOptTimeval(acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv) + } + dut.Close(acceptFd) + + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) + } + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + + time.Sleep(5 * time.Second) + conn.Drain() + + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if tt.linger2 { + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected a RST packet within a second but got none: %s", err) + } + } else { + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, 10*time.Second); err == nil { + t.Fatalf("expected no RST packets within ten seconds but got one: %s", err) + } + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go new file mode 100644 index 000000000..eb4cc7a65 --- /dev/null +++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go @@ -0,0 +1,102 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_close_wait_ack_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestCloseWaitAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {"OTW", GenerateOTWSeqSegment, 0, false}, + {"OTW", GenerateOTWSeqSegment, 1, true}, + {"OTW", GenerateOTWSeqSegment, 2, true}, + {"ACK", GenerateUnaccACKSegment, 0, false}, + {"ACK", GenerateUnaccACKSegment, 1, true}, + {"ACK", GenerateUnaccACKSegment, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + + // Send a FIN to DUT to intiate the active close + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) + } + + // Send a segment with OTW Seq / unacc ACK and expect an ACK back + conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset), &tb.Payload{Bytes: []byte("Sample Data")}) + gotAck, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Fatalf("expected an ack but got none: %s", err) + } + if !tt.expectAck && gotAck != nil { + t.Fatalf("expected no ack but got one: %s", gotAck) + } + + // Now let's verify DUT is indeed in CLOSE_WAIT + dut.Close(acceptFd) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + t.Fatalf("expected DUT to send a FIN: %s", err) + } + // Ack the FIN from DUT + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + // Send some extra data to DUT + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}, &tb.Payload{Bytes: []byte("Sample Data")}) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected DUT to send an RST: %s", err) + } + }) + } +} + +// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the +// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK +// is expected from the receiver. +func GenerateOTWSeqSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP { + windowSize := seqnum.Size(*conn.SynAck().WindowSize) + lastAcceptable := conn.LocalSeqNum().Add(windowSize - 1) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return tb.TCP{SeqNum: tb.Uint32(otwSeq), Flags: tb.Uint8(header.TCPFlagAck)} +} + +// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated +// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is +// expected from the receiver. +func GenerateUnaccACKSegment(conn *tb.TCPIPv4, seqNumOffset seqnum.Size) tb.TCP { + lastAcceptable := conn.RemoteSeqNum() + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return tb.TCP{AckNum: tb.Uint32(unaccAck), Flags: tb.Uint8(header.TCPFlagAck)} +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go new file mode 100644 index 000000000..7ebdd1950 --- /dev/null +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_noaccept_close_rst_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestTcpNoAcceptCloseReset(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + conn.Handshake() + defer conn.Close() + dut.Close(listenFd) + if _, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + t.Fatalf("expected a RST-ACK packet but got none: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go new file mode 100644 index 000000000..db3d3273b --- /dev/null +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_outside_the_window_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive +// that are inside or outside the TCP window. Packets that are outside the +// window should force an extra ACK, as described in RFC793 page 69: +// https://tools.ietf.org/html/rfc793#page-69 +func TestTCPOutsideTheWindow(t *testing.T) { + for _, tt := range []struct { + description string + tcpFlags uint8 + payload []tb.Layer + seqNumOffset seqnum.Size + expectACK bool + }{ + {"SYN", header.TCPFlagSyn, nil, 0, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true}, + {"ACK", header.TCPFlagAck, nil, 0, false}, + {"FIN", header.TCPFlagFin, nil, 0, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 0, true}, + + {"SYN", header.TCPFlagSyn, nil, 1, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true}, + {"ACK", header.TCPFlagAck, nil, 1, true}, + {"FIN", header.TCPFlagFin, nil, 1, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 1, true}, + + {"SYN", header.TCPFlagSyn, nil, 2, true}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true}, + {"ACK", header.TCPFlagAck, nil, 2, true}, + {"FIN", header.TCPFlagFin, nil, 2, false}, + {"Data", header.TCPFlagAck, []tb.Layer{&tb.Payload{Bytes: []byte("abc123")}}, 2, true}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Handshake() + acceptFD, _ := dut.Accept(listenFD) + defer dut.Close(acceptFD) + + windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset + conn.Drain() + // Ignore whatever incrementing that this out-of-order packet might cause + // to the AckNum. + localSeqNum := tb.Uint32(uint32(*conn.LocalSeqNum())) + conn.Send(tb.TCP{ + Flags: tb.Uint8(tt.tcpFlags), + SeqNum: tb.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))), + }, tt.payload...) + timeout := 3 * time.Second + gotACK, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + if tt.expectACK && err != nil { + t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) + } + if !tt.expectACK && gotACK != nil { + t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_should_piggyback_test.go b/test/packetimpact/tests/tcp_should_piggyback_test.go new file mode 100644 index 000000000..b0be6ba23 --- /dev/null +++ b/test/packetimpact/tests/tcp_should_piggyback_test.go @@ -0,0 +1,59 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_should_piggyback_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestPiggyback(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort, WindowSize: tb.Uint16(12)}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + + dut.Send(acceptFd, sampleData, 0) + expectedTCP := tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} + expectedPayload := tb.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err) + } + + // Cause DUT to send us more data as soon as we ACK their first data segment because we have + // a small window. + dut.Send(acceptFd, sampleData, 0) + + // DUT should ACK our segment by piggybacking ACK to their outstanding data segment instead of + // sending a separate ACK packet. + conn.Send(expectedTCP, &expectedPayload) + if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + t.Fatalf("Expected %v but didn't get one: %s", tb.Layers{&expectedTCP, &expectedPayload}, err) + } +} diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go new file mode 100644 index 000000000..3cf82badb --- /dev/null +++ b/test/packetimpact/tests/tcp_user_timeout_test.go @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_user_timeout_test + +import ( + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func sendPayload(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error { + sampleData := make([]byte, 100) + for i := range sampleData { + sampleData[i] = uint8(i) + } + conn.Drain() + dut.Send(fd, sampleData, 0) + if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &tb.Payload{Bytes: sampleData}, time.Second); err != nil { + return fmt.Errorf("expected data but got none: %w", err) + } + return nil +} + +func sendFIN(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error { + dut.Close(fd) + return nil +} + +func TestTCPUserTimeout(t *testing.T) { + for _, tt := range []struct { + description string + userTimeout time.Duration + sendDelay time.Duration + }{ + {"NoUserTimeout", 0, 3 * time.Second}, + {"ACKBeforeUserTimeout", 5 * time.Second, 4 * time.Second}, + {"ACKAfterUserTimeout", 5 * time.Second, 7 * time.Second}, + } { + for _, ttf := range []struct { + description string + f func(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error + }{ + {"AfterPayload", sendPayload}, + {"AfterFIN", sendFIN}, + } { + t.Run(tt.description+ttf.description, func(t *testing.T) { + // Create a socket, listen, TCP handshake, and accept. + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFD) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + conn.Handshake() + acceptFD, _ := dut.Accept(listenFD) + + if tt.userTimeout != 0 { + dut.SetSockOptInt(acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) + } + + if err := ttf.f(&conn, &dut, acceptFD); err != nil { + t.Fatal(err) + } + + time.Sleep(tt.sendDelay) + conn.Drain() + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + + // If TCP_USER_TIMEOUT was set and the above delay was longer than the + // TCP_USER_TIMEOUT then the DUT should send a RST in response to the + // testbench's packet. + expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout + expectTimeout := 5 * time.Second + got, err := conn.Expect(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, expectTimeout) + if expectRST && err != nil { + t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err) + } + if !expectRST && got != nil { + t.Errorf("expected no RST packet within %s but got one: %s", expectTimeout, got) + } + }) + } + } +} diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go new file mode 100644 index 000000000..c9354074e --- /dev/null +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -0,0 +1,68 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_window_shrink_test + +import ( + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestWindowShrink(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(listenFd) + conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) + defer conn.Close() + + conn.Handshake() + acceptFd, _ := dut.Accept(listenFd) + defer dut.Close(acceptFd) + + dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + sampleData := []byte("Sample Data") + samplePayload := &tb.Payload{Bytes: sampleData} + + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + + dut.Send(acceptFd, sampleData, 0) + dut.Send(acceptFd, sampleData, 0) + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + if _, err := conn.ExpectData(&tb.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + // We close our receiving window here + conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), WindowSize: tb.Uint16(0)}) + + dut.Send(acceptFd, []byte("Sample Data"), 0) + // Note: There is another kind of zero-window probing which Windows uses (by sending one + // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change + // the following lines. + expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1 + if _, err := conn.ExpectData(&tb.TCP{SeqNum: tb.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { + t.Fatalf("expected a packet with sequence number %v: %s", expectedRemoteSeqNum, err) + } +} diff --git a/test/packetimpact/tests/test_runner.sh b/test/packetimpact/tests/test_runner.sh new file mode 100755 index 000000000..46d63d5e5 --- /dev/null +++ b/test/packetimpact/tests/test_runner.sh @@ -0,0 +1,285 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Run a packetimpact test. Two docker containers are made, one for the +# Device-Under-Test (DUT) and one for the test bench. Each is attached with +# two networks, one for control packets that aid the test and one for test +# packets which are sent as part of the test and observed for correctness. + +set -euxo pipefail + +function failure() { + local lineno=$1 + local msg=$2 + local filename="$0" + echo "FAIL: $filename:$lineno: $msg" +} +trap 'failure ${LINENO} "$BASH_COMMAND"' ERR + +declare -r LONGOPTS="dut_platform:,posix_server_binary:,testbench_binary:,runtime:,tshark,extra_test_arg:,expect_failure" + +# Don't use declare below so that the error from getopt will end the script. +PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@") + +eval set -- "$PARSED" + +declare -a EXTRA_TEST_ARGS + +while true; do + case "$1" in + --dut_platform) + # Either "linux" or "netstack". + declare -r DUT_PLATFORM="$2" + shift 2 + ;; + --posix_server_binary) + declare -r POSIX_SERVER_BINARY="$2" + shift 2 + ;; + --testbench_binary) + declare -r TESTBENCH_BINARY="$2" + shift 2 + ;; + --runtime) + # Not readonly because there might be multiple --runtime arguments and we + # want to use just the last one. Only used if --dut_platform is + # "netstack". + declare RUNTIME="$2" + shift 2 + ;; + --tshark) + declare -r TSHARK="1" + shift 1 + ;; + --extra_test_arg) + EXTRA_TEST_ARGS+="$2" + shift 2 + ;; + --expect_failure) + declare -r EXPECT_FAILURE="1" + shift 1 + ;; + --) + shift + break + ;; + *) + echo "Programming error" + exit 3 + esac +done + +# All the other arguments are scripts. +declare -r scripts="$@" + +# Check that the required flags are defined in a way that is safe for "set -u". +if [[ "${DUT_PLATFORM-}" == "netstack" ]]; then + if [[ -z "${RUNTIME-}" ]]; then + echo "FAIL: Missing --runtime argument: ${RUNTIME-}" + exit 2 + fi + declare -r RUNTIME_ARG="--runtime ${RUNTIME}" +elif [[ "${DUT_PLATFORM-}" == "linux" ]]; then + declare -r RUNTIME_ARG="" +else + echo "FAIL: Bad or missing --dut_platform argument: ${DUT_PLATFORM-}" + exit 2 +fi +if [[ ! -f "${POSIX_SERVER_BINARY-}" ]]; then + echo "FAIL: Bad or missing --posix_server_binary: ${POSIX_SERVER-}" + exit 2 +fi +if [[ ! -f "${TESTBENCH_BINARY-}" ]]; then + echo "FAIL: Bad or missing --testbench_binary: ${TESTBENCH_BINARY-}" + exit 2 +fi + +function new_net_prefix() { + # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. + echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)" +} + +# Variables specific to the control network and interface start with CTRL_. +# Variables specific to the test network and interface start with TEST_. +# Variables specific to the DUT start with DUT_. +# Variables specific to the test bench start with TESTBENCH_. +# Use random numbers so that test networks don't collide. +declare CTRL_NET="ctrl_net-${RANDOM}${RANDOM}" +declare CTRL_NET_PREFIX=$(new_net_prefix) +declare TEST_NET="test_net-${RANDOM}${RANDOM}" +declare TEST_NET_PREFIX=$(new_net_prefix) +# On both DUT and test bench, testing packets are on the eth2 interface. +declare -r TEST_DEVICE="eth2" +# Number of bits in the *_NET_PREFIX variables. +declare -r NET_MASK="24" +# Last bits of the DUT's IP address. +declare -r DUT_NET_SUFFIX=".10" +# Control port. +declare -r CTRL_PORT="40000" +# Last bits of the test bench's IP address. +declare -r TESTBENCH_NET_SUFFIX=".20" +declare -r TIMEOUT="60" +declare -r IMAGE_TAG="gcr.io/gvisor-presubmit/packetimpact" + +# Make sure that docker is installed. +docker --version + +function finish { + local cleanup_success=1 + + if [[ -z "${TSHARK-}" ]]; then + # Kill tcpdump so that it will flush output. + docker exec -t "${TESTBENCH}" \ + killall tcpdump || \ + cleanup_success=0 + else + # Kill tshark so that it will flush output. + docker exec -t "${TESTBENCH}" \ + killall tshark || \ + cleanup_success=0 + fi + + for net in "${CTRL_NET}" "${TEST_NET}"; do + # Kill all processes attached to ${net}. + for docker_command in "kill" "rm"; do + (docker network inspect "${net}" \ + --format '{{range $key, $value := .Containers}}{{$key}} {{end}}' \ + | xargs -r docker "${docker_command}") || \ + cleanup_success=0 + done + # Remove the network. + docker network rm "${net}" || \ + cleanup_success=0 + done + + if ((!$cleanup_success)); then + echo "FAIL: Cleanup command failed" + exit 4 + fi +} +trap finish EXIT + +# Subnet for control packets between test bench and DUT. +while ! docker network create \ + "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do + sleep 0.1 + CTRL_NET_PREFIX=$(new_net_prefix) + CTRL_NET="ctrl_net-${RANDOM}${RANDOM}" +done + +# Subnet for the packets that are part of the test. +while ! docker network create \ + "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do + sleep 0.1 + TEST_NET_PREFIX=$(new_net_prefix) + TEST_NET="test_net-${RANDOM}${RANDOM}" +done + +docker pull "${IMAGE_TAG}" + +# Create the DUT container and connect to network. +DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \ + --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG}) +docker network connect "${CTRL_NET}" \ + --ip "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \ + || (docker kill ${DUT}; docker rm ${DUT}; false) +docker network connect "${TEST_NET}" \ + --ip "${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \ + || (docker kill ${DUT}; docker rm ${DUT}; false) +docker start "${DUT}" + +# Create the test bench container and connect to network. +TESTBENCH=$(docker create --privileged --rm \ + --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG}) +docker network connect "${CTRL_NET}" \ + --ip "${CTRL_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" "${TESTBENCH}" \ + || (docker kill ${TESTBENCH}; docker rm ${TESTBENCH}; false) +docker network connect "${TEST_NET}" \ + --ip "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" "${TESTBENCH}" \ + || (docker kill ${TESTBENCH}; docker rm ${TESTBENCH}; false) +docker start "${TESTBENCH}" + +# Start the posix_server in the DUT. +declare -r DOCKER_POSIX_SERVER_BINARY="/$(basename ${POSIX_SERVER_BINARY})" +docker cp -L ${POSIX_SERVER_BINARY} "${DUT}:${DOCKER_POSIX_SERVER_BINARY}" + +docker exec -t "${DUT}" \ + /bin/bash -c "${DOCKER_POSIX_SERVER_BINARY} \ + --ip ${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \ + --port ${CTRL_PORT}" & + +# Because the Linux kernel receives the SYN-ACK but didn't send the SYN it will +# issue a RST. To prevent this IPtables can be used to filter those out. +docker exec "${TESTBENCH}" \ + iptables -A INPUT -i ${TEST_DEVICE} -j DROP + +# Wait for the DUT server to come up. Attempt to connect to it from the test +# bench every 100 milliseconds until success. +while ! docker exec "${TESTBENCH}" \ + nc -zv "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${CTRL_PORT}"; do + sleep 0.1 +done + +declare -r REMOTE_MAC=$(docker exec -t "${DUT}" ip link show \ + "${TEST_DEVICE}" | tail -1 | cut -d' ' -f6) +declare -r LOCAL_MAC=$(docker exec -t "${TESTBENCH}" ip link show \ + "${TEST_DEVICE}" | tail -1 | cut -d' ' -f6) + +declare -r DOCKER_TESTBENCH_BINARY="/$(basename ${TESTBENCH_BINARY})" +docker cp -L "${TESTBENCH_BINARY}" "${TESTBENCH}:${DOCKER_TESTBENCH_BINARY}" + +if [[ -z "${TSHARK-}" ]]; then + # Run tcpdump in the test bench unbuffered, without dns resolution, just on + # the interface with the test packets. + docker exec -t "${TESTBENCH}" \ + tcpdump -S -vvv -U -n -i "${TEST_DEVICE}" net "${TEST_NET_PREFIX}/24" & +else + # Run tshark in the test bench unbuffered, without dns resolution, just on the + # interface with the test packets. + docker exec -t "${TESTBENCH}" \ + tshark -V -l -n -i "${TEST_DEVICE}" \ + -o tcp.check_checksum:TRUE \ + -o udp.check_checksum:TRUE \ + host "${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX}" & +fi + +# tcpdump and tshark take time to startup +sleep 3 + +# Start a packetimpact test on the test bench. The packetimpact test sends and +# receives packets and also sends POSIX socket commands to the posix_server to +# be executed on the DUT. +docker exec -t "${TESTBENCH}" \ + /bin/bash -c "${DOCKER_TESTBENCH_BINARY} \ + ${EXTRA_TEST_ARGS[@]-} \ + --posix_server_ip=${CTRL_NET_PREFIX}${DUT_NET_SUFFIX} \ + --posix_server_port=${CTRL_PORT} \ + --remote_ipv4=${TEST_NET_PREFIX}${DUT_NET_SUFFIX} \ + --local_ipv4=${TEST_NET_PREFIX}${TESTBENCH_NET_SUFFIX} \ + --remote_mac=${REMOTE_MAC} \ + --local_mac=${LOCAL_MAC} \ + --device=${TEST_DEVICE}" && true +declare -r TEST_RESULT="${?}" +if [[ -z "${EXPECT_FAILURE-}" && "${TEST_RESULT}" != 0 ]]; then + echo 'FAIL: This test was expected to pass.' + exit ${TEST_RESULT} +fi +if [[ ! -z "${EXPECT_FAILURE-}" && "${TEST_RESULT}" == 0 ]]; then + echo 'FAIL: This test was expected to fail but passed. Enable the test and' \ + 'mark the corresponding bug as fixed.' + exit 1 +fi +echo PASS: No errors. diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go new file mode 100644 index 000000000..61fd17050 --- /dev/null +++ b/test/packetimpact/tests/udp_recv_multicast_test.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_recv_multicast_test + +import ( + "net" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + tb "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func TestUDPRecvMulticast(t *testing.T) { + dut := tb.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) + defer dut.Close(boundFD) + conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort}) + defer conn.Close() + frame := conn.CreateFrame(&tb.UDP{}, &tb.Payload{Bytes: []byte("hello world")}) + frame[1].(*tb.IPv4).DstAddr = tb.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4())) + conn.SendFrame(frame) + dut.Recv(boundFD, 100, 0) +} diff --git a/test/perf/BUILD b/test/perf/BUILD new file mode 100644 index 000000000..471d8c2ab --- /dev/null +++ b/test/perf/BUILD @@ -0,0 +1,117 @@ +load("//test/runner:defs.bzl", "syscall_test") + +package(licenses = ["notice"]) + +syscall_test( + test = "//test/perf/linux:clock_getres_benchmark", +) + +syscall_test( + test = "//test/perf/linux:clock_gettime_benchmark", +) + +syscall_test( + test = "//test/perf/linux:death_benchmark", +) + +syscall_test( + test = "//test/perf/linux:epoll_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:fork_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:futex_benchmark", +) + +syscall_test( + size = "enormous", + shard_count = 10, + tags = ["nogotsan"], + test = "//test/perf/linux:getdents_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:getpid_benchmark", +) + +syscall_test( + size = "enormous", + tags = ["nogotsan"], + test = "//test/perf/linux:gettid_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:mapping_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:open_benchmark", +) + +syscall_test( + test = "//test/perf/linux:pipe_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:randread_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:read_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:sched_yield_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:send_recv_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:seqwrite_benchmark", +) + +syscall_test( + size = "enormous", + test = "//test/perf/linux:signal_benchmark", +) + +syscall_test( + test = "//test/perf/linux:sleep_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:stat_benchmark", +) + +syscall_test( + size = "enormous", + add_overlay = True, + test = "//test/perf/linux:unlink_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:write_benchmark", +) diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD new file mode 100644 index 000000000..b4e907826 --- /dev/null +++ b/test/perf/linux/BUILD @@ -0,0 +1,356 @@ +load("//tools:defs.bzl", "cc_binary", "gbenchmark", "gtest") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +cc_binary( + name = "getpid_benchmark", + testonly = 1, + srcs = [ + "getpid_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "send_recv_benchmark", + testonly = 1, + srcs = [ + "send_recv_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/syscalls/linux:socket_test_util", + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_binary( + name = "gettid_benchmark", + testonly = 1, + srcs = [ + "gettid_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "sched_yield_benchmark", + testonly = 1, + srcs = [ + "sched_yield_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "clock_getres_benchmark", + testonly = 1, + srcs = [ + "clock_getres_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "clock_gettime_benchmark", + testonly = 1, + srcs = [ + "clock_gettime_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "open_benchmark", + testonly = 1, + srcs = [ + "open_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + ], +) + +cc_binary( + name = "read_benchmark", + testonly = 1, + srcs = [ + "read_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "randread_benchmark", + testonly = 1, + srcs = [ + "randread_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/random", + ], +) + +cc_binary( + name = "write_benchmark", + testonly = 1, + srcs = [ + "write_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "seqwrite_benchmark", + testonly = 1, + srcs = [ + "seqwrite_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/random", + ], +) + +cc_binary( + name = "pipe_benchmark", + testonly = 1, + srcs = [ + "pipe_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + +cc_binary( + name = "fork_benchmark", + testonly = 1, + srcs = [ + "fork_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:cleanup", + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_binary( + name = "futex_benchmark", + testonly = 1, + srcs = [ + "futex_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:thread_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "epoll_benchmark", + testonly = 1, + srcs = [ + "epoll_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:epoll_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "death_benchmark", + testonly = 1, + srcs = [ + "death_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + ], +) + +cc_binary( + name = "mapping_benchmark", + testonly = 1, + srcs = [ + "mapping_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:memory_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "signal_benchmark", + testonly = 1, + srcs = [ + "signal_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "getdents_benchmark", + testonly = 1, + srcs = [ + "getdents_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "sleep_benchmark", + testonly = 1, + srcs = [ + "sleep_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + ], +) + +cc_binary( + name = "stat_benchmark", + testonly = 1, + srcs = [ + "stat_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "unlink_benchmark", + testonly = 1, + srcs = [ + "unlink_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) diff --git a/test/perf/linux/clock_getres_benchmark.cc b/test/perf/linux/clock_getres_benchmark.cc new file mode 100644 index 000000000..b051293ad --- /dev/null +++ b/test/perf/linux/clock_getres_benchmark.cc @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <time.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +// clock_getres(1) is very nearly a no-op syscall, but it does require copying +// out to a userspace struct. It thus provides a nice small copy-out benchmark. +void BM_ClockGetRes(benchmark::State& state) { + struct timespec ts; + for (auto _ : state) { + clock_getres(CLOCK_MONOTONIC, &ts); + } +} + +BENCHMARK(BM_ClockGetRes); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/clock_gettime_benchmark.cc b/test/perf/linux/clock_gettime_benchmark.cc new file mode 100644 index 000000000..6691bebd9 --- /dev/null +++ b/test/perf/linux/clock_gettime_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <pthread.h> +#include <time.h> + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_ClockGettimeThreadCPUTime(benchmark::State& state) { + clockid_t clockid; + ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid)); + struct timespec tp; + + for (auto _ : state) { + clock_gettime(clockid, &tp); + } +} + +BENCHMARK(BM_ClockGettimeThreadCPUTime); + +void BM_VDSOClockGettime(benchmark::State& state) { + const clockid_t clock = state.range(0); + struct timespec tp; + absl::Time start = absl::Now(); + + // Don't benchmark the calibration phase. + while (absl::Now() < start + absl::Milliseconds(2100)) { + clock_gettime(clock, &tp); + } + + for (auto _ : state) { + clock_gettime(clock, &tp); + } +} + +BENCHMARK(BM_VDSOClockGettime)->Arg(CLOCK_MONOTONIC)->Arg(CLOCK_REALTIME); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/death_benchmark.cc b/test/perf/linux/death_benchmark.cc new file mode 100644 index 000000000..cb2b6fd07 --- /dev/null +++ b/test/perf/linux/death_benchmark.cc @@ -0,0 +1,36 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <signal.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +namespace { + +// DeathTest is not so much a microbenchmark as a macrobenchmark. It is testing +// the ability of gVisor (on whatever platform) to execute all the related +// stack-dumping routines associated with EXPECT_EXIT / EXPECT_DEATH. +TEST(DeathTest, ZeroEqualsOne) { + EXPECT_EXIT({ TEST_CHECK(0 == 1); }, ::testing::KilledBySignal(SIGABRT), ""); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/epoll_benchmark.cc b/test/perf/linux/epoll_benchmark.cc new file mode 100644 index 000000000..0b121338a --- /dev/null +++ b/test/perf/linux/epoll_benchmark.cc @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/epoll.h> +#include <sys/eventfd.h> + +#include <atomic> +#include <cerrno> +#include <cstdint> +#include <cstdlib> +#include <ctime> +#include <memory> + +#include "gtest/gtest.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" +#include "test/util/epoll_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Returns a new eventfd. +PosixErrorOr<FileDescriptor> NewEventFD() { + int fd = eventfd(0, /* flags = */ 0); + MaybeSave(); + if (fd < 0) { + return PosixError(errno, "eventfd"); + } + return FileDescriptor(fd); +} + +// Also stolen from epoll.cc unit tests. +void BM_EpollTimeout(benchmark::State& state) { + constexpr int kFDsPerEpoll = 3; + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + + std::vector<FileDescriptor> eventfds; + for (int i = 0; i < kFDsPerEpoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO( + RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); + } + + struct epoll_event result[kFDsPerEpoll]; + int timeout_ms = state.range(0); + + for (auto _ : state) { + EXPECT_EQ(0, epoll_wait(epollfd.get(), result, kFDsPerEpoll, timeout_ms)); + } +} + +BENCHMARK(BM_EpollTimeout)->Range(0, 8); + +// Also stolen from epoll.cc unit tests. +void BM_EpollAllEvents(benchmark::State& state) { + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + const int fds_per_epoll = state.range(0); + constexpr uint64_t kEventVal = 5; + + std::vector<FileDescriptor> eventfds; + for (int i = 0; i < fds_per_epoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO( + RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); + + ASSERT_THAT(WriteFd(eventfds[i].get(), &kEventVal, sizeof(kEventVal)), + SyscallSucceedsWithValue(sizeof(kEventVal))); + } + + std::vector<struct epoll_event> result(fds_per_epoll); + + for (auto _ : state) { + EXPECT_EQ(fds_per_epoll, + epoll_wait(epollfd.get(), result.data(), fds_per_epoll, 0)); + } +} + +BENCHMARK(BM_EpollAllEvents)->Range(2, 1024); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/fork_benchmark.cc b/test/perf/linux/fork_benchmark.cc new file mode 100644 index 000000000..84fdbc8a0 --- /dev/null +++ b/test/perf/linux/fork_benchmark.cc @@ -0,0 +1,350 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/synchronization/barrier.h" +#include "benchmark/benchmark.h" +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kBusyMax = 250; + +// Do some CPU-bound busy-work. +int busy(int max) { + // Prevent the compiler from optimizing this work away, + volatile int count = 0; + + for (int i = 1; i < max; i++) { + for (int j = 2; j < i / 2; j++) { + if (i % j == 0) { + count++; + } + } + } + + return count; +} + +void BM_CPUBoundUniprocess(benchmark::State& state) { + for (auto _ : state) { + busy(kBusyMax); + } +} + +BENCHMARK(BM_CPUBoundUniprocess); + +void BM_CPUBoundAsymmetric(benchmark::State& state) { + const size_t max = state.max_iterations; + pid_t child = fork(); + if (child == 0) { + for (int i = 0; i < max; i++) { + busy(kBusyMax); + } + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + ASSERT_TRUE(state.KeepRunningBatch(max)); + + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + ASSERT_FALSE(state.KeepRunning()); +} + +BENCHMARK(BM_CPUBoundAsymmetric)->UseRealTime(); + +void BM_CPUBoundSymmetric(benchmark::State& state) { + std::vector<pid_t> children; + auto child_cleanup = Cleanup([&] { + for (const pid_t child : children) { + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + } + ASSERT_FALSE(state.KeepRunning()); + }); + + const int processes = state.range(0); + for (int i = 0; i < processes; i++) { + size_t cur = (state.max_iterations + (processes - 1)) / processes; + if ((state.iterations() + cur) >= state.max_iterations) { + cur = state.max_iterations - state.iterations(); + } + pid_t child = fork(); + if (child == 0) { + for (int i = 0; i < cur; i++) { + busy(kBusyMax); + } + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + if (cur > 0) { + // We can have a zero cur here, depending. + ASSERT_TRUE(state.KeepRunningBatch(cur)); + } + children.push_back(child); + } +} + +BENCHMARK(BM_CPUBoundSymmetric)->Range(2, 16)->UseRealTime(); + +// Child routine for ProcessSwitch/ThreadSwitch. +// Reads from readfd and writes the result to writefd. +void SwitchChild(int readfd, int writefd) { + while (1) { + char buf; + int ret = ReadFd(readfd, &buf, 1); + if (ret == 0) { + break; + } + TEST_CHECK_MSG(ret == 1, "read failed"); + + ret = WriteFd(writefd, &buf, 1); + if (ret == -1) { + TEST_CHECK_MSG(errno == EPIPE, "unexpected write failure"); + break; + } + TEST_CHECK_MSG(ret == 1, "write failed"); + } +} + +// Send bytes in a loop through a series of pipes, each passing through a +// different process. +// +// Proc 0 Proc 1 +// * ----------> * +// ^ Pipe 1 | +// | | +// | Pipe 0 | Pipe 2 +// | | +// | | +// | Pipe 3 v +// * <---------- * +// Proc 3 Proc 2 +// +// This exercises context switching through multiple processes. +void BM_ProcessSwitch(benchmark::State& state) { + // Code below assumes there are at least two processes. + const int num_processes = state.range(0); + ASSERT_GE(num_processes, 2); + + std::vector<pid_t> children; + auto child_cleanup = Cleanup([&] { + for (const pid_t child : children) { + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + } + }); + + // Must come after children, as the FDs must be closed before the children + // will exit. + std::vector<FileDescriptor> read_fds; + std::vector<FileDescriptor> write_fds; + + for (int i = 0; i < num_processes; i++) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + read_fds.emplace_back(fds[0]); + write_fds.emplace_back(fds[1]); + } + + // This process is one of the processes in the loop. It will be considered + // index 0. + for (int i = 1; i < num_processes; i++) { + // Read from current pipe index, write to next. + const int read_index = i; + const int read_fd = read_fds[read_index].get(); + + const int write_index = (i + 1) % num_processes; + const int write_fd = write_fds[write_index].get(); + + // std::vector isn't safe to use from the fork child. + FileDescriptor* read_array = read_fds.data(); + FileDescriptor* write_array = write_fds.data(); + + pid_t child = fork(); + if (!child) { + // Close all other FDs. + for (int j = 0; j < num_processes; j++) { + if (j != read_index) { + read_array[j].reset(); + } + if (j != write_index) { + write_array[j].reset(); + } + } + + SwitchChild(read_fd, write_fd); + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + children.push_back(child); + } + + // Read from current pipe index (0), write to next (1). + const int read_index = 0; + const int read_fd = read_fds[read_index].get(); + + const int write_index = 1; + const int write_fd = write_fds[write_index].get(); + + // Kick start the loop. + char buf = 'a'; + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + + for (auto _ : state) { + ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + } +} + +BENCHMARK(BM_ProcessSwitch)->Range(2, 16)->UseRealTime(); + +// Equivalent to BM_ThreadSwitch using threads instead of processes. +void BM_ThreadSwitch(benchmark::State& state) { + // Code below assumes there are at least two threads. + const int num_threads = state.range(0); + ASSERT_GE(num_threads, 2); + + // Must come after threads, as the FDs must be closed before the children + // will exit. + std::vector<std::unique_ptr<ScopedThread>> threads; + std::vector<FileDescriptor> read_fds; + std::vector<FileDescriptor> write_fds; + + for (int i = 0; i < num_threads; i++) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + read_fds.emplace_back(fds[0]); + write_fds.emplace_back(fds[1]); + } + + // This thread is one of the threads in the loop. It will be considered + // index 0. + for (int i = 1; i < num_threads; i++) { + // Read from current pipe index, write to next. + // + // Transfer ownership of the FDs to the thread. + const int read_index = i; + const int read_fd = read_fds[read_index].release(); + + const int write_index = (i + 1) % num_threads; + const int write_fd = write_fds[write_index].release(); + + threads.emplace_back(std::make_unique<ScopedThread>([read_fd, write_fd] { + FileDescriptor read(read_fd); + FileDescriptor write(write_fd); + SwitchChild(read.get(), write.get()); + })); + } + + // Read from current pipe index (0), write to next (1). + const int read_index = 0; + const int read_fd = read_fds[read_index].get(); + + const int write_index = 1; + const int write_fd = write_fds[write_index].get(); + + // Kick start the loop. + char buf = 'a'; + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + + for (auto _ : state) { + ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + } + + // The two FDs still owned by this thread are closed, causing the next thread + // to exit its loop and close its FDs, and so on until all threads exit. +} + +BENCHMARK(BM_ThreadSwitch)->Range(2, 16)->UseRealTime(); + +void BM_ThreadStart(benchmark::State& state) { + const int num_threads = state.range(0); + + for (auto _ : state) { + state.PauseTiming(); + + auto barrier = new absl::Barrier(num_threads + 1); + std::vector<std::unique_ptr<ScopedThread>> threads; + + state.ResumeTiming(); + + for (size_t i = 0; i < num_threads; ++i) { + threads.emplace_back(std::make_unique<ScopedThread>([barrier] { + if (barrier->Block()) { + delete barrier; + } + })); + } + + if (barrier->Block()) { + delete barrier; + } + + state.PauseTiming(); + + for (const auto& thread : threads) { + thread->Join(); + } + + state.ResumeTiming(); + } +} + +BENCHMARK(BM_ThreadStart)->Range(1, 2048)->UseRealTime(); + +// Benchmark the complete fork + exit + wait. +void BM_ProcessLifecycle(benchmark::State& state) { + const int num_procs = state.range(0); + + std::vector<pid_t> pids(num_procs); + for (auto _ : state) { + for (size_t i = 0; i < num_procs; ++i) { + int pid = fork(); + if (pid == 0) { + _exit(0); + } + ASSERT_THAT(pid, SyscallSucceeds()); + pids[i] = pid; + } + + for (const int pid : pids) { + ASSERT_THAT(RetryEINTR(waitpid)(pid, nullptr, 0), + SyscallSucceedsWithValue(pid)); + } + } +} + +BENCHMARK(BM_ProcessLifecycle)->Range(1, 512)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc new file mode 100644 index 000000000..241f39896 --- /dev/null +++ b/test/perf/linux/futex_benchmark.cc @@ -0,0 +1,198 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/futex.h> + +#include <atomic> +#include <cerrno> +#include <cstdint> +#include <cstdlib> +#include <ctime> + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +inline int FutexWait(std::atomic<int32_t>* v, int32_t val) { + return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, nullptr); +} + +inline int FutexWaitMonotonicTimeout(std::atomic<int32_t>* v, int32_t val, + const struct timespec* timeout) { + return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, timeout); +} + +inline int FutexWaitMonotonicDeadline(std::atomic<int32_t>* v, int32_t val, + const struct timespec* deadline) { + return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE, val, deadline, + nullptr, FUTEX_BITSET_MATCH_ANY); +} + +inline int FutexWaitRealtimeDeadline(std::atomic<int32_t>* v, int32_t val, + const struct timespec* deadline) { + return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME, + val, deadline, nullptr, FUTEX_BITSET_MATCH_ANY); +} + +inline int FutexWake(std::atomic<int32_t>* v, int32_t count) { + return syscall(SYS_futex, v, FUTEX_WAKE_PRIVATE, count); +} + +// This just uses FUTEX_WAKE on an address with nothing waiting, very simple. +void BM_FutexWakeNop(benchmark::State& state) { + std::atomic<int32_t> v(0); + + for (auto _ : state) { + TEST_PCHECK(FutexWake(&v, 1) == 0); + } +} + +BENCHMARK(BM_FutexWakeNop)->MinTime(5); + +// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the +// syscall won't wait. +void BM_FutexWaitNop(benchmark::State& state) { + std::atomic<int32_t> v(0); + + for (auto _ : state) { + TEST_PCHECK(FutexWait(&v, 1) == -1 && errno == EAGAIN); + } +} + +BENCHMARK(BM_FutexWaitNop)->MinTime(5); + +// This uses FUTEX_WAIT with a timeout on an address whose value never +// changes, such that it always times out. Timeout overhead can be estimated by +// timer overruns for short timeouts. +void BM_FutexWaitMonotonicTimeout(benchmark::State& state) { + const int timeout_ns = state.range(0); + std::atomic<int32_t> v(0); + auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns)); + + for (auto _ : state) { + TEST_PCHECK(FutexWaitMonotonicTimeout(&v, 0, &ts) == -1 && + errno == ETIMEDOUT); + } +} + +BENCHMARK(BM_FutexWaitMonotonicTimeout) + ->MinTime(5) + ->UseRealTime() + ->Arg(1) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(10000); + +// This uses FUTEX_WAIT_BITSET with a deadline that is in the past. This allows +// estimation of the overhead of setting up a timer for a deadline (as opposed +// to a timeout as specified for FUTEX_WAIT). +void BM_FutexWaitMonotonicDeadline(benchmark::State& state) { + std::atomic<int32_t> v(0); + struct timespec ts = {}; + + for (auto _ : state) { + TEST_PCHECK(FutexWaitMonotonicDeadline(&v, 0, &ts) == -1 && + errno == ETIMEDOUT); + } +} + +BENCHMARK(BM_FutexWaitMonotonicDeadline)->MinTime(5); + +// This is equivalent to BM_FutexWaitMonotonicDeadline, but uses CLOCK_REALTIME +// instead of CLOCK_MONOTONIC for the deadline. +void BM_FutexWaitRealtimeDeadline(benchmark::State& state) { + std::atomic<int32_t> v(0); + struct timespec ts = {}; + + for (auto _ : state) { + TEST_PCHECK(FutexWaitRealtimeDeadline(&v, 0, &ts) == -1 && + errno == ETIMEDOUT); + } +} + +BENCHMARK(BM_FutexWaitRealtimeDeadline)->MinTime(5); + +int64_t GetCurrentMonotonicTimeNanos() { + struct timespec ts; + TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) != -1); + return ts.tv_sec * 1000000000ULL + ts.tv_nsec; +} + +void SpinNanos(int64_t delay_ns) { + if (delay_ns <= 0) { + return; + } + const int64_t end = GetCurrentMonotonicTimeNanos() + delay_ns; + while (GetCurrentMonotonicTimeNanos() < end) { + // spin + } +} + +// Each iteration of FutexRoundtripDelayed involves a thread sending a futex +// wakeup to another thread, which spins for delay_us and then sends a futex +// wakeup back. The time per iteration is 2 * (delay_us + kBeforeWakeDelayNs + +// futex/scheduling overhead). +void BM_FutexRoundtripDelayed(benchmark::State& state) { + const int delay_us = state.range(0); + const int64_t delay_ns = delay_us * 1000; + // Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce + // the probability that the wakeup comes before the wait, preventing the wait + // from ever taking effect and causing the benchmark to underestimate the + // actual wakeup time. + constexpr int64_t kBeforeWakeDelayNs = 500; + std::atomic<int32_t> v(0); + ScopedThread t([&] { + for (int i = 0; i < state.max_iterations; i++) { + SpinNanos(delay_ns); + while (v.load(std::memory_order_acquire) == 0) { + FutexWait(&v, 0); + } + SpinNanos(kBeforeWakeDelayNs + delay_ns); + v.store(0, std::memory_order_release); + FutexWake(&v, 1); + } + }); + for (auto _ : state) { + SpinNanos(kBeforeWakeDelayNs + delay_ns); + v.store(1, std::memory_order_release); + FutexWake(&v, 1); + SpinNanos(delay_ns); + while (v.load(std::memory_order_acquire) == 1) { + FutexWait(&v, 1); + } + } +} + +BENCHMARK(BM_FutexRoundtripDelayed) + ->MinTime(5) + ->UseRealTime() + ->Arg(0) + ->Arg(10) + ->Arg(20) + ->Arg(50) + ->Arg(100); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc new file mode 100644 index 000000000..d8e81fa8c --- /dev/null +++ b/test/perf/linux/getdents_benchmark.cc @@ -0,0 +1,149 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +#ifndef SYS_getdents64 +#if defined(__x86_64__) +#define SYS_getdents64 217 +#elif defined(__aarch64__) +#define SYS_getdents64 217 +#else +#error "Unknown architecture" +#endif +#endif // SYS_getdents64 + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kBufferSize = 65536; + +PosixErrorOr<TempPath> CreateDirectory(int count, + std::vector<std::string>* files) { + ASSIGN_OR_RETURN_ERRNO(TempPath dir, TempPath::CreateDir()); + + ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, + Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + for (int i = 0; i < count; i++) { + auto file = NewTempRelPath(); + auto res = MknodAt(dfd, file, S_IFREG | 0644, 0); + RETURN_IF_ERRNO(res); + files->push_back(file); + } + + return std::move(dir); +} + +PosixError CleanupDirectory(const TempPath& dir, + std::vector<std::string>* files) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, + Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + for (auto it = files->begin(); it != files->end(); ++it) { + auto res = UnlinkAt(dfd, *it, 0); + RETURN_IF_ERRNO(res); + } + return NoError(); +} + +// Creates a directory containing `files` files, and reads all the directory +// entries from the directory using a single FD. +void BM_GetdentsSameFD(benchmark::State& state) { + // Create directory with given files. + const int count = state.range(0); + + // Keep a vector of all of the file TempPaths that is destroyed before dir. + // + // Normally, we'd simply allow dir to recursively clean up the contained + // files, but that recursive cleanup uses getdents, which may be very slow in + // extreme benchmarks. + TempPath dir; + std::vector<std::string> files; + dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); + char buffer[kBufferSize]; + + // We read all directory entries on each iteration, but report this as a + // "batch" iteration so that reported times are per file. + while (state.KeepRunningBatch(count)) { + ASSERT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds()); + + int ret; + do { + ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), + SyscallSucceeds()); + } while (ret > 0); + } + + ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime(); + +// Creates a directory containing `files` files, and reads all the directory +// entries from the directory using a new FD each time. +void BM_GetdentsNewFD(benchmark::State& state) { + // Create directory with given files. + const int count = state.range(0); + + // Keep a vector of all of the file TempPaths that is destroyed before dir. + // + // Normally, we'd simply allow dir to recursively clean up the contained + // files, but that recursive cleanup uses getdents, which may be very slow in + // extreme benchmarks. + TempPath dir; + std::vector<std::string> files; + dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); + char buffer[kBufferSize]; + + // We read all directory entries on each iteration, but report this as a + // "batch" iteration so that reported times are per file. + while (state.KeepRunningBatch(count)) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + int ret; + do { + ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), + SyscallSucceeds()); + } while (ret > 0); + } + + ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 12)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/root/testdata/sandbox.go b/test/perf/linux/getpid_benchmark.cc index 0db210370..db74cb264 100644 --- a/test/root/testdata/sandbox.go +++ b/test/perf/linux/getpid_benchmark.cc @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,19 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testdata - -// Sandbox is a default JSON config for a sandbox. -const Sandbox = ` -{ - "metadata": { - "name": "default-sandbox", - "namespace": "default", - "attempt": 1, - "uid": "hdishd83djaidwnduwk28bcsb" - }, - "linux": { - }, - "log_directory": "/tmp" +#include <sys/syscall.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Getpid(benchmark::State& state) { + for (auto _ : state) { + syscall(SYS_getpid); + } } -` + +BENCHMARK(BM_Getpid); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/root/testdata/busybox.go b/test/perf/linux/gettid_benchmark.cc index e4dbd2843..8f4961f5e 100644 --- a/test/root/testdata/busybox.go +++ b/test/perf/linux/gettid_benchmark.cc @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testdata - -// MountOverSymlink is a JSON config for a container that /etc/resolv.conf is a -// symlink to /tmp/resolv.conf. -var MountOverSymlink = ` -{ - "metadata": { - "name": "busybox" - }, - "image": { - "image": "k8s.gcr.io/busybox" - }, - "command": [ - "sleep", - "1000" - ] +#include <sys/syscall.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Gettid(benchmark::State& state) { + for (auto _ : state) { + syscall(SYS_gettid); + } } -` + +BENCHMARK(BM_Gettid)->ThreadRange(1, 4000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/mapping_benchmark.cc b/test/perf/linux/mapping_benchmark.cc new file mode 100644 index 000000000..39c30fe69 --- /dev/null +++ b/test/perf/linux/mapping_benchmark.cc @@ -0,0 +1,163 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdlib.h> +#include <sys/mman.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/memory_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Conservative value for /proc/sys/vm/max_map_count, which limits the number of +// VMAs, minus a safety margin for VMAs that already exist for the test binary. +// The default value for max_map_count is +// include/linux/mm.h:DEFAULT_MAX_MAP_COUNT = 65530. +constexpr size_t kMaxVMAs = 64001; + +// Map then unmap pages without touching them. +void BM_MapUnmap(benchmark::State& state) { + // Number of pages to map. + const int pages = state.range(0); + + while (state.KeepRunning()) { + void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + int ret = munmap(addr, pages * kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } +} + +BENCHMARK(BM_MapUnmap)->Range(1, 1 << 17)->UseRealTime(); + +// Map, touch, then unmap pages. +void BM_MapTouchUnmap(benchmark::State& state) { + // Number of pages to map. + const int pages = state.range(0); + + while (state.KeepRunning()) { + void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + char* c = reinterpret_cast<char*>(addr); + char* end = c + pages * kPageSize; + while (c < end) { + *c = 42; + c += kPageSize; + } + + int ret = munmap(addr, pages * kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } +} + +BENCHMARK(BM_MapTouchUnmap)->Range(1, 1 << 17)->UseRealTime(); + +// Map and touch many pages, unmapping all at once. +// +// NOTE(b/111429208): This is a regression test to ensure performant mapping and +// allocation even with tons of mappings. +void BM_MapTouchMany(benchmark::State& state) { + // Number of pages to map. + const int page_count = state.range(0); + + while (state.KeepRunning()) { + std::vector<void*> pages; + + for (int i = 0; i < page_count; i++) { + void* addr = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + char* c = reinterpret_cast<char*>(addr); + *c = 42; + + pages.push_back(addr); + } + + for (void* addr : pages) { + int ret = munmap(addr, kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } + } + + state.SetBytesProcessed(kPageSize * page_count * state.iterations()); +} + +BENCHMARK(BM_MapTouchMany)->Range(1, 1 << 12)->UseRealTime(); + +void BM_PageFault(benchmark::State& state) { + // Map the region in which we will take page faults. To ensure that each page + // fault maps only a single page, each page we touch must correspond to a + // distinct VMA. Thus we need a 1-page gap between each 1-page VMA. However, + // each gap consists of a PROT_NONE VMA, instead of an unmapped hole, so that + // if there are background threads running, they can't inadvertently creating + // mappings in our gaps that are unmapped when the test ends. + size_t test_pages = kMaxVMAs; + // Ensure that test_pages is odd, since we want the test region to both + // begin and end with a mapped page. + if (test_pages % 2 == 0) { + test_pages--; + } + const size_t test_region_bytes = test_pages * kPageSize; + // Use MAP_SHARED here because madvise(MADV_DONTNEED) on private mappings on + // gVisor won't force future sentry page faults (by design). Use MAP_POPULATE + // so that Linux pre-allocates the shmem file used to back the mapping. + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(test_region_bytes, PROT_READ, MAP_SHARED | MAP_POPULATE)); + for (size_t i = 0; i < test_pages / 2; i++) { + ASSERT_THAT( + mprotect(reinterpret_cast<void*>(m.addr() + ((2 * i + 1) * kPageSize)), + kPageSize, PROT_NONE), + SyscallSucceeds()); + } + + const size_t mapped_pages = test_pages / 2 + 1; + // "Start" at the end of the mapped region to force the mapped region to be + // reset, since we mapped it with MAP_POPULATE. + size_t cur_page = mapped_pages; + for (auto _ : state) { + if (cur_page >= mapped_pages) { + // We've reached the end of our mapped region and have to reset it to + // incur page faults again. + state.PauseTiming(); + ASSERT_THAT(madvise(m.ptr(), test_region_bytes, MADV_DONTNEED), + SyscallSucceeds()); + cur_page = 0; + state.ResumeTiming(); + } + const uintptr_t addr = m.addr() + (2 * cur_page * kPageSize); + const char c = *reinterpret_cast<volatile char*>(addr); + benchmark::DoNotOptimize(c); + cur_page++; + } +} + +BENCHMARK(BM_PageFault)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/open_benchmark.cc b/test/perf/linux/open_benchmark.cc new file mode 100644 index 000000000..68008f6d5 --- /dev/null +++ b/test/perf/linux/open_benchmark.cc @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Open(benchmark::State& state) { + const int size = state.range(0); + std::vector<TempPath> cache; + for (int i = 0; i < size; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + cache.emplace_back(std::move(path)); + } + + unsigned int seed = 1; + for (auto _ : state) { + const int chosen = rand_r(&seed) % size; + int fd = open(cache[chosen].path().c_str(), O_RDONLY); + TEST_CHECK(fd != -1); + close(fd); + } +} + +BENCHMARK(BM_Open)->Range(1, 128)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/pipe_benchmark.cc b/test/perf/linux/pipe_benchmark.cc new file mode 100644 index 000000000..8f5f6a2a3 --- /dev/null +++ b/test/perf/linux/pipe_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <cerrno> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Pipe(benchmark::State& state) { + int fds[2]; + TEST_CHECK(pipe(fds) == 0); + + const int size = state.range(0); + std::vector<char> wbuf(size); + std::vector<char> rbuf(size); + RandomizeBuffer(wbuf.data(), size); + + ScopedThread t([&] { + auto const fd = fds[1]; + for (int i = 0; i < state.max_iterations; i++) { + TEST_CHECK(WriteFd(fd, wbuf.data(), wbuf.size()) == size); + } + }); + + for (auto _ : state) { + TEST_CHECK(ReadFd(fds[0], rbuf.data(), rbuf.size()) == size); + } + + t.Join(); + + close(fds[0]); + close(fds[1]); + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_Pipe)->Range(1, 1 << 20)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/randread_benchmark.cc b/test/perf/linux/randread_benchmark.cc new file mode 100644 index 000000000..b0eb8c24e --- /dev/null +++ b/test/perf/linux/randread_benchmark.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <sys/uio.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Create a 1GB file that will be read from at random positions. This should +// invalid any performance gains from caching. +const uint64_t kFileSize = 1ULL << 30; + +// How many bytes to write at once to initialize the file used to read from. +const uint32_t kWriteSize = 65536; + +// Largest benchmarked read unit. +const uint32_t kMaxRead = 1UL << 26; + +TempPath CreateFile(uint64_t file_size) { + auto path = TempPath::CreateFile().ValueOrDie(); + FileDescriptor fd = Open(path.path(), O_WRONLY).ValueOrDie(); + + // Try to minimize syscalls by using maximum size writev() requests. + std::vector<char> buffer(kWriteSize); + RandomizeBuffer(buffer.data(), buffer.size()); + const std::vector<std::vector<struct iovec>> iovecs_list = + GenerateIovecs(file_size, buffer.data(), buffer.size()); + for (const auto& iovecs : iovecs_list) { + TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0); + } + + return path; +} + +// Global test state, initialized once per process lifetime. +struct GlobalState { + const TempPath tmpfile; + explicit GlobalState(TempPath tfile) : tmpfile(std::move(tfile)) {} +}; + +GlobalState& GetGlobalState() { + // This gets created only once throughout the lifetime of the process. + // Use a dynamically allocated object (that is never deleted) to avoid order + // of destruction of static storage variables issues. + static GlobalState* const state = + // The actual file size is the maximum random seek range (kFileSize) + the + // maximum read size so we can read that number of bytes at the end of the + // file. + new GlobalState(CreateFile(kFileSize + kMaxRead)); + return *state; +} + +void BM_RandRead(benchmark::State& state) { + const int size = state.range(0); + + GlobalState& global_state = GetGlobalState(); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(global_state.tmpfile.path(), O_RDONLY)); + std::vector<char> buf(size); + + unsigned int seed = 1; + for (auto _ : state) { + TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), + rand_r(&seed) % kFileSize) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_RandRead)->Range(1, kMaxRead)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/read_benchmark.cc b/test/perf/linux/read_benchmark.cc new file mode 100644 index 000000000..62445867d --- /dev/null +++ b/test/perf/linux/read_benchmark.cc @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Read(benchmark::State& state) { + const int size = state.range(0); + const std::string contents(size, 0); + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), contents, TempPath::kDefaultFileMode)); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDONLY)); + + std::vector<char> buf(size); + for (auto _ : state) { + TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_Read)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/root/testdata/simple.go b/test/perf/linux/sched_yield_benchmark.cc index 1cca53f0c..6756b5575 100644 --- a/test/root/testdata/simple.go +++ b/test/perf/linux/sched_yield_benchmark.cc @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,30 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testdata - -import ( - "encoding/json" - "fmt" -) - -// SimpleSpec returns a JSON config for a simple container that runs the -// specified command in the specified image. -func SimpleSpec(name, image string, cmd []string) string { - cmds, err := json.Marshal(cmd) - if err != nil { - // This shouldn't happen. - panic(err) - } - return fmt.Sprintf(` -{ - "metadata": { - "name": %q - }, - "image": { - "image": %q - }, - "command": %s - } -`, name, image, cmds) +#include <sched.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Sched_yield(benchmark::State& state) { + for (auto ignored : state) { + TEST_CHECK(sched_yield() == 0); + } } + +BENCHMARK(BM_Sched_yield)->ThreadRange(1, 2000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/send_recv_benchmark.cc b/test/perf/linux/send_recv_benchmark.cc new file mode 100644 index 000000000..d73e49523 --- /dev/null +++ b/test/perf/linux/send_recv_benchmark.cc @@ -0,0 +1,372 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> + +#include <cstring> + +#include "gtest/gtest.h" +#include "absl/synchronization/notification.h" +#include "benchmark/benchmark.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr ssize_t kMessageSize = 1024; + +class Message { + public: + explicit Message(int byte = 0) : Message(byte, kMessageSize, 0) {} + + explicit Message(int byte, int sz) : Message(byte, sz, 0) {} + + explicit Message(int byte, int sz, int cmsg_sz) + : buffer_(sz, byte), cmsg_buffer_(cmsg_sz, 0) { + iov_.iov_base = buffer_.data(); + iov_.iov_len = sz; + hdr_.msg_iov = &iov_; + hdr_.msg_iovlen = 1; + hdr_.msg_control = cmsg_buffer_.data(); + hdr_.msg_controllen = cmsg_sz; + } + + struct msghdr* header() { + return &hdr_; + } + + private: + std::vector<char> buffer_; + std::vector<char> cmsg_buffer_; + struct iovec iov_ = {}; + struct msghdr hdr_ = {}; +}; + +void BM_Recvmsg(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + Message send_msg('a'), recv_msg; + + ScopedThread t([&send_msg, &send_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + sendmsg(send_socket.get(), send_msg.header(), 0); + } + }); + + int64_t bytes_received = 0; + for (auto ignored : state) { + int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_Recvmsg)->UseRealTime(); + +void BM_Sendmsg(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + Message send_msg('a'), recv_msg; + + ScopedThread t([&recv_msg, &recv_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + recvmsg(recv_socket.get(), recv_msg.header(), 0); + } + }); + + int64_t bytes_sent = 0; + for (auto ignored : state) { + int n = sendmsg(send_socket.get(), send_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_sent += n; + } + + notification.Notify(); + send_socket.reset(); + + state.SetBytesProcessed(bytes_sent); +} + +BENCHMARK(BM_Sendmsg)->UseRealTime(); + +void BM_Recvfrom(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + char send_buffer[kMessageSize], recv_buffer[kMessageSize]; + + ScopedThread t([&send_socket, &send_buffer, ¬ification] { + while (!notification.HasBeenNotified()) { + sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); + } + }); + + int bytes_received = 0; + for (auto ignored : state) { + int n = recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, + nullptr); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_Recvfrom)->UseRealTime(); + +void BM_Sendto(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + char send_buffer[kMessageSize], recv_buffer[kMessageSize]; + + ScopedThread t([&recv_socket, &recv_buffer, ¬ification] { + while (!notification.HasBeenNotified()) { + recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, + nullptr); + } + }); + + int64_t bytes_sent = 0; + for (auto ignored : state) { + int n = sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); + TEST_CHECK(n > 0); + bytes_sent += n; + } + + notification.Notify(); + send_socket.reset(); + + state.SetBytesProcessed(bytes_sent); +} + +BENCHMARK(BM_Sendto)->UseRealTime(); + +PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + addr.ss_family = family; + switch (family) { + case AF_INET: + reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr = + htonl(INADDR_LOOPBACK); + break; + case AF_INET6: + reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr = + in6addr_loopback; + break; + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } + return addr; +} + +// BM_RecvmsgWithControlBuf measures the performance of recvmsg when we allocate +// space for control messages. Note that we do not expect to receive any. +void BM_RecvmsgWithControlBuf(benchmark::State& state) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET6)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + absl::Notification notification; + Message send_msg('a'); + // Create a msghdr with a buffer allocated for control messages. + Message recv_msg(0, kMessageSize, /*cmsg_sz=*/24); + + ScopedThread t([&send_msg, &send_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + sendmsg(send_socket.get(), send_msg.header(), 0); + } + }); + + int64_t bytes_received = 0; + for (auto ignored : state) { + int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_RecvmsgWithControlBuf)->UseRealTime(); + +// BM_SendmsgTCP measures the sendmsg throughput with varying payload sizes. +// +// state.Args[0] indicates whether the underlying socket should be blocking or +// non-blocking w/ 0 indicating non-blocking and 1 to indicate blocking. +// state.Args[1] is the size of the payload to be used per sendmsg call. +void BM_SendmsgTCP(benchmark::State& state) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + // Check if we want to run the test w/ a blocking send socket + // or non-blocking. + const int blocking = state.range(0); + if (!blocking) { + // Set the send FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(send_socket.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(send_socket.get(), F_SETFL, opts), SyscallSucceeds()); + } + + absl::Notification notification; + + // Get the buffer size we should use for this iteration of the test. + const int buf_size = state.range(1); + Message send_msg('a', buf_size), recv_msg(0, buf_size); + + ScopedThread t([&recv_msg, &recv_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + TEST_CHECK(recvmsg(recv_socket.get(), recv_msg.header(), 0) >= 0); + } + }); + + int64_t bytes_sent = 0; + int ncalls = 0; + for (auto ignored : state) { + int sent = 0; + while (true) { + struct msghdr hdr = {}; + struct iovec iov = {}; + struct msghdr* snd_header = send_msg.header(); + iov.iov_base = static_cast<char*>(snd_header->msg_iov->iov_base) + sent; + iov.iov_len = snd_header->msg_iov->iov_len - sent; + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + int n = RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0); + ncalls++; + if (n > 0) { + sent += n; + if (sent == buf_size) { + break; + } + // n can be > 0 but less than requested size. In which case we don't + // poll. + continue; + } + // Poll the fd for it to become writable. + struct pollfd poll_fd = {send_socket.get(), POLL_OUT, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), + SyscallSucceedsWithValue(0)); + } + bytes_sent += static_cast<int64_t>(sent); + } + + notification.Notify(); + send_socket.reset(); + state.SetBytesProcessed(bytes_sent); +} + +void Args(benchmark::internal::Benchmark* benchmark) { + for (int blocking = 0; blocking < 2; blocking++) { + for (int buf_size = 1024; buf_size <= 256 << 20; buf_size *= 2) { + benchmark->Args({blocking, buf_size}); + } + } +} + +BENCHMARK(BM_SendmsgTCP)->Apply(&Args)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/seqwrite_benchmark.cc b/test/perf/linux/seqwrite_benchmark.cc new file mode 100644 index 000000000..af49e4477 --- /dev/null +++ b/test/perf/linux/seqwrite_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// The maximum file size of the test file, when writes get beyond this point +// they wrap around. This should be large enough to blow away caches. +const uint64_t kMaxFile = 1 << 30; + +// Perform writes of various sizes sequentially to one file. Wraps around if it +// goes above a certain maximum file size. +void BM_SeqWrite(benchmark::State& state) { + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); + + const int size = state.range(0); + std::vector<char> buf(size); + RandomizeBuffer(buf.data(), buf.size()); + + // Start writes at offset 0. + uint64_t offset = 0; + for (auto _ : state) { + TEST_CHECK(PwriteFd(fd.get(), buf.data(), buf.size(), offset) == + buf.size()); + offset += buf.size(); + // Wrap around if going above the maximum file size. + if (offset >= kMaxFile) { + offset = 0; + } + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_SeqWrite)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc new file mode 100644 index 000000000..cec679191 --- /dev/null +++ b/test/perf/linux/signal_benchmark.cc @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <signal.h> +#include <string.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void FixupHandler(int sig, siginfo_t* si, void* void_ctx) { + static unsigned int dataval = 0; + + // Skip the offending instruction. + ucontext_t* ctx = reinterpret_cast<ucontext_t*>(void_ctx); + ctx->uc_mcontext.gregs[REG_RAX] = reinterpret_cast<greg_t>(&dataval); +} + +void BM_FaultSignalFixup(benchmark::State& state) { + // Set up the signal handler. + struct sigaction sa = {}; + sigemptyset(&sa.sa_mask); + sa.sa_sigaction = FixupHandler; + sa.sa_flags = SA_SIGINFO; + TEST_CHECK(sigaction(SIGSEGV, &sa, nullptr) == 0); + + // Fault, fault, fault. + for (auto _ : state) { + // Trigger the segfault. + asm volatile( + "movq $0, %%rax\n" + "movq $0x77777777, (%%rax)\n" + : + : + : "rax"); + } +} + +BENCHMARK(BM_FaultSignalFixup)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/sleep_benchmark.cc b/test/perf/linux/sleep_benchmark.cc new file mode 100644 index 000000000..99ef05117 --- /dev/null +++ b/test/perf/linux/sleep_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <sys/syscall.h> +#include <time.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Sleep for 'param' nanoseconds. +void BM_Sleep(benchmark::State& state) { + const int nanoseconds = state.range(0); + + for (auto _ : state) { + struct timespec ts; + ts.tv_sec = 0; + ts.tv_nsec = nanoseconds; + + int ret; + do { + ret = syscall(SYS_nanosleep, &ts, &ts); + if (ret < 0) { + TEST_CHECK(errno == EINTR); + } + } while (ret < 0); + } +} + +BENCHMARK(BM_Sleep) + ->Arg(0) + ->Arg(1) + ->Arg(1000) // 1us + ->Arg(1000 * 1000) // 1ms + ->Arg(10 * 1000 * 1000) // 10ms + ->Arg(50 * 1000 * 1000) // 50ms + ->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/stat_benchmark.cc b/test/perf/linux/stat_benchmark.cc new file mode 100644 index 000000000..f15424482 --- /dev/null +++ b/test/perf/linux/stat_benchmark.cc @@ -0,0 +1,62 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Creates a file in a nested directory hierarchy at least `depth` directories +// deep, and stats that file multiple times. +void BM_Stat(benchmark::State& state) { + // Create nested directories with given depth. + int depth = state.range(0); + const TempPath top_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string dir_path = top_dir.path(); + + while (depth-- > 0) { + // Don't use TempPath because it will make paths too long to use. + // + // The top_dir destructor will clean up this whole tree. + dir_path = JoinPath(dir_path, absl::StrCat(depth)); + ASSERT_NO_ERRNO(Mkdir(dir_path, 0755)); + } + + // Create the file that will be stat'd. + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir_path)); + + struct stat st; + for (auto _ : state) { + ASSERT_THAT(stat(file.path().c_str(), &st), SyscallSucceeds()); + } +} + +BENCHMARK(BM_Stat)->Range(1, 100)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/unlink_benchmark.cc b/test/perf/linux/unlink_benchmark.cc new file mode 100644 index 000000000..92243a042 --- /dev/null +++ b/test/perf/linux/unlink_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Creates a directory containing `files` files, and unlinks all the files. +void BM_Unlink(benchmark::State& state) { + // Create directory with given files. + const int file_count = state.range(0); + + // We unlink all files on each iteration, but report this as a "batch" + // iteration so that reported times are per file. + TempPath dir; + while (state.KeepRunningBatch(file_count)) { + state.PauseTiming(); + // N.B. dir is declared outside the loop so that destruction of the previous + // iteration's directory occurs here, inside of PauseTiming. + dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + std::vector<TempPath> files; + for (int i = 0; i < file_count; i++) { + TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + files.push_back(std::move(file)); + } + state.ResumeTiming(); + + while (!files.empty()) { + // Destructor unlinks. + files.pop_back(); + } + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_Unlink)->Range(1, 100 * 1000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc new file mode 100644 index 000000000..7b060c70e --- /dev/null +++ b/test/perf/linux/write_benchmark.cc @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Write(benchmark::State& state) { + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); + + const int size = state.range(0); + std::vector<char> buf(size); + RandomizeBuffer(buf.data(), size); + + for (auto _ : state) { + TEST_CHECK(PwriteFd(fd.get(), buf.data(), size, 0) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/root/BUILD b/test/root/BUILD index 23ce2a70f..17e51e66e 100644 --- a/test/root/BUILD +++ b/test/root/BUILD @@ -1,4 +1,5 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/vm:defs.bzl", "vm_test") package(licenses = ["notice"]) @@ -16,6 +17,7 @@ go_test( "crictl_test.go", "main_test.go", "oom_score_adj_test.go", + "runsc_test.go", ], data = [ "//runsc", @@ -23,21 +25,32 @@ go_test( library = ":root", tags = [ # Requires docker and runsc to be configured before the test runs. - # Also test only runs as root. + # Also, the test needs to be run as root. Note that below, the + # root_vm_test relies on the default runtime 'runsc' being installed by + # the default installer. "manual", "local", ], visibility = ["//:sandbox"], deps = [ - "//runsc/boot", + "//pkg/test/criutil", + "//pkg/test/dockerutil", + "//pkg/test/testutil", "//runsc/cgroup", "//runsc/container", - "//runsc/criutil", - "//runsc/dockerutil", "//runsc/specutils", - "//runsc/testutil", - "//test/root/testdata", + "@com_github_cenkalti_backoff//:go_default_library", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +vm_test( + name = "root_vm_test", + shard_count = 1, + targets = [ + "//tools/installers:shim", + ":root_test", ], ) diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index 4038661cb..8876d0d61 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -26,9 +26,9 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/runsc/cgroup" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" ) func verifyPid(pid int, path string) error { @@ -53,68 +53,82 @@ func verifyPid(pid int, path string) error { if scanner.Err() != nil { return scanner.Err() } - return fmt.Errorf("got: %s, want: %d", gots, pid) + return fmt.Errorf("got: %v, want: %d", gots, pid) } -// TestCgroup sets cgroup options and checks that cgroup was properly configured. func TestMemCGroup(t *testing.T) { - allocMemSize := 128 << 20 - if err := dockerutil.Pull("python"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("memusage-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // Start a new container and allocate the specified about of memory. - args := []string{ - "--memory=256MB", - "python", - "python", - "-c", - fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize), - } - if err := d.Run(args...); err != nil { - t.Fatal("docker create failed:", err) + allocMemSize := 128 << 20 + allocMemLimit := 2 * allocMemSize + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/python", + Memory: allocMemLimit / 1024, // Must be in Kb. + }, "python", "-c", fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize)); err != nil { + t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() + // Extract the ID to lookup the cgroup. gid, err := d.ID() if err != nil { t.Fatalf("Docker.ID() failed: %v", err) } t.Logf("cgroup ID: %s", gid) - path := filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.usage_in_bytes") - memUsage := 0 - // Wait when the container will allocate memory. + memUsage := 0 start := time.Now() - for time.Now().Sub(start) < 30*time.Second { + for time.Since(start) < 30*time.Second { + // Sleep for a brief period of time after spawning the + // container (so that Docker can create the cgroup etc. + // or after looping below (so the application can start). + time.Sleep(100 * time.Millisecond) + + // Read the cgroup memory limit. + path := filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.limit_in_bytes") outRaw, err := ioutil.ReadFile(path) if err != nil { - t.Fatalf("failed to read %q: %v", path, err) + // It's possible that the container does not exist yet. + continue } out := strings.TrimSpace(string(outRaw)) + memLimit, err := strconv.Atoi(out) + if err != nil { + t.Fatalf("Atoi(%v): %v", out, err) + } + if memLimit != allocMemLimit { + // The group may not have had the correct limit set yet. + continue + } + + // Read the cgroup memory usage. + path = filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.max_usage_in_bytes") + outRaw, err = ioutil.ReadFile(path) + if err != nil { + t.Fatalf("error reading usage: %v", err) + } + out = strings.TrimSpace(string(outRaw)) memUsage, err = strconv.Atoi(out) if err != nil { t.Fatalf("Atoi(%v): %v", out, err) } + t.Logf("read usage: %v, wanted: %v", memUsage, allocMemSize) - if memUsage > allocMemSize { + // Are we done? + if memUsage >= allocMemSize { return } - - time.Sleep(100 * time.Millisecond) } - t.Fatalf("%vMB is less than %vMB: %v", memUsage>>20, allocMemSize>>20) + t.Fatalf("%vMB is less than %vMB", memUsage>>20, allocMemSize>>20) } // TestCgroup sets cgroup options and checks that cgroup was properly configured. func TestCgroup(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() // This is not a comprehensive list of attributes. // @@ -179,10 +193,11 @@ func TestCgroup(t *testing.T) { want: "5", }, { - arg: "--blkio-weight=750", - ctrl: "blkio", - file: "blkio.weight", - want: "750", + arg: "--blkio-weight=750", + ctrl: "blkio", + file: "blkio.weight", + want: "750", + skipIfNotFound: true, // blkio groups may not be available. }, } @@ -191,12 +206,15 @@ func TestCgroup(t *testing.T) { args = append(args, attr.arg) } - args = append(args, "alpine", "sleep", "10000") - if err := d.Run(args...); err != nil { - t.Fatal("docker create failed:", err) + // Start the container. + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + Extra: args, // Cgroup arguments. + }, "sleep", "10000"); err != nil { + t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() + // Lookup the relevant cgroup ID. gid, err := d.ID() if err != nil { t.Fatalf("Docker.ID() failed: %v", err) @@ -245,17 +263,21 @@ func TestCgroup(t *testing.T) { } } +// TestCgroup sets cgroup options and checks that cgroup was properly configured. func TestCgroupParent(t *testing.T) { - if err := dockerutil.Pull("alpine"); err != nil { - t.Fatal("docker pull failed:", err) - } - d := dockerutil.MakeDocker("cgroup-test") + d := dockerutil.MakeDocker(t) + defer d.CleanUp() - parent := testutil.RandomName("runsc") - if err := d.Run("--cgroup-parent", parent, "alpine", "sleep", "10000"); err != nil { - t.Fatal("docker create failed:", err) + // Construct a known cgroup name. + parent := testutil.RandomID("runsc-") + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + Extra: []string{fmt.Sprintf("--cgroup-parent=%s", parent)}, + }, "sleep", "10000"); err != nil { + t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() + + // Extract the ID to look up the cgroup. gid, err := d.ID() if err != nil { t.Fatalf("Docker.ID() failed: %v", err) diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go index be0f63d18..a306132a4 100644 --- a/test/root/chroot_test.go +++ b/test/root/chroot_test.go @@ -24,17 +24,20 @@ import ( "strings" "testing" - "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" ) // TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned // up after the sandbox is destroyed. func TestChroot(t *testing.T) { - d := dockerutil.MakeDocker("chroot-test") - if err := d.Run("alpine", "sleep", "10000"); err != nil { + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "10000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() pid, err := d.SandboxPid() if err != nil { @@ -76,11 +79,14 @@ func TestChroot(t *testing.T) { } func TestChrootGofer(t *testing.T) { - d := dockerutil.MakeDocker("chroot-test") - if err := d.Run("alpine", "sleep", "10000"); err != nil { + d := dockerutil.MakeDocker(t) + defer d.CleanUp() + + if err := d.Spawn(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "10000"); err != nil { t.Fatalf("docker run failed: %v", err) } - defer d.CleanUp() // It's tricky to find gofers. Get sandbox PID first, then find parent. From // parent get all immediate children, remove the sandbox, and everything else diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go index 3f90c4c6a..85007dcce 100644 --- a/test/root/crictl_test.go +++ b/test/root/crictl_test.go @@ -16,6 +16,7 @@ package root import ( "bytes" + "encoding/json" "fmt" "io" "io/ioutil" @@ -29,16 +30,58 @@ import ( "testing" "time" - "gvisor.dev/gvisor/runsc/criutil" - "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/pkg/test/criutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" - "gvisor.dev/gvisor/test/root/testdata" ) // Tests for crictl have to be run as root (rather than in a user namespace) // because crictl creates named network namespaces in /var/run/netns/. +// SimpleSpec returns a JSON config for a simple container that runs the +// specified command in the specified image. +func SimpleSpec(name, image string, cmd []string, extra map[string]interface{}) string { + s := map[string]interface{}{ + "metadata": map[string]string{ + "name": name, + }, + "image": map[string]string{ + "image": testutil.ImageByName(image), + }, + "log_path": fmt.Sprintf("%s.log", name), + } + if len(cmd) > 0 { // Omit if empty. + s["command"] = cmd + } + for k, v := range extra { + s[k] = v // Extra settings. + } + v, err := json.Marshal(s) + if err != nil { + // This shouldn't happen. + panic(err) + } + return string(v) +} + +// Sandbox is a default JSON config for a sandbox. +var Sandbox = `{ + "metadata": { + "name": "default-sandbox", + "namespace": "default", + "attempt": 1, + "uid": "hdishd83djaidwnduwk28bcsb" + }, + "linux": { + }, + "log_directory": "/tmp" +} +` + +// Httpd is a JSON config for an httpd container. +var Httpd = SimpleSpec("httpd", "basic/httpd", nil, nil) + // TestCrictlSanity refers to b/112433158. func TestCrictlSanity(t *testing.T) { // Setup containerd and crictl. @@ -47,9 +90,9 @@ func TestCrictlSanity(t *testing.T) { t.Fatalf("failed to setup crictl: %v", err) } defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("httpd", testdata.Sandbox, testdata.Httpd) + podID, contID, err := crictl.StartPodAndContainer("basic/httpd", Sandbox, Httpd) if err != nil { - t.Fatal(err) + t.Fatalf("start failed: %v", err) } // Look for the httpd page. @@ -59,10 +102,38 @@ func TestCrictlSanity(t *testing.T) { // Stop everything. if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) + t.Fatalf("stop failed: %v", err) } } +// HttpdMountPaths is a JSON config for an httpd container with additional +// mounts. +var HttpdMountPaths = SimpleSpec("httpd", "basic/httpd", nil, map[string]interface{}{ + "mounts": []map[string]interface{}{ + map[string]interface{}{ + "container_path": "/var/run/secrets/kubernetes.io/serviceaccount", + "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/volumes/kubernetes.io~secret/default-token-2rpfx", + "readonly": true, + }, + map[string]interface{}{ + "container_path": "/etc/hosts", + "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/etc-hosts", + "readonly": false, + }, + map[string]interface{}{ + "container_path": "/dev/termination-log", + "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/containers/httpd/d1709580", + "readonly": false, + }, + map[string]interface{}{ + "container_path": "/usr/local/apache2/htdocs/test", + "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064", + "readonly": true, + }, + }, + "linux": map[string]interface{}{}, +}) + // TestMountPaths refers to b/117635704. func TestMountPaths(t *testing.T) { // Setup containerd and crictl. @@ -71,9 +142,9 @@ func TestMountPaths(t *testing.T) { t.Fatalf("failed to setup crictl: %v", err) } defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("httpd", testdata.Sandbox, testdata.HttpdMountPaths) + podID, contID, err := crictl.StartPodAndContainer("basic/httpd", Sandbox, HttpdMountPaths) if err != nil { - t.Fatal(err) + t.Fatalf("start failed: %v", err) } // Look for the directory available at /test. @@ -83,7 +154,7 @@ func TestMountPaths(t *testing.T) { // Stop everything. if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) + t.Fatalf("stop failed: %v", err) } } @@ -95,14 +166,16 @@ func TestMountOverSymlinks(t *testing.T) { t.Fatalf("failed to setup crictl: %v", err) } defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, testdata.MountOverSymlink) + + spec := SimpleSpec("busybox", "basic/resolv", []string{"sleep", "1000"}, nil) + podID, contID, err := crictl.StartPodAndContainer("basic/resolv", Sandbox, spec) if err != nil { - t.Fatal(err) + t.Fatalf("start failed: %v", err) } out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf") if err != nil { - t.Fatal(err) + t.Fatalf("readlink failed: %v, out: %s", err, out) } if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) { t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out)) @@ -110,11 +183,11 @@ func TestMountOverSymlinks(t *testing.T) { etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf") if err != nil { - t.Fatal(err) + t.Fatalf("cat failed: %v, out: %s", err, etc) } tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf") if err != nil { - t.Fatal(err) + t.Fatalf("cat failed: %v, out: %s", err, out) } if tmp != etc { t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp)) @@ -122,7 +195,7 @@ func TestMountOverSymlinks(t *testing.T) { // Stop everything. if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) + t.Fatalf("stop failed: %v", err) } } @@ -135,16 +208,16 @@ func TestHomeDir(t *testing.T) { t.Fatalf("failed to setup crictl: %v", err) } defer cleanup() - contSpec := testdata.SimpleSpec("root", "k8s.gcr.io/busybox", []string{"sleep", "1000"}) - podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, contSpec) + contSpec := SimpleSpec("root", "basic/busybox", []string{"sleep", "1000"}, nil) + podID, contID, err := crictl.StartPodAndContainer("basic/busybox", Sandbox, contSpec) if err != nil { - t.Fatal(err) + t.Fatalf("start failed: %v", err) } t.Run("root container", func(t *testing.T) { out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME") if err != nil { - t.Fatal(err) + t.Fatalf("exec failed: %v, out: %s", err, out) } if got, want := strings.TrimSpace(string(out)), "/root"; got != want { t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) @@ -153,32 +226,47 @@ func TestHomeDir(t *testing.T) { t.Run("sub-container", func(t *testing.T) { // Create a sub container in the same pod. - subContSpec := testdata.SimpleSpec("subcontainer", "k8s.gcr.io/busybox", []string{"sleep", "1000"}) - subContID, err := crictl.StartContainer(podID, "k8s.gcr.io/busybox", testdata.Sandbox, subContSpec) + subContSpec := SimpleSpec("subcontainer", "basic/busybox", []string{"sleep", "1000"}, nil) + subContID, err := crictl.StartContainer(podID, "basic/busybox", Sandbox, subContSpec) if err != nil { - t.Fatal(err) + t.Fatalf("start failed: %v", err) } out, err := crictl.Exec(subContID, "sh", "-c", "echo $HOME") if err != nil { - t.Fatal(err) + t.Fatalf("exec failed: %v, out: %s", err, out) } if got, want := strings.TrimSpace(string(out)), "/root"; got != want { t.Fatalf("Home directory invalid. Got %q, Want: %q", got, want) } if err := crictl.StopContainer(subContID); err != nil { - t.Fatal(err) + t.Fatalf("stop failed: %v", err) } }) // Stop everything. if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatal(err) + t.Fatalf("stop failed: %v", err) } } +// containerdConfigTemplate is a .toml config for containerd. It contains a +// formatting verb so the runtime field can be set via fmt.Sprintf. +const containerdConfigTemplate = ` +disabled_plugins = ["restart"] +[plugins.linux] + runtime = "%s" + runtime_root = "/tmp/test-containerd/runsc" + shim = "/usr/local/bin/gvisor-containerd-shim" + shim_debug = true + +[plugins.cri.containerd.runtimes.runsc] + runtime_type = "io.containerd.runtime.v1.linux" + runtime_engine = "%s" +` + // setup sets up before a test. Specifically it: // * Creates directories and a socket for containerd to utilize. // * Runs containerd and waits for it to reach a "ready" state for testing. @@ -213,50 +301,52 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) { if err != nil { t.Fatalf("error discovering runtime path: %v", err) } - config, err := testutil.WriteTmpFile("containerd-config", testdata.ContainerdConfig(runtime)) + config, configCleanup, err := testutil.WriteTmpFile("containerd-config", fmt.Sprintf(containerdConfigTemplate, runtime, runtime)) if err != nil { t.Fatalf("failed to write containerd config") } - cleanups = append(cleanups, func() { os.RemoveAll(config) }) + cleanups = append(cleanups, configCleanup) // Start containerd. - containerd := exec.Command(getContainerd(), + cmd := exec.Command(getContainerd(), "--config", config, "--log-level", "debug", "--root", containerdRoot, "--state", containerdState, "--address", sockAddr) + startupR, startupW := io.Pipe() + defer startupR.Close() + defer startupW.Close() + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + cmd.Stderr = io.MultiWriter(startupW, stderr) + cmd.Stdout = io.MultiWriter(startupW, stdout) cleanups = append(cleanups, func() { - if err := testutil.KillCommand(containerd); err != nil { - log.Printf("error killing containerd: %v", err) - } + t.Logf("containerd stdout: %s", stdout.String()) + t.Logf("containerd stderr: %s", stderr.String()) }) - containerdStderr, err := containerd.StderrPipe() - if err != nil { - t.Fatalf("failed to get containerd stderr: %v", err) - } - containerdStdout, err := containerd.StdoutPipe() - if err != nil { - t.Fatalf("failed to get containerd stdout: %v", err) - } - if err := containerd.Start(); err != nil { + + // Start the process. + if err := cmd.Start(); err != nil { t.Fatalf("failed running containerd: %v", err) } - // Wait for containerd to boot. Then put all containerd output into a - // buffer to be logged at the end of the test. - testutil.WaitUntilRead(containerdStderr, "Start streaming server", nil, 10*time.Second) - stdoutBuf := &bytes.Buffer{} - stderrBuf := &bytes.Buffer{} - go func() { io.Copy(stdoutBuf, containerdStdout) }() - go func() { io.Copy(stderrBuf, containerdStderr) }() + // Wait for containerd to boot. + if err := testutil.WaitUntilRead(startupR, "Start streaming server", nil, 10*time.Second); err != nil { + t.Fatalf("failed to start containerd: %v", err) + } + + // Kill must be the last cleanup (as it will be executed first). + cc := criutil.NewCrictl(t, sockAddr) cleanups = append(cleanups, func() { - t.Logf("containerd stdout: %s", string(stdoutBuf.Bytes())) - t.Logf("containerd stderr: %s", string(stderrBuf.Bytes())) + cc.CleanUp() // Remove tmp files, etc. + if err := testutil.KillCommand(cmd); err != nil { + log.Printf("error killing containerd: %v", err) + } }) cleanup.Release() - return criutil.NewCrictl(20*time.Second, sockAddr), cleanupFunc, nil + return cc, cleanupFunc, nil } // httpGet GETs the contents of a file served from a pod on port 80. diff --git a/test/root/main_test.go b/test/root/main_test.go index d74dec85f..9fb17e0dd 100644 --- a/test/root/main_test.go +++ b/test/root/main_test.go @@ -21,7 +21,7 @@ import ( "testing" "github.com/syndtr/gocapability/capability" - "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/runsc/specutils" ) diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go index 126f0975a..9a3cecd97 100644 --- a/test/root/oom_score_adj_test.go +++ b/test/root/oom_score_adj_test.go @@ -20,10 +20,9 @@ import ( "testing" specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" ) var ( @@ -40,15 +39,6 @@ var ( // TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a // single container sandbox. func TestOOMScoreAdjSingle(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -89,11 +79,11 @@ func TestOOMScoreAdjSingle(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.Name, func(t *testing.T) { - id := testutil.UniqueContainerID() + id := testutil.RandomContainerID() s := testutil.NewSpecWithArgs("sleep", "1000") s.Process.OOMScoreAdj = testCase.OOMScoreAdj - containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id}) + containers, cleanup, err := startContainers(t, []*specs.Spec{s}, []string{id}) if err != nil { t.Fatalf("error starting containers: %v", err) } @@ -131,15 +121,6 @@ func TestOOMScoreAdjSingle(t *testing.T) { // TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a // multi-container sandbox. func TestOOMScoreAdjMulti(t *testing.T) { - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfig() - conf.RootDir = rootDir - ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -257,7 +238,7 @@ func TestOOMScoreAdjMulti(t *testing.T) { } } - containers, cleanup, err := startContainers(conf, specs, ids) + containers, cleanup, err := startContainers(t, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) } @@ -321,7 +302,7 @@ func TestOOMScoreAdjMulti(t *testing.T) { func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { var specs []*specs.Spec var ids []string - rootID := testutil.UniqueContainerID() + rootID := testutil.RandomContainerID() for i, cmd := range cmds { spec := testutil.NewSpecWithArgs(cmd...) @@ -335,35 +316,48 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer, specutils.ContainerdSandboxIDAnnotation: rootID, } - ids = append(ids, testutil.UniqueContainerID()) + ids = append(ids, testutil.RandomContainerID()) } specs = append(specs, spec) } return specs, ids } -func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) { - if len(conf.RootDir) == 0 { - panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") - } - - var containers []*container.Container - var bundles []string - cleanup := func() { +func startContainers(t *testing.T, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) { + var ( + containers []*container.Container + cleanups []func() + ) + cleanups = append(cleanups, func() { for _, c := range containers { c.Destroy() } - for _, b := range bundles { - os.RemoveAll(b) + }) + cleanupAll := func() { + for _, c := range cleanups { + c() } } + localClean := specutils.MakeCleanup(cleanupAll) + defer localClean.Clean() + + // All containers must share the same root. + rootDir, cleanup, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + cleanups = append(cleanups, cleanup) + + // Point this to from the configuration. + conf := testutil.TestConfig(t) + conf.RootDir = rootDir + for i, spec := range specs { - bundleDir, err := testutil.SetupBundleDir(spec) + bundleDir, cleanup, err := testutil.SetupBundleDir(spec) if err != nil { - cleanup() - return nil, nil, fmt.Errorf("error setting up container: %v", err) + return nil, nil, fmt.Errorf("error setting up bundle: %v", err) } - bundles = append(bundles, bundleDir) + cleanups = append(cleanups, cleanup) args := container.Args{ ID: ids[i], @@ -372,15 +366,15 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*c } cont, err := container.New(conf, args) if err != nil { - cleanup() return nil, nil, fmt.Errorf("error creating container: %v", err) } containers = append(containers, cont) if err := cont.Start(conf); err != nil { - cleanup() return nil, nil, fmt.Errorf("error starting container: %v", err) } } - return containers, cleanup, nil + + localClean.Release() + return containers, cleanupAll, nil } diff --git a/test/root/runsc_test.go b/test/root/runsc_test.go new file mode 100644 index 000000000..25204bebb --- /dev/null +++ b/test/root/runsc_test.go @@ -0,0 +1,151 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package root + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/cenkalti/backoff" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/runsc/specutils" +) + +// TestDoKill checks that when "runsc do..." is killed, the sandbox process is +// also terminated. This ensures that parent death signal is propagate to the +// sandbox process correctly. +func TestDoKill(t *testing.T) { + // Make the sandbox process be reparented here when it's killed, so we can + // wait for it. + if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil { + t.Fatalf("prctl(PR_SET_CHILD_SUBREAPER): %v", err) + } + + cmd := exec.Command(specutils.ExePath, "do", "sleep", "10000") + buf := &bytes.Buffer{} + cmd.Stdout = buf + cmd.Stderr = buf + cmd.Start() + + var pid int + findSandbox := func() error { + var err error + pid, err = sandboxPid(cmd.Process.Pid) + if err != nil { + return &backoff.PermanentError{Err: err} + } + if pid == 0 { + return fmt.Errorf("sandbox process not found") + } + return nil + } + if err := testutil.Poll(findSandbox, 10*time.Second); err != nil { + t.Fatalf("failed to find sandbox: %v", err) + } + t.Logf("Found sandbox, pid: %d", pid) + + if err := cmd.Process.Kill(); err != nil { + t.Fatalf("failed to kill run process: %v", err) + } + cmd.Wait() + t.Logf("Parent process killed (%d). Output: %s", cmd.Process.Pid, buf.String()) + + ch := make(chan struct{}) + go func() { + defer func() { ch <- struct{}{} }() + t.Logf("Waiting for sandbox process (%d) termination", pid) + if _, err := unix.Wait4(pid, nil, 0, nil); err != nil { + t.Errorf("error waiting for sandbox process (%d): %v", pid, err) + } + }() + select { + case <-ch: + // Done + case <-time.After(5 * time.Second): + t.Fatalf("timeout waiting for sandbox process (%d) to exit", pid) + } +} + +// sandboxPid looks for the sandbox process inside the process tree starting +// from "pid". It returns 0 and no error if no sandbox process is found. It +// returns error if anything failed. +func sandboxPid(pid int) (int, error) { + cmd := exec.Command("pgrep", "-P", strconv.Itoa(pid)) + buf := &bytes.Buffer{} + cmd.Stdout = buf + if err := cmd.Start(); err != nil { + return 0, err + } + ps, err := cmd.Process.Wait() + if err != nil { + return 0, err + } + if ps.ExitCode() == 1 { + // pgrep returns 1 when no process is found. + return 0, nil + } + + var children []int + for _, line := range strings.Split(buf.String(), "\n") { + if len(line) == 0 { + continue + } + child, err := strconv.Atoi(line) + if err != nil { + return 0, err + } + + cmdline, err := ioutil.ReadFile(filepath.Join("/proc", line, "cmdline")) + if err != nil { + if os.IsNotExist(err) { + // Raced with process exit. + continue + } + return 0, err + } + args := strings.SplitN(string(cmdline), "\x00", 2) + if len(args) == 0 { + return 0, fmt.Errorf("malformed cmdline file: %q", cmdline) + } + // The sandbox process has the first argument set to "runsc-sandbox". + if args[0] == "runsc-sandbox" { + return child, nil + } + + children = append(children, child) + } + + // Sandbox process wasn't found, try another level down. + for _, pid := range children { + sand, err := sandboxPid(pid) + if err != nil { + return 0, err + } + if sand != 0 { + return sand, nil + } + // Not found, continue the search. + } + return 0, nil +} diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD deleted file mode 100644 index 6859541ad..000000000 --- a/test/root/testdata/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "testdata", - srcs = [ - "busybox.go", - "containerd_config.go", - "httpd.go", - "httpd_mount_paths.go", - "sandbox.go", - "simple.go", - ], - visibility = [ - "//:sandbox", - ], -) diff --git a/test/root/testdata/containerd_config.go b/test/root/testdata/containerd_config.go deleted file mode 100644 index e12f1ec88..000000000 --- a/test/root/testdata/containerd_config.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testdata contains data required for root tests. -package testdata - -import "fmt" - -// containerdConfigTemplate is a .toml config for containerd. It contains a -// formatting verb so the runtime field can be set via fmt.Sprintf. -const containerdConfigTemplate = ` -disabled_plugins = ["restart"] -[plugins.linux] - runtime = "%s" - runtime_root = "/tmp/test-containerd/runsc" - shim = "/usr/local/bin/gvisor-containerd-shim" - shim_debug = true - -[plugins.cri.containerd.runtimes.runsc] - runtime_type = "io.containerd.runtime.v1.linux" - runtime_engine = "%s" -` - -// ContainerdConfig returns a containerd config file with the specified -// runtime. -func ContainerdConfig(runtime string) string { - return fmt.Sprintf(containerdConfigTemplate, runtime, runtime) -} diff --git a/test/root/testdata/httpd_mount_paths.go b/test/root/testdata/httpd_mount_paths.go deleted file mode 100644 index ac3f4446a..000000000 --- a/test/root/testdata/httpd_mount_paths.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testdata - -// HttpdMountPaths is a JSON config for an httpd container with additional -// mounts. -const HttpdMountPaths = ` -{ - "metadata": { - "name": "httpd" - }, - "image":{ - "image": "httpd" - }, - "mounts": [ - { - "container_path": "/var/run/secrets/kubernetes.io/serviceaccount", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/volumes/kubernetes.io~secret/default-token-2rpfx", - "readonly": true - }, - { - "container_path": "/etc/hosts", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/etc-hosts", - "readonly": false - }, - { - "container_path": "/dev/termination-log", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/containers/httpd/d1709580", - "readonly": false - }, - { - "container_path": "/usr/local/apache2/htdocs/test", - "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064", - "readonly": true - } - ], - "linux": { - }, - "log_path": "httpd.log" -} -` diff --git a/test/runner/BUILD b/test/runner/BUILD new file mode 100644 index 000000000..6833c9986 --- /dev/null +++ b/test/runner/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "runner", + testonly = 1, + srcs = ["runner.go"], + data = [ + "//runsc", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "//pkg/test/testutil", + "//runsc/specutils", + "//test/runner/gtest", + "//test/uds", + "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/test/syscalls/build_defs.bzl b/test/runner/defs.bzl index cbab85ef7..0a75b158f 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/runner/defs.bzl @@ -1,6 +1,119 @@ """Defines a rule for syscall test targets.""" -load("//tools:defs.bzl", "loopback") +load("//tools:defs.bzl", "default_platform", "loopback", "platforms") + +def _runner_test_impl(ctx): + # Generate a runner binary. + runner = ctx.actions.declare_file("%s-runner" % ctx.label.name) + runner_content = "\n".join([ + "#!/bin/bash", + "set -euf -x -o pipefail", + "if [[ -n \"${TEST_UNDECLARED_OUTPUTS_DIR}\" ]]; then", + " mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", + " chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", + "fi", + "exec %s %s %s\n" % ( + ctx.files.runner[0].short_path, + " ".join(ctx.attr.runner_args), + ctx.files.test[0].short_path, + ), + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + + # Return with all transitive files. + runfiles = ctx.runfiles( + transitive_files = depset(transitive = [ + target.data_runfiles.files + for target in (ctx.attr.runner, ctx.attr.test) + if hasattr(target, "data_runfiles") + ]), + files = ctx.files.runner + ctx.files.test, + collect_default = True, + collect_data = True, + ) + return [DefaultInfo(executable = runner, runfiles = runfiles)] + +_runner_test = rule( + attrs = { + "runner": attr.label( + default = "//test/runner:runner", + ), + "test": attr.label( + mandatory = True, + ), + "runner_args": attr.string_list(), + "data": attr.label_list( + allow_files = True, + ), + }, + test = True, + implementation = _runner_test_impl, +) + +def _syscall_test( + test, + shard_count, + size, + platform, + use_tmpfs, + tags, + network = "none", + file_access = "exclusive", + overlay = False, + add_uds_tree = False): + # Prepend "runsc" to non-native platform names. + full_platform = platform if platform == "native" else "runsc_" + platform + + # Name the test appropriately. + name = test.split(":")[1] + "_" + full_platform + if file_access == "shared": + name += "_shared" + if overlay: + name += "_overlay" + if network != "none": + name += "_" + network + "net" + + # Apply all tags. + if tags == None: + tags = [] + + # Add the full_platform and file access in a tag to make it easier to run + # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. + tags += [full_platform, "file_" + file_access] + + # Hash this target into one of 15 buckets. This can be used to + # randomly split targets between different workflows. + hash15 = hash(native.package_name() + name) % 15 + tags.append("hash15:" + str(hash15)) + + # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until + # we figure out how to request ipv4 sockets on Guitar machines. + if network == "host": + tags.append("noguitar") + + # Disable off-host networking. + tags.append("requires-net:loopback") + + runner_args = [ + # Arguments are passed directly to runner binary. + "--platform=" + platform, + "--network=" + network, + "--use-tmpfs=" + str(use_tmpfs), + "--file-access=" + file_access, + "--overlay=" + str(overlay), + "--add-uds-tree=" + str(add_uds_tree), + ] + + # Call the rule above. + _runner_test( + name = name, + test = test, + runner_args = runner_args, + data = [loopback], + size = size, + tags = tags, + shard_count = shard_count, + ) def syscall_test( test, @@ -23,6 +136,8 @@ def syscall_test( add_hostinet: add a hostinet test. tags: starting test tags. """ + if not tags: + tags = [] _syscall_test( test = test, @@ -34,35 +149,26 @@ def syscall_test( tags = tags, ) - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "kvm", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) + for (platform, platform_tags) in platforms.items(): + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platform_tags + tags, + ) if add_overlay: _syscall_test( test = test, shard_count = shard_count, size = size, - platform = "ptrace", + platform = default_platform, use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. add_uds_tree = add_uds_tree, - tags = tags, + tags = platforms[default_platform] + tags, overlay = True, ) @@ -72,10 +178,10 @@ def syscall_test( test = test, shard_count = shard_count, size = size, - platform = "ptrace", + platform = default_platform, use_tmpfs = use_tmpfs, add_uds_tree = add_uds_tree, - tags = tags, + tags = platforms[default_platform] + tags, file_access = "shared", ) @@ -84,97 +190,9 @@ def syscall_test( test = test, shard_count = shard_count, size = size, - platform = "ptrace", + platform = default_platform, use_tmpfs = use_tmpfs, network = "host", add_uds_tree = add_uds_tree, - tags = tags, + tags = platforms[default_platform] + tags, ) - -def _syscall_test( - test, - shard_count, - size, - platform, - use_tmpfs, - tags, - network = "none", - file_access = "exclusive", - overlay = False, - add_uds_tree = False): - test_name = test.split(":")[1] - - # Prepend "runsc" to non-native platform names. - full_platform = platform if platform == "native" else "runsc_" + platform - - name = test_name + "_" + full_platform - if file_access == "shared": - name += "_shared" - if overlay: - name += "_overlay" - if network != "none": - name += "_" + network + "net" - - if tags == None: - tags = [] - - # Add the full_platform and file access in a tag to make it easier to run - # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. - tags += [full_platform, "file_" + file_access] - - # Hash this target into one of 15 buckets. This can be used to - # randomly split targets between different workflows. - hash15 = hash(native.package_name() + name) % 15 - tags.append("hash15:" + str(hash15)) - - # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until - # we figure out how to request ipv4 sockets on Guitar machines. - if network == "host": - tags.append("noguitar") - - # Disable off-host networking. - tags.append("requires-net:loopback") - - # Add tag to prevent the tests from running in a Bazel sandbox. - # TODO(b/120560048): Make the tests run without this tag. - tags.append("no-sandbox") - - # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is - # more stable. - if platform == "kvm": - tags.append("manual") - tags.append("requires-kvm") - - # TODO(b/112165693): Remove when tests pass reliably. - tags.append("notap") - - args = [ - # Arguments are passed directly to syscall_test_runner binary. - "--test-name=" + test_name, - "--platform=" + platform, - "--network=" + network, - "--use-tmpfs=" + str(use_tmpfs), - "--file-access=" + file_access, - "--overlay=" + str(overlay), - "--add-uds-tree=" + str(add_uds_tree), - ] - - sh_test( - srcs = ["syscall_test_runner.sh"], - name = name, - data = [ - ":syscall_test_runner", - loopback, - test, - ], - args = args, - size = size, - tags = tags, - shard_count = shard_count, - ) - -def sh_test(**kwargs): - """Wraps the standard sh_test.""" - native.sh_test( - **kwargs - ) diff --git a/test/syscalls/gtest/BUILD b/test/runner/gtest/BUILD index de4b2727c..de4b2727c 100644 --- a/test/syscalls/gtest/BUILD +++ b/test/runner/gtest/BUILD diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go new file mode 100644 index 000000000..869169ad5 --- /dev/null +++ b/test/runner/gtest/gtest.go @@ -0,0 +1,168 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gtest contains helpers for running google-test tests from Go. +package gtest + +import ( + "fmt" + "os/exec" + "strings" +) + +var ( + // listTestFlag is the flag that will list tests in gtest binaries. + listTestFlag = "--gtest_list_tests" + + // filterTestFlag is the flag that will filter tests in gtest binaries. + filterTestFlag = "--gtest_filter" + + // listBechmarkFlag is the flag that will list benchmarks in gtest binaries. + listBenchmarkFlag = "--benchmark_list_tests" + + // filterBenchmarkFlag is the flag that will run specified benchmarks. + filterBenchmarkFlag = "--benchmark_filter" +) + +// TestCase is a single gtest test case. +type TestCase struct { + // Suite is the suite for this test. + Suite string + + // Name is the name of this individual test. + Name string + + // all indicates that this will run without flags. This takes + // precendence over benchmark below. + all bool + + // benchmark indicates that this is a benchmark. In this case, the + // suite will be empty, and we will use the appropriate test and + // benchmark flags. + benchmark bool +} + +// FullName returns the name of the test including the suite. It is suitable to +// pass to "-gtest_filter". +func (tc TestCase) FullName() string { + return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) +} + +// Args returns arguments to be passed when invoking the test. +func (tc TestCase) Args() []string { + if tc.all { + return []string{} // No arguments. + } + if tc.benchmark { + return []string{ + fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name), + fmt.Sprintf("%s=", filterTestFlag), + } + } + return []string{ + fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()), + } +} + +// ParseTestCases calls a gtest test binary to list its test and returns a +// slice with the name and suite of each test. +// +// If benchmarks is true, then benchmarks will be included in the list of test +// cases provided. Note that this requires the binary to support the +// benchmarks_list_tests flag. +func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]TestCase, error) { + // Run to extract test cases. + args := append([]string{listTestFlag}, extraArgs...) + cmd := exec.Command(testBin, args...) + out, err := cmd.Output() + if err != nil { + // We failed to list tests with the given flags. Just + // return something that will run the binary with no + // flags, which should execute all tests. + return []TestCase{ + TestCase{ + Suite: "Default", + Name: "All", + all: true, + }, + }, nil + } + + // Parse test output. + var t []TestCase + var suite string + for _, line := range strings.Split(string(out), "\n") { + // Strip comments. + line = strings.Split(line, "#")[0] + + // New suite? + if !strings.HasPrefix(line, " ") { + suite = strings.TrimSuffix(strings.TrimSpace(line), ".") + continue + } + + // Individual test. + name := strings.TrimSpace(line) + + // Do we have a suite yet? + if suite == "" { + return nil, fmt.Errorf("test without a suite: %v", name) + } + + // Add this individual test. + t = append(t, TestCase{ + Suite: suite, + Name: name, + }) + } + + // Finished? + if !benchmarks { + return t, nil + } + + // Run again to extract benchmarks. + args = append([]string{listBenchmarkFlag}, extraArgs...) + cmd = exec.Command(testBin, args...) + out, err = cmd.Output() + if err != nil { + // We were able to enumerate tests above, but not benchmarks? + // We requested them, so we return an error in this case. + exitErr, ok := err.(*exec.ExitError) + if !ok { + return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err) + } + return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr) + } + + out = []byte(strings.Trim(string(out), "\n")) + + // Parse benchmark output. + for _, line := range strings.Split(string(out), "\n") { + // Strip comments. + line = strings.Split(line, "#")[0] + + // Single benchmark. + name := strings.TrimSpace(line) + + // Add the single benchmark. + t = append(t, TestCase{ + Suite: "Benchmarks", + Name: name, + benchmark: true, + }) + } + + return t, nil +} diff --git a/test/syscalls/syscall_test_runner.go b/test/runner/runner.go index ae342b68c..14c9cbc47 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/runner/runner.go @@ -32,17 +32,13 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" - "gvisor.dev/gvisor/test/syscalls/gtest" + "gvisor.dev/gvisor/test/runner/gtest" "gvisor.dev/gvisor/test/uds" ) -// Location of syscall tests, relative to the repo root. -const testDir = "test/syscalls/linux" - var ( - testName = flag.String("test-name", "", "name of test binary to run") debug = flag.Bool("debug", false, "enable debug logs") strace = flag.Bool("strace", false, "enable strace logs") platform = flag.String("platform", "ptrace", "platform to run on") @@ -103,7 +99,7 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) } - cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName()) + cmd := exec.Command(testBin, tc.Args()...) cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -119,20 +115,20 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { // // Returns an error if the sandboxed application exits non-zero. func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { - bundleDir, err := testutil.SetupBundleDir(spec) + bundleDir, cleanup, err := testutil.SetupBundleDir(spec) if err != nil { return fmt.Errorf("SetupBundleDir failed: %v", err) } - defer os.RemoveAll(bundleDir) + defer cleanup() - rootDir, err := testutil.SetupRootDir() + rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { return fmt.Errorf("SetupRootDir failed: %v", err) } - defer os.RemoveAll(rootDir) + defer cleanup() name := tc.FullName() - id := testutil.UniqueContainerID() + id := testutil.RandomContainerID() log.Infof("Running test %q in container %q", name, id) specutils.LogSpec(spec) @@ -296,7 +292,7 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Run a new container with the test executable and filter for the // given test suite and name. - spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName()) + spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...) // Mark the root as writeable, as some tests attempt to // write to the rootfs, and expect EACCES, not EROFS. @@ -304,6 +300,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Test spec comes with pre-defined mounts that we don't want. Reset it. spec.Mounts = nil + testTmpDir := "/tmp" if *useTmpfs { // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on // features only available in gVisor's internal tmpfs may fail. @@ -329,11 +326,19 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { t.Fatalf("could not chmod temp dir: %v", err) } - spec.Mounts = append(spec.Mounts, specs.Mount{ - Type: "bind", - Destination: "/tmp", - Source: tmpDir, - }) + // "/tmp" is not replaced with a tmpfs mount inside the sandbox + // when it's not empty. This ensures that testTmpDir uses gofer + // in exclusive mode. + testTmpDir = tmpDir + if *fileAccess == "shared" { + // All external mounts except the root mount are shared. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Type: "bind", + Destination: "/tmp", + Source: tmpDir, + }) + testTmpDir = "/tmp" + } } // Set environment variables that indicate we are @@ -353,12 +358,8 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to // be backed by tmpfs. - for i, kv := range env { - if strings.HasPrefix(kv, "TEST_TMPDIR=") { - env[i] = "TEST_TMPDIR=/tmp" - break - } - } + env = filterEnv(env, []string{"TEST_TMPDIR"}) + env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir)) spec.Process.Env = env @@ -404,9 +405,10 @@ func matchString(a, b string) (bool, error) { func main() { flag.Parse() - if *testName == "" { - fatalf("test-name flag must be provided") + if flag.NArg() != 1 { + fatalf("test must be provided") } + testBin := flag.Args()[0] // Only argument. log.SetLevel(log.Info) if *debug { @@ -436,15 +438,8 @@ func main() { } } - // Get path to test binary. - fullTestName := filepath.Join(testDir, *testName) - testBin, err := testutil.FindFile(fullTestName) - if err != nil { - fatalf("FindFile(%q) failed: %v", fullTestName, err) - } - // Get all test cases in each binary. - testCases, err := gtest.ParseTestCases(testBin) + testCases, err := gtest.ParseTestCases(testBin, true) if err != nil { fatalf("ParseTestCases(%q) failed: %v", testBin, err) } @@ -455,14 +450,19 @@ func main() { fatalf("TestsForShard() failed: %v", err) } + // Resolve the absolute path for the binary. + testBin, err = filepath.Abs(testBin) + if err != nil { + fatalf("Abs() failed: %v", err) + } + // Run the tests. var tests []testing.InternalTest for _, tci := range indices { // Capture tc. tc := testCases[tci] - testName := fmt.Sprintf("%s_%s", tc.Suite, tc.Name) tests = append(tests, testing.InternalTest{ - Name: testName, + Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name), F: func(t *testing.T) { if *parallel { t.Parallel() diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 2c472bf8d..4cd627222 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -1,20 +1,7 @@ -# These packages are used to run language runtime tests inside gVisor sandboxes. - -load("//tools:defs.bzl", "go_binary", "go_test") -load("//test/runtimes:build_defs.bzl", "runtime_test") +load("//test/runtimes:defs.bzl", "runtime_test") package(licenses = ["notice"]) -go_binary( - name = "runner", - testonly = 1, - srcs = ["runner.go"], - deps = [ - "//runsc/dockerutil", - "//runsc/testutil", - ], -) - runtime_test( name = "go1.12", blacklist_file = "blacklist_go1.12.csv", @@ -44,10 +31,3 @@ runtime_test( blacklist_file = "blacklist_python3.7.3.csv", lang = "python", ) - -go_test( - name = "blacklist_test", - size = "small", - srcs = ["blacklist_test.go"], - library = ":runner", -) diff --git a/test/runtimes/README.md b/test/runtimes/README.md deleted file mode 100644 index e41e78f77..000000000 --- a/test/runtimes/README.md +++ /dev/null @@ -1,41 +0,0 @@ -# Runtimes Tests Dockerfiles - -The Dockerfiles defined under this path are configured to host the execution of -the runtimes language tests. Each Dockerfile can support the language indicated -by its directory. - -The following runtimes are currently supported: - -- Go 1.12 -- Java 11 -- Node.js 12 -- PHP 7.3 -- Python 3.7 - -#### Prerequisites: - -1) [Install and configure Docker](https://docs.docker.com/install/) - -2) Build each Docker container from the runtimes/images directory: - -```bash -$ cd images -$ docker build -f Dockerfile_$LANG [-t $NAME] . -``` - -### Testing: - -If the prerequisites have been fulfilled, you can run the tests with the -following command: - -```bash -$ docker run --rm -it $NAME [FLAG] -``` - -Running the command with no flags will cause all the available tests to execute. - -Flags can be added for additional functionality: - -- --list: Print a list of all available tests -- --test <name>: Run a single test from the list of available tests -- --v: Print the language version diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl deleted file mode 100644 index 92e275a76..000000000 --- a/test/runtimes/build_defs.bzl +++ /dev/null @@ -1,75 +0,0 @@ -"""Defines a rule for runtime test targets.""" - -load("//tools:defs.bzl", "go_test", "loopback") - -def runtime_test( - name, - lang, - image_repo = "gcr.io/gvisor-presubmit", - image_name = None, - blacklist_file = None, - shard_count = 50, - size = "enormous"): - """Generates sh_test and blacklist test targets for a given runtime. - - Args: - name: The name of the runtime being tested. Typically, the lang + version. - This is used in the names of the generated test targets. - lang: The language being tested. - image_repo: The docker repository containing the proctor image to run. - i.e., the prefix to the fully qualified docker image id. - image_name: The name of the image in the image_repo. - Defaults to the test name. - blacklist_file: A test blacklist to pass to the runtime test's runner. - shard_count: See Bazel common test attributes. - size: See Bazel common test attributes. - """ - if image_name == None: - image_name = name - args = [ - "--lang", - lang, - "--image", - "/".join([image_repo, image_name]), - ] - data = [ - ":runner", - loopback, - ] - if blacklist_file: - args += ["--blacklist_file", "test/runtimes/" + blacklist_file] - data += [blacklist_file] - - # Add a test that the blacklist parses correctly. - blacklist_test(name, blacklist_file) - - sh_test( - name = name + "_test", - srcs = ["runner.sh"], - args = args, - data = data, - size = size, - shard_count = shard_count, - tags = [ - # Requires docker and runsc to be configured before the test runs. - "local", - # Don't include test target in wildcard target patterns. - "manual", - ], - ) - -def blacklist_test(name, blacklist_file): - """Test that a blacklist parses correctly.""" - go_test( - name = name + "_blacklist_test", - library = ":runner", - srcs = ["blacklist_test.go"], - args = ["--blacklist_file", "test/runtimes/" + blacklist_file], - data = [blacklist_file], - ) - -def sh_test(**kwargs): - """Wraps the standard sh_test.""" - native.sh_test( - **kwargs - ) diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl new file mode 100644 index 000000000..f836dd952 --- /dev/null +++ b/test/runtimes/defs.bzl @@ -0,0 +1,79 @@ +"""Defines a rule for runtime test targets.""" + +load("//tools:defs.bzl", "go_test") + +def _runtime_test_impl(ctx): + # Construct arguments. + args = [ + "--lang", + ctx.attr.lang, + "--image", + ctx.attr.image, + ] + if ctx.attr.blacklist_file: + args += [ + "--blacklist_file", + ctx.files.blacklist_file[0].short_path, + ] + + # Build a runner. + runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) + runner_content = "\n".join([ + "#!/bin/bash", + "%s %s\n" % (ctx.files._runner[0].short_path, " ".join(args)), + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + + # Return the runner. + return [DefaultInfo( + executable = runner, + runfiles = ctx.runfiles( + files = ctx.files._runner + ctx.files.blacklist_file + ctx.files._proctor, + collect_default = True, + collect_data = True, + ), + )] + +_runtime_test = rule( + implementation = _runtime_test_impl, + attrs = { + "image": attr.string( + mandatory = False, + ), + "lang": attr.string( + mandatory = True, + ), + "blacklist_file": attr.label( + mandatory = False, + allow_single_file = True, + ), + "_runner": attr.label( + default = "//test/runtimes/runner:runner", + ), + "_proctor": attr.label( + default = "//test/runtimes/proctor:proctor", + ), + }, + test = True, +) + +def runtime_test(name, **kwargs): + _runtime_test( + name = name, + image = name, # Resolved as images/runtimes/%s. + tags = [ + "local", + "manual", + ], + **kwargs + ) + +def blacklist_test(name, blacklist_file): + """Test that a blacklist parses correctly.""" + go_test( + name = name + "_blacklist_test", + library = ":runner", + srcs = ["blacklist_test.go"], + args = ["--blacklist_file", "test/runtimes/" + blacklist_file], + data = [blacklist_file], + ) diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/proctor/BUILD index 85e004c45..50a26d182 100644 --- a/test/runtimes/images/proctor/BUILD +++ b/test/runtimes/proctor/BUILD @@ -12,7 +12,8 @@ go_binary( "proctor.go", "python.go", ], - visibility = ["//test/runtimes/images:__subpackages__"], + pure = True, + visibility = ["//test/runtimes:__pkg__"], ) go_test( @@ -21,6 +22,6 @@ go_test( srcs = ["proctor_test.go"], library = ":proctor", deps = [ - "//runsc/testutil", + "//pkg/test/testutil", ], ) diff --git a/test/runtimes/images/proctor/go.go b/test/runtimes/proctor/go.go index 3e2d5d8db..3e2d5d8db 100644 --- a/test/runtimes/images/proctor/go.go +++ b/test/runtimes/proctor/go.go diff --git a/test/runtimes/images/proctor/java.go b/test/runtimes/proctor/java.go index 8b362029d..8b362029d 100644 --- a/test/runtimes/images/proctor/java.go +++ b/test/runtimes/proctor/java.go diff --git a/test/runtimes/images/proctor/nodejs.go b/test/runtimes/proctor/nodejs.go index bd57db444..bd57db444 100644 --- a/test/runtimes/images/proctor/nodejs.go +++ b/test/runtimes/proctor/nodejs.go diff --git a/test/runtimes/images/proctor/php.go b/test/runtimes/proctor/php.go index 9115040e1..9115040e1 100644 --- a/test/runtimes/images/proctor/php.go +++ b/test/runtimes/proctor/php.go diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/proctor/proctor.go index b54abe434..b54abe434 100644 --- a/test/runtimes/images/proctor/proctor.go +++ b/test/runtimes/proctor/proctor.go diff --git a/test/runtimes/images/proctor/proctor_test.go b/test/runtimes/proctor/proctor_test.go index 6bb61d142..6ef2de085 100644 --- a/test/runtimes/images/proctor/proctor_test.go +++ b/test/runtimes/proctor/proctor_test.go @@ -23,24 +23,24 @@ import ( "strings" "testing" - "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) func touch(t *testing.T, name string) { t.Helper() f, err := os.Create(name) if err != nil { - t.Fatal(err) + t.Fatalf("error creating file %q: %v", name, err) } if err := f.Close(); err != nil { - t.Fatal(err) + t.Fatalf("error closing file %q: %v", name, err) } } func TestSearchEmptyDir(t *testing.T) { td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest") if err != nil { - t.Fatal(err) + t.Fatalf("error creating searchtest: %v", err) } defer os.RemoveAll(td) @@ -60,7 +60,7 @@ func TestSearchEmptyDir(t *testing.T) { func TestSearch(t *testing.T) { td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest") if err != nil { - t.Fatal(err) + t.Fatalf("error creating searchtest: %v", err) } defer os.RemoveAll(td) @@ -101,14 +101,14 @@ func TestSearch(t *testing.T) { if strings.HasSuffix(item, "/") { // This item is a directory, create it. if err := os.MkdirAll(filepath.Join(td, item), 0755); err != nil { - t.Fatal(err) + t.Fatalf("error making directory: %v", err) } } else { // This item is a file, create the directory and touch file. // Create directory in which file should be created fullDirPath := filepath.Join(td, filepath.Dir(item)) if err := os.MkdirAll(fullDirPath, 0755); err != nil { - t.Fatal(err) + t.Fatalf("error making directory: %v", err) } // Create file with full path to file. touch(t, filepath.Join(td, item)) diff --git a/test/runtimes/images/proctor/python.go b/test/runtimes/proctor/python.go index b9e0fbe6f..b9e0fbe6f 100644 --- a/test/runtimes/images/proctor/python.go +++ b/test/runtimes/proctor/python.go diff --git a/test/runtimes/runner.sh b/test/runtimes/runner.sh deleted file mode 100755 index a8d9a3460..000000000 --- a/test/runtimes/runner.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -euf -x -o pipefail - -echo -- "$@" - -# Create outputs dir if it does not exist. -if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then - mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}" - chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}" -fi - -# Update the timestamp on the shard status file. Bazel looks for this. -touch "${TEST_SHARD_STATUS_FILE}" - -# Get location of runner binary. -readonly runner=$(find "${TEST_SRCDIR}" -name runner) - -# Pass the arguments of this script directly to the runner. -exec "${runner}" "$@" - diff --git a/test/runtimes/runner/BUILD b/test/runtimes/runner/BUILD new file mode 100644 index 000000000..63924b9c5 --- /dev/null +++ b/test/runtimes/runner/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_binary", "go_test") + +package(licenses = ["notice"]) + +go_binary( + name = "runner", + testonly = 1, + srcs = ["main.go"], + visibility = ["//test/runtimes:__pkg__"], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) + +go_test( + name = "blacklist_test", + size = "small", + srcs = ["blacklist_test.go"], + library = ":runner", +) diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/runner/blacklist_test.go index 52f49b984..0ff69ab18 100644 --- a/test/runtimes/blacklist_test.go +++ b/test/runtimes/runner/blacklist_test.go @@ -32,6 +32,6 @@ func TestBlacklists(t *testing.T) { t.Fatalf("error parsing blacklist: %v", err) } if *blacklistFile != "" && len(bl) == 0 { - t.Errorf("got empty blacklist for file %q", blacklistFile) + t.Errorf("got empty blacklist for file %q", *blacklistFile) } } diff --git a/test/runtimes/runner.go b/test/runtimes/runner/main.go index ddb890dbc..57540e00e 100644 --- a/test/runtimes/runner.go +++ b/test/runtimes/runner/main.go @@ -26,8 +26,8 @@ import ( "testing" "time" - "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" ) var ( @@ -45,7 +45,6 @@ func main() { fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n") os.Exit(1) } - os.Exit(runTests()) } @@ -60,8 +59,8 @@ func runTests() int { return 1 } - // Create a single docker container that will be used for all tests. - d := dockerutil.MakeDocker("gvisor-" + *lang) + // Construct the shared docker instance. + d := dockerutil.MakeDocker(testutil.DefaultLogger(*lang)) defer d.CleanUp() // Get a slice of tests to run. This will also start a single Docker @@ -77,21 +76,18 @@ func runTests() int { return m.Run() } -// getTests returns a slice of tests to run, subject to the shard size and -// index. -func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.InternalTest, error) { - // Pull the image. - if err := dockerutil.Pull(*image); err != nil { - return nil, fmt.Errorf("docker pull %q failed: %v", *image, err) - } - - // Run proctor with --pause flag to keep container alive forever. - if err := d.Run(*image, "--pause"); err != nil { +// getTests executes all tests as table tests. +func getTests(d *dockerutil.Docker, blacklist map[string]struct{}) ([]testing.InternalTest, error) { + // Start the container. + d.CopyFiles("/proctor", "test/runtimes/proctor/proctor") + if err := d.Spawn(dockerutil.RunOpts{ + Image: fmt.Sprintf("runtimes/%s", *image), + }, "/proctor/proctor", "--pause"); err != nil { return nil, fmt.Errorf("docker run failed: %v", err) } // Get a list of all tests in the image. - list, err := d.Exec("/proctor", "--runtime", *lang, "--list") + list, err := d.Exec(dockerutil.RunOpts{}, "/proctor/proctor", "--runtime", *lang, "--list") if err != nil { return nil, fmt.Errorf("docker exec failed: %v", err) } @@ -114,7 +110,7 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int F: func(t *testing.T) { // Is the test blacklisted? if _, ok := blacklist[tc]; ok { - t.Skip("SKIP: blacklisted test %q", tc) + t.Skipf("SKIP: blacklisted test %q", tc) } var ( @@ -126,7 +122,7 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int go func() { fmt.Printf("RUNNING %s...\n", tc) - output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc) + output, err = d.Exec(dockerutil.RunOpts{}, "/proctor/proctor", "--runtime", *lang, "--test", tc) close(done) }() @@ -143,6 +139,7 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int }, }) } + return itests, nil } @@ -153,11 +150,7 @@ func getBlacklist() (map[string]struct{}, error) { if *blacklistFile == "" { return blacklist, nil } - file, err := testutil.FindFile(*blacklistFile) - if err != nil { - return nil, err - } - f, err := os.Open(file) + f, err := os.Open(*blacklistFile) if err != nil { return nil, err } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 31d239c0e..9800a0cdf 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -1,5 +1,4 @@ -load("//tools:defs.bzl", "go_binary") -load("//test/syscalls:build_defs.bzl", "syscall_test") +load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) @@ -259,6 +258,8 @@ syscall_test( syscall_test(test = "//test/syscalls/linux:munmap_test") +syscall_test(test = "//test/syscalls/linux:network_namespace_test") + syscall_test( add_overlay = True, test = "//test/syscalls/linux:open_create_test", @@ -317,10 +318,14 @@ syscall_test( test = "//test/syscalls/linux:proc_test", ) -syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test") - syscall_test(test = "//test/syscalls/linux:proc_net_test") +syscall_test(test = "//test/syscalls/linux:proc_pid_oomscore_test") + +syscall_test(test = "//test/syscalls/linux:proc_pid_smaps_test") + +syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test") + syscall_test( size = "medium", test = "//test/syscalls/linux:pselect_test", @@ -677,6 +682,13 @@ syscall_test( test = "//test/syscalls/linux:truncate_test", ) +syscall_test(test = "//test/syscalls/linux:tuntap_test") + +syscall_test( + add_hostinet = True, + test = "//test/syscalls/linux:tuntap_hostinet_test", +) + syscall_test(test = "//test/syscalls/linux:udp_bind_test") syscall_test( @@ -726,21 +738,3 @@ syscall_test(test = "//test/syscalls/linux:proc_net_unix_test") syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test") syscall_test(test = "//test/syscalls/linux:proc_net_udp_test") - -go_binary( - name = "syscall_test_runner", - testonly = 1, - srcs = ["syscall_test_runner.go"], - data = [ - "//runsc", - ], - deps = [ - "//pkg/log", - "//runsc/specutils", - "//runsc/testutil", - "//test/syscalls/gtest", - "//test/uds", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/test/syscalls/gtest/gtest.go b/test/syscalls/gtest/gtest.go deleted file mode 100644 index bdec8eb07..000000000 --- a/test/syscalls/gtest/gtest.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gtest contains helpers for running google-test tests from Go. -package gtest - -import ( - "fmt" - "os/exec" - "strings" -) - -var ( - // ListTestFlag is the flag that will list tests in gtest binaries. - ListTestFlag = "--gtest_list_tests" - - // FilterTestFlag is the flag that will filter tests in gtest binaries. - FilterTestFlag = "--gtest_filter" -) - -// TestCase is a single gtest test case. -type TestCase struct { - // Suite is the suite for this test. - Suite string - - // Name is the name of this individual test. - Name string -} - -// FullName returns the name of the test including the suite. It is suitable to -// pass to "-gtest_filter". -func (tc TestCase) FullName() string { - return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) -} - -// ParseTestCases calls a gtest test binary to list its test and returns a -// slice with the name and suite of each test. -func ParseTestCases(testBin string, extraArgs ...string) ([]TestCase, error) { - args := append([]string{ListTestFlag}, extraArgs...) - cmd := exec.Command(testBin, args...) - out, err := cmd.Output() - if err != nil { - exitErr, ok := err.(*exec.ExitError) - if !ok { - return nil, fmt.Errorf("could not enumerate gtest tests: %v", err) - } - return nil, fmt.Errorf("could not enumerate gtest tests: %v\nstderr:\n%s", err, exitErr.Stderr) - } - - var t []TestCase - var suite string - for _, line := range strings.Split(string(out), "\n") { - // Strip comments. - line = strings.Split(line, "#")[0] - - // New suite? - if !strings.HasPrefix(line, " ") { - suite = strings.TrimSuffix(strings.TrimSpace(line), ".") - continue - } - - // Individual test. - name := strings.TrimSpace(line) - - // Do we have a suite yet? - if suite == "" { - return nil, fmt.Errorf("test without a suite: %v", name) - } - - // Add this individual test. - t = append(t, TestCase{ - Suite: suite, - Name: name, - }) - - } - - if len(t) == 0 { - return nil, fmt.Errorf("no tests parsed from %v", testBin) - } - return t, nil -} diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc index c47a05181..3c825477c 100644 --- a/test/syscalls/linux/32bit.cc +++ b/test/syscalls/linux/32bit.cc @@ -74,7 +74,7 @@ void ExitGroup32(const char instruction[2], int code) { "int $3\n" : : [ code ] "m"(code), [ ip ] "d"(m.ptr()) - : "rax", "rbx", "rsp"); + : "rax", "rbx"); } constexpr int kExitCode = 42; diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index ca1af209a..d9095c95f 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -10,13 +10,20 @@ exports_files( "socket.cc", "socket_inet_loopback.cc", "socket_ip_loopback_blocking.cc", + "socket_ip_tcp_generic_loopback.cc", "socket_ip_tcp_loopback.cc", + "socket_ip_tcp_loopback_blocking.cc", + "socket_ip_tcp_loopback_nonblock.cc", + "socket_ip_tcp_udp_generic.cc", "socket_ip_udp_loopback.cc", + "socket_ip_udp_loopback_blocking.cc", + "socket_ip_udp_loopback_nonblock.cc", "socket_ip_unbound.cc", "socket_ipv4_tcp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_loopback.cc", "tcp_socket.cc", + "udp_bind.cc", "udp_socket.cc", ], visibility = ["//:sandbox"], @@ -125,6 +132,16 @@ cc_library( ) cc_library( + name = "socket_netlink_route_util", + testonly = 1, + srcs = ["socket_netlink_route_util.cc"], + hdrs = ["socket_netlink_route_util.h"], + deps = [ + ":socket_netlink_util", + ], +) + +cc_library( name = "socket_test_util", testonly = 1, srcs = [ @@ -149,11 +166,6 @@ cc_library( ) cc_library( - name = "temp_umask", - hdrs = ["temp_umask.h"], -) - -cc_library( name = "unix_domain_socket_test_util", testonly = 1, srcs = ["unix_domain_socket_test_util.cc"], @@ -590,7 +602,10 @@ cc_binary( cc_binary( name = "exceptions_test", testonly = 1, - srcs = ["exceptions.cc"], + srcs = select_arch( + amd64 = ["exceptions.cc"], + arm64 = [], + ), linkstatic = 1, deps = [ gtest, @@ -647,10 +662,7 @@ cc_binary( cc_binary( name = "exec_binary_test", testonly = 1, - srcs = select_arch( - amd64 = ["exec_binary.cc"], - arm64 = [], - ), + srcs = ["exec_binary.cc"], linkstatic = 1, deps = [ "//test/util:cleanup", @@ -1122,11 +1134,11 @@ cc_binary( srcs = ["mkdir.cc"], linkstatic = 1, deps = [ - ":temp_umask", "//test/util:capability_util", "//test/util:fs_util", gtest, "//test/util:temp_path", + "//test/util:temp_umask", "//test/util:test_main", "//test/util:test_util", ], @@ -1281,12 +1293,12 @@ cc_binary( srcs = ["open_create.cc"], linkstatic = 1, deps = [ - ":temp_umask", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", gtest, "//test/util:temp_path", + "//test/util:temp_umask", "//test/util:test_main", "//test/util:test_util", ], @@ -1457,7 +1469,10 @@ cc_binary( cc_binary( name = "arch_prctl_test", testonly = 1, - srcs = ["arch_prctl.cc"], + srcs = select_arch( + amd64 = ["arch_prctl.cc"], + arm64 = [], + ), linkstatic = 1, deps = [ "//test/util:file_descriptor", @@ -1613,6 +1628,19 @@ cc_binary( ) cc_binary( + name = "proc_pid_oomscore_test", + testonly = 1, + srcs = ["proc_pid_oomscore.cc"], + linkstatic = 1, + deps = [ + "//test/util:fs_util", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( name = "proc_pid_smaps_test", testonly = 1, srcs = ["proc_pid_smaps.cc"], @@ -1994,6 +2022,8 @@ cc_binary( "//test/util:file_descriptor", "@com_google_absl//absl/strings", gtest, + ":ip_socket_test_util", + ":unix_domain_socket_test_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -2770,13 +2800,13 @@ cc_binary( srcs = ["socket_netlink_route.cc"], linkstatic = 1, deps = [ + ":socket_netlink_route_util", ":socket_netlink_util", ":socket_test_util", "//test/util:capability_util", "//test/util:cleanup", "//test/util:file_descriptor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:optional", gtest, "//test/util:test_main", "//test/util:test_util", @@ -3423,6 +3453,37 @@ cc_binary( ], ) +cc_binary( + name = "tuntap_test", + testonly = 1, + srcs = ["tuntap.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + gtest, + "//test/syscalls/linux:socket_netlink_route_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "tuntap_hostinet_test", + testonly = 1, + srcs = ["tuntap_hostinet.cc"], + linkstatic = 1, + deps = [ + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + cc_library( name = "udp_socket_test_cases", testonly = 1, @@ -3633,6 +3694,22 @@ cc_binary( ) cc_binary( + name = "network_namespace_test", + testonly = 1, + srcs = ["network_namespace.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + gtest, + "//test/util:capability_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + +cc_binary( name = "semaphore_test", testonly = 1, srcs = ["semaphore.cc"], diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc index a33daff17..806d5729e 100644 --- a/test/syscalls/linux/aio.cc +++ b/test/syscalls/linux/aio.cc @@ -89,6 +89,7 @@ class AIOTest : public FileTest { FileTest::TearDown(); if (ctx_ != 0) { ASSERT_THAT(DestroyContext(), SyscallSucceeds()); + ctx_ = 0; } } @@ -188,14 +189,19 @@ TEST_F(AIOTest, BadWrite) { } TEST_F(AIOTest, ExitWithPendingIo) { - // Setup a context that is 5 entries deep. - ASSERT_THAT(SetupContext(5), SyscallSucceeds()); + // Setup a context that is 100 entries deep. + ASSERT_THAT(SetupContext(100), SyscallSucceeds()); struct iocb cb = CreateCallback(); struct iocb* cbs[] = {&cb}; // Submit a request but don't complete it to make it pending. - EXPECT_THAT(Submit(1, cbs), SyscallSucceeds()); + for (int i = 0; i < 100; ++i) { + EXPECT_THAT(Submit(1, cbs), SyscallSucceeds()); + } + + ASSERT_THAT(DestroyContext(), SyscallSucceeds()); + ctx_ = 0; } int Submitter(void* arg) { diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc index d89269985..940c97285 100644 --- a/test/syscalls/linux/alarm.cc +++ b/test/syscalls/linux/alarm.cc @@ -188,6 +188,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/bad.cc b/test/syscalls/linux/bad.cc index adfb149df..a26fc6af3 100644 --- a/test/syscalls/linux/bad.cc +++ b/test/syscalls/linux/bad.cc @@ -28,7 +28,7 @@ namespace { constexpr uint32_t kNotImplementedSyscall = SYS_get_kernel_syms; #elif __aarch64__ // Use the last of arch_specific_syscalls which are not implemented on arm64. -constexpr uint32_t kNotImplementedSyscall = SYS_arch_specific_syscall + 15; +constexpr uint32_t kNotImplementedSyscall = __NR_arch_specific_syscall + 15; #endif TEST(BadSyscallTest, NotImplemented) { diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc index a4f8f3cec..f57d38dc7 100644 --- a/test/syscalls/linux/epoll.cc +++ b/test/syscalls/linux/epoll.cc @@ -56,10 +56,6 @@ TEST(EpollTest, AllWritable) { struct epoll_event result[kFDsPerEpoll]; ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1), SyscallSucceedsWithValue(kFDsPerEpoll)); - // TODO(edahlgren): Why do some tests check epoll_event::data, and others - // don't? Does Linux actually guarantee that, in any of these test cases, - // epoll_wait will necessarily write out the epoll_events in the order that - // they were registered? for (int i = 0; i < kFDsPerEpoll; i++) { ASSERT_EQ(result[i].events, EPOLLOUT); } diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index b5e0a512b..12c9b05ca 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -812,26 +812,28 @@ void ExecFromThread() { bool ValidateProcCmdlineVsArgv(const int argc, const char* const* argv) { auto contents_or = GetContents("/proc/self/cmdline"); if (!contents_or.ok()) { - std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error(); + std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error() + << std::endl; return false; } auto contents = contents_or.ValueOrDie(); if (contents.back() != '\0') { - std::cerr << "Non-null terminated /proc/self/cmdline!"; + std::cerr << "Non-null terminated /proc/self/cmdline!" << std::endl; return false; } contents.pop_back(); std::vector<std::string> procfs_cmdline = absl::StrSplit(contents, '\0'); if (static_cast<int>(procfs_cmdline.size()) != argc) { - std::cerr << "argc = " << argc << " != " << procfs_cmdline.size(); + std::cerr << "argc = " << argc << " != " << procfs_cmdline.size() + << std::endl; return false; } for (int i = 0; i < argc; ++i) { if (procfs_cmdline[i] != argv[i]) { std::cerr << "Procfs command line argument " << i << " mismatch " - << procfs_cmdline[i] << " != " << argv[i]; + << procfs_cmdline[i] << " != " << argv[i] << std::endl; return false; } } @@ -868,6 +870,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 736452b0c..1a9f203b9 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -48,10 +48,17 @@ namespace { using ::testing::AnyOf; using ::testing::Eq; -#ifndef __x86_64__ +#if !defined(__x86_64__) && !defined(__aarch64__) // The assembly stub and ELF internal details must be ported to other arches. -#error "Test only supported on x86-64" -#endif // __x86_64__ +#error "Test only supported on x86-64/arm64" +#endif // __x86_64__ || __aarch64__ + +#if defined(__x86_64__) +#define EM_TYPE EM_X86_64 +#define IP_REG(p) ((p).rip) +#define RAX_REG(p) ((p).rax) +#define RDI_REG(p) ((p).rdi) +#define RETURN_REG(p) ((p).rax) // amd64 stub that calls PTRACE_TRACEME and sends itself SIGSTOP. const char kPtraceCode[] = { @@ -139,6 +146,76 @@ const char kPtraceCode[] = { // Size of a syscall instruction. constexpr int kSyscallSize = 2; +#elif defined(__aarch64__) +#define EM_TYPE EM_AARCH64 +#define IP_REG(p) ((p).pc) +#define RAX_REG(p) ((p).regs[8]) +#define RDI_REG(p) ((p).regs[0]) +#define RETURN_REG(p) ((p).regs[0]) + +const char kPtraceCode[] = { + // MOVD $117, R8 /* ptrace */ + '\xa8', + '\x0e', + '\x80', + '\xd2', + // MOVD $0, R0 /* PTRACE_TRACEME */ + '\x00', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R1 /* pid */ + '\x01', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R2 /* addr */ + '\x02', + '\x00', + '\x80', + '\xd2', + // MOVD $0, R3 /* data */ + '\x03', + '\x00', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', + // MOVD $172, R8 /* getpid */ + '\x88', + '\x15', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', + // MOVD $129, R8 /* kill, R0=pid */ + '\x28', + '\x10', + '\x80', + '\xd2', + // MOVD $19, R1 /* SIGSTOP */ + '\x61', + '\x02', + '\x80', + '\xd2', + // SVC + '\x01', + '\x00', + '\x00', + '\xd4', +}; +// Size of a syscall instruction. +constexpr int kSyscallSize = 4; +#else +#error "Unknown architecture" +#endif + // This test suite tests executable loading in the kernel (ELF and interpreter // scripts). @@ -281,7 +358,7 @@ ElfBinary<64> StandardElf() { elf.header.e_ident[EI_DATA] = ELFDATA2LSB; elf.header.e_ident[EI_VERSION] = EV_CURRENT; elf.header.e_type = ET_EXEC; - elf.header.e_machine = EM_X86_64; + elf.header.e_machine = EM_TYPE; elf.header.e_version = EV_CURRENT; elf.header.e_phoff = sizeof(elf.header); elf.header.e_phentsize = sizeof(decltype(elf)::ElfPhdr); @@ -327,9 +404,15 @@ TEST(ElfTest, Execute) { ASSERT_NO_ERRNO(WaitStopped(child)); struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); - // RIP is just beyond the final syscall instruction. - EXPECT_EQ(regs.rip, elf.header.e_entry + sizeof(kPtraceCode)); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); + // RIP/PC is just beyond the final syscall instruction. + EXPECT_EQ(IP_REG(regs), elf.header.e_entry + sizeof(kPtraceCode)); EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, @@ -718,9 +801,16 @@ TEST(ElfTest, PIE) { // RIP tells us which page the first segment was loaded into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ // text page. @@ -787,9 +877,15 @@ TEST(ElfTest, PIENonZeroStart) { // RIP tells us which page the first segment was loaded into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1); // The ELF is loaded at an arbitrary address, not the first PT_LOAD vaddr. // @@ -910,9 +1006,15 @@ TEST(ElfTest, ELFInterpreter) { // RIP tells us which page the first segment of the interpreter was loaded // into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT( child, ContainsMappings(std::vector<ProcMapsEntry>({ @@ -1084,9 +1186,15 @@ TEST(ElfTest, ELFInterpreterRelative) { // RIP tells us which page the first segment of the interpreter was loaded // into. struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); - const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); + const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1); EXPECT_THAT( child, ContainsMappings(std::vector<ProcMapsEntry>({ @@ -1480,14 +1588,21 @@ TEST(ExecveTest, BrkAfterBinary) { ASSERT_NO_ERRNO(WaitStopped(child)); struct user_regs_struct regs; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); // RIP is just beyond the final syscall instruction. Rewind to execute a brk // syscall. - regs.rip -= kSyscallSize; - regs.rax = __NR_brk; - regs.rdi = 0; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child, 0, ®s), SyscallSucceeds()); + IP_REG(regs) -= kSyscallSize; + RAX_REG(regs) = __NR_brk; + RDI_REG(regs) = 0; + ASSERT_THAT(ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); // Resume the child, waiting for syscall entry. ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds()); @@ -1504,7 +1619,12 @@ TEST(ExecveTest, BrkAfterBinary) { ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP) << "status = " << status; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, ®s), SyscallSucceeds()); + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov), + SyscallSucceeds()); + // Read exactly the full register set. + EXPECT_EQ(iov.iov_len, sizeof(regs)); // brk is after the text page. // @@ -1512,7 +1632,7 @@ TEST(ExecveTest, BrkAfterBinary) { // address will be, but it is always beyond the final page in the binary. // i.e., it does not start immediately after memsz in the middle of a page. // Userspace may expect to use that space. - EXPECT_GE(regs.rax, 0x41000); + EXPECT_GE(RETURN_REG(regs), 0x41000); } } // namespace diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc index 1c3d00287..7819f4ac3 100644 --- a/test/syscalls/linux/fallocate.cc +++ b/test/syscalls/linux/fallocate.cc @@ -33,7 +33,7 @@ namespace testing { namespace { int fallocate(int fd, int mode, off_t offset, off_t len) { - return syscall(__NR_fallocate, fd, mode, offset, len); + return RetryEINTR(syscall)(__NR_fallocate, fd, mode, offset, len); } class AllocateTest : public FileTest { @@ -47,27 +47,27 @@ TEST_F(AllocateTest, Fallocate) { EXPECT_EQ(buf.st_size, 0); // Grow to ten bytes. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 10); // Allocate to a smaller size should be noop. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 10); // Grow again. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 20); // Grow with offset. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 30); // Grow with offset beyond EOF. - EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds()); + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 40); } diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 421c15b87..c7cc5816e 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -1128,5 +1128,5 @@ int main(int argc, char** argv) { exit(err); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h index 6f80bc97c..fb418e052 100644 --- a/test/syscalls/linux/file_base.h +++ b/test/syscalls/linux/file_base.h @@ -52,17 +52,6 @@ class FileTest : public ::testing::Test { test_file_fd_ = ASSERT_NO_ERRNO_AND_VALUE( Open(test_file_name_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR)); - // FIXME(edahlgren): enable when mknod syscall is supported. - // test_fifo_name_ = NewTempAbsPath(); - // ASSERT_THAT(mknod(test_fifo_name_.c_str()), S_IFIFO|0644, 0, - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[1] = open(test_fifo_name_.c_str(), - // O_WRONLY), - // SyscallSucceeds()); - // ASSERT_THAT(test_fifo_[0] = open(test_fifo_name_.c_str(), - // O_RDONLY), - // SyscallSucceeds()); - ASSERT_THAT(pipe(test_pipe_), SyscallSucceeds()); ASSERT_THAT(fcntl(test_pipe_[0], F_SETFL, O_NONBLOCK), SyscallSucceeds()); } @@ -96,18 +85,12 @@ class FileTest : public ::testing::Test { CloseFile(); UnlinkFile(); ClosePipes(); - - // FIXME(edahlgren): enable when mknod syscall is supported. - // close(test_fifo_[0]); - // close(test_fifo_[1]); - // unlink(test_fifo_name_.c_str()); } + protected: std::string test_file_name_; - std::string test_fifo_name_; FileDescriptor test_file_fd_; - int test_fifo_[2]; int test_pipe_[2]; }; diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc index ff8bdfeb0..853f6231a 100644 --- a/test/syscalls/linux/fork.cc +++ b/test/syscalls/linux/fork.cc @@ -431,7 +431,6 @@ TEST(CloneTest, NewUserNamespacePermitsAllOtherNamespaces) { << "status = " << status; } -#ifdef __x86_64__ // Clone with CLONE_SETTLS and a non-canonical TLS address is rejected. TEST(CloneTest, NonCanonicalTLS) { constexpr uintptr_t kNonCanonical = 1ull << 48; @@ -440,11 +439,25 @@ TEST(CloneTest, NonCanonicalTLS) { // on this. char stack; + // The raw system call interface on x86-64 is: + // long clone(unsigned long flags, void *stack, + // int *parent_tid, int *child_tid, + // unsigned long tls); + // + // While on arm64, the order of the last two arguments is reversed: + // long clone(unsigned long flags, void *stack, + // int *parent_tid, unsigned long tls, + // int *child_tid); +#if defined(__x86_64__) EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr, nullptr, kNonCanonical), SyscallFailsWithErrno(EPERM)); -} +#elif defined(__aarch64__) + EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr, + kNonCanonical, nullptr), + SyscallFailsWithErrno(EPERM)); #endif +} } // namespace } // namespace testing diff --git a/test/syscalls/linux/getrandom.cc b/test/syscalls/linux/getrandom.cc index f97f60029..f87cdd7a1 100644 --- a/test/syscalls/linux/getrandom.cc +++ b/test/syscalls/linux/getrandom.cc @@ -29,6 +29,8 @@ namespace { #define SYS_getrandom 318 #elif defined(__i386__) #define SYS_getrandom 355 +#elif defined(__aarch64__) +#define SYS_getrandom 278 #else #error "Unknown architecture" #endif diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index bba022a41..98d07ae85 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -16,7 +16,6 @@ #include <net/if.h> #include <netinet/in.h> -#include <sys/ioctl.h> #include <sys/socket.h> #include <cstring> @@ -35,12 +34,11 @@ uint16_t PortFromInetSockaddr(const struct sockaddr* addr) { } PosixErrorOr<int> InterfaceIndex(std::string name) { - // TODO(igudger): Consider using netlink. - ifreq req = {}; - memcpy(req.ifr_name, name.c_str(), name.size()); - ASSIGN_OR_RETURN_ERRNO(auto sock, Socket(AF_INET, SOCK_DGRAM, 0)); - RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(sock.get(), SIOCGIFINDEX, &req)); - return req.ifr_ifindex; + int index = if_nametoindex(name.c_str()); + if (index) { + return index; + } + return PosixError(errno); } namespace { @@ -177,17 +175,17 @@ SocketKind IPv6TCPUnboundSocket(int type) { PosixError IfAddrHelper::Load() { Release(); RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_)); - return PosixError(0); + return NoError(); } void IfAddrHelper::Release() { if (ifaddr_) { freeifaddrs(ifaddr_); + ifaddr_ = nullptr; } - ifaddr_ = nullptr; } -std::vector<std::string> IfAddrHelper::InterfaceList(int family) { +std::vector<std::string> IfAddrHelper::InterfaceList(int family) const { std::vector<std::string> names; for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { @@ -198,7 +196,7 @@ std::vector<std::string> IfAddrHelper::InterfaceList(int family) { return names; } -sockaddr* IfAddrHelper::GetAddr(int family, std::string name) { +const sockaddr* IfAddrHelper::GetAddr(int family, std::string name) const { for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) { if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) { continue; @@ -210,7 +208,7 @@ sockaddr* IfAddrHelper::GetAddr(int family, std::string name) { return nullptr; } -PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) { +PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) const { return InterfaceIndex(name); } diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 083ebbcf0..9c3859fcd 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -84,20 +84,20 @@ SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type); // SocketPairs created with AF_INET and the given type. SocketPairKind IPv4UDPUnboundSocketPair(int type); -// IPv4UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type. +// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_DGRAM, and the given type. SocketKind IPv4UDPUnboundSocket(int type); -// IPv6UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type. +// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_DGRAM, and the given type. SocketKind IPv6UDPUnboundSocket(int type); -// IPv4TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_STREAM and the given type. +// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_STREAM and the given type. SocketKind IPv4TCPUnboundSocket(int type); -// IPv6TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type. +// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_STREAM and the given type. SocketKind IPv6TCPUnboundSocket(int type); // IfAddrHelper is a helper class that determines the local interfaces present @@ -110,10 +110,10 @@ class IfAddrHelper { PosixError Load(); void Release(); - std::vector<std::string> InterfaceList(int family); + std::vector<std::string> InterfaceList(int family) const; - struct sockaddr* GetAddr(int family, std::string name); - PosixErrorOr<int> GetIndex(std::string name); + const sockaddr* GetAddr(int family, std::string name) const; + PosixErrorOr<int> GetIndex(std::string name) const; private: struct ifaddrs* ifaddr_; diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc index b77e4cbd1..dd981a278 100644 --- a/test/syscalls/linux/itimer.cc +++ b/test/syscalls/linux/itimer.cc @@ -246,7 +246,7 @@ int TestSIGPROFFairness(absl::Duration sleep) { // The number of samples on the main thread should be very low as it did // nothing. - TEST_CHECK(result.main_thread_samples < 60); + TEST_CHECK(result.main_thread_samples < 80); // Both workers should get roughly equal number of samples. TEST_CHECK(result.worker_samples.size() == 2); @@ -349,6 +349,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc index a8af8e545..6ce1e6cc3 100644 --- a/test/syscalls/linux/lseek.cc +++ b/test/syscalls/linux/lseek.cc @@ -53,7 +53,7 @@ TEST(LseekTest, NegativeOffset) { // A 32-bit off_t is not large enough to represent an offset larger than // maximum file size on standard file systems, so it isn't possible to cause // overflow. -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(__aarch64__) TEST(LseekTest, Overflow) { // HA! Classic Linux. We really should have an EOVERFLOW // here, since we're seeking to something that cannot be diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc index e57b49a4a..f8b7f7938 100644 --- a/test/syscalls/linux/memfd.cc +++ b/test/syscalls/linux/memfd.cc @@ -16,6 +16,7 @@ #include <fcntl.h> #include <linux/magic.h> #include <linux/memfd.h> +#include <linux/unistd.h> #include <string.h> #include <sys/mman.h> #include <sys/statfs.h> diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index cf138d328..4036a9275 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -18,10 +18,10 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/temp_umask.h" #include "test/util/capability_util.h" #include "test/util/fs_util.h" #include "test/util/temp_path.h" +#include "test/util/temp_umask.h" #include "test/util/test_util.h" namespace gvisor { @@ -36,21 +36,12 @@ class MkdirTest : public ::testing::Test { // TearDown unlinks created files. void TearDown() override { - // FIXME(edahlgren): We don't currently implement rmdir. - // We do this unconditionally because there's no harm in trying. - rmdir(dirname_.c_str()); + EXPECT_THAT(rmdir(dirname_.c_str()), SyscallSucceeds()); } std::string dirname_; }; -TEST_F(MkdirTest, DISABLED_CanCreateReadbleDir) { - ASSERT_THAT(mkdir(dirname_.c_str(), 0444), SyscallSucceeds()); - ASSERT_THAT( - open(JoinPath(dirname_, "anything").c_str(), O_RDWR | O_CREAT, 0666), - SyscallFailsWithErrno(EACCES)); -} - TEST_F(MkdirTest, CanCreateWritableDir) { ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); std::string filename = JoinPath(dirname_, "anything"); @@ -84,10 +75,11 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - auto parent = ASSERT_NO_ERRNO_AND_VALUE( - TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555)); - auto dir = JoinPath(parent.path(), "foo"); - ASSERT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES)); + ASSERT_THAT(mkdir(dirname_.c_str(), 0555), SyscallSucceeds()); + auto dir = JoinPath(dirname_.c_str(), "foo"); + EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(open(JoinPath(dirname_, "file").c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EACCES)); } } // namespace diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc index 367a90fe1..78ac96bed 100644 --- a/test/syscalls/linux/mlock.cc +++ b/test/syscalls/linux/mlock.cc @@ -199,8 +199,10 @@ TEST(MunlockallTest, Basic) { } #ifndef SYS_mlock2 -#ifdef __x86_64__ +#if defined(__x86_64__) #define SYS_mlock2 325 +#elif defined(__aarch64__) +#define SYS_mlock2 284 #endif #endif diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 11fb1b457..6d3227ab6 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -361,7 +361,7 @@ TEST_F(MMapTest, MapFixed) { } // 64-bit addresses work too -#ifdef __x86_64__ +#if defined(__x86_64__) || defined(__aarch64__) TEST_F(MMapTest, MapFixed64) { EXPECT_THAT(Map(0x300000000000, kPageSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0), @@ -571,6 +571,12 @@ const uint8_t machine_code[] = { 0xb8, 0x2a, 0x00, 0x00, 0x00, // movl $42, %eax 0xc3, // retq }; +#elif defined(__aarch64__) +const uint8_t machine_code[] = { + 0x40, 0x05, 0x80, 0x52, // mov w0, #42 + 0xc0, 0x03, 0x5f, 0xd6, // ret +}; +#endif // PROT_EXEC allows code execution TEST_F(MMapTest, ProtExec) { @@ -605,7 +611,6 @@ TEST_F(MMapTest, NoProtExecDeath) { EXPECT_EXIT(func(), ::testing::KilledBySignal(SIGSEGV), ""); } -#endif TEST_F(MMapTest, NoExceedLimitData) { void* prevbrk; @@ -1644,6 +1649,7 @@ TEST(MMapNoFixtureTest, MapReadOnlyAfterCreateWriteOnly) { } // Conditional on MAP_32BIT. +// This flag is supported only on x86-64, for 64-bit programs. #ifdef __x86_64__ TEST(MMapNoFixtureTest, Map32Bit) { diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc new file mode 100644 index 000000000..133fdecf0 --- /dev/null +++ b/test/syscalls/linux/network_namespace.cc @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <net/if.h> +#include <sched.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { +namespace { + +TEST(NetworkNamespaceTest, LoopbackExists) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + ScopedThread t([&] { + ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0)); + + // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists. + // Check loopback device exists. + int sock = socket(AF_INET, SOCK_DGRAM, 0); + ASSERT_THAT(sock, SyscallSucceeds()); + struct ifreq ifr; + strncpy(ifr.ifr_name, "lo", IFNAMSIZ); + EXPECT_THAT(ioctl(sock, SIOCGIFINDEX, &ifr), SyscallSucceeds()) + << "lo cannot be found"; + }); +} + +} // namespace +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index 267ae19f6..640fe6bfc 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -186,6 +186,28 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) { ASSERT_NO_ERRNO_AND_VALUE(Open(path_via_symlink, O_RDONLY | O_NOFOLLOW)); } +// Test that open(2) can follow symlinks that point back to the same tree. +// Test sets up files as follows: +// root/child/symlink => redirects to ../.. +// root/child/target => regular file +// +// open("root/child/symlink/root/child/file") +TEST_F(OpenTest, SymlinkRecurse) { + auto root = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir())); + auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); + auto symlink = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(child.path(), "../..")); + auto target = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(child.path(), "abc", 0644)); + auto path_via_symlink = + JoinPath(symlink.path(), Basename(root.path()), Basename(child.path()), + Basename(target.path())); + const auto contents = + ASSERT_NO_ERRNO_AND_VALUE(GetContents(path_via_symlink)); + ASSERT_EQ(contents, "abc"); +} + TEST_F(OpenTest, Fault) { char* totally_not_null = nullptr; ASSERT_THAT(open(totally_not_null, O_RDONLY), SyscallFailsWithErrno(EFAULT)); diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 902d0a0dc..51eacf3f2 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -19,11 +19,11 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/temp_umask.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/temp_path.h" +#include "test/util/temp_umask.h" #include "test/util/test_util.h" namespace gvisor { diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 92ae55eec..5ac68feb4 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <arpa/inet.h> +#include <ifaddrs.h> #include <linux/capability.h> #include <linux/if_arp.h> #include <linux/if_packet.h> @@ -163,16 +164,11 @@ int CookedPacketTest::GetLoopbackIndex() { return ifr.ifr_ifindex; } -// Receive via a packet socket. -TEST_P(CookedPacketTest, Receive) { - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - SendUDPMessage(udp_sock.get()); - +// Receive and verify the message via packet socket on interface. +void ReceiveMessage(int sock, int ifindex) { // Wait for the socket to become readable. struct pollfd pfd = {}; - pfd.fd = socket_; + pfd.fd = sock; pfd.events = POLLIN; EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); @@ -182,9 +178,10 @@ TEST_P(CookedPacketTest, Receive) { char buf[64]; struct sockaddr_ll src = {}; socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_len), SyscallSucceedsWithValue(packet_size)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns @@ -194,7 +191,7 @@ TEST_P(CookedPacketTest, Receive) { // TODO(b/129292371): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_ifindex, ifindex); EXPECT_EQ(src.sll_halen, ETH_ALEN); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { @@ -222,6 +219,18 @@ TEST_P(CookedPacketTest, Receive) { EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); } +// Receive via a packet socket. +TEST_P(CookedPacketTest, Receive) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + int loopback_index = GetLoopbackIndex(); + ReceiveMessage(socket_, loopback_index); +} + // Send via a packet socket. TEST_P(CookedPacketTest, Send) { // TODO(b/129292371): Remove once we support packet socket writing. @@ -313,6 +322,115 @@ TEST_P(CookedPacketTest, Send) { EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); } +// Bind and receive via packet socket. +TEST_P(CookedPacketTest, BindReceive) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + ReceiveMessage(socket_, bind_addr.sll_ifindex); +} + +// Double Bind socket. +TEST_P(CookedPacketTest, DoubleBind) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Binding socket again should fail. + ASSERT_THAT( + bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched + // to EADDRINUSE. + AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL))); +} + +// Bind and verify we do not receive data on interface which is not bound +TEST_P(CookedPacketTest, BindDrop) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + + // Bind to packet socket requires only family, protocol and ifindex. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = ifr.ifr_ifindex; + + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Send to loopback interface. + struct sockaddr_in dest = {}; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket never receives any data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); +} + +// Bind with invalid address. +TEST_P(CookedPacketTest, BindFail) { + // Null address. + ASSERT_THAT( + bind(socket_, nullptr, sizeof(struct sockaddr)), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EINVAL))); + + // Address of size 1. + uint8_t addr = 0; + ASSERT_THAT( + bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallFailsWithErrno(EINVAL)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index d8e19e910..67228b66b 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -265,6 +265,8 @@ TEST_P(PipeTest, OffsetCalls) { SyscallFailsWithErrno(ESPIPE)); struct iovec iov; + iov.iov_base = &buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(preadv(wfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); } diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc index c42472474..1e35a4a8b 100644 --- a/test/syscalls/linux/poll.cc +++ b/test/syscalls/linux/poll.cc @@ -266,7 +266,7 @@ TEST_F(PollTest, Nfds) { } rlim_t max_fds = rlim.rlim_cur; - std::cout << "Using limit: " << max_fds; + std::cout << "Using limit: " << max_fds << std::endl; // Create an eventfd. Since its value is initially zero, it is writable. FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index d07571a5f..04c5161f5 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -226,5 +226,5 @@ int main(int argc, char** argv) { prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc index 30f0d75b3..c4e9cf528 100644 --- a/test/syscalls/linux/prctl_setuid.cc +++ b/test/syscalls/linux/prctl_setuid.cc @@ -264,5 +264,5 @@ int main(int argc, char** argv) { prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index 2cecf2e5f..bcdbbb044 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <linux/unistd.h> #include <sys/mman.h> #include <sys/socket.h> #include <sys/types.h> @@ -118,6 +119,21 @@ TEST_F(Pread64Test, EndOfFile) { EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallSucceedsWithValue(0)); } +int memfd_create(const std::string& name, unsigned int flags) { + return syscall(__NR_memfd_create, name.c_str(), flags); +} + +TEST_F(Pread64Test, Overflow) { + int f = memfd_create("negative", 0); + const FileDescriptor fd(f); + + EXPECT_THAT(ftruncate(fd.get(), 0x7fffffffffffffffull), SyscallSucceeds()); + + char buf[10]; + EXPECT_THAT(pread64(fd.get(), buf, sizeof(buf), 0x7fffffffffffffffull), + SyscallFailsWithErrno(EINVAL)); +} + TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) { int sock_fds[2]; EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds()); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index a23fdb58d..79a625ebc 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -994,7 +994,7 @@ constexpr uint64_t kMappingSize = 100 << 20; // Tolerance on RSS comparisons to account for background thread mappings, // reclaimed pages, newly faulted pages, etc. -constexpr uint64_t kRSSTolerance = 5 << 20; +constexpr uint64_t kRSSTolerance = 10 << 20; // Capture RSS before and after an anonymous mapping with passed prot. void MapPopulateRSS(int prot, uint64_t* before, uint64_t* after) { @@ -1326,8 +1326,6 @@ TEST(ProcPidSymlink, SubprocessRunning) { SyscallSucceedsWithValue(sizeof(buf))); } -// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux -// on proc files. TEST(ProcPidSymlink, SubprocessZombied) { ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); @@ -1337,7 +1335,7 @@ TEST(ProcPidSymlink, SubprocessZombied) { int want = EACCES; if (!IsRunningOnGvisor()) { auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); - if (version.major == 4 && version.minor > 3) { + if (version.major > 4 || (version.major == 4 && version.minor > 3)) { want = ENOENT; } } @@ -1350,30 +1348,25 @@ TEST(ProcPidSymlink, SubprocessZombied) { SyscallFailsWithErrno(want)); } - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between linux on proc + // files. // // ~4.3: Syscall fails with EACCES. - // 4.17 & gVisor: Syscall succeeds and returns 1. + // 4.17: Syscall succeeds and returns 1. // - // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); + if (!IsRunningOnGvisor()) { + return; + } - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux - // on proc files. - // - // ~4.3: Syscall fails with EACCES. - // 4.17 & gVisor: Syscall succeeds and returns 1. - // - // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), - // SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), + SyscallFailsWithErrno(want)); + + EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), + SyscallFailsWithErrno(want)); } // Test whether /proc/PID/ symlinks can be read for an exited process. TEST(ProcPidSymlink, SubprocessExited) { - // FIXME(gvisor.dev/issue/164): These all succeed on gVisor. - SKIP_IF(IsRunningOnGvisor()); - char buf[1]; EXPECT_THAT(ReadlinkWhileExited("exe", buf, sizeof(buf)), @@ -1431,6 +1424,12 @@ TEST(ProcPidFile, SubprocessRunning) { EXPECT_THAT(ReadWhileRunning("uid_map", buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(ReadWhileRunning("oom_score", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(ReadWhileRunning("oom_score_adj", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); } // Test whether /proc/PID/ files can be read for a zombie process. @@ -1466,6 +1465,12 @@ TEST(ProcPidFile, SubprocessZombie) { EXPECT_THAT(ReadWhileZombied("uid_map", buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + EXPECT_THAT(ReadWhileZombied("oom_score", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(ReadWhileZombied("oom_score_adj", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux // on proc files. // @@ -1527,6 +1532,15 @@ TEST(ProcPidFile, SubprocessExited) { EXPECT_THAT(ReadWhileExited("uid_map", buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + + if (!IsRunningOnGvisor()) { + // FIXME(gvisor.dev/issue/164): Succeeds on gVisor. + EXPECT_THAT(ReadWhileExited("oom_score", buf, sizeof(buf)), + SyscallFailsWithErrno(ESRCH)); + } + + EXPECT_THAT(ReadWhileExited("oom_score_adj", buf, sizeof(buf)), + SyscallFailsWithErrno(ESRCH)); } PosixError DirContainsImpl(absl::string_view path, @@ -2076,5 +2090,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 3a611a86f..cac394910 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -33,6 +33,31 @@ namespace gvisor { namespace testing { namespace { +constexpr const char kProcNet[] = "/proc/net"; + +TEST(ProcNetSymlinkTarget, FileMode) { + struct stat s; + ASSERT_THAT(stat(kProcNet, &s), SyscallSucceeds()); + EXPECT_EQ(s.st_mode & S_IFMT, S_IFDIR); + EXPECT_EQ(s.st_mode & 0777, 0555); +} + +TEST(ProcNetSymlink, FileMode) { + struct stat s; + ASSERT_THAT(lstat(kProcNet, &s), SyscallSucceeds()); + EXPECT_EQ(s.st_mode & S_IFMT, S_IFLNK); + EXPECT_EQ(s.st_mode & 0777, 0777); +} + +TEST(ProcNetSymlink, Contents) { + char buf[40] = {}; + int n = readlink(kProcNet, buf, sizeof(buf)); + ASSERT_THAT(n, SyscallSucceeds()); + + buf[n] = 0; + EXPECT_STREQ(buf, "self/net"); +} + TEST(ProcNetIfInet6, Format) { auto ifinet6 = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/if_inet6")); EXPECT_THAT(ifinet6, @@ -67,6 +92,59 @@ TEST(ProcSysNetIpv4Sack, CanReadAndWrite) { EXPECT_EQ(buf, to_write); } +// DeviceEntry is an entry in /proc/net/dev +struct DeviceEntry { + std::string name; + uint64_t stats[16]; +}; + +PosixErrorOr<std::vector<DeviceEntry>> GetDeviceMetricsFromProc( + const std::string dev) { + std::vector<std::string> lines = absl::StrSplit(dev, '\n'); + std::vector<DeviceEntry> entries; + + // /proc/net/dev prints 2 lines of headers followed by a line of metrics for + // each network interface. + for (unsigned i = 2; i < lines.size(); i++) { + // Ignore empty lines. + if (lines[i].empty()) { + continue; + } + + std::vector<std::string> values = + absl::StrSplit(lines[i], ' ', absl::SkipWhitespace()); + + // Interface name + 16 values. + if (values.size() != 17) { + return PosixError(EINVAL, "invalid line: " + lines[i]); + } + + DeviceEntry entry; + entry.name = values[0]; + // Skip the interface name and read only the values. + for (unsigned j = 1; j < 17; j++) { + uint64_t num; + if (!absl::SimpleAtoi(values[j], &num)) { + return PosixError(EINVAL, "invalid value: " + values[j]); + } + entry.stats[j - 1] = num; + } + + entries.push_back(entry); + } + + return entries; +} + +// TEST(ProcNetDev, Format) tests that /proc/net/dev is parsable and +// contains at least one entry. +TEST(ProcNetDev, Format) { + auto dev = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/dev")); + auto entries = ASSERT_NO_ERRNO_AND_VALUE(GetDeviceMetricsFromProc(dev)); + + EXPECT_GT(entries.size(), 0); +} + PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp, const std::string& type, const std::string& item) { @@ -275,7 +353,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { EXPECT_EQ(oldNoPorts, newNoPorts - 1); } -TEST(ProcNetSnmp, UdpIn) { +TEST(ProcNetSnmp, UdpIn_NoRandomSave) { // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. const DisableSave ds; diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 66db0acaa..a63067586 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -106,7 +106,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() { std::vector<UnixEntry> entries; std::vector<std::string> lines = absl::StrSplit(content, '\n'); std::cerr << "<contents of /proc/net/unix>" << std::endl; - for (std::string line : lines) { + for (const std::string& line : lines) { // Emit the proc entry to the test output to provide context for the test // results. std::cerr << line << std::endl; @@ -374,7 +374,7 @@ TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) { // corresponding entries, as they don't have an address yet. if (IsRunningOnGvisor()) { ASSERT_EQ(entries.size(), 2); - for (auto e : entries) { + for (const auto& e : entries) { ASSERT_EQ(e.state, SS_DISCONNECTING); } } @@ -403,7 +403,7 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { // corresponding entries, as they don't have an address yet. if (IsRunningOnGvisor()) { ASSERT_EQ(entries.size(), 2); - for (auto e : entries) { + for (const auto& e : entries) { ASSERT_EQ(e.state, SS_DISCONNECTING); } } diff --git a/test/syscalls/linux/proc_pid_oomscore.cc b/test/syscalls/linux/proc_pid_oomscore.cc new file mode 100644 index 000000000..707821a3f --- /dev/null +++ b/test/syscalls/linux/proc_pid_oomscore.cc @@ -0,0 +1,72 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> + +#include <exception> +#include <iostream> +#include <string> + +#include "test/util/fs_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +PosixErrorOr<int> ReadProcNumber(std::string path) { + ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents(path)); + EXPECT_EQ(contents[contents.length() - 1], '\n'); + + int num; + if (!absl::SimpleAtoi(contents, &num)) { + return PosixError(EINVAL, "invalid value: " + contents); + } + + return num; +} + +TEST(ProcPidOomscoreTest, BasicRead) { + auto const oom_score = + ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score")); + EXPECT_LE(oom_score, 1000); + EXPECT_GE(oom_score, -1000); +} + +TEST(ProcPidOomscoreAdjTest, BasicRead) { + auto const oom_score = + ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj")); + + // oom_score_adj defaults to 0. + EXPECT_EQ(oom_score, 0); +} + +TEST(ProcPidOomscoreAdjTest, BasicWrite) { + constexpr int test_value = 7; + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/oom_score_adj", O_WRONLY)); + ASSERT_THAT( + RetryEINTR(write)(fd.get(), std::to_string(test_value).c_str(), 1), + SyscallSucceeds()); + + auto const oom_score = + ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj")); + EXPECT_EQ(oom_score, test_value); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc index 7f2e8f203..9fb1b3a2c 100644 --- a/test/syscalls/linux/proc_pid_smaps.cc +++ b/test/syscalls/linux/proc_pid_smaps.cc @@ -173,7 +173,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps( return; } unknown_fields.insert(std::string(key)); - std::cerr << "skipping unknown smaps field " << key; + std::cerr << "skipping unknown smaps field " << key << std::endl; }; auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty()); @@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps( // amount of whitespace). if (!entry) { std::cerr << "smaps line not considered a maps line: " - << maybe_maps_entry.error_message(); + << maybe_maps_entry.error_message() << std::endl; return PosixError( EINVAL, absl::StrCat("smaps field line without preceding maps line: ", l)); diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index 4dd5cf27b..926690eb8 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -400,9 +400,11 @@ TEST(PtraceTest, GetRegSet) { // Read exactly the full register set. EXPECT_EQ(iov.iov_len, sizeof(regs)); -#ifdef __x86_64__ +#if defined(__x86_64__) // Child called kill(2), with SIGSTOP as arg 2. EXPECT_EQ(regs.rsi, SIGSTOP); +#elif defined(__aarch64__) + EXPECT_EQ(regs.regs[1], SIGSTOP); #endif // Suppress SIGSTOP and resume the child. @@ -752,15 +754,23 @@ TEST(PtraceTest, SyscallSucceeds()); EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80)) << "si_code = " << siginfo.si_code; -#ifdef __x86_64__ + { struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); +#if defined(__x86_64__) EXPECT_TRUE(regs.orig_rax == SYS_vfork || regs.orig_rax == SYS_clone) << "orig_rax = " << regs.orig_rax; EXPECT_EQ(grandchild_pid, regs.rax); - } +#elif defined(__aarch64__) + EXPECT_TRUE(regs.regs[8] == SYS_clone) << "regs[8] = " << regs.regs[8]; + EXPECT_EQ(grandchild_pid, regs.regs[0]); #endif // defined(__x86_64__) + } // After this point, the child will be making wait4 syscalls that will be // interrupted by saving, so saving is not permitted. Note that this is @@ -805,14 +815,21 @@ TEST(PtraceTest, SyscallSucceedsWithValue(child_pid)); EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80)) << " status " << status; -#ifdef __x86_64__ { struct user_regs_struct regs = {}; - ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, ®s), SyscallSucceeds()); + struct iovec iov; + iov.iov_base = ®s; + iov.iov_len = sizeof(regs); + EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); +#if defined(__x86_64__) EXPECT_EQ(SYS_wait4, regs.orig_rax); EXPECT_EQ(grandchild_pid, regs.rax); - } +#elif defined(__aarch64__) + EXPECT_EQ(SYS_wait4, regs.regs[8]); + EXPECT_EQ(grandchild_pid, regs.regs[0]); #endif // defined(__x86_64__) + } // Detach from the child and wait for it to exit. ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds()); @@ -1188,7 +1205,7 @@ TEST(PtraceTest, SeizeSetOptions) { // gVisor is not susceptible to this race because // kernel.Task.waitCollectTraceeStopLocked() checks specifically for an // active ptraceStop, which is not initiated if SIGKILL is pending. - std::cout << "Observed syscall-exit after SIGKILL"; + std::cout << "Observed syscall-exit after SIGKILL" << std::endl; ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceedsWithValue(child_pid)); } @@ -1208,5 +1225,5 @@ int main(int argc, char** argv) { gvisor::testing::RunExecveChild(); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index dafe64d20..b8a0159ba 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -1126,7 +1126,7 @@ TEST_F(PtyTest, SwitchTwiceMultiline) { std::string kExpected = "GO\nBLUE\n!"; // Write each line. - for (std::string input : kInputs) { + for (const std::string& input : kInputs) { ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()), SyscallSucceedsWithValue(input.size())); } diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc index b48fe540d..e69794910 100644 --- a/test/syscalls/linux/pwrite64.cc +++ b/test/syscalls/linux/pwrite64.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <linux/unistd.h> #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> @@ -27,14 +28,7 @@ namespace testing { namespace { -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is not incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. +// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary. class Pwrite64 : public ::testing::Test { void SetUp() override { name_ = NewTempAbsPath(); @@ -72,6 +66,17 @@ TEST_F(Pwrite64, InvalidArgs) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(Pwrite64, Overflow) { + int fd; + ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds()); + constexpr int64_t kBufSize = 1024; + std::vector<char> buf(kBufSize); + std::fill(buf.begin(), buf.end(), 'a'); + EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0x7fffffffffffffffull), + SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(close(fd), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD index ed488dbc2..853258b04 100644 --- a/test/syscalls/linux/rseq/BUILD +++ b/test/syscalls/linux/rseq/BUILD @@ -1,7 +1,7 @@ # This package contains a standalone rseq test binary. This binary must not # depend on libc, which might use rseq itself. -load("//tools:defs.bzl", "cc_flags_supplier", "cc_library", "cc_toolchain") +load("//tools:defs.bzl", "cc_flags_supplier", "cc_library", "cc_toolchain", "select_arch") package(licenses = ["notice"]) @@ -9,32 +9,35 @@ genrule( name = "rseq_binary", srcs = [ "critical.h", - "critical.S", + "critical_amd64.S", + "critical_arm64.S", "rseq.cc", "syscalls.h", - "start.S", + "start_amd64.S", + "start_arm64.S", "test.h", "types.h", "uapi.h", ], outs = ["rseq"], - cmd = " ".join([ - "$(CC)", - "$(CC_FLAGS) ", - "-I.", - "-Wall", - "-Werror", - "-O2", - "-std=c++17", - "-static", - "-nostdlib", - "-ffreestanding", - "-o", - "$(location rseq)", - "$(location critical.S)", - "$(location rseq.cc)", - "$(location start.S)", - ]), + cmd = "$(CC) " + + "$(CC_FLAGS) " + + "-I. " + + "-Wall " + + "-Werror " + + "-O2 " + + "-std=c++17 " + + "-static " + + "-nostdlib " + + "-ffreestanding " + + "-o " + + "$(location rseq) " + + select_arch( + amd64 = "$(location critical_amd64.S) $(location start_amd64.S) ", + arm64 = "$(location critical_arm64.S) $(location start_arm64.S) ", + no_match_error = "unsupported architecture", + ) + + "$(location rseq.cc)", toolchains = [ cc_toolchain, ":no_pie_cc_flags", diff --git a/test/syscalls/linux/rseq/critical.S b/test/syscalls/linux/rseq/critical_amd64.S index 8c0687e6d..8c0687e6d 100644 --- a/test/syscalls/linux/rseq/critical.S +++ b/test/syscalls/linux/rseq/critical_amd64.S diff --git a/test/syscalls/linux/rseq/critical_arm64.S b/test/syscalls/linux/rseq/critical_arm64.S new file mode 100644 index 000000000..bfe7e8307 --- /dev/null +++ b/test/syscalls/linux/rseq/critical_arm64.S @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Restartable sequences critical sections. + +// Loops continuously until aborted. +// +// void rseq_loop(struct rseq* r, struct rseq_cs* cs) + + .text + .globl rseq_loop + .type rseq_loop, @function + +rseq_loop: + b begin + + // Abort block before the critical section. + // Abort signature. + .byte 0x90, 0x90, 0x90, 0x90 + .globl rseq_loop_early_abort +rseq_loop_early_abort: + ret + +begin: + // r->rseq_cs = cs + str x1, [x0, #8] + + // N.B. rseq_cs will be cleared by any preempt, even outside the critical + // section. Thus it must be set in or immediately before the critical section + // to ensure it is not cleared before the section begins. + .globl rseq_loop_start +rseq_loop_start: + b rseq_loop_start + + // "Pre-commit": extra instructions inside the critical section. These are + // used as the abort point in TestAbortPreCommit, which is not valid. + .globl rseq_loop_pre_commit +rseq_loop_pre_commit: + // Extra abort signature + nop for TestAbortPostCommit. + .byte 0x90, 0x90, 0x90, 0x90 + nop + + // "Post-commit": never reached in this case. + .globl rseq_loop_post_commit +rseq_loop_post_commit: + + // Abort signature. + .byte 0x90, 0x90, 0x90, 0x90 + + .globl rseq_loop_abort +rseq_loop_abort: + ret + + .size rseq_loop,.-rseq_loop + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/start.S b/test/syscalls/linux/rseq/start_amd64.S index b9611b276..b9611b276 100644 --- a/test/syscalls/linux/rseq/start.S +++ b/test/syscalls/linux/rseq/start_amd64.S diff --git a/test/syscalls/linux/rseq/start_arm64.S b/test/syscalls/linux/rseq/start_arm64.S new file mode 100644 index 000000000..693c1c6eb --- /dev/null +++ b/test/syscalls/linux/rseq/start_arm64.S @@ -0,0 +1,45 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + + .text + .align 4 + .type _start,@function + .globl _start + +_start: + mov x29, sp + bl __init + wfi + + .size _start,.-_start + .section .note.GNU-stack,"",@progbits + + .text + .globl raw_syscall + .type raw_syscall, @function + +raw_syscall: + mov x8,x0 // syscall # + mov x0,x1 // arg0 + mov x1,x2 // arg1 + mov x2,x3 // arg2 + mov x3,x4 // arg3 + mov x4,x5 // arg4 + mov x5,x6 // arg5 + svc #0 + ret + + .size raw_syscall,.-raw_syscall + .section .note.GNU-stack,"",@progbits diff --git a/test/syscalls/linux/rseq/syscalls.h b/test/syscalls/linux/rseq/syscalls.h index e5299c188..c4118e6c5 100644 --- a/test/syscalls/linux/rseq/syscalls.h +++ b/test/syscalls/linux/rseq/syscalls.h @@ -17,10 +17,13 @@ #include "test/syscalls/linux/rseq/types.h" -#ifdef __x86_64__ // Syscall numbers. +#if defined(__x86_64__) constexpr int kGetpid = 39; constexpr int kExitGroup = 231; +#elif defined(__aarch64__) +constexpr int kGetpid = 172; +constexpr int kExitGroup = 94; #else #error "Unknown architecture" #endif diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h index e3ff0579a..d3e60d0a4 100644 --- a/test/syscalls/linux/rseq/uapi.h +++ b/test/syscalls/linux/rseq/uapi.h @@ -15,37 +15,34 @@ #ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ #define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_ -// User-kernel ABI for restartable sequences. +#include <stdint.h> -// Standard types. -// -// N.B. This header will be included in targets that do have the standard -// library, so we can't shadow the standard type names. -using __u32 = __UINT32_TYPE__; -using __u64 = __UINT64_TYPE__; +// User-kernel ABI for restartable sequences. -#ifdef __x86_64__ // Syscall numbers. +#if defined(__x86_64__) constexpr int kRseqSyscall = 334; +#elif defined(__aarch64__) +constexpr int kRseqSyscall = 293; #else #error "Unknown architecture" #endif // __x86_64__ struct rseq_cs { - __u32 version; - __u32 flags; - __u64 start_ip; - __u64 post_commit_offset; - __u64 abort_ip; -} __attribute__((aligned(4 * sizeof(__u64)))); + uint32_t version; + uint32_t flags; + uint64_t start_ip; + uint64_t post_commit_offset; + uint64_t abort_ip; +} __attribute__((aligned(4 * sizeof(uint64_t)))); // N.B. alignment is enforced by the kernel. struct rseq { - __u32 cpu_id_start; - __u32 cpu_id; + uint32_t cpu_id_start; + uint32_t cpu_id; struct rseq_cs* rseq_cs; - __u32 flags; -} __attribute__((aligned(4 * sizeof(__u64)))); + uint32_t flags; +} __attribute__((aligned(4 * sizeof(uint64_t)))); constexpr int kRseqFlagUnregister = 1 << 0; diff --git a/test/syscalls/linux/rtsignal.cc b/test/syscalls/linux/rtsignal.cc index 81d193ffd..ed27e2566 100644 --- a/test/syscalls/linux/rtsignal.cc +++ b/test/syscalls/linux/rtsignal.cc @@ -167,6 +167,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc index 2c947feb7..ce88d90dd 100644 --- a/test/syscalls/linux/seccomp.cc +++ b/test/syscalls/linux/seccomp.cc @@ -53,7 +53,7 @@ namespace { constexpr uint32_t kFilteredSyscall = SYS_vserver; #elif __aarch64__ // Use the last of arch_specific_syscalls which are not implemented on arm64. -constexpr uint32_t kFilteredSyscall = SYS_arch_specific_syscall + 15; +constexpr uint32_t kFilteredSyscall = __NR_arch_specific_syscall + 15; #endif // Applies a seccomp-bpf filter that returns `filtered_result` for @@ -70,20 +70,27 @@ void ApplySeccompFilter(uint32_t sysno, uint32_t filtered_result, MaybeSave(); struct sock_filter filter[] = { - // A = seccomp_data.arch - BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4), - // if (A != AUDIT_ARCH_X86_64) goto kill - BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4), - // A = seccomp_data.nr - BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0), - // if (A != sysno) goto allow - BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1), - // return filtered_result - BPF_STMT(BPF_RET | BPF_K, filtered_result), - // allow: return SECCOMP_RET_ALLOW - BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW), - // kill: return SECCOMP_RET_KILL - BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL), + // A = seccomp_data.arch + BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4), +#if defined(__x86_64__) + // if (A != AUDIT_ARCH_X86_64) goto kill + BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4), +#elif defined(__aarch64__) + // if (A != AUDIT_ARCH_AARCH64) goto kill + BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_AARCH64, 0, 4), +#else +#error "Unknown architecture" +#endif + // A = seccomp_data.nr + BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0), + // if (A != sysno) goto allow + BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1), + // return filtered_result + BPF_STMT(BPF_RET | BPF_K, filtered_result), + // allow: return SECCOMP_RET_ALLOW + BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW), + // kill: return SECCOMP_RET_KILL + BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL), }; struct sock_fprog prog; prog.len = ABSL_ARRAYSIZE(filter); @@ -179,9 +186,12 @@ TEST(SeccompTest, RetTrapCausesSIGSYS) { TEST_CHECK(info->si_errno == kTrapValue); TEST_CHECK(info->si_call_addr != nullptr); TEST_CHECK(info->si_syscall == kFilteredSyscall); -#ifdef __x86_64__ +#if defined(__x86_64__) TEST_CHECK(info->si_arch == AUDIT_ARCH_X86_64); TEST_CHECK(uc->uc_mcontext.gregs[REG_RAX] == kFilteredSyscall); +#elif defined(__aarch64__) + TEST_CHECK(info->si_arch == AUDIT_ARCH_AARCH64); + TEST_CHECK(uc->uc_mcontext.regs[8] == kFilteredSyscall); #endif // defined(__x86_64__) _exit(0); }); @@ -411,5 +421,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 580ab5193..64123e904 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/unistd.h> #include <sys/eventfd.h> #include <sys/sendfile.h> #include <unistd.h> @@ -70,6 +71,28 @@ TEST(SendFileTest, InvalidOffset) { SyscallFailsWithErrno(EINVAL)); } +int memfd_create(const std::string& name, unsigned int flags) { + return syscall(__NR_memfd_create, name.c_str(), flags); +} + +TEST(SendFileTest, Overflow) { + // Create input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Open the output file. + int fd; + EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds()); + const FileDescriptor outf(fd); + + // out_offset + kSize overflows INT64_MAX. + loff_t out_offset = 0x7ffffffffffffffeull; + constexpr int kSize = 3; + EXPECT_THAT(sendfile(outf.get(), inf.get(), &out_offset, kSize), + SyscallFailsWithErrno(EINVAL)); +} + TEST(SendFileTest, SendTrivially) { // Create temp files. constexpr char kData[] = "To be, or not to be, that is the question:"; @@ -530,6 +553,34 @@ TEST(SendFileTest, SendToSpecialFile) { SyscallSucceedsWithValue(kSize & (~7))); } +TEST(SendFileTest, SendFileToPipe) { + // Create temp file. + constexpr char kData[] = "<insert-quote-here>"; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Create a pipe for sending to a pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Expect to read up to the given size. + std::vector<char> buf(kDataSize); + ScopedThread t([&]() { + absl::SleepFor(absl::Milliseconds(100)); + ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kDataSize)); + }); + + // Send with twice the size of the file, which should hit EOF. + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize * 2), + SyscallSucceedsWithValue(kDataSize)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc index 8f7ee4163..c101fe9d2 100644 --- a/test/syscalls/linux/sendfile_socket.cc +++ b/test/syscalls/linux/sendfile_socket.cc @@ -23,6 +23,7 @@ #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" @@ -35,61 +36,39 @@ namespace { class SendFileTest : public ::testing::TestWithParam<int> { protected: - PosixErrorOr<std::tuple<int, int>> Sockets() { + PosixErrorOr<std::unique_ptr<SocketPair>> Sockets(int type) { // Bind a server socket. int family = GetParam(); - struct sockaddr server_addr = {}; switch (family) { case AF_INET: { - struct sockaddr_in* server_addr_in = - reinterpret_cast<struct sockaddr_in*>(&server_addr); - server_addr_in->sin_family = family; - server_addr_in->sin_addr.s_addr = INADDR_ANY; - break; + if (type == SOCK_STREAM) { + return SocketPairKind{ + "TCP", AF_INET, type, 0, + TCPAcceptBindSocketPairCreator(AF_INET, type, 0, false)} + .Create(); + } else { + return SocketPairKind{ + "UDP", AF_INET, type, 0, + UDPBidirectionalBindSocketPairCreator(AF_INET, type, 0, false)} + .Create(); + } } case AF_UNIX: { - struct sockaddr_un* server_addr_un = - reinterpret_cast<struct sockaddr_un*>(&server_addr); - server_addr_un->sun_family = family; - server_addr_un->sun_path[0] = '\0'; - break; + if (type == SOCK_STREAM) { + return SocketPairKind{ + "UNIX", AF_UNIX, type, 0, + FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)} + .Create(); + } else { + return SocketPairKind{ + "UNIX", AF_UNIX, type, 0, + FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)} + .Create(); + } } default: return PosixError(EINVAL); } - int server = socket(family, SOCK_STREAM, 0); - if (bind(server, &server_addr, sizeof(server_addr)) < 0) { - return PosixError(errno); - } - if (listen(server, 1) < 0) { - close(server); - return PosixError(errno); - } - - // Fetch the address; both are anonymous. - socklen_t length = sizeof(server_addr); - if (getsockname(server, &server_addr, &length) < 0) { - close(server); - return PosixError(errno); - } - - // Connect the client. - int client = socket(family, SOCK_STREAM, 0); - if (connect(client, &server_addr, length) < 0) { - close(server); - close(client); - return PosixError(errno); - } - - // Accept on the server. - int server_client = accept(server, nullptr, 0); - if (server_client < 0) { - close(server); - close(client); - return PosixError(errno); - } - close(server); - return std::make_tuple(client, server_client); } }; @@ -106,9 +85,7 @@ TEST_P(SendFileTest, SendMultiple) { const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); // Create sockets. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor server(std::get<0>(fds)); - FileDescriptor client(std::get<1>(fds)); // non-const, reset is used. + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM)); // Thread that reads data from socket and dumps to a file. ScopedThread th([&] { @@ -118,7 +95,7 @@ TEST_P(SendFileTest, SendMultiple) { // Read until socket is closed. char buf[10240]; for (int cnt = 0;; cnt++) { - int r = RetryEINTR(read)(server.get(), buf, sizeof(buf)); + int r = RetryEINTR(read)(socks->first_fd(), buf, sizeof(buf)); // We cannot afford to save on every read() call. if (cnt % 1000 == 0) { ASSERT_THAT(r, SyscallSucceeds()); @@ -149,10 +126,10 @@ TEST_P(SendFileTest, SendMultiple) { for (size_t sent = 0; sent < data.size(); cnt++) { const size_t remain = data.size() - sent; std::cout << "sendfile, size=" << data.size() << ", sent=" << sent - << ", remain=" << remain; + << ", remain=" << remain << std::endl; // Send data and verify that sendfile returns the correct value. - int res = sendfile(client.get(), inf.get(), nullptr, remain); + int res = sendfile(socks->second_fd(), inf.get(), nullptr, remain); // We cannot afford to save on every sendfile() call. if (cnt % 120 == 0) { MaybeSave(); @@ -169,7 +146,7 @@ TEST_P(SendFileTest, SendMultiple) { } // Close socket to stop thread. - client.reset(); + close(socks->release_second_fd()); th.Join(); // Verify that the output file has the correct data. @@ -183,9 +160,7 @@ TEST_P(SendFileTest, SendMultiple) { TEST_P(SendFileTest, Shutdown) { // Create a socket. - std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); - const FileDescriptor client(std::get<0>(fds)); - FileDescriptor server(std::get<1>(fds)); // non-const, reset below. + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM)); // If this is a TCP socket, then turn off linger. if (GetParam() == AF_INET) { @@ -193,7 +168,7 @@ TEST_P(SendFileTest, Shutdown) { sl.l_onoff = 1; sl.l_linger = 0; ASSERT_THAT( - setsockopt(server.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + setsockopt(socks->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), SyscallSucceeds()); } @@ -212,12 +187,12 @@ TEST_P(SendFileTest, Shutdown) { ScopedThread t([&]() { size_t done = 0; while (done < data.size()) { - int n = RetryEINTR(read)(server.get(), data.data(), data.size()); + int n = RetryEINTR(read)(socks->first_fd(), data.data(), data.size()); ASSERT_THAT(n, SyscallSucceeds()); done += n; } // Close the server side socket. - server.reset(); + close(socks->release_first_fd()); }); // Continuously stream from the file to the socket. Note we do not assert @@ -225,7 +200,7 @@ TEST_P(SendFileTest, Shutdown) { // data is written. Eventually, we should get a connection reset error. while (1) { off_t offset = 0; // Always read from the start. - int n = sendfile(client.get(), inf.get(), &offset, data.size()); + int n = sendfile(socks->second_fd(), inf.get(), &offset, data.size()); EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET), SyscallFailsWithErrno(EPIPE), SyscallSucceeds())); if (n <= 0) { @@ -234,6 +209,20 @@ TEST_P(SendFileTest, Shutdown) { } } +TEST_P(SendFileTest, SendpageFromEmptyFileToUDP) { + auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_DGRAM)); + + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + + // The value to the count argument has to be so that it is impossible to + // allocate a buffer of this size. In Linux, sendfile transfer at most + // 0x7ffff000 (MAX_RW_COUNT) bytes. + EXPECT_THAT(sendfile(socks->first_fd(), fd.get(), 0x0, 0x8000000000004), + SyscallSucceedsWithValue(0)); +} + INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest, ::testing::Values(AF_UNIX, AF_INET)); diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigiret.cc index 4deb1ae95..6227774a4 100644 --- a/test/syscalls/linux/sigiret.cc +++ b/test/syscalls/linux/sigiret.cc @@ -132,6 +132,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc index 95be4b66c..389e5fca2 100644 --- a/test/syscalls/linux/signalfd.cc +++ b/test/syscalls/linux/signalfd.cc @@ -369,5 +369,5 @@ int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc index 7db57d968..b2fcedd62 100644 --- a/test/syscalls/linux/sigstop.cc +++ b/test/syscalls/linux/sigstop.cc @@ -147,5 +147,5 @@ int main(int argc, char** argv) { return 1; } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc index 1e5bf5942..4f8afff15 100644 --- a/test/syscalls/linux/sigtimedwait.cc +++ b/test/syscalls/linux/sigtimedwait.cc @@ -319,6 +319,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index e8f24a59e..f7d6139f1 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -447,6 +447,60 @@ TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) { SyscallFailsWithErrno(EAGAIN)); } +TEST_P(AllSocketPairTest, SendTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + timeval actual_tv = {.tv_sec = -1, .tv_usec = -1}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv_usec, 0); +} + +TEST_P(AllSocketPairTest, SetGetSendTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + timeval tv = {.tv_sec = 89, .tv_usec = 42000}; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + timeval actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 89); + EXPECT_EQ(actual_tv.tv_usec, 42000); +} + +TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval_with_extra { + struct timeval tv; + int64_t extra_data; + } ABSL_ATTRIBUTE_PACKED; + + timeval_with_extra tv_extra = { + .tv = {.tv_sec = 0, .tv_usec = 123000}, + }; + + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &tv_extra, sizeof(tv_extra)), + SyscallSucceeds()); + + timeval_with_extra actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv.tv_usec, 123000); +} + TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -491,18 +545,36 @@ TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) { ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf))); } -TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) { +TEST_P(AllSocketPairTest, RecvTimeoutDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - struct timeval tv { - .tv_sec = 0, .tv_usec = 35 - }; + timeval actual_tv = {.tv_sec = -1, .tv_usec = -1}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv_usec, 0); +} + +TEST_P(AllSocketPairTest, SetGetRecvTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + timeval tv = {.tv_sec = 123, .tv_usec = 456000}; EXPECT_THAT( setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), SyscallSucceeds()); + + timeval actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 123); + EXPECT_EQ(actual_tv.tv_usec, 456000); } -TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) { +TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct timeval_with_extra { @@ -510,13 +582,21 @@ TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) { int64_t extra_data; } ABSL_ATTRIBUTE_PACKED; - timeval_with_extra tv_extra; - tv_extra.tv.tv_sec = 0; - tv_extra.tv.tv_usec = 25; + timeval_with_extra tv_extra = { + .tv = {.tv_sec = 0, .tv_usec = 432000}, + }; EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv_extra, sizeof(tv_extra)), SyscallSucceeds()); + + timeval_with_extra actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv.tv_usec, 432000); } TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) { diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index b24618a88..9400ffaeb 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -234,7 +234,7 @@ TEST_P(DualStackSocketTest, AddressOperations) { } } -// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny. +// TODO(gvisor.dev/issue/1556): uncomment V4MappedAny. INSTANTIATE_TEST_SUITE_P( All, DualStackSocketTest, ::testing::Combine( @@ -319,17 +319,57 @@ TEST_P(SocketInetLoopbackTest, TCPListenUnbound) { tcpSimpleConnectTest(listener, connector, false); } -TEST_P(SocketInetLoopbackTest, TCPListenClose) { +TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) { + const auto& param = GetParam(); + + const TestAddress& listener = param.listener; + const TestAddress& connector = param.connector; + + constexpr int kBacklog = 5; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + for (int i = 0; i < kBacklog; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + } + for (int i = 0; i < kBacklog; i++) { + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds()); + } +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; - constexpr int kAcceptCount = 32; - constexpr int kBacklog = kAcceptCount * 2; - constexpr int kFDs = 128; - constexpr int kThreadCount = 4; - constexpr int kFDsPerThread = kFDs / kThreadCount; + constexpr int kBacklog = 2; + constexpr int kFDs = kBacklog + 1; // Create the listening socket. FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( @@ -348,39 +388,167 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - DisableSave ds; // Too many system calls. sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - FileDescriptor clients[kFDs]; - std::unique_ptr<ScopedThread> threads[kThreadCount]; + + // Shutdown the write of the listener, expect to not have any effect. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds()); + for (int i = 0; i < kFDs; i++) { - clients[i] = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds()); } - for (int i = 0; i < kThreadCount; i++) { - threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr, - &clients, i]() { - for (int j = 0; j < kFDsPerThread; j++) { - int k = i * kFDsPerThread + j; - int ret = - connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - } - } - }); + + // Shutdown the read of the listener, expect to fail subsequent + // server accepts, binds and client connects. + ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds()); + + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EINVAL)); + + // Check that shutdown did not release the port. + FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Check that subsequent connection attempts receive a RST. + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallFailsWithErrno(ECONNREFUSED)); } - for (int i = 0; i < kThreadCount; i++) { - threads[i]->Join(); +} + +TEST_P(SocketInetLoopbackTest, TCPListenClose) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kAcceptCount = 2; + constexpr int kBacklog = kAcceptCount + 2; + constexpr int kFDs = kBacklog * 3; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + for (int i = 0; i < kFDs; i++) { + auto client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + clients.push_back(std::move(client)); } for (int i = 0; i < kAcceptCount; i++) { auto accepted = ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); } - // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked - // before function end. - // ds.reset(); +} + +void TestListenWhileConnect(const TestParam& param, + void (*stopListen)(FileDescriptor&)) { + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + constexpr int kBacklog = 2; + constexpr int kClients = kBacklog + 1; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + std::vector<FileDescriptor> clients; + for (int i = 0; i < kClients; i++) { + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + clients.push_back(std::move(client)); + } + } + + stopListen(listen_fd); + + for (auto& client : clients) { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = client.get(), + .events = POLLIN, + }; + // When the listening socket is closed, then we expect the remote to reset + // the connection. + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR); + char c; + // Subsequent read can fail with: + // ECONNRESET: If the client connection was established and was reset by the + // remote. + // ECONNREFUSED: If the client connection failed to be established. + ASSERT_THAT(read(client.get(), &c, sizeof(c)), + AnyOf(SyscallFailsWithErrno(ECONNRESET), + SyscallFailsWithErrno(ECONNREFUSED))); + } +} + +TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(close(f.release()), SyscallSucceeds()); + }); +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { + TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds()); + }); } TEST_P(SocketInetLoopbackTest, TCPbacklog) { @@ -605,15 +773,23 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { &conn_addrlen), SyscallSucceeds()); - constexpr int kTCPLingerTimeout = 5; - EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, - &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), - SyscallSucceedsWithValue(0)); + // Disable cooperative saves after this point as TCP timers are not restored + // across a S/R. + { + DisableSave ds; + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); - // close the connecting FD to trigger FIN_WAIT2 on the connected fd. - conn_fd.reset(); + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + // ds going out of scope will Re-enable S/R's since at this point the timer + // must have fired and cleaned up the endpoint. + } // Now bind and connect a new socket and verify that we can immediately // rebind the address bound by the conn_fd as it never entered TIME_WAIT. @@ -1082,6 +1258,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { if (connects_received >= kConnectAttempts) { // Another thread have shutdown our read side causing the // accept to fail. + ASSERT_EQ(errno, EINVAL); break; } ASSERT_NO_ERRNO(fd); @@ -1149,7 +1326,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -1262,7 +1439,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -2138,8 +2315,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(connect(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), bound_addr_len), + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), + reinterpret_cast<sockaddr*>(&bound_addr), + bound_addr_len), SyscallSucceeds()); // Get the ephemeral port. @@ -2204,7 +2382,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)), SyscallSucceeds()); - std::cout << portreuse1 << " " << portreuse2; + std::cout << portreuse1 << " " << portreuse2 << std::endl; int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen); // Verify that two sockets can be bound to the same port only if diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 53290bed7..1c533fdf2 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -14,6 +14,7 @@ #include "test/syscalls/linux/socket_ip_udp_generic.h" +#include <errno.h> #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -209,46 +210,6 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } -// Ensure that Receiving TOS is off by default. -TEST_P(UDPSocketPairTest, RecvTosDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -// Test that setting and getting IP_RECVTOS works as expected. -TEST_P(UDPSocketPairTest, SetRecvTos) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -357,5 +318,141 @@ TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) { EXPECT_EQ(get, kSockOptOn); } +// Test getsockopt for a socket which is not set with IP_PKTINFO option. +TEST_P(UDPSocketPairTest, IPPKTINFODefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_IP, IP_PKTINFO, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test setsockopt and getsockopt for a socket with IP_PKTINFO option. +TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int level = SOL_IP; + int type = IP_PKTINFO; + + // Check getsockopt before IP_PKTINFO is set. + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOn); + EXPECT_EQ(get_len, sizeof(get)); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOff); + EXPECT_EQ(get_len, sizeof(get)); +} + +// Holds TOS or TClass information for IPv4 or IPv6 respectively. +struct RecvTosOption { + int level; + int option; +}; + +RecvTosOption GetRecvTosOption(int domain) { + TEST_CHECK(domain == AF_INET || domain == AF_INET6); + RecvTosOption opt; + switch (domain) { + case AF_INET: + opt.level = IPPROTO_IP; + opt.option = IP_RECVTOS; + break; + case AF_INET6: + opt.level = IPPROTO_IPV6; + opt.option = IPV6_RECVTCLASS; + break; + } + return opt; +} + +// Ensure that Receiving TOS or TCLASS is off by default. +TEST_P(UDPSocketPairTest, RecvTosDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as +// expected. +TEST_P(UDPSocketPairTest, SetRecvTos) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + +// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this +// mirrors behavior in linux. +TEST_P(UDPSocketPairTest, TOSRecvMismatch) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(AF_INET); + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); +} + +// Test that an IPv4 socket does not support the IPv6 TClass option. +TEST_P(UDPSocketPairTest, TClassRecvMismatch) { + // This should only test AF_INET sockets for the mismatch behavior. + SKIP_IF(GetParam().domain != AF_INET); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS, + &get, &get_len), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 990ccf23c..bc4b07a62 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -15,6 +15,7 @@ #include "test/syscalls/linux/socket_ipv4_udp_unbound.h" #include <arpa/inet.h> +#include <net/if.h> #include <sys/ioctl.h> #include <sys/socket.h> #include <sys/un.h> @@ -2128,5 +2129,88 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { SyscallSucceedsWithValue(kMessageSize)); } +// Test that socket will receive packet info control message. +TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF((IsRunningWithHostinet())); + + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto sender_addr = V4Loopback(); + int level = SOL_IP; + int type = IP_PKTINFO; + + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + auto receiver_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&sender_addr.addr)->sin_port; + ASSERT_THAT( + connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + + // Allow socket to receive control message. + ASSERT_THAT( + setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + msghdr sent_msg = {}; + iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = sent_data; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + sent_msg.msg_flags = 0; + + ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + msghdr received_msg = {}; + iovec received_iov = {}; + char received_data[kDataLength]; + char received_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {}; + size_t cmsg_data_len = sizeof(in_pktinfo); + received_iov.iov_base = received_data; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + received_msg.msg_control = received_cmsg_buf; + + ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, level); + EXPECT_EQ(cmsg->cmsg_type, type); + + // Get loopback index. + ifreq ifr = {}; + absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo"); + ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds()); + ASSERT_NE(ifr.ifr_ifindex, 0); + + // Check the data + in_pktinfo received_pktinfo = {}; + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex); + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK)); +} } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index 40e673625..d690d9564 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -45,37 +45,31 @@ void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { got_if_infos_ = false; // Get interface list. - std::vector<std::string> if_names; ASSERT_NO_ERRNO(if_helper_.Load()); - if_names = if_helper_.InterfaceList(AF_INET); + std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET); if (if_names.size() != 2) { return; } // Figure out which interface is where. - int lo = 0, eth = 1; - if (if_names[lo] != "lo") { - lo = 1; - eth = 0; - } - - if (if_names[lo] != "lo") { - return; - } - - lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[lo])); - lo_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[lo]); - if (lo_if_addr_ == nullptr) { + std::string lo = if_names[0]; + std::string eth = if_names[1]; + if (lo != "lo") std::swap(lo, eth); + if (lo != "lo") return; + + lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo)); + auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo); + if (lo_if_addr == nullptr) { return; } - lo_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(lo_if_addr_)->sin_addr; + lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr); - eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[eth])); - eth_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[eth]); - if (eth_if_addr_ == nullptr) { + eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth)); + auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth); + if (eth_if_addr == nullptr) { return; } - eth_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(eth_if_addr_)->sin_addr; + eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr); got_if_infos_ = true; } @@ -242,7 +236,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the non-receiving socket to the unicast ethernet address. auto norecv_addr = rcv1_addr; reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr = - eth_if_sin_addr_; + eth_if_addr_.sin_addr; ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), norecv_addr.addr_len), SyscallSucceedsWithValue(0)); @@ -1028,7 +1022,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); ip_mreqn iface = {}; iface.imr_ifindex = lo_if_idx_; - iface.imr_address = eth_if_sin_addr_; + iface.imr_address = eth_if_addr_.sin_addr; ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, sizeof(iface)), SyscallSucceeds()); @@ -1058,7 +1052,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SKIP_IF(IsRunningOnGvisor()); // Verify the received source address. - EXPECT_EQ(eth_if_sin_addr_.s_addr, src_addr_in->sin_addr.s_addr); + EXPECT_EQ(eth_if_addr_.sin_addr.s_addr, src_addr_in->sin_addr.s_addr); } // Check that when we are bound to one interface we can set IP_MULTICAST_IF to @@ -1075,7 +1069,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Create sender and bind to eth interface. auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - ASSERT_THAT(bind(sender->get(), eth_if_addr_, sizeof(sockaddr_in)), + ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(ð_if_addr_), + sizeof(eth_if_addr_)), SyscallSucceeds()); // Run through all possible combinations of index and address for @@ -1085,9 +1080,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, struct in_addr imr_address; } test_data[] = { {lo_if_idx_, {}}, - {0, lo_if_sin_addr_}, - {lo_if_idx_, lo_if_sin_addr_}, - {lo_if_idx_, eth_if_sin_addr_}, + {0, lo_if_addr_.sin_addr}, + {lo_if_idx_, lo_if_addr_.sin_addr}, + {lo_if_idx_, eth_if_addr_.sin_addr}, }; for (auto t : test_data) { ip_mreqn iface = {}; diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h index bec2e96ee..10b90b1e0 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h @@ -36,10 +36,8 @@ class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest { // Interface infos. int lo_if_idx_; int eth_if_idx_; - sockaddr* lo_if_addr_; - sockaddr* eth_if_addr_; - in_addr lo_if_sin_addr_; - in_addr eth_if_sin_addr_; + sockaddr_in lo_if_addr_; + sockaddr_in eth_if_addr_; }; } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index e5aed1eec..fbe61c5a0 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -26,7 +26,7 @@ #include "gtest/gtest.h" #include "absl/strings/str_format.h" -#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" #include "test/syscalls/linux/socket_netlink_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/capability_util.h" @@ -118,24 +118,6 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { // TODO(mpratt): Check ifinfomsg contents and following attrs. } -PosixError DumpLinks( - const FileDescriptor& fd, uint32_t seq, - const std::function<void(const struct nlmsghdr* hdr)>& fn) { - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - }; - - struct request req = {}; - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; - req.hdr.nlmsg_seq = seq; - req.ifm.ifi_family = AF_UNSPEC; - - return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); -} - TEST(NetlinkRouteTest, GetLinkDump) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -152,7 +134,7 @@ TEST(NetlinkRouteTest, GetLinkDump) { const struct ifinfomsg* msg = reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); std::cout << "Found interface idx=" << msg->ifi_index - << ", type=" << std::hex << msg->ifi_type; + << ", type=" << std::hex << msg->ifi_type << std::endl; if (msg->ifi_type == ARPHRD_LOOPBACK) { loopbackFound = true; EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); @@ -161,37 +143,6 @@ TEST(NetlinkRouteTest, GetLinkDump) { EXPECT_TRUE(loopbackFound); } -struct Link { - int index; - std::string name; -}; - -PosixErrorOr<absl::optional<Link>> FindLoopbackLink() { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - absl::optional<Link> link; - RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { - if (hdr->nlmsg_type != RTM_NEWLINK || - hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { - return; - } - const struct ifinfomsg* msg = - reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); - if (msg->ifi_type == ARPHRD_LOOPBACK) { - const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); - if (rta == nullptr) { - // Ignore links that do not have a name. - return; - } - - link = Link(); - link->index = msg->ifi_index; - link->name = std::string(reinterpret_cast<const char*>(RTA_DATA(rta))); - } - })); - return link; -} - // CheckLinkMsg checks a netlink message against an expected link. void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK)); @@ -209,9 +160,7 @@ void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { } TEST(NetlinkRouteTest, GetLinkByIndex) { - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -227,13 +176,13 @@ TEST(NetlinkRouteTest, GetLinkByIndex) { req.hdr.nlmsg_flags = NLM_F_REQUEST; req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; - req.ifm.ifi_index = loopback_link->index; + req.ifm.ifi_index = loopback_link.index; bool found = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckLinkMsg(hdr, *loopback_link); + CheckLinkMsg(hdr, loopback_link); found = true; }, false)); @@ -241,9 +190,7 @@ TEST(NetlinkRouteTest, GetLinkByIndex) { } TEST(NetlinkRouteTest, GetLinkByName) { - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -262,8 +209,8 @@ TEST(NetlinkRouteTest, GetLinkByName) { req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; req.rtattr.rta_type = IFLA_IFNAME; - req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1); - strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname)); + req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1); + strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname)); req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); @@ -271,7 +218,7 @@ TEST(NetlinkRouteTest, GetLinkByName) { ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckLinkMsg(hdr, *loopback_link); + CheckLinkMsg(hdr, loopback_link); found = true; }, false)); @@ -523,9 +470,7 @@ TEST(NetlinkRouteTest, LookupAll) { TEST(NetlinkRouteTest, AddAddr) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - absl::optional<Link> loopback_link = - ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); - ASSERT_TRUE(loopback_link.has_value()); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); @@ -545,7 +490,7 @@ TEST(NetlinkRouteTest, AddAddr) { req.ifa.ifa_prefixlen = 24; req.ifa.ifa_flags = 0; req.ifa.ifa_scope = 0; - req.ifa.ifa_index = loopback_link->index; + req.ifa.ifa_index = loopback_link.index; req.rtattr.rta_type = IFA_LOCAL; req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); inet_pton(AF_INET, "10.0.0.1", &req.addr); diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc new file mode 100644 index 000000000..bde1dbb4d --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -0,0 +1,162 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_netlink_route_util.h" + +#include <linux/if.h> +#include <linux/netlink.h> +#include <linux/rtnetlink.h> + +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr uint32_t kSeq = 12345; + +} // namespace + +PosixError DumpLinks( + const FileDescriptor& fd, uint32_t seq, + const std::function<void(const struct nlmsghdr* hdr)>& fn) { + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + + return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); +} + +PosixErrorOr<std::vector<Link>> DumpLinks() { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + std::vector<Link> links; + RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + if (hdr->nlmsg_type != RTM_NEWLINK || + hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { + return; + } + const struct ifinfomsg* msg = + reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr)); + const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + if (rta == nullptr) { + // Ignore links that do not have a name. + return; + } + + links.emplace_back(); + links.back().index = msg->ifi_index; + links.back().type = msg->ifi_type; + links.back().name = + std::string(reinterpret_cast<const char*>(RTA_DATA(rta))); + })); + return links; +} + +PosixErrorOr<Link> LoopbackLink() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.type == ARPHRD_LOOPBACK) { + return link; + } + } + return PosixError(ENOENT, "loopback link not found"); +} + +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char pad[NLMSG_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + req.ifinfo.ifi_flags = flags; + req.ifinfo.ifi_change = change; + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFLA_ADDRESS; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h new file mode 100644 index 000000000..149c4a7f6 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -0,0 +1,55 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ + +#include <linux/netlink.h> +#include <linux/rtnetlink.h> + +#include <vector> + +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { + +struct Link { + int index; + int16_t type; + std::string name; +}; + +PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, + const std::function<void(const struct nlmsghdr* hdr)>& fn); + +PosixErrorOr<std::vector<Link>> DumpLinks(); + +// Returns the loopback link on the system. ENOENT if not found. +PosixErrorOr<Link> LoopbackLink(); + +// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + +// LinkChangeFlags changes interface flags. E.g. IFF_UP. +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); + +// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface. +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 5d3a39868..53b678e94 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -364,11 +364,6 @@ CreateTCPConnectAcceptSocketPair(int bound, int connected, int type, } MaybeSave(); // Successful accept. - // FIXME(b/110484944) - if (connect_result == -1) { - absl::SleepFor(absl::Seconds(1)); - } - T extra_addr = {}; LocalhostAddr(&extra_addr, dual_stack); return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr, diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index 4cf1f76f1..8bf663e8b 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -257,6 +257,8 @@ TEST_P(UnixSocketPairTest, ShutdownWrite) { TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) { // TODO(b/122310852): We should be returning ENXIO and NOT EIO. + // TODO(github.dev/issue/1624): This should be resolved in VFS2. Verify + // that this is the case and delete the SKIP_IF once we delete VFS1. SKIP_IF(IsRunningOnGvisor()); auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index faa1247f6..f103e2e56 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/unistd.h> #include <sys/eventfd.h> #include <sys/resource.h> #include <sys/sendfile.h> diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index c951ac3b3..2503960f3 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -34,6 +34,13 @@ #include "test/util/temp_path.h" #include "test/util/test_util.h" +#ifndef AT_STATX_FORCE_SYNC +#define AT_STATX_FORCE_SYNC 0x2000 +#endif +#ifndef AT_STATX_DONT_SYNC +#define AT_STATX_DONT_SYNC 0x4000 +#endif + namespace gvisor { namespace testing { @@ -607,7 +614,7 @@ int statx(int dirfd, const char* pathname, int flags, unsigned int mask, } TEST_F(StatTest, StatxAbsPath) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && errno == ENOSYS); struct kernel_statx stx; @@ -617,7 +624,7 @@ TEST_F(StatTest, StatxAbsPath) { } TEST_F(StatTest, StatxRelPathDirFD) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && errno == ENOSYS); struct kernel_statx stx; @@ -631,7 +638,7 @@ TEST_F(StatTest, StatxRelPathDirFD) { } TEST_F(StatTest, StatxRelPathCwd) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && errno == ENOSYS); ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); @@ -643,7 +650,7 @@ TEST_F(StatTest, StatxRelPathCwd) { } TEST_F(StatTest, StatxEmptyPath) { - SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 && + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && errno == ENOSYS); const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY)); @@ -653,6 +660,60 @@ TEST_F(StatTest, StatxEmptyPath) { EXPECT_TRUE(S_ISREG(stx.stx_mode)); } +TEST_F(StatTest, StatxDoesNotRejectExtraneousMaskBits) { + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && + errno == ENOSYS); + + struct kernel_statx stx; + // Set all mask bits except for STATX__RESERVED. + uint mask = 0xffffffff & ~0x80000000; + EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, mask, &stx), + SyscallSucceeds()); + EXPECT_TRUE(S_ISREG(stx.stx_mode)); +} + +TEST_F(StatTest, StatxRejectsReservedMaskBit) { + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && + errno == ENOSYS); + + struct kernel_statx stx; + // Set STATX__RESERVED in the mask. + EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, 0x80000000, &stx), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_F(StatTest, StatxSymlink) { + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && + errno == ENOSYS); + + std::string parent_dir = "/tmp"; + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(parent_dir, test_file_name_)); + std::string p = link.path(); + + struct kernel_statx stx; + EXPECT_THAT(statx(AT_FDCWD, p.c_str(), AT_SYMLINK_NOFOLLOW, STATX_ALL, &stx), + SyscallSucceeds()); + EXPECT_TRUE(S_ISLNK(stx.stx_mode)); + EXPECT_THAT(statx(AT_FDCWD, p.c_str(), 0, STATX_ALL, &stx), + SyscallSucceeds()); + EXPECT_TRUE(S_ISREG(stx.stx_mode)); +} + +TEST_F(StatTest, StatxInvalidFlags) { + SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && + errno == ENOSYS); + + struct kernel_statx stx; + EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(), 12345, 0, &stx), + SyscallFailsWithErrno(EINVAL)); + + // Sync flags are mutually exclusive. + EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(), + AT_STATX_FORCE_SYNC | AT_STATX_DONT_SYNC, 0, &stx), + SyscallFailsWithErrno(EINVAL)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc index 7e73325bf..92eec0449 100644 --- a/test/syscalls/linux/sticky.cc +++ b/test/syscalls/linux/sticky.cc @@ -42,8 +42,9 @@ TEST(StickyTest, StickyBitPermDenied) { auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - std::string path = JoinPath(dir.path(), "NewDir"); - ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY)); + ASSERT_THAT(mkdirat(dirfd.get(), "NewDir", 0755), SyscallSucceeds()); // Drop privileges and change IDs only in child thread, or else this parent // thread won't be able to open some log files after the test ends. @@ -61,7 +62,8 @@ TEST(StickyTest, StickyBitPermDenied) { syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1), SyscallSucceeds()); - EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(unlinkat(dirfd.get(), "NewDir", AT_REMOVEDIR), + SyscallFailsWithErrno(EPERM)); }); } @@ -96,8 +98,9 @@ TEST(StickyTest, StickyBitCapFOWNER) { auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - std::string path = JoinPath(dir.path(), "NewDir"); - ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY)); + ASSERT_THAT(mkdirat(dirfd.get(), "NewDir", 0755), SyscallSucceeds()); // Drop privileges and change IDs only in child thread, or else this parent // thread won't be able to open some log files after the test ends. @@ -114,7 +117,8 @@ TEST(StickyTest, StickyBitCapFOWNER) { SyscallSucceeds()); EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true)); - EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds()); + EXPECT_THAT(unlinkat(dirfd.get(), "NewDir", AT_REMOVEDIR), + SyscallSucceeds()); }); } } // namespace diff --git a/test/syscalls/linux/sysret.cc b/test/syscalls/linux/sysret.cc index 819fa655a..19ffbd85b 100644 --- a/test/syscalls/linux/sysret.cc +++ b/test/syscalls/linux/sysret.cc @@ -14,6 +14,8 @@ // Tests to verify that the behavior of linux and gvisor matches when // 'sysret' returns to bad (aka non-canonical) %rip or %rsp. + +#include <linux/elf.h> #include <sys/ptrace.h> #include <sys/user.h> @@ -32,6 +34,7 @@ constexpr uint64_t kNonCanonicalRsp = 0xFFFF000000000000; class SysretTest : public ::testing::Test { protected: struct user_regs_struct regs_; + struct iovec iov; pid_t child_; void SetUp() override { @@ -48,10 +51,15 @@ class SysretTest : public ::testing::Test { // Parent. int status; + memset(&iov, 0, sizeof(iov)); ASSERT_THAT(pid, SyscallSucceeds()); // Might still be < 0. ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid)); EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP); - ASSERT_THAT(ptrace(PTRACE_GETREGS, pid, 0, ®s_), SyscallSucceeds()); + + iov.iov_base = ®s_; + iov.iov_len = sizeof(regs_); + ASSERT_THAT(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov), + SyscallSucceeds()); child_ = pid; } @@ -61,13 +69,27 @@ class SysretTest : public ::testing::Test { } void SetRip(uint64_t newrip) { +#if defined(__x86_64__) regs_.rip = newrip; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, ®s_), SyscallSucceeds()); +#elif defined(__aarch64__) + regs_.pc = newrip; +#else +#error "Unknown architecture" +#endif + ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov), + SyscallSucceeds()); } void SetRsp(uint64_t newrsp) { +#if defined(__x86_64__) regs_.rsp = newrsp; - ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, ®s_), SyscallSucceeds()); +#elif defined(__aarch64__) + regs_.sp = newrsp; +#else +#error "Unknown architecture" +#endif + ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov), + SyscallSucceeds()); } // Wait waits for the child pid and returns the exit status. @@ -104,8 +126,15 @@ TEST_F(SysretTest, BadRsp) { SetRsp(kNonCanonicalRsp); Detach(); int status = Wait(); +#if defined(__x86_64__) EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGBUS) << "status = " << status; +#elif defined(__aarch64__) + EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) + << "status = " << status; +#else +#error "Unknown architecture" +#endif } } // namespace diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index c4591a3b9..d9c1ac0e1 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -143,6 +143,20 @@ TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) { SyscallFailsWithErrno(EISCONN)); } +TEST_P(TcpSocketTest, ShutdownWriteInTimeWait) { + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + EXPECT_THAT(shutdown(s_, SHUT_RDWR), SyscallSucceeds()); + absl::SleepFor(absl::Seconds(1)); // Wait to enter TIME_WAIT. + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(TcpSocketTest, ShutdownWriteInFinWait1) { + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + absl::SleepFor(absl::Seconds(1)); // Wait to enter FIN-WAIT2. + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); +} + TEST_P(TcpSocketTest, DataCoalesced) { char buf[10]; @@ -1349,6 +1363,21 @@ TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) { SyscallFailsWithErrno(ENOTCONN)); } +TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) { + auto s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), + addrlen); + int buf_sz = 1 << 18; + EXPECT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc index 1ccb95733..e75bba669 100644 --- a/test/syscalls/linux/time.cc +++ b/test/syscalls/linux/time.cc @@ -26,6 +26,7 @@ namespace { constexpr long kFudgeSeconds = 5; +#if defined(__x86_64__) || defined(__i386__) // Mimics the time(2) wrapper from glibc prior to 2.15. time_t vsyscall_time(time_t* t) { constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400; @@ -98,6 +99,7 @@ TEST(TimeTest, VsyscallGettimeofday_InvalidAddressSIGSEGV) { reinterpret_cast<struct timezone*>(0x1)), ::testing::KilledBySignal(SIGSEGV), ""); } +#endif } // namespace diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index 2f92c27da..4b3c44527 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -658,5 +658,5 @@ int main(int argc, char** argv) { } } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc new file mode 100644 index 000000000..6195b11e1 --- /dev/null +++ b/test/syscalls/linux/tuntap.cc @@ -0,0 +1,402 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_arp.h> +#include <linux/if_ether.h> +#include <linux/if_tun.h> +#include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr int kIPLen = 4; + +constexpr const char kDevNetTun[] = "/dev/net/tun"; +constexpr const char kTapName[] = "tap0"; + +constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}; +constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}; + +PosixErrorOr<std::set<std::string>> DumpLinkNames() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + std::set<std::string> names; + for (const auto& link : links) { + names.emplace(link.name); + } + return names; +} + +PosixErrorOr<Link> GetLinkByName(const std::string& name) { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.name == name) { + return link; + } + } + return PosixError(ENOENT, "interface not found"); +} + +struct pihdr { + uint16_t pi_flags; + uint16_t pi_protocol; +} __attribute__((packed)); + +struct ping_pkt { + pihdr pi; + struct ethhdr eth; + struct iphdr ip; + struct icmphdr icmp; + char payload[64]; +} __attribute__((packed)); + +ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + ping_pkt pkt = {}; + + pkt.pi.pi_protocol = htons(ETH_P_IP); + + memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest)); + memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source)); + pkt.eth.h_proto = htons(ETH_P_IP); + + pkt.ip.ihl = 5; + pkt.ip.version = 4; + pkt.ip.tos = 0; + pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) + + sizeof(pkt.payload)); + pkt.ip.id = 1; + pkt.ip.frag_off = 1 << 6; // Do not fragment + pkt.ip.ttl = 64; + pkt.ip.protocol = IPPROTO_ICMP; + inet_pton(AF_INET, dstip, &pkt.ip.daddr); + inet_pton(AF_INET, srcip, &pkt.ip.saddr); + pkt.ip.check = IPChecksum(pkt.ip); + + pkt.icmp.type = ICMP_ECHO; + pkt.icmp.code = 0; + pkt.icmp.checksum = 0; + pkt.icmp.un.echo.sequence = 1; + pkt.icmp.un.echo.id = 1; + + strncpy(pkt.payload, "abcd", sizeof(pkt.payload)); + pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload)); + + return pkt; +} + +struct arp_pkt { + pihdr pi; + struct ethhdr eth; + struct arphdr arp; + uint8_t arp_sha[ETH_ALEN]; + uint8_t arp_spa[kIPLen]; + uint8_t arp_tha[ETH_ALEN]; + uint8_t arp_tpa[kIPLen]; +} __attribute__((packed)); + +std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + std::string buffer; + buffer.resize(sizeof(arp_pkt)); + + arp_pkt* pkt = reinterpret_cast<arp_pkt*>(&buffer[0]); + { + pkt->pi.pi_protocol = htons(ETH_P_ARP); + + memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest)); + memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source)); + pkt->eth.h_proto = htons(ETH_P_ARP); + + pkt->arp.ar_hrd = htons(ARPHRD_ETHER); + pkt->arp.ar_pro = htons(ETH_P_IP); + pkt->arp.ar_hln = ETH_ALEN; + pkt->arp.ar_pln = kIPLen; + pkt->arp.ar_op = htons(ARPOP_REPLY); + + memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha)); + inet_pton(AF_INET, srcip, pkt->arp_spa); + memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha)); + inet_pton(AF_INET, dstip, pkt->arp_tpa); + } + return buffer; +} + +} // namespace + +TEST(TuntapStaticTest, NetTunExists) { + struct stat statbuf; + ASSERT_THAT(stat(kDevNetTun, &statbuf), SyscallSucceeds()); + // Check that it's a character device with rw-rw-rw- permissions. + EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); +} + +class TuntapTest : public ::testing::Test { + protected: + void TearDown() override { + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) { + // Bring back capability if we had dropped it in test case. + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true)); + } + } +}; + +TEST_F(TuntapTest, CreateInterfaceNoCap) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false)); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + strncpy(ifr.ifr_name, kTapName, IFNAMSIZ); + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM)); +} + +TEST_F(TuntapTest, CreateFixedNameInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_expect = ifr_set; + // See __tun_chr_ioctl() in net/drivers/tun.c. + ifr_expect.ifr_flags |= IFF_NOFILTER; + + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(kTapName))); + EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0)); +} + +TEST_F(TuntapTest, CreateInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + // Empty ifr.ifr_name. Let kernel assign. + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + std::string ifname = ifr_get.ifr_name; + EXPECT_THAT(ifname, ::testing::StartsWith("tap")); + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(ifname))); +} + +TEST_F(TuntapTest, InvalidReadWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + char buf[128] = {}; + EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); +} + +TEST_F(TuntapTest, WriteToDownDevice) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + // Device created should be down by default. + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + char buf[128] = {}; + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO)); +} + +PosixErrorOr<FileDescriptor> OpenAndAttachTap( + const std::string& dev_name, const std::string& dev_ipv4_addr) { + // Interface creation. + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, dev_name.c_str(), IFNAMSIZ); + if (ioctl(fd.get(), TUNSETIFF, &ifr_set) < 0) { + return PosixError(errno); + } + + ASSIGN_OR_RETURN_ERRNO(auto link, GetLinkByName(dev_name)); + + // Interface setup. + struct in_addr addr; + inet_pton(AF_INET, dev_ipv4_addr.c_str(), &addr); + EXPECT_NO_ERRNO(LinkAddLocalAddr(link.index, AF_INET, /*prefixlen=*/24, &addr, + sizeof(addr))); + + if (!IsRunningOnGvisor()) { + // FIXME(b/110961832): gVisor doesn't support setting MAC address on + // interfaces yet. + RETURN_IF_ERRNO(LinkSetMacAddr(link.index, kMacA, sizeof(kMacA))); + + // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces. + RETURN_IF_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP)); + } + + return fd; +} + +// This test sets up a TAP device and pings kernel by sending ICMP echo request. +// +// It works as the following: +// * Open /dev/net/tun, and create kTapName interface. +// * Use rtnetlink to do initial setup of the interface: +// * Assign IP address 10.0.0.1/24 to kernel. +// * MAC address: kMacA +// * Bring up the interface. +// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel. +// * Loop to receive packets from TAP device/fd: +// * If packet is an ICMP echo reply, it stops and passes the test. +// * If packet is an ARP request, it responds with canned reply and resends +// the +// ICMP request packet. +TEST_F(TuntapTest, PingKernel) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1")); + ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + + // Send ping, this would trigger an ARP request on Linux. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + + // Receive loop to process inbound packets. + struct inpkt { + union { + pihdr pi; + ping_pkt ping; + arp_pkt arp; + }; + }; + while (1) { + inpkt r = {}; + int n = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(n, SyscallSucceeds()); + + if (n < sizeof(pihdr)) { + std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol + << " len: " << n << std::endl; + continue; + } + + // Process ARP packet. + if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) { + // Respond with canned ARP reply. + EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()), + SyscallSucceedsWithValue(arp_rep.size())); + // First ping request might have been dropped due to mac address not in + // ARP cache. Send it again. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + } + + // Process ping response packet. + if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol && + r.ping.ip.protocol == ping_req.ip.protocol && + !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) && + !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) && + r.ping.icmp.type == 0 && r.ping.icmp.code == 0) { + // Ends and passes the test. + break; + } + } +} + +TEST_F(TuntapTest, SendUdpTriggersArpResolution) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1")); + + // Send a UDP packet to remote. + int sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_IP); + ASSERT_THAT(sock, SyscallSucceeds()); + + struct sockaddr_in remote = {}; + remote.sin_family = AF_INET; + remote.sin_port = htons(42); + inet_pton(AF_INET, "10.0.0.2", &remote.sin_addr); + int ret = sendto(sock, "hello", 5, 0, reinterpret_cast<sockaddr*>(&remote), + sizeof(remote)); + ASSERT_THAT(ret, ::testing::AnyOf(SyscallSucceeds(), + SyscallFailsWithErrno(EHOSTDOWN))); + + struct inpkt { + union { + pihdr pi; + arp_pkt arp; + }; + }; + while (1) { + inpkt r = {}; + int n = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(n, SyscallSucceeds()); + + if (n < sizeof(pihdr)) { + std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol + << " len: " << n << std::endl; + continue; + } + + if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) { + break; + } + } +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/tuntap_hostinet.cc b/test/syscalls/linux/tuntap_hostinet.cc new file mode 100644 index 000000000..1513fb9d5 --- /dev/null +++ b/test/syscalls/linux/tuntap_hostinet.cc @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +TEST(TuntapHostInetTest, NoNetTun) { + SKIP_IF(!IsRunningOnGvisor()); + SKIP_IF(!IsRunningWithHostinet()); + + struct stat statbuf; + ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallFailsWithErrno(ENOENT)); +} + +} // namespace +} // namespace testing + +} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index a2f6ef8cc..740c7986d 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -21,6 +21,10 @@ #include <sys/socket.h> #include <sys/types.h> +#ifndef SIOCGSTAMP +#include <linux/sockios.h> +#endif + #include "gtest/gtest.h" #include "absl/base/macros.h" #include "absl/time/clock.h" @@ -1349,9 +1353,6 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack. - SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && - !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1422,7 +1423,6 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack. // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1495,6 +1495,5 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) { memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); EXPECT_EQ(received_tos, sent_tos); } - } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index 6218fbce1..64d6d0b8f 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <grp.h> +#include <sys/resource.h> #include <sys/types.h> #include <unistd.h> @@ -249,6 +250,26 @@ TEST(UidGidRootTest, Setgroups) { SyscallFailsWithErrno(EFAULT)); } +TEST(UidGidRootTest, Setuid_prlimit) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); + + // Do seteuid in a separate thread so that after finishing this test, the + // process can still open files the test harness created before starting this + // test. Otherwise, the files are created by root (UID before the test), but + // cannot be opened by the `uid` set below after the test. + ScopedThread([&] { + // Use syscall instead of glibc setuid wrapper because we want this seteuid + // call to only apply to this task. POSIX threads, however, require that all + // threads have the same UIDs, so using the seteuid wrapper sets all + // threads' UID. + EXPECT_THAT(syscall(SYS_setreuid, -1, 65534), SyscallSucceeds()); + + // Despite the UID change, we should be able to get our own limits. + struct rlimit rl = {}; + EXPECT_THAT(prlimit(0, RLIMIT_NOFILE, NULL, &rl), SyscallSucceeds()); + }); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc index 3a927a430..22e6d1a85 100644 --- a/test/syscalls/linux/utimes.cc +++ b/test/syscalls/linux/utimes.cc @@ -34,17 +34,10 @@ namespace testing { namespace { -// TODO(b/36516566): utimes(nullptr) does not pick the "now" time in the -// application's time domain, so when asserting that times are within a window, -// we expand the window to allow for differences between the time domains. -constexpr absl::Duration kClockSlack = absl::Milliseconds(100); - // TimeBoxed runs fn, setting before and after to (coarse realtime) times // guaranteed* to come before and after fn started and completed, respectively. // // fn may be called more than once if the clock is adjusted. -// -// * See the comment on kClockSlack. gVisor breaks this guarantee. void TimeBoxed(absl::Time* before, absl::Time* after, std::function<void()> const& fn) { do { @@ -69,12 +62,6 @@ void TimeBoxed(absl::Time* before, absl::Time* after, // which could lead to test failures, but that is very unlikely to happen. continue; } - - if (IsRunningOnGvisor()) { - // See comment on kClockSlack. - *before -= kClockSlack; - *after += kClockSlack; - } } while (*after < *before); } @@ -235,10 +222,7 @@ void TestUtimensat(int dirFd, std::string const& path) { EXPECT_GE(mtime3, before); EXPECT_LE(mtime3, after); - if (!IsRunningOnGvisor()) { - // FIXME(b/36516566): Gofers set atime and mtime to different "now" times. - EXPECT_EQ(atime3, mtime3); - } + EXPECT_EQ(atime3, mtime3); } TEST(UtimensatTest, OnAbsPath) { diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc index 0aaba482d..19d05998e 100644 --- a/test/syscalls/linux/vfork.cc +++ b/test/syscalls/linux/vfork.cc @@ -191,5 +191,5 @@ int main(int argc, char** argv) { return gvisor::testing::RunChild(); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/vsyscall.cc b/test/syscalls/linux/vsyscall.cc index 2c2303358..ae4377108 100644 --- a/test/syscalls/linux/vsyscall.cc +++ b/test/syscalls/linux/vsyscall.cc @@ -24,6 +24,7 @@ namespace testing { namespace { +#if defined(__x86_64__) || defined(__i386__) time_t vsyscall_time(time_t* t) { constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400; return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t); @@ -37,6 +38,7 @@ TEST(VsyscallTest, VsyscallAlwaysAvailableOnGvisor) { time_t t; EXPECT_THAT(vsyscall_time(&t), SyscallSucceeds()); } +#endif } // namespace diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 9b219cfd6..39b5b2f56 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -31,14 +31,8 @@ namespace gvisor { namespace testing { namespace { -// This test is currently very rudimentary. -// -// TODO(edahlgren): -// * bad buffer states (EFAULT). -// * bad fds (wrong permission, wrong type of file, EBADF). -// * check offset is incremented. -// * check for EOF. -// * writing to pipes, symlinks, special files. + +// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary. class WriteTest : public ::testing::Test { public: ssize_t WriteBytes(int fd, int bytes) { diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index 8b00ef44c..3231732ec 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -41,12 +41,12 @@ class XattrTest : public FileTest {}; TEST_F(XattrTest, XattrNonexistentFile) { const char* path = "/does/not/exist"; - EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0), - SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(getxattr(path, nullptr, nullptr, 0), + const char* name = "user.test"; + EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallFailsWithErrno(ENOENT)); + EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENOENT)); EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT)); - EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(ENOENT)); + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENOENT)); } TEST_F(XattrTest, XattrNullName) { diff --git a/test/syscalls/syscall_test_runner.sh b/test/syscalls/syscall_test_runner.sh deleted file mode 100755 index 864bb2de4..000000000 --- a/test/syscalls/syscall_test_runner.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# syscall_test_runner.sh is a simple wrapper around the go syscall test runner. -# It exists so that we can build the syscall test runner once, and use it for -# all syscall tests, rather than build it for each test run. - -set -euf -x -o pipefail - -echo -- "$@" - -if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then - mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}" - chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}" -fi - -# Get location of syscall_test_runner binary. -readonly runner=$(find "${TEST_SRCDIR}" -name syscall_test_runner) - -# Pass the arguments of this script directly to the runner. -exec "${runner}" "$@" diff --git a/test/util/BUILD b/test/util/BUILD index 1f22ebe29..2a17c33ee 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_library", "cc_test", "gtest", "select_system") +load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system") package( default_visibility = ["//:sandbox"], @@ -260,6 +260,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", gtest, + gbenchmark, ], ) @@ -349,3 +350,9 @@ cc_library( ":save_util", ], ) + +cc_library( + name = "temp_umask", + testonly = 1, + hdrs = ["temp_umask.h"], +) diff --git a/test/util/capability_util.cc b/test/util/capability_util.cc index 9fee52fbb..a1b994c45 100644 --- a/test/util/capability_util.cc +++ b/test/util/capability_util.cc @@ -63,13 +63,13 @@ PosixErrorOr<bool> CanCreateUserNamespace() { // is in a chroot environment (i.e., the caller's root directory does // not match the root directory of the mount namespace in which it // resides)." - std::cerr << "clone(CLONE_NEWUSER) failed with EPERM"; + std::cerr << "clone(CLONE_NEWUSER) failed with EPERM" << std::endl; return false; } else if (errno == EUSERS) { // "(since Linux 3.11) CLONE_NEWUSER was specified in flags, and the call // would cause the limit on the number of nested user namespaces to be // exceeded. See user_namespaces(7)." - std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS"; + std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS" << std::endl; return false; } else { // Unexpected error code; indicate an actual error. diff --git a/test/syscalls/linux/temp_umask.h b/test/util/temp_umask.h index 81a25440c..e7de84a54 100644 --- a/test/syscalls/linux/temp_umask.h +++ b/test/util/temp_umask.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ -#define GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ +#ifndef GVISOR_TEST_UTIL_TEMP_UMASK_H_ +#define GVISOR_TEST_UTIL_TEMP_UMASK_H_ #include <sys/stat.h> #include <sys/types.h> @@ -36,4 +36,4 @@ class TempUmask { } // namespace testing } // namespace gvisor -#endif // GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ +#endif // GVISOR_TEST_UTIL_TEMP_UMASK_H_ diff --git a/test/util/test_main.cc b/test/util/test_main.cc index 5c7ee0064..1f389e58f 100644 --- a/test/util/test_main.cc +++ b/test/util/test_main.cc @@ -16,5 +16,5 @@ int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/util/test_util.h b/test/util/test_util.h index 2d22b0eb8..c5cb9d6d6 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -771,6 +771,7 @@ std::string RunfilePath(std::string path); #endif void TestInit(int* argc, char*** argv); +int RunAllTests(void); } // namespace testing } // namespace gvisor diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc index ba7c0a85b..7e1ad9e66 100644 --- a/test/util/test_util_impl.cc +++ b/test/util/test_util_impl.cc @@ -17,8 +17,12 @@ #include "gtest/gtest.h" #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "benchmark/benchmark.h" #include "test/util/logging.h" +extern bool FLAGS_benchmark_list_tests; +extern std::string FLAGS_benchmark_filter; + namespace gvisor { namespace testing { @@ -26,6 +30,7 @@ void SetupGvisorDeathTest() {} void TestInit(int* argc, char*** argv) { ::testing::InitGoogleTest(argc, *argv); + benchmark::Initialize(argc, *argv); ::absl::ParseCommandLine(*argc, *argv); // Always mask SIGPIPE as it's common and tests aren't expected to handle it. @@ -34,5 +39,14 @@ void TestInit(int* argc, char*** argv) { TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); } +int RunAllTests() { + if (FLAGS_benchmark_list_tests || FLAGS_benchmark_filter != ".") { + benchmark::RunSpecifiedBenchmarks(); + return 0; + } else { + return RUN_ALL_TESTS(); + } +} + } // namespace testing } // namespace gvisor |