diff options
Diffstat (limited to 'test')
59 files changed, 3032 insertions, 535 deletions
diff --git a/test/BUILD b/test/BUILD index 8e1dc5228..01fa01f2e 100644 --- a/test/BUILD +++ b/test/BUILD @@ -39,6 +39,6 @@ toolchain( ], target_compatible_with = [ ], - toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/1.2/bazel_0.23.0/default:cc-compiler-k8", + toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/9.0.0/bazel_0.28.0/cc:cc-compiler-k8", toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", ) diff --git a/test/runtimes/README.md b/test/runtimes/README.md index 4e5a950bc..34d3507be 100644 --- a/test/runtimes/README.md +++ b/test/runtimes/README.md @@ -16,7 +16,7 @@ The following runtimes are currently supported: 1) [Install and configure Docker](https://docs.docker.com/install/) -2) Build each Docker container: +2) Build each Docker container from the runtimes directory: ```bash $ docker build -f $LANG/Dockerfile [-t $NAME] . diff --git a/test/runtimes/common/BUILD b/test/runtimes/common/BUILD new file mode 100644 index 000000000..1b39606b8 --- /dev/null +++ b/test/runtimes/common/BUILD @@ -0,0 +1,20 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "common", + srcs = ["common.go"], + importpath = "gvisor.dev/gvisor/test/runtimes/common", + visibility = ["//:sandbox"], +) + +go_test( + name = "common_test", + size = "small", + srcs = ["common_test.go"], + deps = [ + ":common", + "//runsc/test/testutil", + ], +) diff --git a/test/runtimes/common/common.go b/test/runtimes/common/common.go new file mode 100644 index 000000000..0ff87fa8b --- /dev/null +++ b/test/runtimes/common/common.go @@ -0,0 +1,114 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package common executes functions for proctor binaries. +package common + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "regexp" +) + +var ( + list = flag.Bool("list", false, "list all available tests") + test = flag.String("test", "", "run a single test from the list of available tests") + version = flag.Bool("v", false, "print out the version of node that is installed") +) + +// TestRunner is an interface to be implemented in each proctor binary. +type TestRunner interface { + // ListTests returns a string slice of tests available to run. + ListTests() ([]string, error) + + // RunTest runs a single test. + RunTest(test string) error +} + +// LaunchFunc parses flags passed by a proctor binary and calls the requested behavior. +func LaunchFunc(tr TestRunner) error { + flag.Parse() + + if *list && *test != "" { + flag.PrintDefaults() + return fmt.Errorf("cannot specify 'list' and 'test' flags simultaneously") + } + if *list { + tests, err := tr.ListTests() + if err != nil { + return fmt.Errorf("failed to list tests: %v", err) + } + for _, test := range tests { + fmt.Println(test) + } + return nil + } + if *version { + fmt.Println(os.Getenv("LANG_NAME"), "version:", os.Getenv("LANG_VER"), "is installed.") + return nil + } + if *test != "" { + if err := tr.RunTest(*test); err != nil { + return fmt.Errorf("test %q failed to run: %v", *test, err) + } + return nil + } + + if err := runAllTests(tr); err != nil { + return fmt.Errorf("error running all tests: %v", err) + } + return nil +} + +// Search uses filepath.Walk to perform a search of the disk for test files +// and returns a string slice of tests. +func Search(root string, testFilter *regexp.Regexp) ([]string, error) { + var testSlice []string + + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + name := filepath.Base(path) + + if info.IsDir() || !testFilter.MatchString(name) { + return nil + } + + relPath, err := filepath.Rel(root, path) + if err != nil { + return err + } + testSlice = append(testSlice, relPath) + return nil + }) + + if err != nil { + return nil, fmt.Errorf("walking %q: %v", root, err) + } + + return testSlice, nil +} + +func runAllTests(tr TestRunner) error { + tests, err := tr.ListTests() + if err != nil { + return fmt.Errorf("failed to list tests: %v", err) + } + for _, test := range tests { + if err := tr.RunTest(test); err != nil { + return fmt.Errorf("test %q failed to run: %v", test, err) + } + } + return nil +} diff --git a/test/runtimes/common/common_test.go b/test/runtimes/common/common_test.go new file mode 100644 index 000000000..4fb1e482a --- /dev/null +++ b/test/runtimes/common/common_test.go @@ -0,0 +1,128 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common_test + +import ( + "io/ioutil" + "os" + "path/filepath" + "reflect" + "regexp" + "strings" + "testing" + + "gvisor.dev/gvisor/runsc/test/testutil" + "gvisor.dev/gvisor/test/runtimes/common" +) + +func touch(t *testing.T, name string) { + t.Helper() + f, err := os.Create(name) + if err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } +} + +func TestSearchEmptyDir(t *testing.T) { + td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(td) + + var want []string + + testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) + got, err := common.Search(td, testFilter) + if err != nil { + t.Errorf("Search error: %v", err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("Found %#v; want %#v", got, want) + } +} + +func TestSearch(t *testing.T) { + td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(td) + + // Creating various files similar to the test filter regex. + files := []string{ + "emp/", + "tee/", + "test-foo.tc", + "test-foo.tc", + "test-bar.tc", + "test-sam.tc", + "Test-que.tc", + "test-brett", + "test--abc.tc", + "test---xyz.tc", + "test-bool.TC", + "--test-gvs.tc", + " test-pew.tc", + "dir/test_baz.tc", + "dir/testsnap.tc", + "dir/test-luk.tc", + "dir/nest/test-ok.tc", + "dir/dip/diz/goog/test-pack.tc", + "dir/dip/diz/wobble/thud/test-cas.e", + "dir/dip/diz/wobble/thud/test-cas.tc", + } + want := []string{ + "dir/dip/diz/goog/test-pack.tc", + "dir/dip/diz/wobble/thud/test-cas.tc", + "dir/nest/test-ok.tc", + "dir/test-luk.tc", + "test-bar.tc", + "test-foo.tc", + "test-sam.tc", + } + + for _, item := range files { + if strings.HasSuffix(item, "/") { + // This item is a directory, create it. + if err := os.MkdirAll(filepath.Join(td, item), 0755); err != nil { + t.Fatal(err) + } + } else { + // This item is a file, create the directory and touch file. + // Create directory in which file should be created + fullDirPath := filepath.Join(td, filepath.Dir(item)) + if err := os.MkdirAll(fullDirPath, 0755); err != nil { + t.Fatal(err) + } + // Create file with full path to file. + touch(t, filepath.Join(td, item)) + } + } + + testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) + got, err := common.Search(td, testFilter) + if err != nil { + t.Errorf("Search error: %v", err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("Found %#v; want %#v", got, want) + } +} diff --git a/test/runtimes/go/BUILD b/test/runtimes/go/BUILD index c34f49ea6..ce971ee9d 100644 --- a/test/runtimes/go/BUILD +++ b/test/runtimes/go/BUILD @@ -5,4 +5,5 @@ package(licenses = ["notice"]) go_binary( name = "proctor-go", srcs = ["proctor-go.go"], + deps = ["//test/runtimes/common"], ) diff --git a/test/runtimes/go/Dockerfile b/test/runtimes/go/Dockerfile index cd55608cd..2d3477392 100644 --- a/test/runtimes/go/Dockerfile +++ b/test/runtimes/go/Dockerfile @@ -23,9 +23,13 @@ ENV LANG_DIR=${GOROOT} WORKDIR ${LANG_DIR}/src RUN ./make.bash +# Pre-compile the tests for faster execution +RUN ["/root/go/bin/go", "tool", "dist", "test", "-compile-only"] WORKDIR ${LANG_DIR} -COPY proctor-go.go ${LANG_DIR} +COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common +COPY go/proctor-go.go ${LANG_DIR} +RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-go.go"] -ENTRYPOINT ["/root/go/bin/go", "run", "proctor-go.go"] +ENTRYPOINT ["/root/go/bin/proctor"] diff --git a/test/runtimes/go/proctor-go.go b/test/runtimes/go/proctor-go.go index c5387e21d..3eb24576e 100644 --- a/test/runtimes/go/proctor-go.go +++ b/test/runtimes/go/proctor-go.go @@ -21,7 +21,6 @@ package main import ( - "flag" "fmt" "log" "os" @@ -29,133 +28,78 @@ import ( "path/filepath" "regexp" "strings" + + "gvisor.dev/gvisor/test/runtimes/common" ) var ( - list = flag.Bool("list", false, "list all available tests") - test = flag.String("test", "", "run a single test from the list of available tests") - version = flag.Bool("v", false, "print out the version of node that is installed") - dir = os.Getenv("LANG_DIR") + goBin = filepath.Join(dir, "bin/go") testDir = filepath.Join(dir, "test") testRegEx = regexp.MustCompile(`^.+\.go$`) // Directories with .dir contain helper files for tests. // Exclude benchmarks and stress tests. - exclDirs = regexp.MustCompile(`^.+\/(bench|stress)\/.+$|^.+\.dir.+$`) + dirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`) ) -func main() { - flag.Parse() +type goRunner struct { +} - if *list && *test != "" { - flag.PrintDefaults() - os.Exit(1) - } - if *list { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return - } - if *version { - fmt.Println("Go version: ", os.Getenv("LANG_VER"), " is installed.") - return - } - if *test != "" { - runTest(*test) - return +func main() { + if err := common.LaunchFunc(goRunner{}); err != nil { + log.Fatalf("Failed to start: %v", err) } - runAllTests() } -func listTests() ([]string, error) { +func (g goRunner) ListTests() ([]string, error) { // Go tool dist test tests. args := []string{"tool", "dist", "test", "-list"} cmd := exec.Command(filepath.Join(dir, "bin/go"), args...) cmd.Stderr = os.Stderr out, err := cmd.Output() if err != nil { - log.Fatalf("Failed to list: %v", err) + return nil, fmt.Errorf("failed to list: %v", err) } - var testSlice []string + var toolSlice []string for _, test := range strings.Split(string(out), "\n") { - testSlice = append(testSlice, test) + toolSlice = append(toolSlice, test) } // Go tests on disk. - if err := filepath.Walk(testDir, func(path string, info os.FileInfo, err error) error { - name := filepath.Base(path) - - if info.IsDir() { - return nil - } - - if !testRegEx.MatchString(name) { - return nil - } - - if exclDirs.MatchString(path) { - return nil + diskSlice, err := common.Search(testDir, testRegEx) + if err != nil { + return nil, err + } + // Remove items from /bench/, /stress/ and .dir files + diskFiltered := diskSlice[:0] + for _, file := range diskSlice { + if !dirFilter.MatchString(file) { + diskFiltered = append(diskFiltered, file) } - - testSlice = append(testSlice, path) - return nil - }); err != nil { - return nil, fmt.Errorf("walking %q: %v", testDir, err) } - return testSlice, nil + return append(toolSlice, diskFiltered...), nil } -func runTest(test string) { - toolArgs := []string{ - "tool", - "dist", - "test", - } - diskArgs := []string{ - "run", - "run.go", - "-v", - } +func (g goRunner) RunTest(test string) error { // Check if test exists on disk by searching for file of the same name. // This will determine whether or not it is a Go test on disk. - if _, err := os.Stat(test); err == nil { - relPath, err := filepath.Rel(testDir, test) - if err != nil { - log.Fatalf("Failed to get rel path: %v", err) - } - diskArgs = append(diskArgs, "--", relPath) - cmd := exec.Command(filepath.Join(dir, "bin/go"), diskArgs...) + if strings.HasSuffix(test, ".go") { + // Test has suffix ".go" which indicates a disk test, run it as such. + cmd := exec.Command(goBin, "run", "run.go", "-v", "--", test) cmd.Dir = testDir cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr if err := cmd.Run(); err != nil { - log.Fatalf("Failed to run: %v", err) + return fmt.Errorf("failed to run test: %v", err) } - } else if os.IsNotExist(err) { - // File was not found, try running as Go tool test. - toolArgs = append(toolArgs, "-run", test) - cmd := exec.Command(filepath.Join(dir, "bin/go"), toolArgs...) + } else { + // No ".go" suffix, run as a tool test. + cmd := exec.Command(goBin, "tool", "dist", "test", "-run", test) cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr if err := cmd.Run(); err != nil { - log.Fatalf("Failed to run: %v", err) + return fmt.Errorf("failed to run test: %v", err) } - } else { - log.Fatalf("Error searching for test: %v", err) - } -} - -func runAllTests() { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - runTest(test) } + return nil } diff --git a/test/runtimes/java/BUILD b/test/runtimes/java/BUILD index 7e2808ece..8c39d39ec 100644 --- a/test/runtimes/java/BUILD +++ b/test/runtimes/java/BUILD @@ -5,4 +5,5 @@ package(licenses = ["notice"]) go_binary( name = "proctor-java", srcs = ["proctor-java.go"], + deps = ["//test/runtimes/common"], ) diff --git a/test/runtimes/java/Dockerfile b/test/runtimes/java/Dockerfile index e162d7218..1a61d9d8f 100644 --- a/test/runtimes/java/Dockerfile +++ b/test/runtimes/java/Dockerfile @@ -1,25 +1,16 @@ FROM ubuntu:bionic # This hash is associated with a specific JDK release and needed for ensuring # the same version is downloaded every time. -ENV LANG_HASH=af47e0398606 -ENV LANG_VER=11u-dev +ENV LANG_HASH=76072a077ee1 +ENV LANG_VER=11 ENV LANG_NAME=Java RUN apt-get update && apt-get install -y \ autoconf \ build-essential \ - curl\ - file \ - libasound2-dev \ - libcups2-dev \ - libfontconfig1-dev \ - libx11-dev \ - libxext-dev \ - libxrandr-dev \ - libxrender-dev \ - libxt-dev \ - libxtst-dev \ + curl \ make \ + openjdk-${LANG_VER}-jdk \ unzip \ zip @@ -27,26 +18,19 @@ WORKDIR /root RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz RUN tar -zxf go.tar.gz -# Use curl instead of ADD to prevent redownload every time. -RUN curl -o jdk.tar.gz http://hg.openjdk.java.net/jdk-updates/jdk${LANG_VER}/archive/${LANG_HASH}.tar.gz -# Download Java version N-1 to be used as the Boot JDK to build Java version N. -RUN curl -o bootjdk.tar.gz https://download.java.net/openjdk/jdk10/ri/openjdk-10+44_linux-x64_bin_ri.tar.gz - -RUN tar -zxf jdk.tar.gz -RUN tar -zxf bootjdk.tar.gz - -# Specify the JDK to be used by jtreg. -ENV JT_JAVA=/root/jdk${LANG_VER}-${LANG_HASH}/build/linux-x86_64-normal-server-release/jdk -ENV LANG_DIR=/root/jdk${LANG_VER}-${LANG_HASH} - -WORKDIR ${LANG_DIR} +# Download the JDK test library. +RUN set -ex \ + && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk${LANG_VER}/archive/${LANG_HASH}.tar.gz/test \ + && tar -xzf /tmp/jdktests.tar.gz -C /root \ + && rm -f /tmp/jdktests.tar.gz RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz RUN tar -xzf jtreg.tar.gz -RUN bash configure --with-boot-jdk=/root/jdk-10 --with-jtreg=${LANG_DIR}/jtreg -RUN make clean -RUN make images -COPY proctor-java.go ${LANG_DIR} +ENV LANG_DIR=/root + +COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common +COPY java/proctor-java.go ${LANG_DIR} +RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-java.go"] -ENTRYPOINT ["/root/go/bin/go", "run", "proctor-java.go"] +ENTRYPOINT ["/root/go/bin/proctor"] diff --git a/test/runtimes/java/proctor-java.go b/test/runtimes/java/proctor-java.go index 0177f421d..7f6a66f4f 100644 --- a/test/runtimes/java/proctor-java.go +++ b/test/runtimes/java/proctor-java.go @@ -16,7 +16,6 @@ package main import ( - "flag" "fmt" "log" "os" @@ -24,49 +23,29 @@ import ( "path/filepath" "regexp" "strings" + + "gvisor.dev/gvisor/test/runtimes/common" ) var ( - list = flag.Bool("list", false, "list all available tests") - test = flag.String("test", "", "run a single test from the list of available tests") - version = flag.Bool("v", false, "print out the version of node that is installed") - dir = os.Getenv("LANG_DIR") + hash = os.Getenv("LANG_HASH") jtreg = filepath.Join(dir, "jtreg/bin/jtreg") exclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`) ) -func main() { - flag.Parse() +type javaRunner struct { +} - if *list && *test != "" { - flag.PrintDefaults() - os.Exit(1) - } - if *list { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return - } - if *version { - fmt.Println("Java version: ", os.Getenv("LANG_VER"), " is installed.") - return - } - if *test != "" { - runTest(*test) - return +func main() { + if err := common.LaunchFunc(javaRunner{}); err != nil { + log.Fatalf("Failed to start: %v", err) } - runAllTests() } -func listTests() ([]string, error) { +func (j javaRunner) ListTests() ([]string, error) { args := []string{ - "-dir:test/jdk", + "-dir:/root/jdk11-" + hash + "/test/jdk", "-ignore:quiet", "-a", "-listtests", @@ -90,21 +69,12 @@ func listTests() ([]string, error) { return testSlice, nil } -func runTest(test string) { - args := []string{"-dir:test/jdk/", test} +func (j javaRunner) RunTest(test string) error { + args := []string{"-noreport", "-dir:/root/jdk11-" + hash + "/test/jdk", test} cmd := exec.Command(jtreg, args...) cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr if err := cmd.Run(); err != nil { - log.Fatalf("Failed to run: %v", err) - } -} - -func runAllTests() { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - runTest(test) + return fmt.Errorf("failed to run: %v", err) } + return nil } diff --git a/test/runtimes/nodejs/BUILD b/test/runtimes/nodejs/BUILD index 0fe5ff83e..0594c250b 100644 --- a/test/runtimes/nodejs/BUILD +++ b/test/runtimes/nodejs/BUILD @@ -5,4 +5,5 @@ package(licenses = ["notice"]) go_binary( name = "proctor-nodejs", srcs = ["proctor-nodejs.go"], + deps = ["//test/runtimes/common"], ) diff --git a/test/runtimes/nodejs/Dockerfile b/test/runtimes/nodejs/Dockerfile index b2416cce8..ce2943af8 100644 --- a/test/runtimes/nodejs/Dockerfile +++ b/test/runtimes/nodejs/Dockerfile @@ -22,8 +22,10 @@ RUN ./configure RUN make RUN make test-build -COPY proctor-nodejs.go ${LANG_DIR} +COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common +COPY nodejs/proctor-nodejs.go ${LANG_DIR} +RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-nodejs.go"] # Including dumb-init emulates the Linux "init" process, preventing the failure # of tests involving worker processes. -ENTRYPOINT ["/usr/bin/dumb-init", "/root/go/bin/go", "run", "proctor-nodejs.go"] +ENTRYPOINT ["/usr/bin/dumb-init", "/root/go/bin/proctor"] diff --git a/test/runtimes/nodejs/proctor-nodejs.go b/test/runtimes/nodejs/proctor-nodejs.go index 8ddfb67fe..0624f6a0d 100644 --- a/test/runtimes/nodejs/proctor-nodejs.go +++ b/test/runtimes/nodejs/proctor-nodejs.go @@ -16,93 +16,45 @@ package main import ( - "flag" "fmt" "log" "os" "os/exec" "path/filepath" "regexp" + + "gvisor.dev/gvisor/test/runtimes/common" ) var ( - list = flag.Bool("list", false, "list all available tests") - test = flag.String("test", "", "run a single test from the list of available tests") - version = flag.Bool("v", false, "print out the version of node that is installed") - dir = os.Getenv("LANG_DIR") - testRegEx = regexp.MustCompile(`^test-.+\.js$`) + testDir = filepath.Join(dir, "test") + testRegEx = regexp.MustCompile(`^test-[^-].+\.js$`) ) -func main() { - flag.Parse() +type nodejsRunner struct { +} - if *list && *test != "" { - flag.PrintDefaults() - os.Exit(1) - } - if *list { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return - } - if *version { - fmt.Println("Node.js version: ", os.Getenv("LANG_VER"), " is installed.") - return - } - if *test != "" { - runTest(*test) - return +func main() { + if err := common.LaunchFunc(nodejsRunner{}); err != nil { + log.Fatalf("Failed to start: %v", err) } - runAllTests() } -func listTests() ([]string, error) { - var testSlice []string - root := filepath.Join(dir, "test") - - err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { - name := filepath.Base(path) - - if info.IsDir() || !testRegEx.MatchString(name) { - return nil - } - - relPath, err := filepath.Rel(root, path) - if err != nil { - return err - } - testSlice = append(testSlice, relPath) - return nil - }) - +func (n nodejsRunner) ListTests() ([]string, error) { + testSlice, err := common.Search(testDir, testRegEx) if err != nil { - return nil, fmt.Errorf("walking %q: %v", root, err) + return nil, err } - return testSlice, nil } -func runTest(test string) { +func (n nodejsRunner) RunTest(test string) error { args := []string{filepath.Join(dir, "tools", "test.py"), test} cmd := exec.Command("/usr/bin/python", args...) cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr if err := cmd.Run(); err != nil { - log.Fatalf("Failed to run: %v", err) - } -} - -func runAllTests() { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - runTest(test) + return fmt.Errorf("failed to run: %v", err) } + return nil } diff --git a/test/runtimes/php/BUILD b/test/runtimes/php/BUILD index 22aef7ba4..31799b77a 100644 --- a/test/runtimes/php/BUILD +++ b/test/runtimes/php/BUILD @@ -5,4 +5,5 @@ package(licenses = ["notice"]) go_binary( name = "proctor-php", srcs = ["proctor-php.go"], + deps = ["//test/runtimes/common"], ) diff --git a/test/runtimes/php/Dockerfile b/test/runtimes/php/Dockerfile index 1f8959b50..d79babe58 100644 --- a/test/runtimes/php/Dockerfile +++ b/test/runtimes/php/Dockerfile @@ -24,6 +24,8 @@ WORKDIR ${LANG_DIR} RUN ./configure RUN make -COPY proctor-php.go ${LANG_DIR} +COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common +COPY php/proctor-php.go ${LANG_DIR} +RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-php.go"] -ENTRYPOINT ["/root/go/bin/go", "run", "proctor-php.go"] +ENTRYPOINT ["/root/go/bin/proctor"] diff --git a/test/runtimes/php/proctor-php.go b/test/runtimes/php/proctor-php.go index 9dfb33b04..e6c5fabdf 100644 --- a/test/runtimes/php/proctor-php.go +++ b/test/runtimes/php/proctor-php.go @@ -16,92 +16,43 @@ package main import ( - "flag" "fmt" "log" "os" "os/exec" - "path/filepath" "regexp" + + "gvisor.dev/gvisor/test/runtimes/common" ) var ( - list = flag.Bool("list", false, "list all available tests") - test = flag.String("test", "", "run a single test from the list of available tests") - version = flag.Bool("v", false, "print out the version of node that is installed") - dir = os.Getenv("LANG_DIR") testRegEx = regexp.MustCompile(`^.+\.phpt$`) ) -func main() { - flag.Parse() +type phpRunner struct { +} - if *list && *test != "" { - flag.PrintDefaults() - os.Exit(1) - } - if *list { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return - } - if *version { - fmt.Println("PHP version: ", os.Getenv("LANG_VER"), " is installed.") - return - } - if *test != "" { - runTest(*test) - return +func main() { + if err := common.LaunchFunc(phpRunner{}); err != nil { + log.Fatalf("Failed to start: %v", err) } - runAllTests() } -func listTests() ([]string, error) { - var testSlice []string - - err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { - name := filepath.Base(path) - - if info.IsDir() || !testRegEx.MatchString(name) { - return nil - } - - relPath, err := filepath.Rel(dir, path) - if err != nil { - return err - } - testSlice = append(testSlice, relPath) - return nil - }) - +func (p phpRunner) ListTests() ([]string, error) { + testSlice, err := common.Search(dir, testRegEx) if err != nil { - return nil, fmt.Errorf("walking %q: %v", dir, err) + return nil, err } - return testSlice, nil } -func runTest(test string) { +func (p phpRunner) RunTest(test string) error { args := []string{"test", "TESTS=" + test} cmd := exec.Command("make", args...) cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr if err := cmd.Run(); err != nil { - log.Fatalf("Failed to run: %v", err) - } -} - -func runAllTests() { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - runTest(test) + return fmt.Errorf("failed to run: %v", err) } + return nil } diff --git a/test/runtimes/python/BUILD b/test/runtimes/python/BUILD index 501f77d63..37fd6a0f2 100644 --- a/test/runtimes/python/BUILD +++ b/test/runtimes/python/BUILD @@ -5,4 +5,5 @@ package(licenses = ["notice"]) go_binary( name = "proctor-python", srcs = ["proctor-python.go"], + deps = ["//test/runtimes/common"], ) diff --git a/test/runtimes/python/Dockerfile b/test/runtimes/python/Dockerfile index 811f48f8a..5ae328890 100644 --- a/test/runtimes/python/Dockerfile +++ b/test/runtimes/python/Dockerfile @@ -26,6 +26,8 @@ WORKDIR ${LANG_DIR} RUN ./configure --with-pydebug RUN make -s -j2 -COPY proctor-python.go ${LANG_DIR} +COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common +COPY python/proctor-python.go ${LANG_DIR} +RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-python.go"] -ENTRYPOINT ["/root/go/bin/go", "run", "proctor-python.go"] +ENTRYPOINT ["/root/go/bin/proctor"] diff --git a/test/runtimes/python/proctor-python.go b/test/runtimes/python/proctor-python.go index 73c8deb49..35e28a7df 100644 --- a/test/runtimes/python/proctor-python.go +++ b/test/runtimes/python/proctor-python.go @@ -16,93 +16,50 @@ package main import ( - "flag" "fmt" "log" "os" "os/exec" "path/filepath" - "regexp" + "strings" + + "gvisor.dev/gvisor/test/runtimes/common" ) var ( - list = flag.Bool("list", false, "list all available tests") - test = flag.String("test", "", "run a single test from the list of available tests") - version = flag.Bool("v", false, "print out the version of node that is installed") - - dir = os.Getenv("LANG_DIR") - testRegEx = regexp.MustCompile(`^test_.+\.py$`) + dir = os.Getenv("LANG_DIR") ) -func main() { - flag.Parse() +type pythonRunner struct { +} - if *list && *test != "" { - flag.PrintDefaults() - os.Exit(1) - } - if *list { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return - } - if *version { - fmt.Println("Python version: ", os.Getenv("LANG_VER"), " is installed.") - return - } - if *test != "" { - runTest(*test) - return +func main() { + if err := common.LaunchFunc(pythonRunner{}); err != nil { + log.Fatalf("Failed to start: %v", err) } - runAllTests() } -func listTests() ([]string, error) { - var testSlice []string - root := filepath.Join(dir, "Lib/test") - - err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { - name := filepath.Base(path) - - if info.IsDir() || !testRegEx.MatchString(name) { - return nil - } - - relPath, err := filepath.Rel(root, path) - if err != nil { - return err - } - testSlice = append(testSlice, relPath) - return nil - }) - +func (p pythonRunner) ListTests() ([]string, error) { + args := []string{"-m", "test", "--list-tests"} + cmd := exec.Command(filepath.Join(dir, "python"), args...) + cmd.Stderr = os.Stderr + out, err := cmd.Output() if err != nil { - return nil, fmt.Errorf("walking %q: %v", root, err) + return nil, fmt.Errorf("failed to list: %v", err) } - - return testSlice, nil + var toolSlice []string + for _, test := range strings.Split(string(out), "\n") { + toolSlice = append(toolSlice, test) + } + return toolSlice, nil } -func runTest(test string) { +func (p pythonRunner) RunTest(test string) error { args := []string{"-m", "test", test} cmd := exec.Command(filepath.Join(dir, "python"), args...) cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr if err := cmd.Run(); err != nil { - log.Fatalf("Failed to run: %v", err) - } -} - -func runAllTests() { - tests, err := listTests() - if err != nil { - log.Fatalf("Failed to list tests: %v", err) - } - for _, test := range tests { - runTest(test) + return fmt.Errorf("failed to run: %v", err) } + return nil } diff --git a/test/runtimes/runtimes_test.go b/test/runtimes/runtimes_test.go index 6bf954e78..9421021a1 100644 --- a/test/runtimes/runtimes_test.go +++ b/test/runtimes/runtimes_test.go @@ -22,8 +22,16 @@ import ( "gvisor.dev/gvisor/runsc/test/testutil" ) -func TestNodeJS(t *testing.T) { - const img = "gcr.io/gvisor-proctor/nodejs" +// Wait time for each test to run. +const timeout = 180 * time.Second + +// Helper function to execute the docker container associated with the +// language passed. Captures the output of the list function and executes +// each test individually, supplying any errors recieved. +func testLang(t *testing.T, lang string) { + t.Helper() + + img := "gcr.io/gvisor-presubmit/" + lang if err := testutil.Pull(img); err != nil { t.Fatalf("docker pull failed: %v", err) } @@ -41,15 +49,13 @@ func TestNodeJS(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc, func(t *testing.T) { - t.Parallel() - d := testutil.MakeDocker("gvisor-test") if err := d.Run(img, "--test", tc); err != nil { t.Fatalf("docker test %q failed to run: %v", tc, err) } defer d.CleanUp() - status, err := d.Wait(60 * time.Second) + status, err := d.Wait(timeout) if err != nil { t.Fatalf("docker test %q failed to wait: %v", tc, err) } @@ -65,3 +71,23 @@ func TestNodeJS(t *testing.T) { }) } } + +func TestGo(t *testing.T) { + testLang(t, "go") +} + +func TestJava(t *testing.T) { + testLang(t, "java") +} + +func TestNodejs(t *testing.T) { + testLang(t, "nodejs") +} + +func TestPhp(t *testing.T) { + testLang(t, "php") +} + +func TestPython(t *testing.T) { + testLang(t, "python") +} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 841a0f2e1..ccae4925f 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -27,6 +27,7 @@ syscall_test( syscall_test( size = "medium", + shard_count = 5, test = "//test/syscalls/linux:alarm_test", ) @@ -180,7 +181,12 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:iptables_test", +) + +syscall_test( size = "medium", + shard_count = 5, test = "//test/syscalls/linux:itimer_test", ) @@ -249,6 +255,10 @@ syscall_test( test = "//test/syscalls/linux:open_test", ) +syscall_test(test = "//test/syscalls/linux:packet_socket_raw_test") + +syscall_test(test = "//test/syscalls/linux:packet_socket_test") + syscall_test(test = "//test/syscalls/linux:partial_bad_buffer_test") syscall_test(test = "//test/syscalls/linux:pause_test") @@ -290,8 +300,6 @@ syscall_test(test = "//test/syscalls/linux:priority_test") syscall_test( size = "medium", - # We don't want our proc changing out from under us. - parallel = False, test = "//test/syscalls/linux:proc_test", ) @@ -306,10 +314,15 @@ syscall_test(test = "//test/syscalls/linux:ptrace_test") syscall_test( size = "medium", + shard_count = 5, test = "//test/syscalls/linux:pty_test", ) syscall_test( + test = "//test/syscalls/linux:pty_root_test", +) + +syscall_test( add_overlay = True, test = "//test/syscalls/linux:pwritev2_test", ) @@ -332,6 +345,7 @@ syscall_test( syscall_test( size = "medium", + shard_count = 5, test = "//test/syscalls/linux:readv_socket_test", ) @@ -436,7 +450,8 @@ syscall_test( ) syscall_test( - size = "medium", + size = "large", + shard_count = 10, test = "//test/syscalls/linux:socket_inet_loopback_test", ) @@ -476,8 +491,6 @@ syscall_test( syscall_test( size = "medium", - # Multicast packets can be received by the wrong test if run in parallel. - parallel = False, test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test", ) @@ -516,6 +529,7 @@ syscall_test( syscall_test( # NOTE(b/116636318): Large sendmsg may stall a long time. size = "enormous", + shard_count = 5, test = "//test/syscalls/linux:socket_unix_dgram_local_test", ) @@ -534,6 +548,7 @@ syscall_test( syscall_test( # NOTE(b/116636318): Large sendmsg may stall a long time. size = "enormous", + shard_count = 5, test = "//test/syscalls/linux:socket_unix_seqpacket_local_test", ) @@ -665,6 +680,7 @@ syscall_test(test = "//test/syscalls/linux:vfork_test") syscall_test( size = "medium", + shard_count = 5, test = "//test/syscalls/linux:wait_test", ) diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index 9f2fc9109..60df47798 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -4,12 +4,11 @@ # on the host (native) and runsc. def syscall_test( test, - shard_count = 1, + shard_count = 5, size = "small", use_tmpfs = False, add_overlay = False, - tags = None, - parallel = True): + tags = None): _syscall_test( test = test, shard_count = shard_count, @@ -17,7 +16,6 @@ def syscall_test( platform = "native", use_tmpfs = False, tags = tags, - parallel = parallel, ) _syscall_test( @@ -27,7 +25,6 @@ def syscall_test( platform = "kvm", use_tmpfs = use_tmpfs, tags = tags, - parallel = parallel, ) _syscall_test( @@ -37,7 +34,6 @@ def syscall_test( platform = "ptrace", use_tmpfs = use_tmpfs, tags = tags, - parallel = parallel, ) if add_overlay: @@ -48,7 +44,6 @@ def syscall_test( platform = "ptrace", use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. tags = tags, - parallel = parallel, overlay = True, ) @@ -61,7 +56,6 @@ def syscall_test( platform = "ptrace", use_tmpfs = use_tmpfs, tags = tags, - parallel = parallel, file_access = "shared", ) @@ -72,7 +66,6 @@ def _syscall_test( platform, use_tmpfs, tags, - parallel, file_access = "exclusive", overlay = False): test_name = test.split(":")[1] @@ -111,9 +104,6 @@ def _syscall_test( "--overlay=" + str(overlay), ] - if parallel: - args += ["--parallel=true"] - sh_test( srcs = ["syscall_test_runner.sh"], name = name, @@ -132,3 +122,6 @@ def sh_test(**kwargs): native.sh_test( **kwargs ) + +def select_for_linux(for_linux, for_others = []): + return for_linux diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 40fc73812..88f3bfcb3 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1,3 +1,5 @@ +load("//test/syscalls:build_defs.bzl", "select_for_linux") + package( default_visibility = ["//:sandbox"], licenses = ["notice"], @@ -108,20 +110,27 @@ cc_library( cc_library( name = "socket_test_util", testonly = 1, - srcs = ["socket_test_util.cc"], + srcs = [ + "socket_test_util.cc", + ] + select_for_linux( + [ + "socket_test_util_impl.cc", + ], + ), hdrs = ["socket_test_util.h"], deps = [ + "@com_google_googletest//:gtest", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", "//test/util:file_descriptor", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest", - ], + ] + select_for_linux([ + ]), ) cc_library( @@ -913,6 +922,24 @@ cc_library( ) cc_binary( + name = "iptables_test", + testonly = 1, + srcs = [ + "iptables.cc", + ], + linkstatic = 1, + deps = [ + ":iptables_types", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "itimer_test", testonly = 1, srcs = ["itimer.cc"], @@ -1209,13 +1236,51 @@ cc_binary( ) cc_binary( + name = "packet_socket_raw_test", + testonly = 1, + srcs = ["packet_socket_raw.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + ":unix_domain_socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:endian", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "packet_socket_test", + testonly = 1, + srcs = ["packet_socket.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + ":unix_domain_socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:endian", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "pty_test", testonly = 1, srcs = ["pty.cc"], linkstatic = 1, deps = [ + "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:posix_error", + "//test/util:pty_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", @@ -1228,15 +1293,36 @@ cc_binary( ) cc_binary( + name = "pty_root_test", + testonly = 1, + srcs = ["pty_root.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:posix_error", + "//test/util:pty_util", + "//test/util:test_main", + "//test/util:thread_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "partial_bad_buffer_test", testonly = 1, srcs = ["partial_bad_buffer.cc"], linkstatic = 1, deps = [ + "//test/syscalls/linux:socket_test_util", + "//test/util:file_descriptor", "//test/util:fs_util", + "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) diff --git a/test/syscalls/linux/affinity.cc b/test/syscalls/linux/affinity.cc index f2d8375b6..128364c34 100644 --- a/test/syscalls/linux/affinity.cc +++ b/test/syscalls/linux/affinity.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <sched.h> +#include <sys/syscall.h> #include <sys/types.h> #include <unistd.h> diff --git a/test/syscalls/linux/base_poll_test.h b/test/syscalls/linux/base_poll_test.h index 088831f9f..0d4a6701e 100644 --- a/test/syscalls/linux/base_poll_test.h +++ b/test/syscalls/linux/base_poll_test.h @@ -56,7 +56,7 @@ class TimerThread { private: mutable absl::Mutex mu_; - bool cancel_ GUARDED_BY(mu_) = false; + bool cancel_ ABSL_GUARDED_BY(mu_) = false; // Must be last to ensure that the destructor for the thread is run before // any other member of the object is destroyed. diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc index aacbb5e70..d3e3f998c 100644 --- a/test/syscalls/linux/futex.cc +++ b/test/syscalls/linux/futex.cc @@ -125,6 +125,10 @@ int futex_lock_pi(bool priv, std::atomic<int>* uaddr) { if (priv) { op |= FUTEX_PRIVATE_FLAG; } + int zero = 0; + if (uaddr->compare_exchange_strong(zero, gettid())) { + return 0; + } return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); } @@ -133,6 +137,10 @@ int futex_trylock_pi(bool priv, std::atomic<int>* uaddr) { if (priv) { op |= FUTEX_PRIVATE_FLAG; } + int zero = 0; + if (uaddr->compare_exchange_strong(zero, gettid())) { + return 0; + } return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); } @@ -141,6 +149,10 @@ int futex_unlock_pi(bool priv, std::atomic<int>* uaddr) { if (priv) { op |= FUTEX_PRIVATE_FLAG; } + int tid = gettid(); + if (uaddr->compare_exchange_strong(tid, 0)) { + return 0; + } return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr); } @@ -689,11 +701,11 @@ TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) { std::atomic<int> a = ATOMIC_VAR_INIT(0); const bool is_priv = IsPrivate(); - std::unique_ptr<ScopedThread> threads[100]; + std::unique_ptr<ScopedThread> threads[10]; for (size_t i = 0; i < ABSL_ARRAYSIZE(threads); ++i) { threads[i] = absl::make_unique<ScopedThread>([is_priv, &a] { for (size_t j = 0; j < 10;) { - if (futex_trylock_pi(is_priv, &a) >= 0) { + if (futex_trylock_pi(is_priv, &a) == 0) { ++j; EXPECT_EQ(a.load() & FUTEX_TID_MASK, gettid()); SleepSafe(absl::Milliseconds(5)); diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc new file mode 100644 index 000000000..b8e4ece64 --- /dev/null +++ b/test/syscalls/linux/iptables.cc @@ -0,0 +1,204 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/iptables.h" + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/netfilter/x_tables.h> +#include <net/if.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <stdio.h> +#include <sys/poll.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <algorithm> + +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kNatTablename[] = "nat"; +constexpr char kErrorTarget[] = "ERROR"; +constexpr size_t kEmptyStandardEntrySize = + sizeof(struct ipt_entry) + sizeof(struct ipt_standard_target); +constexpr size_t kEmptyErrorEntrySize = + sizeof(struct ipt_entry) + sizeof(struct ipt_error_target); + +TEST(IPTablesBasic, CreateSocket) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), + SyscallSucceeds()); + + ASSERT_THAT(close(sock), SyscallSucceeds()); +} + +TEST(IPTablesBasic, FailSockoptNonRaw) { + // Even if the user has CAP_NET_RAW, they shouldn't be able to use the + // iptables sockopts with a non-raw socket. + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); + + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + EXPECT_THAT(getsockopt(sock, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(ENOPROTOOPT)); + + ASSERT_THAT(close(sock), SyscallSucceeds()); +} + +// Fixture for iptables tests. +class IPTablesTest : public ::testing::Test { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // The socket via which to manipulate iptables. + int s_; +}; + +void IPTablesTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); +} + +void IPTablesTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + EXPECT_THAT(close(s_), SyscallSucceeds()); +} + +// This tests the initial state of a machine with empty iptables. We don't have +// a guarantee that the iptables are empty when running in native, but we can +// test that gVisor has the same initial state that a newly-booted Linux machine +// would have. +TEST_F(IPTablesTest, InitialState) { + SKIP_IF(!IsRunningOnGvisor()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // + // Get info via sockopt. + // + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + ASSERT_THAT(getsockopt(s_, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + SyscallSucceeds()); + + // The nat table supports PREROUTING, and OUTPUT. + unsigned int valid_hooks = (1 << NF_IP_PRE_ROUTING) | (1 << NF_IP_LOCAL_OUT) | + (1 << NF_IP_POST_ROUTING) | (1 << NF_IP_LOCAL_IN); + + EXPECT_EQ(info.valid_hooks, valid_hooks); + + // Each chain consists of an empty entry with a standard target.. + EXPECT_EQ(info.hook_entry[NF_IP_PRE_ROUTING], 0); + EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.hook_entry[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.hook_entry[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // The underflow points are the same as the entry points. + EXPECT_EQ(info.underflow[NF_IP_PRE_ROUTING], 0); + EXPECT_EQ(info.underflow[NF_IP_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.underflow[NF_IP_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.underflow[NF_IP_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // One entry for each chain, plus an error entry at the end. + EXPECT_EQ(info.num_entries, 5); + + EXPECT_EQ(info.size, 4 * kEmptyStandardEntrySize + kEmptyErrorEntrySize); + EXPECT_EQ(strcmp(info.name, kNatTablename), 0); + + // + // Use info to get entries. + // + socklen_t entries_size = sizeof(struct ipt_get_entries) + info.size; + struct ipt_get_entries* entries = + static_cast<struct ipt_get_entries*>(malloc(entries_size)); + snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + entries->size = info.size; + ASSERT_THAT( + getsockopt(s_, IPPROTO_IP, SO_GET_ENTRIES, entries, &entries_size), + SyscallSucceeds()); + + // Verify the name and size. + ASSERT_EQ(info.size, entries->size); + ASSERT_EQ(strcmp(entries->name, kNatTablename), 0); + + // Verify that the entrytable is 4 entries with accept targets and no matches + // followed by a single error target. + size_t entry_offset = 0; + while (entry_offset < entries->size) { + struct ipt_entry* entry = reinterpret_cast<struct ipt_entry*>( + reinterpret_cast<char*>(entries->entrytable) + entry_offset); + + // ip should be zeroes. + struct ipt_ip zeroed = {}; + EXPECT_EQ(memcmp(static_cast<void*>(&zeroed), + static_cast<void*>(&entry->ip), sizeof(zeroed)), + 0); + + // target_offset should be zero. + EXPECT_EQ(entry->target_offset, sizeof(ipt_entry)); + + if (entry_offset < kEmptyStandardEntrySize * 4) { + // The first 4 entries are standard targets + struct ipt_standard_target* target = + reinterpret_cast<struct ipt_standard_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyStandardEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, ""), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + // This is what's returned for an accept verdict. I don't know why. + EXPECT_EQ(target->verdict, -NF_ACCEPT - 1); + } else { + // The last entry is an error target + struct ipt_error_target* target = + reinterpret_cast<struct ipt_error_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyErrorEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, kErrorTarget), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + EXPECT_EQ(strcmp(target->errorname, kErrorTarget), 0); + } + + entry_offset += entry->next_offset; + } + + free(entries); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc new file mode 100644 index 000000000..7a3379b9e --- /dev/null +++ b/test/syscalls/linux/packet_socket.cc @@ -0,0 +1,299 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_arp.h> +#include <linux/if_packet.h> +#include <net/ethernet.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/udp.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/base/internal/endian.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Some of these tests involve sending packets via AF_PACKET sockets and the +// loopback interface. Because AF_PACKET circumvents so much of the networking +// stack, Linux sees these packets as "martian", i.e. they claim to be to/from +// localhost but don't have the usual associated data. Thus Linux drops them by +// default. You can see where this happens by following the code at: +// +// - net/ipv4/ip_input.c:ip_rcv_finish, which calls +// - net/ipv4/route.c:ip_route_input_noref, which calls +// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian +// packets. +// +// To tell Linux not to drop these packets, you need to tell it to accept our +// funny packets (which are completely valid and correct, but lack associated +// in-kernel data because we use AF_PACKET): +// +// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local +// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet +// +// These tests require CAP_NET_RAW to run. + +// TODO(gvisor.dev/issue/173): gVisor support. + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kMessage[] = "soweoneul malhaebwa"; +constexpr in_port_t kPort = 0x409c; // htons(40000) + +// +// "Cooked" tests. Cooked AF_PACKET sockets do not contain link layer +// headers, and provide link layer destination/source information via a +// returned struct sockaddr_ll. +// + +// Send kMessage via sock to loopback +void SendUDPMessage(int sock) { + struct sockaddr_in dest = {}; + dest.sin_port = kPort; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); +} + +// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up. +TEST(BasicCookedPacketTest, WrongType) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP)); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Wait and make sure the socket never becomes readable. + struct pollfd pfd = {}; + pfd.fd = sock.get(); + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); +} + +// Tests for "cooked" (SOCK_DGRAM) packet(7) sockets. +class CookedPacketTest : public ::testing::TestWithParam<int> { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // Gets the device index of the loopback device. + int GetLoopbackIndex(); + + // The socket used for both reading and writing. + int socket_; +}; + +void CookedPacketTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), + SyscallSucceeds()); +} + +void CookedPacketTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + EXPECT_THAT(close(socket_), SyscallSucceeds()); +} + +int CookedPacketTest::GetLoopbackIndex() { + struct ifreq ifr; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + return ifr.ifr_ifindex; +} + +// Receive via a packet socket. +TEST_P(CookedPacketTest, Receive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Wait for the socket to become readable. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); + + // Read and verify the data. + constexpr size_t packet_size = + sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + ASSERT_EQ(src_len, sizeof(src)); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); + EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + // This came from the loopback device, so the address is all 0s. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], 0); + } + + // Verify the IP header. We memcpy to deal with pointer aligment. + struct iphdr ip = {}; + memcpy(&ip, buf, sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size)); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. We memcpy to deal with pointer aligment. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + +// Send via a packet socket. +TEST_P(CookedPacketTest, Send) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's send a UDP packet and receive it using a regular UDP socket. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + struct sockaddr_in bind_addr = {}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + bind_addr.sin_port = kPort; + ASSERT_THAT( + bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Set up the destination physical address. + struct sockaddr_ll dest = {}; + dest.sll_family = AF_PACKET; + dest.sll_halen = ETH_ALEN; + dest.sll_ifindex = GetLoopbackIndex(); + dest.sll_protocol = htons(ETH_P_IP); + // We're sending to the loopback device, so the address is all 0s. + memset(dest.sll_addr, 0x00, ETH_ALEN); + + // Set up the IP header. + struct iphdr iphdr = {0}; + iphdr.ihl = 5; + iphdr.version = 4; + iphdr.tos = 0; + iphdr.tot_len = + htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage)); + // Get a pseudo-random ID. If we clash with an in-use ID the test will fail, + // but we have no way of getting an ID we know to be good. + srand(*reinterpret_cast<unsigned int*>(&iphdr)); + iphdr.id = rand(); + // Linux sets this bit ("do not fragment") for small packets. + iphdr.frag_off = 1 << 6; + iphdr.ttl = 64; + iphdr.protocol = IPPROTO_UDP; + iphdr.daddr = htonl(INADDR_LOOPBACK); + iphdr.saddr = htonl(INADDR_LOOPBACK); + iphdr.check = IPChecksum(iphdr); + + // Set up the UDP header. + struct udphdr udphdr = {}; + udphdr.source = kPort; + udphdr.dest = kPort; + udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage)); + udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage)); + + // Copy both headers and the payload into our packet buffer. + char send_buf[sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)]; + memcpy(send_buf, &iphdr, sizeof(iphdr)); + memcpy(send_buf + sizeof(iphdr), &udphdr, sizeof(udphdr)); + memcpy(send_buf + sizeof(iphdr) + sizeof(udphdr), kMessage, sizeof(kMessage)); + + // Send it. + ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Wait for the packet to become available on both sockets. + struct pollfd pfd = {}; + pfd.fd = udp_sock.get(); + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + pfd.fd = socket_; + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + + // Receive on the packet socket. + char recv_buf[sizeof(send_buf)]; + ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); + + // Receive on the UDP socket. + struct sockaddr_in src; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(sizeof(kMessage))); + // Check src and payload. + EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0); + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(src.sin_port, kPort); + EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); +} + +INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, + ::testing::Values(ETH_P_IP, ETH_P_ALL)); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc new file mode 100644 index 000000000..9e96460ee --- /dev/null +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -0,0 +1,314 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_arp.h> +#include <linux/if_packet.h> +#include <net/ethernet.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/udp.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/base/internal/endian.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Some of these tests involve sending packets via AF_PACKET sockets and the +// loopback interface. Because AF_PACKET circumvents so much of the networking +// stack, Linux sees these packets as "martian", i.e. they claim to be to/from +// localhost but don't have the usual associated data. Thus Linux drops them by +// default. You can see where this happens by following the code at: +// +// - net/ipv4/ip_input.c:ip_rcv_finish, which calls +// - net/ipv4/route.c:ip_route_input_noref, which calls +// - net/ipv4/route.c:ip_route_input_slow, which finds and drops martian +// packets. +// +// To tell Linux not to drop these packets, you need to tell it to accept our +// funny packets (which are completely valid and correct, but lack associated +// in-kernel data because we use AF_PACKET): +// +// echo 1 >> /proc/sys/net/ipv4/conf/lo/accept_local +// echo 1 >> /proc/sys/net/ipv4/conf/lo/route_localnet +// +// These tests require CAP_NET_RAW to run. + +// TODO(gvisor.dev/issue/173): gVisor support. + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kMessage[] = "soweoneul malhaebwa"; +constexpr in_port_t kPort = 0x409c; // htons(40000) + +// Send kMessage via sock to loopback +void SendUDPMessage(int sock) { + struct sockaddr_in dest = {}; + dest.sin_port = kPort; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + EXPECT_THAT(sendto(sock, kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); +} + +// +// Raw tests. Packets sent with raw AF_PACKET sockets always include link layer +// headers. +// + +// Tests for "raw" (SOCK_RAW) packet(7) sockets. +class RawPacketTest : public ::testing::TestWithParam<int> { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // Gets the device index of the loopback device. + int GetLoopbackIndex(); + + // The socket used for both reading and writing. + int socket_; +}; + +void RawPacketTest::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + if (!IsRunningOnGvisor()) { + FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); + FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); + char enabled; + ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } + + ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())), + SyscallSucceeds()); +} + +void RawPacketTest::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + EXPECT_THAT(close(socket_), SyscallSucceeds()); +} + +int RawPacketTest::GetLoopbackIndex() { + struct ifreq ifr; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + return ifr.ifr_ifindex; +} + +// Receive via a packet socket. +TEST_P(RawPacketTest, Receive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Wait for the socket to become readable. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); + + // Read and verify the data. + constexpr size_t packet_size = sizeof(struct ethhdr) + sizeof(struct iphdr) + + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + // sizeof(src) is the size of a struct sockaddr_ll. sockaddr_ll ends with an 8 + // byte physical address field, but ethernet (MAC) addresses only use 6 bytes. + // Thus src_len should get modified to be 2 less than the size of sockaddr_ll. + ASSERT_EQ(src_len, sizeof(src) - 2); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); + EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + // This came from the loopback device, so the address is all 0s. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], 0); + } + + // Verify the ethernet header. We memcpy to deal with pointer alignment. + struct ethhdr eth = {}; + memcpy(ð, buf, sizeof(eth)); + // The destination and source address should be 0, for loopback. + for (int i = 0; i < ETH_ALEN; i++) { + EXPECT_EQ(eth.h_dest[i], 0); + EXPECT_EQ(eth.h_source[i], 0); + } + EXPECT_EQ(eth.h_proto, htons(ETH_P_IP)); + + // Verify the IP header. We memcpy to deal with pointer aligment. + struct iphdr ip = {}; + memcpy(&ip, buf + sizeof(ethhdr), sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size - sizeof(eth))); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. We memcpy to deal with pointer aligment. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(eth) + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast<char*>(buf + sizeof(eth) + sizeof(iphdr) + + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + +// Send via a packet socket. +TEST_P(RawPacketTest, Send) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + SKIP_IF(IsRunningOnGvisor()); + + // Let's send a UDP packet and receive it using a regular UDP socket. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + struct sockaddr_in bind_addr = {}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + bind_addr.sin_port = kPort; + ASSERT_THAT( + bind(udp_sock.get(), reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Set up the destination physical address. + struct sockaddr_ll dest = {}; + dest.sll_family = AF_PACKET; + dest.sll_halen = ETH_ALEN; + dest.sll_ifindex = GetLoopbackIndex(); + dest.sll_protocol = htons(ETH_P_IP); + // We're sending to the loopback device, so the address is all 0s. + memset(dest.sll_addr, 0x00, ETH_ALEN); + + // Set up the ethernet header. The kernel takes care of the footer. + // We're sending to and from hardware address 0 (loopback). + struct ethhdr eth = {}; + eth.h_proto = htons(ETH_P_IP); + + // Set up the IP header. + struct iphdr iphdr = {}; + iphdr.ihl = 5; + iphdr.version = 4; + iphdr.tos = 0; + iphdr.tot_len = + htons(sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage)); + // Get a pseudo-random ID. If we clash with an in-use ID the test will fail, + // but we have no way of getting an ID we know to be good. + srand(*reinterpret_cast<unsigned int*>(&iphdr)); + iphdr.id = rand(); + // Linux sets this bit ("do not fragment") for small packets. + iphdr.frag_off = 1 << 6; + iphdr.ttl = 64; + iphdr.protocol = IPPROTO_UDP; + iphdr.daddr = htonl(INADDR_LOOPBACK); + iphdr.saddr = htonl(INADDR_LOOPBACK); + iphdr.check = IPChecksum(iphdr); + + // Set up the UDP header. + struct udphdr udphdr = {}; + udphdr.source = kPort; + udphdr.dest = kPort; + udphdr.len = htons(sizeof(udphdr) + sizeof(kMessage)); + udphdr.check = UDPChecksum(iphdr, udphdr, kMessage, sizeof(kMessage)); + + // Copy both headers and the payload into our packet buffer. + char + send_buf[sizeof(eth) + sizeof(iphdr) + sizeof(udphdr) + sizeof(kMessage)]; + memcpy(send_buf, ð, sizeof(eth)); + memcpy(send_buf + sizeof(ethhdr), &iphdr, sizeof(iphdr)); + memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr), &udphdr, sizeof(udphdr)); + memcpy(send_buf + sizeof(ethhdr) + sizeof(iphdr) + sizeof(udphdr), kMessage, + sizeof(kMessage)); + + // Send it. + ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Wait for the packet to become available on both sockets. + struct pollfd pfd = {}; + pfd.fd = udp_sock.get(); + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + pfd.fd = socket_; + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); + + // Receive on the packet socket. + char recv_buf[sizeof(send_buf)]; + ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); + + // Receive on the UDP socket. + struct sockaddr_in src; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(sizeof(kMessage))); + // Check src and payload. + EXPECT_EQ(strncmp(recv_buf, kMessage, sizeof(kMessage)), 0); + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(src.sin_port, kPort); + EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); +} + +INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, + ::testing::Values(ETH_P_IP /*, ETH_P_ALL*/)); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc index 83b1ad4e4..33822ee57 100644 --- a/test/syscalls/linux/partial_bad_buffer.cc +++ b/test/syscalls/linux/partial_bad_buffer.cc @@ -14,13 +14,20 @@ #include <errno.h> #include <fcntl.h> +#include <netinet/in.h> +#include <netinet/tcp.h> #include <sys/mman.h> +#include <sys/socket.h> #include <sys/syscall.h> #include <sys/uio.h> #include <unistd.h> #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -299,6 +306,109 @@ TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) { EXPECT_STREQ(buf, kMessage); } +PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + addr.ss_family = family; + switch (family) { + case AF_INET: + reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr = + htonl(INADDR_LOOPBACK); + break; + case AF_INET6: + reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr = + in6addr_loopback; + break; + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } + return addr; +} + +// SendMsgTCP verifies that calling sendmsg with a bad address returns an +// EFAULT. It also verifies that passing a buffer which is made up of 2 +// pages one valid and one guard page succeeds as long as the write is +// for exactly the size of 1 page. +TEST_F(PartialBadBufferTest, SendMsgTCP) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + // TODO(gvisor.dev/issue/674): Update this once Netstack matches linux + // behaviour on a setsockopt of SO_RCVBUF/SO_SNDBUF. + // + // Set SO_SNDBUF for socket to exactly kPageSize+1. + // + // gVisor does not double the value passed in SO_SNDBUF like linux does so we + // just increase it by 1 byte here for gVisor so that we can test writing 1 + // byte past the valid page and check that it triggers an EFAULT + // correctly. Otherwise in gVisor the sendmsg call will just return with no + // error with kPageSize bytes written successfully. + const uint32_t buf_size = kPageSize + 1; + ASSERT_THAT(setsockopt(send_socket.get(), SOL_SOCKET, SO_SNDBUF, &buf_size, + sizeof(buf_size)), + SyscallSucceedsWithValue(0)); + + struct msghdr hdr = {}; + struct iovec iov = {}; + iov.iov_base = bad_buffer_; + iov.iov_len = kPageSize; + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallFailsWithErrno(EFAULT)); + + // Now assert that writing kPageSize from addr_ succeeds. + iov.iov_base = addr_; + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallSucceedsWithValue(kPageSize)); + // Read all the data out so that we drain the socket SND_BUF on the sender. + std::vector<char> buffer(kPageSize); + ASSERT_THAT(RetryEINTR(read)(recv_socket.get(), buffer.data(), kPageSize), + SyscallSucceedsWithValue(kPageSize)); + + // Sleep for a shortwhile to ensure that we have time to process the + // ACKs. This is not strictly required unless running under gotsan which is a + // lot slower and can result in the next write to write only 1 byte instead of + // our intended kPageSize + 1. + absl::SleepFor(absl::Milliseconds(50)); + + // Now assert that writing > kPageSize results in EFAULT. + iov.iov_len = kPageSize + 1; + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallFailsWithErrno(EFAULT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index b440ba0df..2b753b7d1 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1602,9 +1602,9 @@ class BlockingChild { } mutable absl::Mutex mu_; - bool stop_ GUARDED_BY(mu_) = false; + bool stop_ ABSL_GUARDED_BY(mu_) = false; pid_t tid_; - bool tid_ready_ GUARDED_BY(mu_) = false; + bool tid_ready_ ABSL_GUARDED_BY(mu_) = false; // Must be last to ensure that the destructor for the thread is run before // any other member of the object is destroyed. diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc index 578b20680..498f62d9c 100644 --- a/test/syscalls/linux/proc_net_tcp.cc +++ b/test/syscalls/linux/proc_net_tcp.cc @@ -187,9 +187,9 @@ TEST(ProcNetTCP, EntryUID) { std::vector<TCPEntry> entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); TCPEntry e; - EXPECT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr())); + ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr())); EXPECT_EQ(e.uid, geteuid()); - EXPECT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr())); + ASSERT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr())); EXPECT_EQ(e.uid, geteuid()); } @@ -249,7 +249,8 @@ TEST(ProcNetTCP, State) { std::unique_ptr<FileDescriptor> client = ASSERT_NO_ERRNO_AND_VALUE(IPv4TCPUnboundSocket(0).Create()); - ASSERT_THAT(connect(client->get(), &addr, addrlen), SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(client->get(), &addr, addrlen), + SyscallSucceeds()); entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries()); ASSERT_TRUE(FindByLocalAddr(entries, &listen_entry, &addr)); EXPECT_EQ(listen_entry.state, TCP_LISTEN); diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index d1ab4703f..bd6907876 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -13,13 +13,17 @@ // limitations under the License. #include <fcntl.h> +#include <linux/capability.h> #include <linux/major.h> #include <poll.h> +#include <sched.h> +#include <signal.h> #include <sys/ioctl.h> #include <sys/mman.h> #include <sys/stat.h> #include <sys/sysmacros.h> #include <sys/types.h> +#include <sys/wait.h> #include <termios.h> #include <unistd.h> @@ -31,8 +35,10 @@ #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" +#include "test/util/pty_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -370,25 +376,6 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, return PosixError(ETIMEDOUT, "Poll timed out"); } -// Opens the slave end of the passed master as R/W and nonblocking. -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { - // Get pty index. - int n; - int ret = ioctl(master.get(), TIOCGPTN, &n); - if (ret < 0) { - return PosixError(errno, "ioctl(TIOCGPTN) failed"); - } - - // Unlock pts. - int unlock = 0; - ret = ioctl(master.get(), TIOCSPTLCK, &unlock); - if (ret < 0) { - return PosixError(errno, "ioctl(TIOSPTLCK) failed"); - } - - return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK); -} - TEST(BasicPtyTest, StatUnopenedMaster) { struct stat s; ASSERT_THAT(stat("/dev/ptmx", &s), SyscallSucceeds()); @@ -1233,6 +1220,340 @@ TEST_F(PtyTest, SetMasterWindowSize) { EXPECT_EQ(retrieved_ws.ws_col, kCols); } +class JobControlTest : public ::testing::Test { + protected: + void SetUp() override { + master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); + slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + + // Make this a session leader, which also drops the controlling terminal. + // In the gVisor test environment, this test will be run as the session + // leader already (as the sentry init process). + if (!IsRunningOnGvisor()) { + ASSERT_THAT(setsid(), SyscallSucceeds()); + } + } + + // Master and slave ends of the PTY. Non-blocking. + FileDescriptor master_; + FileDescriptor slave_; +}; + +TEST_F(JobControlTest, SetTTYMaster) { + ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetTTYNonLeader) { + // Fork a process that won't be the session leader. + pid_t child = fork(); + if (!child) { + // We shouldn't be able to set the terminal. + TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +TEST_F(JobControlTest, SetTTYBadArg) { + // Despite the man page saying arg should be 0 here, Linux doesn't actually + // check. + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetTTYDifferentSession) { + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should fail. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(setsid() >= 0); + // We shouldn't be able to steal the terminal. + TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +TEST_F(JobControlTest, ReleaseTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Make sure we're ignoring SIGHUP, which will be sent to this process once we + // disconnect they TTY. + struct sigaction sa = { + .sa_handler = SIG_IGN, + .sa_flags = 0, + }; + sigemptyset(&sa.sa_mask); + struct sigaction old_sa; + EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds()); + EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); +} + +TEST_F(JobControlTest, ReleaseUnsetTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); +} + +TEST_F(JobControlTest, ReleaseWrongTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); +} + +TEST_F(JobControlTest, ReleaseTTYNonLeader) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + pid_t child = fork(); + if (!child) { + TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +TEST_F(JobControlTest, ReleaseTTYDifferentSession) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + pid_t child = fork(); + if (!child) { + // Join a new session, then try to disconnect. + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +// Used by the child process spawned in ReleaseTTYSignals to track received +// signals. +static int received; + +void sig_handler(int signum) { received |= signum; } + +// When the session leader releases its controlling terminal, the foreground +// process group gets SIGHUP, then SIGCONT. This test: +// - Spawns 2 threads +// - Has thread 1 return 0 if it gets both SIGHUP and SIGCONT +// - Has thread 2 leave the foreground process group, and return non-zero if it +// receives any signals. +// - Has the parent thread release its controlling terminal +// - Checks that thread 1 got both signals +// - Checks that thread 2 didn't get any signals. +TEST_F(JobControlTest, ReleaseTTYSignals) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + received = 0; + struct sigaction sa = { + .sa_handler = sig_handler, + .sa_flags = 0, + }; + sigemptyset(&sa.sa_mask); + sigaddset(&sa.sa_mask, SIGHUP); + sigaddset(&sa.sa_mask, SIGCONT); + sigprocmask(SIG_BLOCK, &sa.sa_mask, NULL); + + pid_t same_pgrp_child = fork(); + if (!same_pgrp_child) { + // The child will wait for SIGHUP and SIGCONT, then return 0. It begins with + // SIGHUP and SIGCONT blocked. We install signal handlers for those signals, + // then use sigsuspend to wait for those specific signals. + TEST_PCHECK(!sigaction(SIGHUP, &sa, NULL)); + TEST_PCHECK(!sigaction(SIGCONT, &sa, NULL)); + sigset_t mask; + sigfillset(&mask); + sigdelset(&mask, SIGHUP); + sigdelset(&mask, SIGCONT); + while (received != (SIGHUP | SIGCONT)) { + sigsuspend(&mask); + } + _exit(0); + } + + // We don't want to block these anymore. + sigprocmask(SIG_UNBLOCK, &sa.sa_mask, NULL); + + // This child will return non-zero if either SIGHUP or SIGCONT are received. + pid_t diff_pgrp_child = fork(); + if (!diff_pgrp_child) { + TEST_PCHECK(!setpgid(0, 0)); + TEST_PCHECK(pause()); + _exit(1); + } + + EXPECT_THAT(setpgid(diff_pgrp_child, diff_pgrp_child), SyscallSucceeds()); + + // Make sure we're ignoring SIGHUP, which will be sent to this process once we + // disconnect they TTY. + struct sigaction sighup_sa = { + .sa_handler = SIG_IGN, + .sa_flags = 0, + }; + sigemptyset(&sighup_sa.sa_mask); + struct sigaction old_sa; + EXPECT_THAT(sigaction(SIGHUP, &sighup_sa, &old_sa), SyscallSucceeds()); + + // Release the controlling terminal, sending SIGHUP and SIGCONT to all other + // processes in this process group. + EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + + EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); + + // The child in the same process group will get signaled. + int wstatus; + EXPECT_THAT(waitpid(same_pgrp_child, &wstatus, 0), + SyscallSucceedsWithValue(same_pgrp_child)); + EXPECT_EQ(wstatus, 0); + + // The other child will not get signaled. + EXPECT_THAT(waitpid(diff_pgrp_child, &wstatus, WNOHANG), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(kill(diff_pgrp_child, SIGKILL), SyscallSucceeds()); +} + +TEST_F(JobControlTest, GetForegroundProcessGroup) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + pid_t foreground_pgid; + pid_t pid; + ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + SyscallSucceeds()); + ASSERT_THAT(pid = getpid(), SyscallSucceeds()); + + ASSERT_EQ(foreground_pgid, pid); +} + +TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { + // At this point there's no controlling terminal, so TIOCGPGRP should fail. + pid_t foreground_pgid; + ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + SyscallFailsWithErrno(ENOTTY)); +} + +// This test: +// - sets itself as the foreground process group +// - creates a child process in a new process group +// - sets that child as the foreground process group +// - kills its child and sets itself as the foreground process group. +TEST_F(JobControlTest, SetForegroundProcessGroup) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. + struct sigaction sa = { + .sa_handler = SIG_IGN, + .sa_flags = 0, + }; + sigemptyset(&sa.sa_mask); + sigaction(SIGTTOU, &sa, NULL); + + // Set ourself as the foreground process group. + ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds()); + + // Create a new process that just waits to be signaled. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(!pause()); + // We should never reach this. + _exit(1); + } + + // Make the child its own process group, then make it the controlling process + // group of the terminal. + ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); + ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds()); + + // Sanity check - we're still the controlling session. + ASSERT_EQ(getsid(0), getsid(child)); + + // Signal the child, wait for it to exit, then retake the terminal. + ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds()); + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_TRUE(WIFSIGNALED(wstatus)); + ASSERT_EQ(WTERMSIG(wstatus), SIGTERM); + + // Set ourself as the foreground process. + pid_t pgid; + ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds()); + ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { + pid_t pid = getpid(); + ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + SyscallFailsWithErrno(ENOTTY)); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + pid_t pid = -1; + ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Create a new process, put it in a new process group, make that group the + // foreground process group, then have the process wait. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(!setpgid(0, 0)); + _exit(0); + } + + // Wait for the child to exit. + int wstatus; + EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + // The child's process group doesn't exist anymore - this should fail. + ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), + SyscallFailsWithErrno(ESRCH)); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Create a new process and put it in a new session. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(setsid() >= 0); + // Tell the parent we're in a new session. + TEST_PCHECK(!raise(SIGSTOP)); + TEST_PCHECK(!pause()); + _exit(1); + } + + // Wait for the child to tell us it's in a new session. + int wstatus; + EXPECT_THAT(waitpid(child, &wstatus, WUNTRACED), + SyscallSucceedsWithValue(child)); + EXPECT_TRUE(WSTOPSIG(wstatus)); + + // Child is in a new session, so we can't make it the foregroup process group. + EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), + SyscallFailsWithErrno(EPERM)); + + EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds()); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc new file mode 100644 index 000000000..d2a321a6e --- /dev/null +++ b/test/syscalls/linux/pty_root.cc @@ -0,0 +1,68 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/ioctl.h> +#include <termios.h> + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/pty_util.h" + +namespace gvisor { +namespace testing { + +// These tests should be run as root. +namespace { + +TEST(JobControlRootTest, StealTTY) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + // Make this a session leader, which also drops the controlling terminal. + // In the gVisor test environment, this test will be run as the session + // leader already (as the sentry init process). + if (!IsRunningOnGvisor()) { + ASSERT_THAT(setsid(), SyscallSucceeds()); + } + + FileDescriptor master = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); + FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + + // Make slave the controlling terminal. + ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg + // of 1. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(setsid() >= 0); + // We shouldn't be able to steal the terminal with the wrong arg value. + TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); + // We should be able to steal it here. + TEST_PCHECK(!ioctl(slave.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +} // namespace +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc index db519f4e0..f6a0fc96c 100644 --- a/test/syscalls/linux/pwritev2.cc +++ b/test/syscalls/linux/pwritev2.cc @@ -244,8 +244,10 @@ TEST(Pwritev2Test, TestInvalidOffset) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + char buf[16]; struct iovec iov; - iov.iov_base = nullptr; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/static_cast<off_t>(-8), /*flags=*/0), @@ -286,8 +288,10 @@ TEST(Pwritev2Test, TestUnseekableFileInValid) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); int pipe_fds[2]; + char buf[16]; struct iovec iov; - iov.iov_base = nullptr; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); @@ -307,8 +311,10 @@ TEST(Pwritev2Test, TestReadOnlyFile) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + char buf[16]; struct iovec iov; - iov.iov_base = nullptr; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/0, /*flags=*/0), @@ -324,8 +330,10 @@ TEST(Pwritev2Test, TestInvalidFlag) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR | O_DIRECT)); + char buf[16]; struct iovec iov; - iov.iov_base = nullptr; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/0, /*flags=*/0xF0), diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index ad19120d5..971592d7d 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -35,32 +35,6 @@ namespace testing { namespace { -// Compute the internet checksum of the ICMP header (assuming no payload). -static uint16_t Checksum(struct icmphdr* icmp) { - uint32_t total = 0; - uint16_t* num = reinterpret_cast<uint16_t*>(icmp); - - // This is just the ICMP header, so there's an even number of bytes. - static_assert( - sizeof(*icmp) % sizeof(*num) == 0, - "sizeof(struct icmphdr) is not an integer multiple of sizeof(uint16_t)"); - for (unsigned int i = 0; i < sizeof(*icmp); i += sizeof(*num)) { - total += *num; - num++; - } - - // Combine the upper and lower 16 bits. This happens twice in case the first - // combination causes a carry. - unsigned short upper = total >> 16; - unsigned short lower = total & 0xffff; - total = upper + lower; - upper = total >> 16; - lower = total & 0xffff; - total = upper + lower; - - return ~total; -} - // The size of an empty ICMP packet and IP header together. constexpr size_t kEmptyICMPSize = 28; @@ -164,7 +138,7 @@ TEST_F(RawSocketICMPTest, SendAndReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2012; icmp.un.echo.id = 2014; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); @@ -187,7 +161,7 @@ TEST_F(RawSocketICMPTest, MultipleSocketReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2016; icmp.un.echo.id = 2018; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); // Both sockets will receive the echo request and reply in indeterminate @@ -297,7 +271,7 @@ TEST_F(RawSocketICMPTest, ShortEchoRawAndPingSockets) { icmp.un.echo.sequence = 0; icmp.un.echo.id = 6789; icmp.checksum = 0; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); // Omit 2 bytes from ICMP packet. constexpr int kShortICMPSize = sizeof(icmp) - 2; @@ -338,7 +312,7 @@ TEST_F(RawSocketICMPTest, ShortEchoReplyRawAndPingSockets) { icmp.un.echo.sequence = 0; icmp.un.echo.id = 6789; icmp.checksum = 0; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); // Omit 2 bytes from ICMP packet. constexpr int kShortICMPSize = sizeof(icmp) - 2; @@ -381,7 +355,7 @@ TEST_F(RawSocketICMPTest, SendAndReceiveViaConnect) { icmp.checksum = 0; icmp.un.echo.sequence = 2003; icmp.un.echo.id = 2004; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_THAT(send(s_, &icmp, sizeof(icmp), 0), SyscallSucceedsWithValue(sizeof(icmp))); @@ -405,7 +379,7 @@ TEST_F(RawSocketICMPTest, BindSendAndReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2004; icmp.un.echo.id = 2007; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); @@ -431,7 +405,7 @@ TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) { icmp.checksum = 0; icmp.un.echo.sequence = 2010; icmp.un.echo.id = 7; - icmp.checksum = Checksum(&icmp); + icmp.checksum = ICMPChecksum(icmp, NULL, 0); ASSERT_NO_FATAL_FAILURE(SendEmptyICMP(icmp)); ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); @@ -470,9 +444,8 @@ void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) { EXPECT_EQ(recvd_icmp->un.echo.id, icmp.un.echo.id); // A couple are different. EXPECT_EQ(recvd_icmp->type, ICMP_ECHOREPLY); - // The checksum is computed in such a way that it is guaranteed to have - // changed. - EXPECT_NE(recvd_icmp->checksum, icmp.checksum); + // The checksum computed over the reply should still be valid. + EXPECT_EQ(ICMPChecksum(*recvd_icmp, NULL, 0), 0); break; } } diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index e5d72e28a..9167ab066 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -299,10 +299,30 @@ TEST(SendFileTest, DoNotSendfileIfOutfileIsAppendOnly) { // Open the output file as append only. const FileDescriptor outf = - ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_APPEND)); + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY | O_APPEND)); // Send data and verify that sendfile returns the correct errno. EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(SendFileTest, AppendCheckOrdering) { + constexpr char kData[] = "And by opposing end them: to die, to sleep"; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + const FileDescriptor read = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + const FileDescriptor write = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + const FileDescriptor append = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_APPEND)); + + // Check that read/write file mode is verified before append. + EXPECT_THAT(sendfile(append.get(), read.get(), nullptr, kDataSize), + SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(sendfile(write.get(), write.get(), nullptr, kDataSize), SyscallFailsWithErrno(EBADF)); } diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index 0404190a0..caae215b8 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -30,12 +30,25 @@ TEST(SocketTest, UnixSocketPairProtocol) { close(socks[1]); } -TEST(SocketTest, Protocol) { +TEST(SocketTest, ProtocolUnix) { struct { int domain, type, protocol; } tests[] = { - {AF_UNIX, SOCK_STREAM, PF_UNIX}, {AF_UNIX, SOCK_SEQPACKET, PF_UNIX}, - {AF_UNIX, SOCK_DGRAM, PF_UNIX}, {AF_INET, SOCK_DGRAM, IPPROTO_UDP}, + {AF_UNIX, SOCK_STREAM, PF_UNIX}, + {AF_UNIX, SOCK_SEQPACKET, PF_UNIX}, + {AF_UNIX, SOCK_DGRAM, PF_UNIX}, + }; + for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { + ASSERT_NO_ERRNO_AND_VALUE( + Socket(tests[i].domain, tests[i].type, tests[i].protocol)); + } +} + +TEST(SocketTest, ProtocolInet) { + struct { + int domain, type, protocol; + } tests[] = { + {AF_INET, SOCK_DGRAM, IPPROTO_UDP}, {AF_INET, SOCK_STREAM, IPPROTO_TCP}, }; for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index df31d25b5..322ee07ad 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -145,6 +145,67 @@ TEST_P(SocketInetLoopbackTest, TCP) { ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds()); } +TEST_P(SocketInetLoopbackTest, TCPListenClose) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), 1001), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + DisableSave ds; // Too many system calls. + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + constexpr int kFDs = 2048; + constexpr int kThreadCount = 4; + constexpr int kFDsPerThread = kFDs / kThreadCount; + FileDescriptor clients[kFDs]; + std::unique_ptr<ScopedThread> threads[kThreadCount]; + for (int i = 0; i < kFDs; i++) { + clients[i] = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + } + for (int i = 0; i < kThreadCount; i++) { + threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr, + &clients, i]() { + for (int j = 0; j < kFDsPerThread; j++) { + int k = i * kFDsPerThread + j; + int ret = + connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + } + }); + } + for (int i = 0; i < kThreadCount; i++) { + threads[i]->Join(); + } + for (int i = 0; i < 32; i++) { + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + } + // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked + // before function end. + // ds.reset() +} + TEST_P(SocketInetLoopbackTest, TCPbacklog) { auto const& param = GetParam(); diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index d9aa7ff3f..67d29af0a 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -30,6 +30,7 @@ namespace gvisor { namespace testing { constexpr char kMulticastAddress[] = "224.0.2.1"; +constexpr char kBroadcastAddress[] = "255.255.255.255"; TestAddress V4Multicast() { TestAddress t("V4Multicast"); @@ -40,6 +41,15 @@ TestAddress V4Multicast() { return t; } +TestAddress V4Broadcast() { + TestAddress t("V4Broadcast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + inet_addr(kBroadcastAddress); + return t; +} + // Check that packets are not received without a group membership. Default send // interface configured by bind. TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) { @@ -1426,5 +1436,249 @@ TEST_P(IPv4UDPUnboundSocketPairTest, } } +// Check that a receiving socket can bind to the multicast address before +// joining the group and receive data once the group has been joined. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Register to receive multicast packets. + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the multicast address and won't +// receive multicast data if it hasn't joined the group. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the multicast address. + auto receiver_addr = V4Multicast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + // Update receiver_addr with the correct port number. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a multicast packet on the first socket out the loopback interface. + ip_mreq iface = {}; + iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, + &iface, sizeof(iface)), + SyscallSucceeds()); + auto sendto_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we don't receive the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Check that a socket can bind to a multicast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the multicast address. + auto sender_addr = V4Multicast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a receiving socket can bind to the broadcast address and receive +// broadcast packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the broadcast address. + auto receiver_addr = V4Broadcast(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Send a broadcast packet on the first socket out the loopback interface. + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + // Note: Binding to the loopback interface makes the broadcast go out of it. + auto sender_bind_addr = V4Loopback(); + ASSERT_THAT(bind(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), + sender_bind_addr.addr_len), + SyscallSucceeds()); + auto sendto_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + +// Check that a socket can bind to the broadcast address and still send out +// packets. +TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Bind second socket (receiver) to the ANY address. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(sockets->second_fd(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Bind the first socket (sender) to the broadcast address. + auto sender_addr = V4Broadcast(); + ASSERT_THAT( + bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(sockets->first_fd(), + reinterpret_cast<sockaddr*>(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + // Send a packet on the first socket to the loopback address. + auto sendto_addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the packet. + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 6a5fa8965..765f8e0e4 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -89,7 +89,8 @@ TEST(NetdeviceTest, Netmask) { // (i.e. netmask) for the loopback device. int prefixlen = -1; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr *hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr *hdr) { EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE))); EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) @@ -111,7 +112,8 @@ TEST(NetdeviceTest, Netmask) { ifaddrmsg->ifa_family == AF_INET) { prefixlen = ifaddrmsg->ifa_prefixlen; } - })); + }, + false)); ASSERT_GE(prefixlen, 0); diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index b5c38f27e..32fe0d6d1 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <arpa/inet.h> #include <ifaddrs.h> #include <linux/netlink.h> #include <linux/rtnetlink.h> @@ -237,7 +238,8 @@ TEST(NetlinkRouteTest, GetLinkDump) { // Loopback is common among all tests, check that it's found. bool loopbackFound = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { CheckGetLinkResponse(hdr, kSeq, port); if (hdr->nlmsg_type != RTM_NEWLINK) { return; @@ -251,10 +253,44 @@ TEST(NetlinkRouteTest, GetLinkDump) { loopbackFound = true; EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); } - })); + }, + false)); EXPECT_TRUE(loopbackFound); } +TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + // If type & 0x3 is equal to 0x2, this means a get request + // which doesn't require CAP_SYS_ADMIN. + req.hdr.nlmsg_type = ((__RTM_MAX + 1024) & (~0x3)) | 0x2; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { + EXPECT_THAT(hdr->nlmsg_type, Eq(NLMSG_ERROR)); + EXPECT_EQ(hdr->nlmsg_seq, kSeq); + EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); + + const struct nlmsgerr* msg = + reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr)); + EXPECT_EQ(msg->error, -EOPNOTSUPP); + }, + true)); +} + TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); @@ -363,9 +399,11 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { req.ifm.ifi_family = AF_UNSPEC; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { CheckGetLinkResponse(hdr, kSeq, port); - })); + }, + false)); } TEST(NetlinkRouteTest, GetAddrDump) { @@ -387,7 +425,8 @@ TEST(NetlinkRouteTest, GetAddrDump) { req.rgm.rtgen_family = AF_UNSPEC; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE))); EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) @@ -404,7 +443,8 @@ TEST(NetlinkRouteTest, GetAddrDump) { EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct ifaddrmsg)); // TODO(mpratt): Check ifaddrmsg contents and following attrs. - })); + }, + false)); } TEST(NetlinkRouteTest, LookupAll) { @@ -425,6 +465,80 @@ TEST(NetlinkRouteTest, LookupAll) { ASSERT_GT(count, 0); } +// GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. +TEST(NetlinkRouteTest, GetRouteDump) { + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); + + struct request { + struct nlmsghdr hdr; + struct rtmsg rtm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETROUTE; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rtm.rtm_family = AF_UNSPEC; + + bool routeFound = false; + bool dstFound = true; + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { + // Validate the reponse to RTM_GETROUTE + NLM_F_DUMP. + EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWROUTE), Eq(NLMSG_DONE))); + + EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) + << std::hex << hdr->nlmsg_flags; + + EXPECT_EQ(hdr->nlmsg_seq, kSeq); + EXPECT_EQ(hdr->nlmsg_pid, port); + + // The test should not proceed if it's not a RTM_NEWROUTE message. + if (hdr->nlmsg_type != RTM_NEWROUTE) { + return; + } + + // RTM_NEWROUTE contains at least the header and rtmsg. + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct rtmsg))); + const struct rtmsg* msg = + reinterpret_cast<const struct rtmsg*>(NLMSG_DATA(hdr)); + // NOTE: rtmsg fields are char fields. + std::cout << "Found route table=" << static_cast<int>(msg->rtm_table) + << ", protocol=" << static_cast<int>(msg->rtm_protocol) + << ", scope=" << static_cast<int>(msg->rtm_scope) + << ", type=" << static_cast<int>(msg->rtm_type); + + int len = RTM_PAYLOAD(hdr); + bool rtDstFound = false; + for (struct rtattr* attr = RTM_RTA(msg); RTA_OK(attr, len); + attr = RTA_NEXT(attr, len)) { + if (attr->rta_type == RTA_DST) { + char address[INET_ADDRSTRLEN] = {}; + inet_ntop(AF_INET, RTA_DATA(attr), address, sizeof(address)); + std::cout << ", dst=" << address; + rtDstFound = true; + } + } + + std::cout << std::endl; + + if (msg->rtm_table == RT_TABLE_MAIN) { + routeFound = true; + dstFound = rtDstFound && dstFound; + } + }, + false)); + // At least one route found in main route table. + EXPECT_TRUE(routeFound); + // Found RTA_DST for each route in main table. + EXPECT_TRUE(dstFound); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index 728d25434..36b6560c2 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -54,7 +54,8 @@ PosixErrorOr<uint32_t> NetlinkPortID(int fd) { PosixError NetlinkRequestResponse( const FileDescriptor& fd, void* request, size_t len, - const std::function<void(const struct nlmsghdr* hdr)>& fn) { + const std::function<void(const struct nlmsghdr* hdr)>& fn, + bool expect_nlmsgerr) { struct iovec iov = {}; iov.iov_base = request; iov.iov_len = len; @@ -93,7 +94,11 @@ PosixError NetlinkRequestResponse( } } while (type != NLMSG_DONE && type != NLMSG_ERROR); - EXPECT_EQ(type, NLMSG_DONE); + if (expect_nlmsgerr) { + EXPECT_EQ(type, NLMSG_ERROR); + } else { + EXPECT_EQ(type, NLMSG_DONE); + } return NoError(); } diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index bea449107..db8639a2f 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -34,7 +34,8 @@ PosixErrorOr<uint32_t> NetlinkPortID(int fd); // Send the passed request and call fn will all response netlink messages. PosixError NetlinkRequestResponse( const FileDescriptor& fd, void* request, size_t len, - const std::function<void(const struct nlmsghdr* hdr)>& fn); + const std::function<void(const struct nlmsghdr* hdr)>& fn, + bool expect_nlmsgerr); } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 4f65cf5ae..eff7d577e 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -588,8 +588,9 @@ ssize_t SendLargeSendMsg(const std::unique_ptr<SocketPair>& sockets, return RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0); } -PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type, - bool reuse_addr) { +namespace internal { +PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, + SocketType type, bool reuse_addr) { if (port < 0) { return PosixError(EINVAL, "Invalid port"); } @@ -664,10 +665,7 @@ PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type, return available_port; } - -PosixError FreeAvailablePort(int port) { - return NoError(); -} +} // namespace internal PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) { struct iovec iov; @@ -744,5 +742,74 @@ TestAddress V6Loopback() { return t; } +// Checksum computes the internet checksum of a buffer. +uint16_t Checksum(uint16_t* buf, ssize_t buf_size) { + // Add up the 16-bit values in the buffer. + uint32_t total = 0; + for (unsigned int i = 0; i < buf_size; i += sizeof(*buf)) { + total += *buf; + buf++; + } + + // If buf has an odd size, add the remaining byte. + if (buf_size % 2) { + total += *(reinterpret_cast<unsigned char*>(buf) - 1); + } + + // This carries any bits past the lower 16 until everything fits in 16 bits. + while (total >> 16) { + uint16_t lower = total & 0xffff; + uint16_t upper = total >> 16; + total = lower + upper; + } + + return ~total; +} + +uint16_t IPChecksum(struct iphdr ip) { + return Checksum(reinterpret_cast<uint16_t*>(&ip), sizeof(ip)); +} + +// The pseudo-header defined in RFC 768 for calculating the UDP checksum. +struct udp_pseudo_hdr { + uint32_t srcip; + uint32_t destip; + char zero; + char protocol; + uint16_t udplen; +}; + +uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr, + const char* payload, ssize_t payload_len) { + struct udp_pseudo_hdr phdr = {}; + phdr.srcip = iphdr.saddr; + phdr.destip = iphdr.daddr; + phdr.zero = 0; + phdr.protocol = IPPROTO_UDP; + phdr.udplen = udphdr.len; + + ssize_t buf_size = sizeof(phdr) + sizeof(udphdr) + payload_len; + char* buf = static_cast<char*>(malloc(buf_size)); + memcpy(buf, &phdr, sizeof(phdr)); + memcpy(buf + sizeof(phdr), &udphdr, sizeof(udphdr)); + memcpy(buf + sizeof(phdr) + sizeof(udphdr), payload, payload_len); + + uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size); + free(buf); + return csum; +} + +uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, + ssize_t payload_len) { + ssize_t buf_size = sizeof(icmphdr) + payload_len; + char* buf = static_cast<char*>(malloc(buf_size)); + memcpy(buf, &icmphdr, sizeof(icmphdr)); + memcpy(buf + sizeof(icmphdr), payload, payload_len); + + uint16_t csum = Checksum(reinterpret_cast<uint16_t*>(buf), buf_size); + free(buf); + return csum; +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 4fd59767a..6efa8055f 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -17,9 +17,12 @@ #include <errno.h> #include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <netinet/udp.h> #include <sys/socket.h> #include <sys/types.h> #include <sys/un.h> + #include <functional> #include <memory> #include <string> @@ -478,6 +481,22 @@ TestAddress V4MappedLoopback(); TestAddress V6Any(); TestAddress V6Loopback(); +// Compute the internet checksum of an IP header. +uint16_t IPChecksum(struct iphdr ip); + +// Compute the internet checksum of a UDP header. +uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr, + const char* payload, ssize_t payload_len); + +// Compute the internet checksum of an ICMP header. +uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, + ssize_t payload_len); + +namespace internal { +PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, + SocketType type, bool reuse_addr); +} // namespace internal + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util_impl.cc b/test/syscalls/linux/socket_test_util_impl.cc new file mode 100644 index 000000000..ef661a0e3 --- /dev/null +++ b/test/syscalls/linux/socket_test_util_impl.cc @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type, + bool reuse_addr) { + return internal::TryPortAvailable(port, family, type, reuse_addr); +} + +PosixError FreeAvailablePort(int port) { return NoError(); } + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index 1875f4533..e25f264f6 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <sys/eventfd.h> #include <sys/sendfile.h> #include <unistd.h> @@ -135,6 +136,80 @@ TEST(SpliceTest, PipeOffsets) { SyscallFailsWithErrno(ESPIPE)); } +// Event FDs may be used with splice without an offset. +TEST(SpliceTest, FromEventFD) { + // Open the input eventfd with an initial value so that it is readable. + constexpr uint64_t kEventFDValue = 1; + int efd; + ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds()); + const FileDescriptor inf(efd); + + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Splice 8-byte eventfd value to pipe. + constexpr int kEventFDSize = 8; + EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), + SyscallSucceedsWithValue(kEventFDSize)); + + // Contents should be equal. + std::vector<char> rbuf(kEventFDSize); + ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(kEventFDSize)); + EXPECT_EQ(memcmp(rbuf.data(), &kEventFDValue, rbuf.size()), 0); +} + +// Event FDs may not be used with splice with an offset. +TEST(SpliceTest, FromEventFDOffset) { + int efd; + ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); + const FileDescriptor inf(efd); + + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Attempt to splice 8-byte eventfd value to pipe with offset. + // + // This is not allowed because eventfd doesn't support pread. + constexpr int kEventFDSize = 8; + loff_t in_off = 0; + EXPECT_THAT(splice(inf.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), + SyscallFailsWithErrno(EINVAL)); +} + +// Event FDs may not be used with splice with an offset. +TEST(SpliceTest, ToEventFDOffset) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill with a value. + constexpr int kEventFDSize = 8; + std::vector<char> buf(kEventFDSize); + buf[0] = 1; + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kEventFDSize)); + + int efd; + ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); + const FileDescriptor outf(efd); + + // Attempt to splice 8-byte eventfd value to pipe with offset. + // + // This is not allowed because eventfd doesn't support pwrite. + loff_t out_off = 0; + EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_off, kEventFDSize, 0), + SyscallFailsWithErrno(EINVAL)); +} + TEST(SpliceTest, ToPipe) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 1bb0307c4..111dbacdf 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -39,7 +39,7 @@ constexpr int TestPort = 40000; // Fixture for tests parameterized by the address family to use (AF_INET and // AF_INET6) when creating sockets. -class UdpSocketTest : public ::testing::TestWithParam<int> { +class UdpSocketTest : public ::testing::TestWithParam<AddressFamily> { protected: // Creates two sockets that will be used by test cases. void SetUp() override; @@ -97,31 +97,32 @@ uint16_t* Port(struct sockaddr_storage* addr) { } void UdpSocketTest::SetUp() { - ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP), - SyscallSucceeds()); + int type; + if (GetParam() == AddressFamily::kIpv4) { + type = AF_INET; + auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_); + addrlen_ = sizeof(*sin); + sin->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + type = AF_INET6; + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_); + addrlen_ = sizeof(*sin6); + if (GetParam() == AddressFamily::kIpv6) { + sin6->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + sin6->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - ASSERT_THAT(t_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP), - SyscallSucceeds()); + ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_)); anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_); - anyaddr_->sa_family = GetParam(); - - // Initialize address-family-specific values. - switch (GetParam()) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin6); - sin6->sin6_addr = in6addr_any; - break; - } - } + anyaddr_->sa_family = type; if (gvisor::testing::IsRunningOnGvisor()) { for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { @@ -154,9 +155,9 @@ void UdpSocketTest::SetUp() { memset(&addr_storage_[i], 0, sizeof(addr_storage_[i])); addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]); - addr_[i]->sa_family = GetParam(); + addr_[i]->sa_family = type; - switch (GetParam()) { + switch (type) { case AF_INET: { auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]); sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); @@ -174,17 +175,20 @@ void UdpSocketTest::SetUp() { } TEST_P(UdpSocketTest, Creation) { + int type = AF_INET6; + if (GetParam() == AddressFamily::kIpv4) { + type = AF_INET; + } + int s_; - ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, IPPROTO_UDP), - SyscallSucceeds()); + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); EXPECT_THAT(close(s_), SyscallSucceeds()); - ASSERT_THAT(s_ = socket(GetParam(), SOCK_DGRAM, 0), SyscallSucceeds()); + ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds()); EXPECT_THAT(close(s_), SyscallSucceeds()); - ASSERT_THAT(s_ = socket(GetParam(), SOCK_STREAM, IPPROTO_UDP), - SyscallFails()); + ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails()); } TEST_P(UdpSocketTest, Getsockname) { @@ -374,6 +378,178 @@ TEST_P(UdpSocketTest, Connect) { EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); } +void ConnectAny(AddressFamily family, int sockfd, uint16_t port) { + struct sockaddr_storage addr = {}; + + // Precondition check. + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr any = IN6ADDR_ANY_INIT; + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0); + } + + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } + + struct sockaddr_storage baddr = {}; + if (family == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin_family = AF_INET; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + addr_in->sin_port = port; + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + if (family == AddressFamily::kIpv6) { + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + addr_in->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + + // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. + if (port == 0) { + SKIP_IF(IsRunningOnGvisor()); + } + + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), + SyscallSucceeds()); + } + + // Postcondition check. + { + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback; + if (family == AddressFamily::kIpv6) { + loopback = IN6ADDR_LOOPBACK_INIT; + } else { + TestAddress const& v4_mapped_loopback = V4MappedLoopback(); + loopback = reinterpret_cast<const struct sockaddr_in6*>( + &v4_mapped_loopback.addr) + ->sin6_addr; + } + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } + + addrlen = sizeof(addr); + if (port == 0) { + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + } else { + EXPECT_THAT( + getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + } + } +} + +TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); } + +TEST_P(UdpSocketTest, ConnectAnyWithPort) { + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + ConnectAny(GetParam(), s_, port); +} + +void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) { + struct sockaddr_storage addr = {}; + + socklen_t addrlen = sizeof(addr); + struct sockaddr_storage baddr = {}; + if (family == AddressFamily::kIpv4) { + auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin_family = AF_INET; + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); + addr_in->sin_port = port; + } else { + auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); + addrlen = sizeof(*addr_in); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + if (family == AddressFamily::kIpv6) { + addr_in->sin6_addr = IN6ADDR_ANY_INIT; + } else { + TestAddress const& v4_mapped_any = V4MappedAny(); + addr_in->sin6_addr = + reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) + ->sin6_addr; + } + } + + // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. + if (port == 0) { + SKIP_IF(IsRunningOnGvisor()); + } + + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), + SyscallSucceeds()); + // Now the socket is bound to the loopback address. + + // Disconnect + addrlen = sizeof(addr); + addr.ss_family = AF_UNSPEC; + ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + // Check that after disconnect the socket is bound to the ANY address. + EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + if (family == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback = IN6ADDR_ANY_INIT; + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { + DisconnectAfterConnectAny(GetParam(), s_, 0); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { + auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); + DisconnectAfterConnectAny(GetParam(), s_, port); +} + TEST_P(UdpSocketTest, DisconnectAfterBind) { ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); // Connect the socket. @@ -402,19 +578,17 @@ TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { struct sockaddr_storage baddr = {}; socklen_t addrlen; auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - if (addr_[0]->sa_family == AF_INET) { + if (GetParam() == AddressFamily::kIpv4) { auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); addr_in->sin_family = AF_INET; addr_in->sin_port = port; - inet_pton(AF_INET, "0.0.0.0", - reinterpret_cast<void*>(&addr_in->sin_addr.s_addr)); + addr_in->sin_addr.s_addr = htonl(INADDR_ANY); } else { auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); addr_in->sin6_family = AF_INET6; addr_in->sin6_port = port; - inet_pton(AF_INET6, - "::", reinterpret_cast<void*>(&addr_in->sin6_addr.s6_addr)); addr_in->sin6_scope_id = 0; + addr_in->sin6_addr = IN6ADDR_ANY_INIT; } ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), SyscallSucceeds()); @@ -1165,7 +1339,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { } INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, - ::testing::Values(AF_INET, AF_INET6)); + ::testing::Values(AddressFamily::kIpv4, + AddressFamily::kIpv6, + AddressFamily::kDualStack)); } // namespace diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go index 5936d66ff..32408f021 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/syscalls/syscall_test_runner.go @@ -23,11 +23,13 @@ import ( "math" "os" "os/exec" + "os/signal" "path/filepath" "strconv" "strings" "syscall" "testing" + "time" specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" @@ -189,6 +191,8 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { "-log-format=text", "-TESTONLY-unsafe-nonroot=true", "-net-raw=true", + fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM), + "-watchdog-action=panic", } if *overlay { args = append(args, "-overlay") @@ -220,8 +224,8 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Current process doesn't have CAP_SYS_ADMIN, create user namespace and run // as root inside that namespace to get it. - args = append(args, "run", "--bundle", bundleDir, id) - cmd := exec.Command(*runscPath, args...) + rArgs := append(args, "run", "--bundle", bundleDir, id) + cmd := exec.Command(*runscPath, rArgs...) cmd.SysProcAttr = &syscall.SysProcAttr{ Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS, // Set current user/group as root inside the namespace. @@ -239,9 +243,45 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { } cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGTERM) + go func() { + s, ok := <-sig + if !ok { + return + } + t.Errorf("%s: Got signal: %v", tc.FullName(), s) + done := make(chan bool) + go func() { + dArgs := append(args, "-alsologtostderr=true", "debug", "--stacks", id) + cmd := exec.Command(*runscPath, dArgs...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Run() + done <- true + }() + + timeout := time.Tick(3 * time.Second) + select { + case <-timeout: + t.Logf("runsc debug --stacks is timeouted") + case <-done: + } + + t.Logf("Send SIGTERM to the sandbox process") + dArgs := append(args, "debug", + fmt.Sprintf("--signal=%d", syscall.SIGTERM), + id) + cmd = exec.Command(*runscPath, dArgs...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Run() + }() if err = cmd.Run(); err != nil { t.Errorf("test %q exited with status %v, want 0", tc.FullName(), err) } + signal.Stop(sig) + close(sig) } // filterEnv returns an environment with the blacklisted variables removed. @@ -277,7 +317,7 @@ func main() { fatalf("test-name flag must be provided") } - log.SetLevel(log.Warning) + log.SetLevel(log.Info) if *debug { log.SetLevel(log.Debug) } diff --git a/test/util/BUILD b/test/util/BUILD index a1b9ff526..cfea029b2 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,3 +1,5 @@ +load("//test/syscalls:build_defs.bzl", "select_for_linux") + package( default_visibility = ["//:sandbox"], licenses = ["notice"], @@ -139,7 +141,11 @@ cc_library( cc_library( name = "save_util", testonly = 1, - srcs = ["save_util.cc"], + srcs = ["save_util.cc"] + + select_for_linux( + ["save_util_linux.cc"], + ["save_util_other.cc"], + ), hdrs = ["save_util.h"], ) @@ -184,6 +190,17 @@ cc_test( ) cc_library( + name = "pty_util", + testonly = 1, + srcs = ["pty_util.cc"], + hdrs = ["pty_util.h"], + deps = [ + ":file_descriptor", + ":posix_error", + ], +) + +cc_library( name = "signal_util", testonly = 1, srcs = ["signal_util.cc"], diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc new file mode 100644 index 000000000..c0fd9a095 --- /dev/null +++ b/test/util/pty_util.cc @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/pty_util.h" + +#include <sys/ioctl.h> +#include <termios.h> + +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { + // Get pty index. + int n; + int ret = ioctl(master.get(), TIOCGPTN, &n); + if (ret < 0) { + return PosixError(errno, "ioctl(TIOCGPTN) failed"); + } + + // Unlock pts. + int unlock = 0; + ret = ioctl(master.get(), TIOCSPTLCK, &unlock); + if (ret < 0) { + return PosixError(errno, "ioctl(TIOSPTLCK) failed"); + } + + return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/pty_util.h b/test/util/pty_util.h new file mode 100644 index 000000000..367b14f15 --- /dev/null +++ b/test/util/pty_util.h @@ -0,0 +1,30 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_UTIL_PTY_UTIL_H_ +#define GVISOR_TEST_UTIL_PTY_UTIL_H_ + +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +// Opens the slave end of the passed master as R/W and nonblocking. +PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_UTIL_PTY_UTIL_H_ diff --git a/test/util/save_util.cc b/test/util/save_util.cc index 05f52b80d..384d626f0 100644 --- a/test/util/save_util.cc +++ b/test/util/save_util.cc @@ -16,8 +16,8 @@ #include <stddef.h> #include <stdlib.h> -#include <sys/syscall.h> #include <unistd.h> + #include <atomic> #include <cerrno> @@ -61,13 +61,11 @@ void DisableSave::reset() { } } -void MaybeSave() { - if (CooperativeSaveEnabled() && !save_disable.load()) { - int orig_errno = errno; - syscall(SYS_create_module, nullptr, 0); - errno = orig_errno; - } +namespace internal { +bool ShouldSave() { + return CooperativeSaveEnabled() && (save_disable.load() == 0); } +} // namespace internal } // namespace testing } // namespace gvisor diff --git a/test/util/save_util.h b/test/util/save_util.h index 90460701e..bddad6120 100644 --- a/test/util/save_util.h +++ b/test/util/save_util.h @@ -41,6 +41,11 @@ class DisableSave { // // errno is guaranteed to be preserved. void MaybeSave(); + +namespace internal { +bool ShouldSave(); +} // namespace internal + } // namespace testing } // namespace gvisor diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc new file mode 100644 index 000000000..7a0f14342 --- /dev/null +++ b/test/util/save_util_linux.cc @@ -0,0 +1,33 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <sys/syscall.h> +#include <unistd.h> + +#include "test/util/save_util.h" + +namespace gvisor { +namespace testing { + +void MaybeSave() { + if (internal::ShouldSave()) { + int orig_errno = errno; + syscall(SYS_create_module, nullptr, 0); + errno = orig_errno; + } +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc new file mode 100644 index 000000000..1aca663b7 --- /dev/null +++ b/test/util/save_util_other.cc @@ -0,0 +1,23 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace gvisor { +namespace testing { + +void MaybeSave() { + // Saving is never available in a non-linux environment. +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/thread_util.h b/test/util/thread_util.h index 860e77531..923c4fe10 100644 --- a/test/util/thread_util.h +++ b/test/util/thread_util.h @@ -16,7 +16,9 @@ #define GVISOR_TEST_UTIL_THREAD_UTIL_H_ #include <pthread.h> +#ifdef __linux__ #include <sys/syscall.h> +#endif #include <unistd.h> #include <functional> @@ -66,13 +68,13 @@ class ScopedThread { private: void CreateThread() { - TEST_PCHECK_MSG( - pthread_create(&pt_, /* attr = */ nullptr, - +[](void* arg) -> void* { - return static_cast<ScopedThread*>(arg)->f_(); - }, - this) == 0, - "thread creation failed"); + TEST_PCHECK_MSG(pthread_create( + &pt_, /* attr = */ nullptr, + +[](void* arg) -> void* { + return static_cast<ScopedThread*>(arg)->f_(); + }, + this) == 0, + "thread creation failed"); } std::function<void*()> f_; @@ -81,7 +83,9 @@ class ScopedThread { void* retval_ = nullptr; }; +#ifdef __linux__ inline pid_t gettid() { return syscall(SYS_gettid); } +#endif } // namespace testing } // namespace gvisor |