From 7436ea247bc946b36a7e5e6ca6019796ef76d85c Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Tue, 4 Jun 2019 11:06:13 -0700 Subject: Fix Kokoro revision and 'go get usage' As a convenience for debugging, also factor the scripts such that can be run without Kokoro. In the future, this may be used to add additional presubmit hooks that run without Kokoro. PiperOrigin-RevId: 251474868 --- kokoro/run_build.sh | 43 +-------- kokoro/run_tests.sh | 266 +------------------------------------------------- tools/run_build.sh | 44 +++++++++ tools/run_tests.sh | 273 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 319 insertions(+), 307 deletions(-) mode change 100755 => 120000 kokoro/run_build.sh mode change 100755 => 120000 kokoro/run_tests.sh create mode 100755 tools/run_build.sh create mode 100755 tools/run_tests.sh diff --git a/kokoro/run_build.sh b/kokoro/run_build.sh deleted file mode 100755 index 63fffda48..000000000 --- a/kokoro/run_build.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Fail on any error. -set -e -# Display commands to stderr. -set -x - -# Install the latest version of Bazel. -use_bazel.sh latest - -# Log the bazel path and version. -which bazel -bazel version - -cd git/repo - -# Build runsc. -bazel build //runsc - -# Move the runsc binary into "latest" directory, and also a directory with the -# current date. -latest_dir="${KOKORO_ARTIFACTS_DIR}"/latest -today_dir="${KOKORO_ARTIFACTS_DIR}"/"$(date -Idate)" -mkdir -p "${latest_dir}" "${today_dir}" -cp bazel-bin/runsc/linux_amd64_pure_stripped/runsc "${latest_dir}" -sha512sum "${latest_dir}"/runsc | awk '{print $1 " runsc"}' > "${latest_dir}"/runsc.sha512 -cp bazel-bin/runsc/linux_amd64_pure_stripped/runsc "${today_dir}" -sha512sum "${today_dir}"/runsc | awk '{print $1 " runsc"}' > "${today_dir}"/runsc.sha512 diff --git a/kokoro/run_build.sh b/kokoro/run_build.sh new file mode 120000 index 000000000..9deafe9bb --- /dev/null +++ b/kokoro/run_build.sh @@ -0,0 +1 @@ +../tools/run_build.sh \ No newline at end of file diff --git a/kokoro/run_tests.sh b/kokoro/run_tests.sh deleted file mode 100755 index 6ff72ce1d..000000000 --- a/kokoro/run_tests.sh +++ /dev/null @@ -1,265 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Fail on any error. Treat unset variables as error. Print commands as executed. -set -eux - - -################### -# GLOBAL ENV VARS # -################### - -readonly WORKSPACE_DIR="${PWD}/git/repo" - -# Used to configure RBE. -readonly CLOUD_PROJECT_ID="gvisor-rbe" -readonly RBE_PROJECT_ID="projects/${CLOUD_PROJECT_ID}/instances/default_instance" - -# Random runtime name to avoid collisions. -readonly RUNTIME="runsc_test_$((RANDOM))" - -# Packages that will be built and tested. -readonly BUILD_PACKAGES=("//...") -readonly TEST_PACKAGES=("//pkg/..." "//runsc/..." "//tools/...") - -####################### -# BAZEL CONFIGURATION # -####################### - -# Install the latest version of Bazel, and log the location and version. -use_bazel.sh latest -which bazel -bazel version - -# Load the kvm module -sudo -n -E modprobe kvm - -# General Bazel build/test flags. -BAZEL_BUILD_FLAGS=( - "--show_timestamps" - "--test_output=errors" - "--keep_going" - "--verbose_failures=true" -) - -# Bazel build/test for RBE, a super-set of BAZEL_BUILD_FLAGS. -BAZEL_BUILD_RBE_FLAGS=( - "${BAZEL_BUILD_FLAGS[@]}" - "--config=remote" - "--project_id=${CLOUD_PROJECT_ID}" - "--remote_instance_name=${RBE_PROJECT_ID}" - "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" -) - -#################### -# Helper Functions # -#################### - -sanity_checks() { - cd ${WORKSPACE_DIR} - bazel run //:gazelle -- update-repos -from_file=go.mod - git diff --exit-code WORKSPACE -} - -build_everything() { - FLAVOR="${1}" - - cd ${WORKSPACE_DIR} - bazel build \ - -c "${FLAVOR}" "${BAZEL_BUILD_RBE_FLAGS[@]}" \ - "${BUILD_PACKAGES[@]}" -} - -# Run simple tests runs the tests that require no special setup or -# configuration. -run_simple_tests() { - cd ${WORKSPACE_DIR} - bazel test \ - "${BAZEL_BUILD_FLAGS[@]}" \ - "${TEST_PACKAGES[@]}" -} - -install_runtime() { - cd ${WORKSPACE_DIR} - sudo -n ${WORKSPACE_DIR}/runsc/test/install.sh --runtime ${RUNTIME} -} - -# Install dependencies for the crictl tests. -install_crictl_test_deps() { - # Install containerd. - sudo -n -E apt-get update - sudo -n -E apt-get install -y btrfs-tools libseccomp-dev - # go get will exit with a status of 1 despite succeeding, so ignore errors. - go get -d github.com/containerd/containerd || true - cd ${GOPATH}/src/github.com/containerd/containerd - git checkout v1.2.2 - make - sudo -n -E make install - - # Install crictl. - # go get will exit with a status of 1 despite succeeding, so ignore errors. - go get -d github.com/kubernetes-sigs/cri-tools || true - cd ${GOPATH}/src/github.com/kubernetes-sigs/cri-tools - git checkout tags/v1.11.0 - make - sudo -n -E make install - - # Install gvisor-containerd-shim. - local latest=/tmp/gvisor-containerd-shim-latest - local shim_path=/tmp/gvisor-containerd-shim - wget --no-verbose https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim/latest -O ${latest} - wget --no-verbose https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} - chmod +x ${shim_path} - sudo -n -E mv ${shim_path} /usr/local/bin - - # Configure containerd-shim. - local shim_config_path=/etc/containerd - local shim_config_tmp_path=/tmp/gvisor-containerd-shim.toml - sudo -n -E mkdir -p ${shim_config_path} - cat > ${shim_config_tmp_path} <<-EOF - runc_shim = "/usr/local/bin/containerd-shim" - - [runsc_config] - debug = "true" - debug-log = "/tmp/runsc-logs/" - strace = "true" - file-access = "shared" -EOF - sudo mv ${shim_config_tmp_path} ${shim_config_path} - - # Configure CNI. - sudo -n -E env PATH=${PATH} ${GOPATH}/src/github.com/containerd/containerd/script/setup/install-cni -} - -# Run the tests that require docker. -run_docker_tests() { - cd ${WORKSPACE_DIR} - - # Run tests with a default runtime (runc). - bazel test \ - "${BAZEL_BUILD_FLAGS[@]}" \ - --test_env=RUNSC_RUNTIME="" \ - --test_output=all \ - //runsc/test/image:image_test - - # These names are used to exclude tests not supported in certain - # configuration, e.g. save/restore not supported with hostnet. - declare -a variations=("" "-kvm" "-hostnet" "-overlay") - for v in "${variations[@]}"; do - # Run runsc tests with docker that are tagged manual. - bazel test \ - "${BAZEL_BUILD_FLAGS[@]}" \ - --test_env=RUNSC_RUNTIME="${RUNTIME}${v}" \ - --test_output=all \ - //runsc/test/image:image_test \ - //runsc/test/integration:integration_test - done -} - -# Run the tests that require root. -run_root_tests() { - cd ${WORKSPACE_DIR} - bazel build //runsc/test/root:root_test - local root_test=$(find -L ./bazel-bin/ -executable -type f -name root_test | grep __main__) - if [[ ! -f "${root_test}" ]]; then - echo "root_test executable not found" - exit 1 - fi - sudo -n -E RUNSC_RUNTIME="${RUNTIME}" RUNSC_EXEC=/tmp/"${RUNTIME}"/runsc ${root_test} -} - -# Run syscall unit tests. -run_syscall_tests() { - cd ${WORKSPACE_DIR} - bazel test "${BAZEL_BUILD_RBE_FLAGS[@]}" \ - --test_tag_filters=runsc_ptrace //test/syscalls/... -} - -run_runsc_do_tests() { - local runsc=$(find bazel-bin/runsc -type f -executable -name "runsc" | head -n1) - - # run runsc do without root privileges. - unshare -Ur ${runsc} --network=none --TESTONLY-unsafe-nonroot do true - unshare -Ur ${runsc} --TESTONLY-unsafe-nonroot --network=host do --netns=false true - - # run runsc do with root privileges. - sudo -n -E ${runsc} do true -} - -# Find and rename all test xml and log files so that Sponge can pick them up. -# XML files must be named sponge_log.xml, and log files must be named -# sponge_log.log. We move all such files into KOKORO_ARTIFACTS_DIR, in a -# subdirectory named with the test name. -upload_test_artifacts() { - cd ${WORKSPACE_DIR} - find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | - tar --create --files-from - --transform 's/test\./sponge_log./' | - tar --extract --directory ${KOKORO_ARTIFACTS_DIR} - if [[ -d "/tmp/${RUNTIME}/logs" ]]; then - tar --create --gzip "--file=${KOKORO_ARTIFACTS_DIR}/runsc-logs.tar.gz" -C /tmp/ ${RUNTIME}/logs - fi -} - -# Finish runs at exit, even in the event of an error, and uploads all test -# artifacts. -finish() { - # Grab the last exit code, we will return it. - local exit_code=${?} - upload_test_artifacts - exit ${exit_code} -} - -# Run bazel in a docker container -build_in_docker() { - cd ${WORKSPACE_DIR} - bazel clean - bazel shutdown - make - make runsc - make bazel-shutdown -} - -######## -# MAIN # -######## - -main() { - # Register finish to run at exit. - trap finish EXIT - - # Build and run the simple tests. - sanity_checks - build_everything opt - run_simple_tests - - # So far so good. Install more deps and run the integration tests. - install_runtime - install_crictl_test_deps - run_docker_tests - run_root_tests - - run_syscall_tests - run_runsc_do_tests - - # Build other flavors too. - build_everything dbg - - build_in_docker - # No need to call "finish" here, it will happen at exit. -} - -# Kick it off. -main diff --git a/kokoro/run_tests.sh b/kokoro/run_tests.sh new file mode 120000 index 000000000..931cd2622 --- /dev/null +++ b/kokoro/run_tests.sh @@ -0,0 +1 @@ +../tools/run_tests.sh \ No newline at end of file diff --git a/tools/run_build.sh b/tools/run_build.sh new file mode 100755 index 000000000..b6b446690 --- /dev/null +++ b/tools/run_build.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Fail on any error. +set -e +# Display commands to stderr. +set -x + +# Install the latest version of Bazel and log the version. +(which use_bazel.sh && use_bazel.sh latest) || which bazel +bazel version + +# Switch into the workspace and checkout the appropriate commit. +if [[ -v KOKORO_GIT_COMMIT ]]; then + cd git/repo && git checkout "${KOKORO_GIT_COMMIT}" +fi + +# Build runsc. +bazel build //runsc + +# Move the runsc binary into "latest" directory, and also a directory with the +# current date. +if [[ -v KOKORO_ARTIFACTS_DIR ]]; then + latest_dir="${KOKORO_ARTIFACTS_DIR}"/latest + today_dir="${KOKORO_ARTIFACTS_DIR}"/"$(date -Idate)" + mkdir -p "${latest_dir}" "${today_dir}" + cp bazel-bin/runsc/linux_amd64_pure_stripped/runsc "${latest_dir}" + sha512sum "${latest_dir}"/runsc | awk '{print $1 " runsc"}' > "${latest_dir}"/runsc.sha512 + cp bazel-bin/runsc/linux_amd64_pure_stripped/runsc "${today_dir}" + sha512sum "${today_dir}"/runsc | awk '{print $1 " runsc"}' > "${today_dir}"/runsc.sha512 +fi diff --git a/tools/run_tests.sh b/tools/run_tests.sh new file mode 100755 index 000000000..c6e97dc95 --- /dev/null +++ b/tools/run_tests.sh @@ -0,0 +1,273 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Fail on any error. Treat unset variables as error. Print commands as executed. +set -eux + +################### +# GLOBAL ENV VARS # +################### + +if [[ -v KOKORO_GIT_COMMIT ]]; then + readonly WORKSPACE_DIR="${PWD}/git/repo" +else + readonly WORKSPACE_DIR="${PWD}" +fi + +# Used to configure RBE. +readonly CLOUD_PROJECT_ID="gvisor-rbe" +readonly RBE_PROJECT_ID="projects/${CLOUD_PROJECT_ID}/instances/default_instance" + +# Random runtime name to avoid collisions. +readonly RUNTIME="runsc_test_$((RANDOM))" + +# Packages that will be built and tested. +readonly BUILD_PACKAGES=("//...") +readonly TEST_PACKAGES=("//pkg/..." "//runsc/..." "//tools/...") + +####################### +# BAZEL CONFIGURATION # +####################### + +# Install the latest version of Bazel and log the version. +(which use_bazel.sh && use_bazel.sh latest) || which bazel +bazel version + +# Checkout the appropriate commit. +if [[ -v KOKORO_GIT_COMMIT ]]; then + (cd "${WORKSPACE_DIR}" && git checkout "${KOKORO_GIT_COMMIT}") +fi + +# Load the kvm module. +sudo -n -E modprobe kvm + +# General Bazel build/test flags. +BAZEL_BUILD_FLAGS=( + "--show_timestamps" + "--test_output=errors" + "--keep_going" + "--verbose_failures=true" +) + +# Bazel build/test for RBE, a super-set of BAZEL_BUILD_FLAGS. +BAZEL_BUILD_RBE_FLAGS=( + "${BAZEL_BUILD_FLAGS[@]}" + "--config=remote" + "--project_id=${CLOUD_PROJECT_ID}" + "--remote_instance_name=${RBE_PROJECT_ID}" +) +if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then + BAZEL_BUILD_RBE_FLAGS=( + "${BAZEL_BUILD_RBE_FLAGS[@]}" + "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" + ) +fi + +#################### +# Helper Functions # +#################### + +sanity_checks() { + cd ${WORKSPACE_DIR} + bazel run //:gazelle -- update-repos -from_file=go.mod + git diff --exit-code WORKSPACE +} + +build_everything() { + FLAVOR="${1}" + + cd ${WORKSPACE_DIR} + bazel build \ + -c "${FLAVOR}" "${BAZEL_BUILD_RBE_FLAGS[@]}" \ + "${BUILD_PACKAGES[@]}" +} + +# Run simple tests runs the tests that require no special setup or +# configuration. +run_simple_tests() { + cd ${WORKSPACE_DIR} + bazel test \ + "${BAZEL_BUILD_FLAGS[@]}" \ + "${TEST_PACKAGES[@]}" +} + +install_runtime() { + cd ${WORKSPACE_DIR} + sudo -n ${WORKSPACE_DIR}/runsc/test/install.sh --runtime ${RUNTIME} +} + +# Install dependencies for the crictl tests. +install_crictl_test_deps() { + sudo -n -E apt-get update + sudo -n -E apt-get install -y btrfs-tools libseccomp-dev + + # Install containerd. + [[ -d containerd ]] || git clone https://github.com/containerd/containerd + (cd containerd && git checkout v1.2.2 && make && sudo -n -E make install) + + # Install crictl. + [[ -d cri-tools ]] || git clone https://github.com/kubernetes-sigs/cri-tools + (cd cri-tools && git checkout tags/v1.11.0 && make && sudo -n -E make install) + + # Install gvisor-containerd-shim. + local latest=/tmp/gvisor-containerd-shim-latest + local shim_path=/tmp/gvisor-containerd-shim + wget --no-verbose https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim/latest -O ${latest} + wget --no-verbose https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} + chmod +x ${shim_path} + sudo -n -E mv ${shim_path} /usr/local/bin + + # Configure containerd-shim. + local shim_config_path=/etc/containerd + local shim_config_tmp_path=/tmp/gvisor-containerd-shim.toml + sudo -n -E mkdir -p ${shim_config_path} + cat > ${shim_config_tmp_path} <<-EOF + runc_shim = "/usr/local/bin/containerd-shim" + + [runsc_config] + debug = "true" + debug-log = "/tmp/runsc-logs/" + strace = "true" + file-access = "shared" +EOF + sudo mv ${shim_config_tmp_path} ${shim_config_path} + + # Configure CNI. + sudo -n -E env PATH=${PATH} containerd/script/setup/install-cni +} + +# Run the tests that require docker. +run_docker_tests() { + cd ${WORKSPACE_DIR} + + # Run tests with a default runtime (runc). + bazel test \ + "${BAZEL_BUILD_FLAGS[@]}" \ + --test_env=RUNSC_RUNTIME="" \ + --test_output=all \ + //runsc/test/image:image_test + + # These names are used to exclude tests not supported in certain + # configuration, e.g. save/restore not supported with hostnet. + declare -a variations=("" "-kvm" "-hostnet" "-overlay") + for v in "${variations[@]}"; do + # Run runsc tests with docker that are tagged manual. + bazel test \ + "${BAZEL_BUILD_FLAGS[@]}" \ + --test_env=RUNSC_RUNTIME="${RUNTIME}${v}" \ + --test_output=all \ + //runsc/test/image:image_test \ + //runsc/test/integration:integration_test + done +} + +# Run the tests that require root. +run_root_tests() { + cd ${WORKSPACE_DIR} + bazel build //runsc/test/root:root_test + local root_test=$(find -L ./bazel-bin/ -executable -type f -name root_test | grep __main__) + if [[ ! -f "${root_test}" ]]; then + echo "root_test executable not found" + exit 1 + fi + sudo -n -E RUNSC_RUNTIME="${RUNTIME}" RUNSC_EXEC=/tmp/"${RUNTIME}"/runsc ${root_test} +} + +# Run syscall unit tests. +run_syscall_tests() { + cd ${WORKSPACE_DIR} + bazel test "${BAZEL_BUILD_RBE_FLAGS[@]}" \ + --test_tag_filters=runsc_ptrace //test/syscalls/... +} + +run_runsc_do_tests() { + local runsc=$(find bazel-bin/runsc -type f -executable -name "runsc" | head -n1) + + # run runsc do without root privileges. + unshare -Ur ${runsc} --network=none --TESTONLY-unsafe-nonroot do true + unshare -Ur ${runsc} --TESTONLY-unsafe-nonroot --network=host do --netns=false true + + # run runsc do with root privileges. + sudo -n -E ${runsc} do true +} + +# Find and rename all test xml and log files so that Sponge can pick them up. +# XML files must be named sponge_log.xml, and log files must be named +# sponge_log.log. We move all such files into KOKORO_ARTIFACTS_DIR, in a +# subdirectory named with the test name. +upload_test_artifacts() { + # Skip if no kokoro directory. + [[ -v KOKORO_ARTIFACTS_DIR ]] || return + + cd ${WORKSPACE_DIR} + find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | + tar --create --files-from - --transform 's/test\./sponge_log./' | + tar --extract --directory ${KOKORO_ARTIFACTS_DIR} + if [[ -d "/tmp/${RUNTIME}/logs" ]]; then + tar --create --gzip "--file=${KOKORO_ARTIFACTS_DIR}/runsc-logs.tar.gz" -C /tmp/ ${RUNTIME}/logs + fi +} + +# Finish runs at exit, even in the event of an error, and uploads all test +# artifacts. +finish() { + # Grab the last exit code, we will return it. + local exit_code=${?} + upload_test_artifacts + exit ${exit_code} +} + +# Run bazel in a docker container +build_in_docker() { + cd ${WORKSPACE_DIR} + bazel clean + bazel shutdown + make + make runsc + make bazel-shutdown +} + +######## +# MAIN # +######## + +main() { + # Register finish to run at exit. + trap finish EXIT + + # Build and run the simple tests. + sanity_checks + build_everything opt + run_simple_tests + + # So far so good. Install more deps and run the integration tests. + install_runtime + install_crictl_test_deps + run_docker_tests + run_root_tests + + run_syscall_tests + run_runsc_do_tests + + # Build other flavors too. + build_everything dbg + + build_in_docker + # No need to call "finish" here, it will happen at exit. +} + +# Kick it off. +main -- cgit v1.2.3 From 0c292cdaab5c226bcf90c3376a0f3942cb266eed Mon Sep 17 00:00:00 2001 From: Nicolas Lacasse Date: Tue, 4 Jun 2019 12:57:41 -0700 Subject: Remove the Dirent field from Pipe. Dirents are ref-counted, but Pipes are not. Holding a Dirent inside of a Pipe raises difficult questions about the lifecycle of the Pipe and Dirent. Fortunately, we can side-step those questions by removing the Dirent field from Pipe entirely. We only need the Dirent when constructing fs.Files (which are ref-counted), and in GetFile (when a Dirent is passed to us anyways). PiperOrigin-RevId: 251497628 --- pkg/sentry/kernel/pipe/node.go | 12 ++++++++---- pkg/sentry/kernel/pipe/node_test.go | 8 ++++++-- pkg/sentry/kernel/pipe/pipe.go | 39 +++++++++++++++++-------------------- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 926c4c623..dc7da529e 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -38,7 +38,11 @@ type inodeOperations struct { fsutil.InodeNotMappable `state:"nosave"` fsutil.InodeNotSocket `state:"nosave"` fsutil.InodeNotSymlink `state:"nosave"` - fsutil.InodeNotVirtual `state:"nosave"` + + // Marking pipe inodes as virtual allows them to be saved and restored + // even if they have been unlinked. We can get away with this because + // their state exists entirely within the sentry. + fsutil.InodeVirtual `state:"nosave"` fsutil.InodeSimpleAttributes @@ -86,7 +90,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi switch { case flags.Read && !flags.Write: // O_RDONLY. - r := i.p.Open(ctx, flags) + r := i.p.Open(ctx, d, flags) i.newHandleLocked(&i.rWakeup) if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() { @@ -102,7 +106,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi return r, nil case flags.Write && !flags.Read: // O_WRONLY. - w := i.p.Open(ctx, flags) + w := i.p.Open(ctx, d, flags) i.newHandleLocked(&i.wWakeup) if i.p.isNamed && !i.p.HasReaders() { @@ -122,7 +126,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi case flags.Read && flags.Write: // O_RDWR. // Pipes opened for read-write always succeeds without blocking. - rw := i.p.Open(ctx, flags) + rw := i.p.Open(ctx, d, flags) i.newHandleLocked(&i.rWakeup) i.newHandleLocked(&i.wWakeup) return rw, nil diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index 31d9b0443..9a946b380 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -62,7 +62,9 @@ var perms fs.FilePermissions = fs.FilePermissions{ } func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, doneChan chan<- struct{}) (*fs.File, error) { - file, err := n.GetFile(ctx, nil, flags) + inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) + d := fs.NewDirent(inode, "pipe") + file, err := n.GetFile(ctx, d, flags) if err != nil { t.Fatalf("open with flags %+v failed: %v", flags, err) } @@ -73,7 +75,9 @@ func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flag } func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.FileFlags, resChan chan<- openResult) (*fs.File, error) { - file, err := n.GetFile(ctx, nil, flags) + inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe}) + d := fs.NewDirent(inode, "pipe") + file, err := n.GetFile(ctx, d, flags) if resChan != nil { resChan <- openResult{file, err} } diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index b65204492..73438dc62 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -71,11 +71,6 @@ type Pipe struct { // This value is immutable. atomicIOBytes int64 - // The dirent backing this pipe. Shared by all readers and writers. - // - // This value is immutable. - Dirent *fs.Dirent - // The number of active readers for this pipe. // // Access atomically. @@ -130,14 +125,20 @@ func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) if atomicIOBytes > sizeBytes { atomicIOBytes = sizeBytes } - p := &Pipe{ + return &Pipe{ isNamed: isNamed, max: sizeBytes, atomicIOBytes: atomicIOBytes, } +} - // Build the fs.Dirent of this pipe, shared by all fs.Files associated - // with this pipe. +// NewConnectedPipe initializes a pipe and returns a pair of objects +// representing the read and write ends of the pipe. +func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) { + p := NewPipe(ctx, false /* isNamed */, sizeBytes, atomicIOBytes) + + // Build an fs.Dirent for the pipe which will be shared by both + // returned files. perms := fs.FilePermissions{ User: fs.PermMask{Read: true, Write: true}, } @@ -150,36 +151,32 @@ func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) BlockSize: int64(atomicIOBytes), } ms := fs.NewPseudoMountSource() - p.Dirent = fs.NewDirent(fs.NewInode(iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino)) - return p -} - -// NewConnectedPipe initializes a pipe and returns a pair of objects -// representing the read and write ends of the pipe. -func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) { - p := NewPipe(ctx, false /* isNamed */, sizeBytes, atomicIOBytes) - return p.Open(ctx, fs.FileFlags{Read: true}), p.Open(ctx, fs.FileFlags{Write: true}) + d := fs.NewDirent(fs.NewInode(iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino)) + // The p.Open calls below will each take a reference on the Dirent. We + // must drop the one we already have. + defer d.DecRef() + return p.Open(ctx, d, fs.FileFlags{Read: true}), p.Open(ctx, d, fs.FileFlags{Write: true}) } // Open opens the pipe and returns a new file. // // Precondition: at least one of flags.Read or flags.Write must be set. -func (p *Pipe) Open(ctx context.Context, flags fs.FileFlags) *fs.File { +func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.File { switch { case flags.Read && flags.Write: p.rOpen() p.wOpen() - return fs.NewFile(ctx, p.Dirent, flags, &ReaderWriter{ + return fs.NewFile(ctx, d, flags, &ReaderWriter{ Pipe: p, }) case flags.Read: p.rOpen() - return fs.NewFile(ctx, p.Dirent, flags, &Reader{ + return fs.NewFile(ctx, d, flags, &Reader{ ReaderWriter: ReaderWriter{Pipe: p}, }) case flags.Write: p.wOpen() - return fs.NewFile(ctx, p.Dirent, flags, &Writer{ + return fs.NewFile(ctx, d, flags, &Writer{ ReaderWriter: ReaderWriter{Pipe: p}, }) default: -- cgit v1.2.3 From 6f92038ce0d2062c3dfd84fe65141ee09deeabfc Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Tue, 4 Jun 2019 14:42:25 -0700 Subject: Use github directory if it exists. Unfortunately, kokoro names the top-level directory per the SCM type. This means there's no way to make the job names match; we simply need to probe for the existence of the correct directory. PiperOrigin-RevId: 251519409 --- tools/run_build.sh | 8 +++++--- tools/run_tests.sh | 9 +++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tools/run_build.sh b/tools/run_build.sh index b6b446690..d49a1d4be 100755 --- a/tools/run_build.sh +++ b/tools/run_build.sh @@ -23,9 +23,11 @@ set -x (which use_bazel.sh && use_bazel.sh latest) || which bazel bazel version -# Switch into the workspace and checkout the appropriate commit. -if [[ -v KOKORO_GIT_COMMIT ]]; then - cd git/repo && git checkout "${KOKORO_GIT_COMMIT}" +# Switch into the workspace. +if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then + cd git/repo +elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then + cd github/repo fi # Build runsc. diff --git a/tools/run_tests.sh b/tools/run_tests.sh index c6e97dc95..dc282c142 100755 --- a/tools/run_tests.sh +++ b/tools/run_tests.sh @@ -21,8 +21,10 @@ set -eux # GLOBAL ENV VARS # ################### -if [[ -v KOKORO_GIT_COMMIT ]]; then +if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then readonly WORKSPACE_DIR="${PWD}/git/repo" +elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then + readonly WORKSPACE_DIR="${PWD}/github/repo" else readonly WORKSPACE_DIR="${PWD}" fi @@ -46,11 +48,6 @@ readonly TEST_PACKAGES=("//pkg/..." "//runsc/..." "//tools/...") (which use_bazel.sh && use_bazel.sh latest) || which bazel bazel version -# Checkout the appropriate commit. -if [[ -v KOKORO_GIT_COMMIT ]]; then - (cd "${WORKSPACE_DIR}" && git checkout "${KOKORO_GIT_COMMIT}") -fi - # Load the kvm module. sudo -n -E modprobe kvm -- cgit v1.2.3 From 7398f013f043cfe43b5fc615bd24b641df17e6bc Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 4 Jun 2019 15:39:24 -0700 Subject: Drop one dirent reference after referenced by file When pipe is created, a dirent of pipe will be created and its initial reference is set as 0. Cause all dirent will only be destroyed when the reference decreased to -1, so there is already a 'initial reference' of dirent after it created. For destroying dirent after all reference released, the correct way is to drop the 'initial reference' once someone hold a reference to the dirent, such as fs.NewFile, otherwise the reference of dirent will stay 0 all the time, and will cause memory leak of dirent. Except pipe, timerfd/eventfd/epoll has the same problem Here is a simple case to create memory leak of dirent for pipe/timerfd/eventfd/epoll in C langange, after run the case, pprof the runsc process, you will find lots dirents of pipe/timerfd/eventfd/epoll not freed: int main(int argc, char *argv[]) { int i; int n; int pipefd[2]; if (argc != 3) { printf("Usage: %s epoll|timerfd|eventfd|pipe \n", argv[0]); } n = strtol(argv[2], NULL, 10); if (strcmp(argv[1], "epoll") == 0) { for (i = 0; i < n; ++i) close(epoll_create(1)); } else if (strcmp(argv[1], "timerfd") == 0) { for (i = 0; i < n; ++i) close(timerfd_create(CLOCK_REALTIME, 0)); } else if (strcmp(argv[1], "eventfd") == 0) { for (i = 0; i < n; ++i) close(eventfd(0, 0)); } else if (strcmp(argv[1], "pipe") == 0) { for (i = 0; i < n; ++i) if (pipe(pipefd) == 0) { close(pipefd[0]); close(pipefd[1]); } } printf("%s %s test finished\r\n",argv[1],argv[2]); return 0; } Change-Id: Ia1b8a1fb9142edb00c040e44ec644d007f81f5d2 PiperOrigin-RevId: 251531096 --- pkg/sentry/fs/timerfd/timerfd.go | 2 ++ pkg/sentry/kernel/epoll/epoll.go | 2 ++ pkg/sentry/kernel/eventfd/eventfd.go | 2 ++ 3 files changed, 6 insertions(+) diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index bce5f091d..c1721f434 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -54,6 +54,8 @@ type TimerOperations struct { // NewFile returns a timerfd File that receives time from c. func NewFile(ctx context.Context, c ktime.Clock) *fs.File { dirent := fs.NewDirent(anon.NewInode(ctx), "anon_inode:[timerfd]") + // Release the initial dirent reference after NewFile takes a reference. + defer dirent.DecRef() tops := &TimerOperations{} tops.timer = ktime.NewTimer(c, tops) // Timerfds reject writes, but the Write flag must be set in order to diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index bbacba1f4..43ae22a5d 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -156,6 +156,8 @@ var cycleMu sync.Mutex func NewEventPoll(ctx context.Context) *fs.File { // name matches fs/eventpoll.c:epoll_create1. dirent := fs.NewDirent(anon.NewInode(ctx), fmt.Sprintf("anon_inode:[eventpoll]")) + // Release the initial dirent reference after NewFile takes a reference. + defer dirent.DecRef() return fs.NewFile(ctx, dirent, fs.FileFlags{}, &EventPoll{ files: make(map[FileIdentifier]*pollEntry), }) diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 2f900be38..fe474cbf0 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -69,6 +69,8 @@ type EventOperations struct { func New(ctx context.Context, initVal uint64, semMode bool) *fs.File { // name matches fs/eventfd.c:eventfd_file_create. dirent := fs.NewDirent(anon.NewInode(ctx), "anon_inode:[eventfd]") + // Release the initial dirent reference after NewFile takes a reference. + defer dirent.DecRef() return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &EventOperations{ val: initVal, semMode: semMode, -- cgit v1.2.3 From e0fb921205b79f401375544652e4de8077162292 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 4 Jun 2019 16:16:24 -0700 Subject: Fix data race in synRcvdState. When checking the length of the acceptedChan we should hold the endpoint mutex otherwise a syn received while the listening socket is being closed can result in a data race where the cleanupLocked routine sets acceptedChan to nil while a handshake goroutine in progress could try and check it at the same time. PiperOrigin-RevId: 251537697 --- pkg/tcpip/transport/tcp/connect.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 2aed6f286..371d2ed29 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -284,14 +284,19 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { // listenContext is also used by a tcp.Forwarder and in that // context we do not have a listening endpoint to check the // backlog. So skip this check if listenEP is nil. - if h.listenEP != nil && len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) { - // If there is no space in the accept queue to accept - // this endpoint then silently drop this ACK. The peer - // will anyway resend the ack and we can complete the - // connection the next time it's retransmitted. - h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment() - h.ep.stack.Stats().DroppedPackets.Increment() - return nil + if h.listenEP != nil { + h.listenEP.mu.Lock() + if len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) { + h.listenEP.mu.Unlock() + // If there is no space in the accept queue to accept + // this endpoint then silently drop this ACK. The peer + // will anyway resend the ack and we can complete the + // connection the next time it's retransmitted. + h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + h.listenEP.mu.Unlock() } // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped -- cgit v1.2.3 From cecb71dc37a77d8e4e88cdfada92a37a72c67602 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Tue, 4 Jun 2019 23:08:20 -0700 Subject: Building containerd with go modules is broken, use GOPATH. PiperOrigin-RevId: 251583707 --- tools/run_tests.sh | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/tools/run_tests.sh b/tools/run_tests.sh index dc282c142..b35d2e4b8 100755 --- a/tools/run_tests.sh +++ b/tools/run_tests.sh @@ -106,18 +106,31 @@ install_runtime() { sudo -n ${WORKSPACE_DIR}/runsc/test/install.sh --runtime ${RUNTIME} } +install_helper() { + PACKAGE="${1}" + TAG="${2}" + GOPATH="${3}" + + # Clone the repository. + mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \ + git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}" + + # Checkout and build the repository. + (cd "${GOPATH}"/src/"${PACKAGE}" && \ + git checkout "${TAG}" && \ + GOPATH="${GOPATH}" make && \ + sudo -n -E env GOPATH="${GOPATH}" make install) +} + # Install dependencies for the crictl tests. install_crictl_test_deps() { sudo -n -E apt-get update sudo -n -E apt-get install -y btrfs-tools libseccomp-dev - # Install containerd. - [[ -d containerd ]] || git clone https://github.com/containerd/containerd - (cd containerd && git checkout v1.2.2 && make && sudo -n -E make install) - - # Install crictl. - [[ -d cri-tools ]] || git clone https://github.com/kubernetes-sigs/cri-tools - (cd cri-tools && git checkout tags/v1.11.0 && make && sudo -n -E make install) + # Install containerd & cri-tools. + GOPATH=$(mktemp -d --tmpdir gopathXXXXX) + install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}" + install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}" # Install gvisor-containerd-shim. local latest=/tmp/gvisor-containerd-shim-latest @@ -143,7 +156,8 @@ EOF sudo mv ${shim_config_tmp_path} ${shim_config_path} # Configure CNI. - sudo -n -E env PATH=${PATH} containerd/script/setup/install-cni + (cd "${GOPATH}" && sudo -n -E env PATH="${PATH}" GOPATH="${GOPATH}" \ + src/github.com/containerd/containerd/script/setup/install-cni) } # Run the tests that require docker. -- cgit v1.2.3 From d3ed9baac0dc967eaf6d3e3f986cafe60604121a Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Wed, 5 Jun 2019 13:59:01 -0700 Subject: Implement dumpability tracking and checks We don't actually support core dumps, but some applications want to get/set dumpability, which still has an effect in procfs. Lack of support for set-uid binaries or fs creds simplifies things a bit. As-is, processes started via CreateProcess (i.e., init and sentryctl exec) have normal dumpability. I'm a bit torn on whether sentryctl exec tasks should be dumpable, but at least since they have no parent normal UID/GID checks should protect them. PiperOrigin-RevId: 251712714 --- pkg/abi/linux/prctl.go | 7 +++++ pkg/sentry/fs/proc/inode.go | 40 ++++++++++++++++++++++-- pkg/sentry/fs/proc/task.go | 17 +++++++++- pkg/sentry/kernel/ptrace.go | 17 +++++++++- pkg/sentry/kernel/task_exec.go | 7 +++++ pkg/sentry/kernel/task_identity.go | 24 ++++++++++++-- pkg/sentry/mm/lifecycle.go | 6 ++-- pkg/sentry/mm/metadata.go | 30 ++++++++++++++++++ pkg/sentry/mm/mm.go | 6 ++++ pkg/sentry/syscalls/linux/sys_prctl.go | 33 ++++++++++++++++++-- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/prctl.cc | 34 ++++++++++++++++++++ test/syscalls/linux/proc.cc | 57 ++++++++++++++++++++++++++++++++++ 13 files changed, 268 insertions(+), 11 deletions(-) diff --git a/pkg/abi/linux/prctl.go b/pkg/abi/linux/prctl.go index 0428282dd..391cfaa1c 100644 --- a/pkg/abi/linux/prctl.go +++ b/pkg/abi/linux/prctl.go @@ -155,3 +155,10 @@ const ( ARCH_GET_GS = 0x1004 ARCH_SET_CPUID = 0x1012 ) + +// Flags for prctl(PR_SET_DUMPABLE), defined in include/linux/sched/coredump.h. +const ( + SUID_DUMP_DISABLE = 0 + SUID_DUMP_USER = 1 + SUID_DUMP_ROOT = 2 +) diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go index 379569823..986bc0a45 100644 --- a/pkg/sentry/fs/proc/inode.go +++ b/pkg/sentry/fs/proc/inode.go @@ -21,11 +21,14 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" ) // taskOwnedInodeOps wraps an fs.InodeOperations and overrides the UnstableAttr -// method to return the task as the owner. +// method to return either the task or root as the owner, depending on the +// task's dumpability. // // +stateify savable type taskOwnedInodeOps struct { @@ -41,9 +44,42 @@ func (i *taskOwnedInodeOps) UnstableAttr(ctx context.Context, inode *fs.Inode) ( if err != nil { return fs.UnstableAttr{}, err } - // Set the task owner as the file owner. + + // By default, set the task owner as the file owner. creds := i.t.Credentials() uattr.Owner = fs.FileOwner{creds.EffectiveKUID, creds.EffectiveKGID} + + // Linux doesn't apply dumpability adjustments to world + // readable/executable directories so that applications can stat + // /proc/PID to determine the effective UID of a process. See + // fs/proc/base.c:task_dump_owner. + if fs.IsDir(inode.StableAttr) && uattr.Perms == fs.FilePermsFromMode(0555) { + return uattr, nil + } + + // If the task is not dumpable, then root (in the namespace preferred) + // owns the file. + var m *mm.MemoryManager + i.t.WithMuLocked(func(t *kernel.Task) { + m = t.MemoryManager() + }) + + if m == nil { + uattr.Owner.UID = auth.RootKUID + uattr.Owner.GID = auth.RootKGID + } else if m.Dumpability() != mm.UserDumpable { + if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() { + uattr.Owner.UID = kuid + } else { + uattr.Owner.UID = auth.RootKUID + } + if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() { + uattr.Owner.GID = kgid + } else { + uattr.Owner.GID = auth.RootKGID + } + } + return uattr, nil } diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 77e03d349..21a965f90 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -96,7 +96,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks boo contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers) } - // TODO(b/31916171): Set EUID/EGID based on dumpability. + // N.B. taskOwnedInodeOps enforces dumpability-based ownership. d := &taskDir{ Dir: *ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)), t: t, @@ -667,6 +667,21 @@ func newComm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { return newProcInode(c, msrc, fs.SpecialFile, t) } +// Check implements fs.InodeOperations.Check. +func (c *comm) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool { + // This file can always be read or written by members of the same + // thread group. See fs/proc/base.c:proc_tid_comm_permission. + // + // N.B. This check is currently a no-op as we don't yet support writing + // and this file is world-readable anyways. + t := kernel.TaskFromContext(ctx) + if t != nil && t.ThreadGroup() == c.t.ThreadGroup() && !p.Execute { + return true + } + + return fs.ContextCanAccessFile(ctx, inode, p) +} + // GetFile implements fs.InodeOperations.GetFile. func (c *comm) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { return fs.NewFile(ctx, dirent, flags, &commFile{t: c.t}), nil diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 4423e7efd..193447b17 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -19,6 +19,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserror" ) @@ -92,6 +93,14 @@ const ( // ptrace(2), subsection "Ptrace access mode checking". If attach is true, it // checks for access mode PTRACE_MODE_ATTACH; otherwise, it checks for access // mode PTRACE_MODE_READ. +// +// NOTE(b/30815691): The result of CanTrace is immediately stale (e.g., a +// racing setuid(2) may change traceability). This may pose a risk when a task +// changes from traceable to not traceable. This is only problematic across +// execve, where privileges may increase. +// +// We currently do not implement privileged executables (set-user/group-ID bits +// and file capabilities), so that case is not reachable. func (t *Task) CanTrace(target *Task, attach bool) bool { // "1. If the calling thread and the target thread are in the same thread // group, access is always allowed." - ptrace(2) @@ -162,7 +171,13 @@ func (t *Task) CanTrace(target *Task, attach bool) bool { if cgid := callerCreds.RealKGID; cgid != targetCreds.RealKGID || cgid != targetCreds.EffectiveKGID || cgid != targetCreds.SavedKGID { return false } - // TODO(b/31916171): dumpability check + var targetMM *mm.MemoryManager + target.WithMuLocked(func(t *Task) { + targetMM = t.MemoryManager() + }) + if targetMM != nil && targetMM.Dumpability() != mm.UserDumpable { + return false + } if callerCreds.UserNamespace != targetCreds.UserNamespace { return false } diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 5d1425d5c..35d5cb90c 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -68,6 +68,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/sentry/arch" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" "gvisor.googlesource.com/gvisor/pkg/syserror" ) @@ -198,6 +199,12 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { return flags.CloseOnExec }) + // NOTE(b/30815691): We currently do not implement privileged + // executables (set-user/group-ID bits and file capabilities). This + // allows us to unconditionally enable user dumpability on the new mm. + // See fs/exec.c:setup_new_exec. + r.tc.MemoryManager.SetDumpability(mm.UserDumpable) + // Switch to the new process. t.MemoryManager().Deactivate() t.mu.Lock() diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go index 17f08729a..ec95f78d0 100644 --- a/pkg/sentry/kernel/task_identity.go +++ b/pkg/sentry/kernel/task_identity.go @@ -17,6 +17,7 @@ package kernel import ( "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" "gvisor.googlesource.com/gvisor/pkg/syserror" ) @@ -206,8 +207,17 @@ func (t *Task) setKUIDsUncheckedLocked(newR, newE, newS auth.KUID) { // (filesystem UIDs aren't implemented, nor are any of the capabilities in // question) - // Not documented, but compare Linux's kernel/cred.c:commit_creds(). if oldE != newE { + // "[dumpability] is reset to the current value contained in + // the file /proc/sys/fs/suid_dumpable (which by default has + // the value 0), in the following circumstances: The process's + // effective user or group ID is changed." - prctl(2) + // + // (suid_dumpable isn't implemented, so we just use the + // default. + t.MemoryManager().SetDumpability(mm.NotDumpable) + + // Not documented, but compare Linux's kernel/cred.c:commit_creds(). t.parentDeathSignal = 0 } } @@ -303,8 +313,18 @@ func (t *Task) setKGIDsUncheckedLocked(newR, newE, newS auth.KGID) { t.creds = t.creds.Fork() // See doc for creds. t.creds.RealKGID, t.creds.EffectiveKGID, t.creds.SavedKGID = newR, newE, newS - // Not documented, but compare Linux's kernel/cred.c:commit_creds(). if oldE != newE { + // "[dumpability] is reset to the current value contained in + // the file /proc/sys/fs/suid_dumpable (which by default has + // the value 0), in the following circumstances: The process's + // effective user or group ID is changed." - prctl(2) + // + // (suid_dumpable isn't implemented, so we just use the + // default. + t.MemoryManager().SetDumpability(mm.NotDumpable) + + // Not documented, but compare Linux's + // kernel/cred.c:commit_creds(). t.parentDeathSignal = 0 } } diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 7a65a62a2..7646d5ab2 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -37,6 +37,7 @@ func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *Memo privateRefs: &privateRefs{}, users: 1, auxv: arch.Auxv{}, + dumpability: UserDumpable, aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, } } @@ -79,8 +80,9 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) { envv: mm.envv, auxv: append(arch.Auxv(nil), mm.auxv...), // IncRef'd below, once we know that there isn't an error. - executable: mm.executable, - aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, + executable: mm.executable, + dumpability: mm.dumpability, + aioManager: aioManager{contexts: make(map[uint64]*AIOContext)}, } // Copy vmas. diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go index 9768e51f1..c218006ee 100644 --- a/pkg/sentry/mm/metadata.go +++ b/pkg/sentry/mm/metadata.go @@ -20,6 +20,36 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" ) +// Dumpability describes if and how core dumps should be created. +type Dumpability int + +const ( + // NotDumpable indicates that core dumps should never be created. + NotDumpable Dumpability = iota + + // UserDumpable indicates that core dumps should be created, owned by + // the current user. + UserDumpable + + // RootDumpable indicates that core dumps should be created, owned by + // root. + RootDumpable +) + +// Dumpability returns the dumpability. +func (mm *MemoryManager) Dumpability() Dumpability { + mm.metadataMu.Lock() + defer mm.metadataMu.Unlock() + return mm.dumpability +} + +// SetDumpability sets the dumpability. +func (mm *MemoryManager) SetDumpability(d Dumpability) { + mm.metadataMu.Lock() + defer mm.metadataMu.Unlock() + mm.dumpability = d +} + // ArgvStart returns the start of the application argument vector. // // There is no guarantee that this value is sensible w.r.t. ArgvEnd. diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index eb6defa2b..0a026ff8c 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -219,6 +219,12 @@ type MemoryManager struct { // executable is protected by metadataMu. executable *fs.Dirent + // dumpability describes if and how this MemoryManager may be dumped to + // userspace. + // + // dumpability is protected by metadataMu. + dumpability Dumpability + // aioManager keeps track of AIOContexts used for async IOs. AIOManager // must be cloned when CLONE_VM is used. aioManager aioManager diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index 117ae1a0e..1b7e5616b 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -15,6 +15,7 @@ package linux import ( + "fmt" "syscall" "gvisor.googlesource.com/gvisor/pkg/abi/linux" @@ -23,6 +24,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" ) // Prctl implements linux syscall prctl(2). @@ -44,6 +46,33 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall _, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal())) return 0, nil, err + case linux.PR_GET_DUMPABLE: + d := t.MemoryManager().Dumpability() + switch d { + case mm.NotDumpable: + return linux.SUID_DUMP_DISABLE, nil, nil + case mm.UserDumpable: + return linux.SUID_DUMP_USER, nil, nil + case mm.RootDumpable: + return linux.SUID_DUMP_ROOT, nil, nil + default: + panic(fmt.Sprintf("Unknown dumpability %v", d)) + } + + case linux.PR_SET_DUMPABLE: + var d mm.Dumpability + switch args[1].Int() { + case linux.SUID_DUMP_DISABLE: + d = mm.NotDumpable + case linux.SUID_DUMP_USER: + d = mm.UserDumpable + default: + // N.B. Userspace may not pass SUID_DUMP_ROOT. + return 0, nil, syscall.EINVAL + } + t.MemoryManager().SetDumpability(d) + return 0, nil, nil + case linux.PR_GET_KEEPCAPS: if t.Credentials().KeepCaps { return 1, nil, nil @@ -171,9 +200,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } return 0, nil, t.DropBoundingCapability(cp) - case linux.PR_GET_DUMPABLE, - linux.PR_SET_DUMPABLE, - linux.PR_GET_TIMING, + case linux.PR_GET_TIMING, linux.PR_SET_TIMING, linux.PR_GET_TSC, linux.PR_SET_TSC, diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index ba9fd6d1f..7633ab162 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1317,6 +1317,7 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:capability_util", + "//test/util:cleanup", "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:test_util", diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index bce42dc74..bd1779557 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -17,10 +17,12 @@ #include #include #include + #include #include "gtest/gtest.h" #include "test/util/capability_util.h" +#include "test/util/cleanup.h" #include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" #include "test/util/test_util.h" @@ -35,6 +37,16 @@ namespace testing { namespace { +#ifndef SUID_DUMP_DISABLE +#define SUID_DUMP_DISABLE 0 +#endif /* SUID_DUMP_DISABLE */ +#ifndef SUID_DUMP_USER +#define SUID_DUMP_USER 1 +#endif /* SUID_DUMP_USER */ +#ifndef SUID_DUMP_ROOT +#define SUID_DUMP_ROOT 2 +#endif /* SUID_DUMP_ROOT */ + TEST(PrctlTest, NameInitialized) { const size_t name_length = 20; char name[name_length] = {}; @@ -178,6 +190,28 @@ TEST(PrctlTest, InvalidPrSetMM) { ASSERT_THAT(prctl(PR_SET_MM, 0, 0, 0, 0), SyscallFailsWithErrno(EPERM)); } +// Sanity check that dumpability is remembered. +TEST(PrctlTest, SetGetDumpability) { + int before; + ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds()); + auto cleanup = Cleanup([before] { + ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds()); + }); + + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_DISABLE), SyscallSucceeds()); + EXPECT_THAT(prctl(PR_GET_DUMPABLE), + SyscallSucceedsWithValue(SUID_DUMP_DISABLE)); + + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_USER), SyscallSucceeds()); + EXPECT_THAT(prctl(PR_GET_DUMPABLE), SyscallSucceedsWithValue(SUID_DUMP_USER)); +} + +// SUID_DUMP_ROOT cannot be set via PR_SET_DUMPABLE. +TEST(PrctlTest, RootDumpability) { + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_ROOT), + SyscallFailsWithErrno(EINVAL)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index ede6fb860..924b98e3a 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -69,9 +69,11 @@ // way to get it tested on both gVisor, PTrace and Linux. using ::testing::AllOf; +using ::testing::AnyOf; using ::testing::ContainerEq; using ::testing::Contains; using ::testing::ContainsRegex; +using ::testing::Eq; using ::testing::Gt; using ::testing::HasSubstr; using ::testing::IsSupersetOf; @@ -86,6 +88,16 @@ namespace gvisor { namespace testing { namespace { +#ifndef SUID_DUMP_DISABLE +#define SUID_DUMP_DISABLE 0 +#endif /* SUID_DUMP_DISABLE */ +#ifndef SUID_DUMP_USER +#define SUID_DUMP_USER 1 +#endif /* SUID_DUMP_USER */ +#ifndef SUID_DUMP_ROOT +#define SUID_DUMP_ROOT 2 +#endif /* SUID_DUMP_ROOT */ + // O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0 // because "it isn't needed", even though Linux can return it via F_GETFL. constexpr int kOLargeFile = 00100000; @@ -1896,6 +1908,51 @@ void CheckDuplicatesRecursively(std::string path) { TEST(Proc, NoDuplicates) { CheckDuplicatesRecursively("/proc"); } +// Most /proc/PID files are owned by the task user with SUID_DUMP_USER. +TEST(ProcPid, UserDumpableOwner) { + int before; + ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds()); + auto cleanup = Cleanup([before] { + ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds()); + }); + + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_USER), SyscallSucceeds()); + + // This applies to the task directory itself and files inside. + struct stat st; + ASSERT_THAT(stat("/proc/self/", &st), SyscallSucceeds()); + EXPECT_EQ(st.st_uid, geteuid()); + EXPECT_EQ(st.st_gid, getegid()); + + ASSERT_THAT(stat("/proc/self/stat", &st), SyscallSucceeds()); + EXPECT_EQ(st.st_uid, geteuid()); + EXPECT_EQ(st.st_gid, getegid()); +} + +// /proc/PID files are owned by root with SUID_DUMP_DISABLE. +TEST(ProcPid, RootDumpableOwner) { + int before; + ASSERT_THAT(before = prctl(PR_GET_DUMPABLE), SyscallSucceeds()); + auto cleanup = Cleanup([before] { + ASSERT_THAT(prctl(PR_SET_DUMPABLE, before), SyscallSucceeds()); + }); + + EXPECT_THAT(prctl(PR_SET_DUMPABLE, SUID_DUMP_DISABLE), SyscallSucceeds()); + + // This *does not* applies to the task directory itself (or other 0555 + // directories), but does to files inside. + struct stat st; + ASSERT_THAT(stat("/proc/self/", &st), SyscallSucceeds()); + EXPECT_EQ(st.st_uid, geteuid()); + EXPECT_EQ(st.st_gid, getegid()); + + // This file is owned by root. Also allow nobody in case this test is running + // in a userns without root mapped. + ASSERT_THAT(stat("/proc/self/stat", &st), SyscallSucceeds()); + EXPECT_THAT(st.st_uid, AnyOf(Eq(0), Eq(65534))); + EXPECT_THAT(st.st_gid, AnyOf(Eq(0), Eq(65534))); +} + } // namespace } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 3b950344b3abb1676cd10bf8b455a3355cfcc78d Mon Sep 17 00:00:00 2001 From: Nicolas Lacasse Date: Wed, 5 Jun 2019 14:16:00 -0700 Subject: Bump googletest version. PiperOrigin-RevId: 251716439 --- WORKSPACE | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index a54a80fb7..421453a87 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -173,10 +173,10 @@ http_archive( http_archive( name = "com_google_googletest", - sha256 = "574e884a41f0a9b76f849a5cdd89c393651e7537e5daa725cf12511232cbd74b", - strip_prefix = "googletest-61cdca569b1f7e4629f8b949f0a9606c28281a6b", + sha256 = "db657310d3c5ca2d3f674e3a4b79718d1d39da70604568ee0568ba8e39065ef4", + strip_prefix = "googletest-31200def0dec8a624c861f919e86e4444e6e6ee7", urls = [ - "https://mirror.bazel.build/github.com/google/googletest/archive/61cdca569b1f7e4629f8b949f0a9606c28281a6b.tar.gz", - "https://github.com/google/googletest/archive/61cdca569b1f7e4629f8b949f0a9606c28281a6b.tar.gz", + "https://mirror.bazel.build/github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz", + "https://github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz", ], ) -- cgit v1.2.3 From c08fcaa364e917b19aad0f74a8e3a1c700d0bfcc Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Wed, 5 Jun 2019 15:56:21 -0700 Subject: Give test instantiations meaningful names. PiperOrigin-RevId: 251737069 --- test/syscalls/linux/socket_abstract.cc | 4 ++-- test/syscalls/linux/socket_filesystem.cc | 4 ++-- test/syscalls/linux/socket_ip_loopback_blocking.cc | 2 +- test/syscalls/linux/socket_ip_tcp_generic_loopback.cc | 2 +- test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc | 2 +- test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc | 2 +- test/syscalls/linux/socket_ip_tcp_udp_generic.cc | 2 +- test/syscalls/linux/socket_ip_udp_loopback.cc | 6 +++--- test/syscalls/linux/socket_ip_udp_loopback_blocking.cc | 2 +- test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc | 2 +- .../linux/socket_ipv4_tcp_unbound_external_networking_test.cc | 2 +- .../linux/socket_ipv4_udp_unbound_external_networking_test.cc | 2 +- test/syscalls/linux/socket_unix_abstract_nonblock.cc | 2 +- test/syscalls/linux/socket_unix_blocking_local.cc | 2 +- test/syscalls/linux/socket_unix_dgram_local.cc | 6 +++--- test/syscalls/linux/socket_unix_dgram_non_blocking.cc | 2 +- test/syscalls/linux/socket_unix_filesystem_nonblock.cc | 2 +- test/syscalls/linux/socket_unix_non_stream_blocking_local.cc | 2 +- test/syscalls/linux/socket_unix_pair_nonblock.cc | 2 +- test/syscalls/linux/socket_unix_seqpacket_local.cc | 6 +++--- test/syscalls/linux/socket_unix_stream_blocking_local.cc | 2 +- test/syscalls/linux/socket_unix_stream_local.cc | 2 +- test/syscalls/linux/socket_unix_stream_nonblock_local.cc | 2 +- 23 files changed, 31 insertions(+), 31 deletions(-) diff --git a/test/syscalls/linux/socket_abstract.cc b/test/syscalls/linux/socket_abstract.cc index 2faf678f7..503ba986b 100644 --- a/test/syscalls/linux/socket_abstract.cc +++ b/test/syscalls/linux/socket_abstract.cc @@ -31,11 +31,11 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, AllSocketPairTest, + AbstractUnixSockets, AllSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnixSocketPairTest, + AbstractUnixSockets, UnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_filesystem.cc b/test/syscalls/linux/socket_filesystem.cc index f7cb72df4..e38a320f6 100644 --- a/test/syscalls/linux/socket_filesystem.cc +++ b/test/syscalls/linux/socket_filesystem.cc @@ -31,11 +31,11 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, AllSocketPairTest, + FilesystemUnixSockets, AllSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnixSocketPairTest, + FilesystemUnixSockets, UnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_ip_loopback_blocking.cc b/test/syscalls/linux/socket_ip_loopback_blocking.cc index d7fc20aad..d7fc9715b 100644 --- a/test/syscalls/linux/socket_ip_loopback_blocking.cc +++ b/test/syscalls/linux/socket_ip_loopback_blocking.cc @@ -39,7 +39,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, BlockingSocketPairTest, + BlockingIPSockets, BlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc index 2c6ae17bf..0dc274e2d 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc @@ -35,7 +35,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, TCPSocketPairTest, + AllTCPSockets, TCPSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc index d1ea8ef12..cd3ad97d0 100644 --- a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc +++ b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc @@ -35,7 +35,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, BlockingStreamSocketPairTest, + BlockingTCPSockets, BlockingStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc index 96c1b3b3d..1acdecc17 100644 --- a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc +++ b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc @@ -34,7 +34,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonBlockingSocketPairTest, + NonBlockingTCPSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc index 251817a9f..de63f79d9 100644 --- a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc @@ -69,7 +69,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllTCPSockets, TcpUdpSocketPairTest, + AllIPSockets, TcpUdpSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace diff --git a/test/syscalls/linux/socket_ip_udp_loopback.cc b/test/syscalls/linux/socket_ip_udp_loopback.cc index fc124e9ef..1df74a348 100644 --- a/test/syscalls/linux/socket_ip_udp_loopback.cc +++ b/test/syscalls/linux/socket_ip_udp_loopback.cc @@ -33,15 +33,15 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, AllSocketPairTest, + AllUDPSockets, AllSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonStreamSocketPairTest, + AllUDPSockets, NonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); INSTANTIATE_TEST_SUITE_P( - UDPSockets, UDPSocketPairTest, + AllUDPSockets, UDPSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc index 1c3d1c0ad..1e259efa7 100644 --- a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc +++ b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc @@ -30,7 +30,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, BlockingNonStreamSocketPairTest, + BlockingUDPSockets, BlockingNonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc index 7554b08d5..74cbd326d 100644 --- a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc +++ b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc @@ -30,7 +30,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonBlockingSocketPairTest, + NonBlockingUDPSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc index 040bb176e..92f03e045 100644 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc +++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc @@ -28,7 +28,7 @@ std::vector GetSockets() { AllBitwiseCombinations(List{0, SOCK_NONBLOCK})); } -INSTANTIATE_TEST_SUITE_P(IPv4TCPSockets, +INSTANTIATE_TEST_SUITE_P(IPv4TCPUnboundSockets, IPv4TCPUnboundExternalNetworkingSocketTest, ::testing::ValuesIn(GetSockets())); } // namespace testing diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc index ffbb8e6eb..9d4e1ab97 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc @@ -28,7 +28,7 @@ std::vector GetSockets() { AllBitwiseCombinations(List{0, SOCK_NONBLOCK})); } -INSTANTIATE_TEST_SUITE_P(IPv4UDPSockets, +INSTANTIATE_TEST_SUITE_P(IPv4UDPUnboundSockets, IPv4UDPUnboundExternalNetworkingSocketTest, ::testing::ValuesIn(GetSockets())); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_abstract_nonblock.cc b/test/syscalls/linux/socket_unix_abstract_nonblock.cc index 9de0f6dfe..be31ab2a7 100644 --- a/test/syscalls/linux/socket_unix_abstract_nonblock.cc +++ b/test/syscalls/linux/socket_unix_abstract_nonblock.cc @@ -30,7 +30,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonBlockingSocketPairTest, + NonBlockingAbstractUnixSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_blocking_local.cc b/test/syscalls/linux/socket_unix_blocking_local.cc index 320915b0f..1994139e6 100644 --- a/test/syscalls/linux/socket_unix_blocking_local.cc +++ b/test/syscalls/linux/socket_unix_blocking_local.cc @@ -37,7 +37,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, BlockingSocketPairTest, + NonBlockingUnixDomainSockets, BlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_dgram_local.cc b/test/syscalls/linux/socket_unix_dgram_local.cc index 4ba2c80ae..8c5a473bd 100644 --- a/test/syscalls/linux/socket_unix_dgram_local.cc +++ b/test/syscalls/linux/socket_unix_dgram_local.cc @@ -41,15 +41,15 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, DgramUnixSocketPairTest, + DgramUnixSockets, DgramUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnixNonStreamSocketPairTest, + DgramUnixSockets, UnixNonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonStreamSocketPairTest, + DgramUnixSockets, NonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc index 9fe86cee8..707052af8 100644 --- a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc +++ b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc @@ -44,7 +44,7 @@ TEST_P(NonBlockingDgramUnixSocketPairTest, ReadOneSideClosed) { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonBlockingDgramUnixSocketPairTest, + NonBlockingDgramUnixSockets, NonBlockingDgramUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(std::vector{ UnixDomainSocketPair(SOCK_DGRAM | SOCK_NONBLOCK), FilesystemBoundUnixDomainSocketPair(SOCK_DGRAM | SOCK_NONBLOCK), diff --git a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc index 137db53c4..8ba7af971 100644 --- a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc +++ b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc @@ -30,7 +30,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonBlockingSocketPairTest, + NonBlockingFilesystemUnixSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc index 98cf1fe8a..da762cd83 100644 --- a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc +++ b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc @@ -34,7 +34,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, BlockingNonStreamSocketPairTest, + BlockingNonStreamUnixSockets, BlockingNonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_pair_nonblock.cc b/test/syscalls/linux/socket_unix_pair_nonblock.cc index 583506f08..3135d325f 100644 --- a/test/syscalls/linux/socket_unix_pair_nonblock.cc +++ b/test/syscalls/linux/socket_unix_pair_nonblock.cc @@ -30,7 +30,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonBlockingSocketPairTest, + NonBlockingUnixSockets, NonBlockingSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_seqpacket_local.cc b/test/syscalls/linux/socket_unix_seqpacket_local.cc index b903a9e8f..dff75a532 100644 --- a/test/syscalls/linux/socket_unix_seqpacket_local.cc +++ b/test/syscalls/linux/socket_unix_seqpacket_local.cc @@ -41,15 +41,15 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonStreamSocketPairTest, + SeqpacketUnixSockets, NonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, SeqpacketUnixSocketPairTest, + SeqpacketUnixSockets, SeqpacketUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, UnixNonStreamSocketPairTest, + SeqpacketUnixSockets, UnixNonStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_stream_blocking_local.cc index ce0f1e50d..fa0a9d367 100644 --- a/test/syscalls/linux/socket_unix_stream_blocking_local.cc +++ b/test/syscalls/linux/socket_unix_stream_blocking_local.cc @@ -32,7 +32,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, BlockingStreamSocketPairTest, + BlockingStreamUnixSockets, BlockingStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_stream_local.cc b/test/syscalls/linux/socket_unix_stream_local.cc index 6b840189c..65eef1a81 100644 --- a/test/syscalls/linux/socket_unix_stream_local.cc +++ b/test/syscalls/linux/socket_unix_stream_local.cc @@ -39,7 +39,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, StreamSocketPairTest, + StreamUnixSockets, StreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing diff --git a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc index ebec4e0ec..ec777c59f 100644 --- a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc +++ b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc @@ -31,7 +31,7 @@ std::vector GetSocketPairs() { } INSTANTIATE_TEST_SUITE_P( - AllUnixDomainSockets, NonBlockingStreamSocketPairTest, + NonBlockingStreamUnixSockets, NonBlockingStreamSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); } // namespace testing -- cgit v1.2.3 From d18bb4f38a7a89456dad1f3a0e8ff13a0b65ba7f Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Wed, 5 Jun 2019 16:07:18 -0700 Subject: Adjust route when looping multicast packets Multicast packets are special in that their destination address does not identify a specific interface. When sending out such a packet the multicast address is the remote address, but for incoming packets it is the local address. Hence, when looping a multicast packet, the route needs to be tweaked to reflect this. PiperOrigin-RevId: 251739298 --- pkg/tcpip/network/ipv4/ipv4.go | 4 +- pkg/tcpip/network/ipv6/ipv6.go | 4 +- pkg/tcpip/stack/route.go | 10 ++ .../socket_ipv4_udp_unbound_external_networking.cc | 129 +++++++++++++++++++++ 4 files changed, 145 insertions(+), 2 deletions(-) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index da07a39e5..44b1d5b9b 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -215,7 +215,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen views[0] = hdr.View() views = append(views, payload.Views()...) vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - e.HandlePacket(r, vv) + loopedR := r.MakeLoopedRoute() + e.HandlePacket(&loopedR, vv) + loopedR.Release() } if loop&stack.PacketOut == 0 { return nil diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4b8cd496b..bcae98e1f 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -108,7 +108,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen views[0] = hdr.View() views = append(views, payload.Views()...) vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - e.HandlePacket(r, vv) + loopedR := r.MakeLoopedRoute() + e.HandlePacket(&loopedR, vv) + loopedR.Release() } if loop&stack.PacketOut == 0 { return nil diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 3d4c282a9..55ed02479 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -187,3 +187,13 @@ func (r *Route) Clone() Route { r.ref.incRef() return *r } + +// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast. +func (r *Route) MakeLoopedRoute() Route { + l := r.Clone() + if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { + l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress + l.RemoteLinkAddress = l.LocalLinkAddress + } + return l +} diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index 53dcd58cd..6b92e05aa 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -559,5 +559,134 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } +// Check that two sockets can join the same multicast group at the same time, +// and both will receive data on it. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + std::unique_ptr receivers[2] = { + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; + + ip_mreq group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + auto receiver_addr = V4Any(); + int bound_port = 0; + for (auto& receiver : receivers) { + ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + // Bind the receiver to the v4 any address to ensure that we can receive the + // multicast packet. + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + // On the first iteration, save the port we are bound to. On the second + // iteration, verify the port is the same as the one from the first + // iteration. In other words, both sockets listen on the same port. + if (bound_port == 0) { + bound_port = + reinterpret_cast(&receiver_addr.addr)->sin_port; + } else { + EXPECT_EQ(bound_port, + reinterpret_cast(&receiver_addr.addr)->sin_port); + } + + // Register to receive multicast packets. + ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + } + + // Send a multicast packet to the group and verify both receivers get it. + auto send_addr = V4Multicast(); + reinterpret_cast(&send_addr.addr)->sin_port = bound_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + for (auto& receiver : receivers) { + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT( + RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } +} + +// Check that when receiving a looped-back multicast packet, its source address +// is not a multicast address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + IpMulticastLoopbackFromAddr) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + auto receiver_addr = V4Any(); + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + int receiver_port = + reinterpret_cast(&receiver_addr.addr)->sin_port; + + ip_mreq group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), + SyscallSucceeds()); + + // Connect to the multicast address. This binds us to the outgoing interface + // and allows us to get its IP (to be compared against the src-IP on the + // receiver side). + auto sendto_addr = V4Multicast(); + reinterpret_cast(&sendto_addr.addr)->sin_port = receiver_port; + ASSERT_THAT(RetryEINTR(connect)( + sender->get(), reinterpret_cast(&sendto_addr.addr), + sendto_addr.addr_len), + SyscallSucceeds()); + TestAddress sender_addr(""); + ASSERT_THAT( + getsockname(sender->get(), reinterpret_cast(&sender_addr.addr), + &sender_addr.addr_len), + SyscallSucceeds()); + ASSERT_EQ(sizeof(struct sockaddr_in), sender_addr.addr_len); + sockaddr_in* sender_addr_in = + reinterpret_cast(&sender_addr.addr); + + // Send a multicast packet. + char send_buf[4] = {}; + ASSERT_THAT(RetryEINTR(send)(sender->get(), send_buf, sizeof(send_buf), 0), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Receive a multicast packet. + char recv_buf[sizeof(send_buf)] = {}; + TestAddress src_addr(""); + ASSERT_THAT( + RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0, + reinterpret_cast(&src_addr.addr), + &src_addr.addr_len), + SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len); + sockaddr_in* src_addr_in = reinterpret_cast(&src_addr.addr); + + // Verify that the received source IP:port matches the sender one. + EXPECT_EQ(sender_addr_in->sin_port, src_addr_in->sin_port); + EXPECT_EQ(sender_addr_in->sin_addr.s_addr, src_addr_in->sin_addr.s_addr); +} + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From a12848ffebcc2e123f55d8ac805c5248d03a9055 Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Wed, 5 Jun 2019 16:29:25 -0700 Subject: netstack/tcp: fix calculating a number of outstanding packets In case of GSO, a segment can container more than one packet and we need to use the pCount() helper to get a number of packets. PiperOrigin-RevId: 251743020 --- pkg/tcpip/transport/tcp/snd.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index afc1d0a55..3464e4be7 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -779,7 +779,7 @@ func (s *sender) sendData() { break } dataSent = true - s.outstanding++ + s.outstanding += s.pCount(seg) s.writeNext = seg.Next() } } -- cgit v1.2.3 From 57772db2e7351511de422baeecf807785709ee5d Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Wed, 5 Jun 2019 18:39:30 -0700 Subject: Shutdown host sockets on internal shutdown This is required to make the shutdown visible to peers outside the sandbox. The readClosed / writeClosed fields were dropped, as they were preventing a shutdown socket from reading the remainder of queued bytes. The host syscalls will return the appropriate errors for shutdown. The control message tests have been split out of socket_unix.cc to make the (few) remaining tests accessible to testing inherited host UDS, which don't support sending control messages. Updates #273 PiperOrigin-RevId: 251763060 --- pkg/sentry/fs/host/socket.go | 62 +- pkg/sentry/fs/host/socket_test.go | 156 --- runsc/boot/filter/config.go | 4 + test/syscalls/linux/BUILD | 23 + test/syscalls/linux/socket_abstract.cc | 5 + test/syscalls/linux/socket_filesystem.cc | 5 + test/syscalls/linux/socket_unix.cc | 1518 ++---------------------------- test/syscalls/linux/socket_unix_cmsg.cc | 1473 +++++++++++++++++++++++++++++ test/syscalls/linux/socket_unix_cmsg.h | 30 + test/syscalls/linux/socket_unix_pair.cc | 5 + 10 files changed, 1655 insertions(+), 1626 deletions(-) create mode 100644 test/syscalls/linux/socket_unix_cmsg.cc create mode 100644 test/syscalls/linux/socket_unix_cmsg.h diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 3ed137006..e4ec0f62c 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -15,6 +15,7 @@ package host import ( + "fmt" "sync" "syscall" @@ -51,20 +52,6 @@ type ConnectedEndpoint struct { // ref keeps track of references to a connectedEndpoint. ref refs.AtomicRefCount - // mu protects fd, readClosed and writeClosed. - mu sync.RWMutex `state:"nosave"` - - // file is an *fd.FD containing the FD backing this endpoint. It must be - // set to nil if it has been closed. - file *fd.FD `state:"nosave"` - - // readClosed is true if the FD has read shutdown or if it has been closed. - readClosed bool - - // writeClosed is true if the FD has write shutdown or if it has been - // closed. - writeClosed bool - // If srfd >= 0, it is the host FD that file was imported from. srfd int `state:"wait"` @@ -78,6 +65,13 @@ type ConnectedEndpoint struct { // prevent lots of small messages from filling the real send buffer // size on the host. sndbuf int `state:"nosave"` + + // mu protects the fields below. + mu sync.RWMutex `state:"nosave"` + + // file is an *fd.FD containing the FD backing this endpoint. It must be + // set to nil if it has been closed. + file *fd.FD `state:"nosave"` } // init performs initialization required for creating new ConnectedEndpoints and @@ -208,9 +202,6 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() - if c.writeClosed { - return 0, false, syserr.ErrClosedForSend - } if !controlMessages.Empty() { return 0, false, syserr.ErrInvalidEndpointState @@ -244,8 +235,13 @@ func (c *ConnectedEndpoint) SendNotify() {} // CloseSend implements transport.ConnectedEndpoint.CloseSend. func (c *ConnectedEndpoint) CloseSend() { c.mu.Lock() - c.writeClosed = true - c.mu.Unlock() + defer c.mu.Unlock() + + if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_WR); err != nil { + // A well-formed UDS shutdown can't fail. See + // net/unix/af_unix.c:unix_shutdown. + panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err)) + } } // CloseNotify implements transport.ConnectedEndpoint.CloseNotify. @@ -255,9 +251,7 @@ func (c *ConnectedEndpoint) CloseNotify() {} func (c *ConnectedEndpoint) Writable() bool { c.mu.RLock() defer c.mu.RUnlock() - if c.writeClosed { - return true - } + return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0 } @@ -285,9 +279,6 @@ func (c *ConnectedEndpoint) EventUpdate() { func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() - if c.readClosed { - return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive - } var cm unet.ControlMessage if numRights > 0 { @@ -344,31 +335,34 @@ func (c *ConnectedEndpoint) RecvNotify() {} // CloseRecv implements transport.Receiver.CloseRecv. func (c *ConnectedEndpoint) CloseRecv() { c.mu.Lock() - c.readClosed = true - c.mu.Unlock() + defer c.mu.Unlock() + + if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_RD); err != nil { + // A well-formed UDS shutdown can't fail. See + // net/unix/af_unix.c:unix_shutdown. + panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err)) + } } // Readable implements transport.Receiver.Readable. func (c *ConnectedEndpoint) Readable() bool { c.mu.RLock() defer c.mu.RUnlock() - if c.readClosed { - return true - } + return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0 } // SendQueuedSize implements transport.Receiver.SendQueuedSize. func (c *ConnectedEndpoint) SendQueuedSize() int64 { - // SendQueuedSize isn't supported for host sockets because we don't allow the - // sentry to call ioctl(2). + // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host + // sockets because we don't allow the sentry to call ioctl(2). return -1 } // RecvQueuedSize implements transport.Receiver.RecvQueuedSize. func (c *ConnectedEndpoint) RecvQueuedSize() int64 { - // RecvQueuedSize isn't supported for host sockets because we don't allow the - // sentry to call ioctl(2). + // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host + // sockets because we don't allow the sentry to call ioctl(2). return -1 } diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go index 06392a65a..bc3ce5627 100644 --- a/pkg/sentry/fs/host/socket_test.go +++ b/pkg/sentry/fs/host/socket_test.go @@ -198,20 +198,6 @@ func TestListen(t *testing.T) { } } -func TestSend(t *testing.T) { - e := ConnectedEndpoint{writeClosed: true} - if _, _, err := e.Send(nil, transport.ControlMessages{}, tcpip.FullAddress{}); err != syserr.ErrClosedForSend { - t.Errorf("Got %#v.Send() = %v, want = %v", e, err, syserr.ErrClosedForSend) - } -} - -func TestRecv(t *testing.T) { - e := ConnectedEndpoint{readClosed: true} - if _, _, _, _, _, _, err := e.Recv(nil, false, 0, false); err != syserr.ErrClosedForReceive { - t.Errorf("Got %#v.Recv() = %v, want = %v", e, err, syserr.ErrClosedForReceive) - } -} - func TestPasscred(t *testing.T) { e := ConnectedEndpoint{} if got, want := e.Passcred(), false; got != want { @@ -244,20 +230,6 @@ func TestQueuedSize(t *testing.T) { } } -func TestReadable(t *testing.T) { - e := ConnectedEndpoint{readClosed: true} - if got, want := e.Readable(), true; got != want { - t.Errorf("Got %#v.Readable() = %t, want = %t", e, got, want) - } -} - -func TestWritable(t *testing.T) { - e := ConnectedEndpoint{writeClosed: true} - if got, want := e.Writable(), true; got != want { - t.Errorf("Got %#v.Writable() = %t, want = %t", e, got, want) - } -} - func TestRelease(t *testing.T) { f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { @@ -272,131 +244,3 @@ func TestRelease(t *testing.T) { t.Errorf("got = %#v, want = %#v", c, want) } } - -func TestClose(t *testing.T) { - type testCase struct { - name string - cep *ConnectedEndpoint - addFD bool - f func() - want *ConnectedEndpoint - } - - var tests []testCase - - // nil is the value used by ConnectedEndpoint to indicate a closed file. - // Non-nil files are used to check if the file gets closed. - - f, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} - tests = append(tests, testCase{ - name: "First CloseRecv", - cep: c, - addFD: false, - f: c.CloseRecv, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} - tests = append(tests, testCase{ - name: "Second CloseRecv", - cep: c, - addFD: false, - f: c.CloseRecv, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)} - tests = append(tests, testCase{ - name: "First CloseSend", - cep: c, - addFD: false, - f: c.CloseSend, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} - tests = append(tests, testCase{ - name: "Second CloseSend", - cep: c, - addFD: false, - f: c.CloseSend, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), writeClosed: true} - tests = append(tests, testCase{ - name: "CloseSend then CloseRecv", - cep: c, - addFD: true, - f: c.CloseRecv, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true} - tests = append(tests, testCase{ - name: "CloseRecv then CloseSend", - cep: c, - addFD: true, - f: c.CloseSend, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} - tests = append(tests, testCase{ - name: "Full close then CloseRecv", - cep: c, - addFD: false, - f: c.CloseRecv, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, - }) - - f, err = syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) - if err != nil { - t.Fatal("Creating socket:", err) - } - c = &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f), readClosed: true, writeClosed: true} - tests = append(tests, testCase{ - name: "Full close then CloseSend", - cep: c, - addFD: false, - f: c.CloseSend, - want: &ConnectedEndpoint{queue: c.queue, file: c.file, readClosed: true, writeClosed: true}, - }) - - for _, test := range tests { - if test.addFD { - fdnotifier.AddFD(int32(test.cep.file.FD()), nil) - } - if test.f(); !reflect.DeepEqual(test.cep, test.want) { - t.Errorf("%s: got = %#v, want = %#v", test.name, test.cep, test.want) - } - } -} diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 652da1cef..ef2dbfad2 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -246,6 +246,10 @@ var allowedSyscalls = seccomp.SyscallRules{ }, syscall.SYS_SETITIMER: {}, syscall.SYS_SHUTDOWN: []seccomp.Rule{ + // Used by fs/host to shutdown host sockets. + {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RD)}, + {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_WR)}, + // Used by unet to shutdown connections. {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)}, }, syscall.SYS_SIGALTSTACK: {}, diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 7633ab162..0cb7b47b6 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2096,6 +2096,7 @@ cc_binary( deps = [ ":socket_generic_test_cases", ":socket_test_util", + ":socket_unix_cmsg_test_cases", ":socket_unix_test_cases", ":unix_domain_socket_test_util", "//test/util:test_main", @@ -2369,6 +2370,7 @@ cc_binary( deps = [ ":socket_generic_test_cases", ":socket_test_util", + ":socket_unix_cmsg_test_cases", ":socket_unix_test_cases", ":unix_domain_socket_test_util", "//test/util:test_main", @@ -2490,6 +2492,26 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "socket_unix_cmsg_test_cases", + testonly = 1, + srcs = [ + "socket_unix_cmsg.cc", + ], + hdrs = [ + "socket_unix_cmsg.h", + ], + deps = [ + ":socket_test_util", + ":unix_domain_socket_test_util", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], + alwayslink = 1, +) + cc_library( name = "socket_stream_blocking_test_cases", testonly = 1, @@ -2733,6 +2755,7 @@ cc_binary( linkstatic = 1, deps = [ ":socket_test_util", + ":socket_unix_cmsg_test_cases", ":socket_unix_test_cases", ":unix_domain_socket_test_util", "//test/util:test_main", diff --git a/test/syscalls/linux/socket_abstract.cc b/test/syscalls/linux/socket_abstract.cc index 503ba986b..715d87b76 100644 --- a/test/syscalls/linux/socket_abstract.cc +++ b/test/syscalls/linux/socket_abstract.cc @@ -17,6 +17,7 @@ #include "test/syscalls/linux/socket_generic.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/socket_unix.h" +#include "test/syscalls/linux/socket_unix_cmsg.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -38,5 +39,9 @@ INSTANTIATE_TEST_SUITE_P( AbstractUnixSockets, UnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +INSTANTIATE_TEST_SUITE_P( + AbstractUnixSockets, UnixSocketPairCmsgTest, + ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_filesystem.cc b/test/syscalls/linux/socket_filesystem.cc index e38a320f6..74e262959 100644 --- a/test/syscalls/linux/socket_filesystem.cc +++ b/test/syscalls/linux/socket_filesystem.cc @@ -17,6 +17,7 @@ #include "test/syscalls/linux/socket_generic.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/socket_unix.h" +#include "test/syscalls/linux/socket_unix_cmsg.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -38,5 +39,9 @@ INSTANTIATE_TEST_SUITE_P( FilesystemUnixSockets, UnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +INSTANTIATE_TEST_SUITE_P( + FilesystemUnixSockets, UnixSocketPairCmsgTest, + ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index 95cf8d2a3..875f0391f 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -32,1437 +32,16 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -// This file is a generic socket test file. It must be built with another file -// that provides the test types. - -namespace gvisor { -namespace testing { - -namespace { - -TEST_P(UnixSocketPairTest, BasicFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairTest, BasicTwoFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair1->second_fd(), pair2->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); - - char received_data[20]; - int received_fds[] = {-1, -1}; - - ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 2, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); -} - -TEST_P(UnixSocketPairTest, BasicThreeFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair3 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); - - char received_data[20]; - int received_fds[] = {-1, -1, -1}; - - ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 3, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); - ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[2], pair3->first_fd())); -} - -TEST_P(UnixSocketPairTest, BadFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - int sent_fd = -1; - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(sent_fd))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_len = CMSG_LEN(sizeof(sent_fd)); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd)); - - struct iovec iov; - iov.iov_base = sent_data; - iov.iov_len = sizeof(sent_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EBADF)); -} - -// BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass. -// The difference is that when calling recvmsg, no space for FDs is provided, -// only space for the cmsg header. -TEST_P(UnixSocketPairTest, BasicFDPassNoSpace) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - - struct msghdr msg = {}; - std::vector control(CMSG_SPACE(0)); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -// BasicFDPassNoSpaceMsgCtrunc sends an FD, but does not provide any space to -// receive it. It then verifies that the MSG_CTRUNC flag is set in the msghdr. -TEST_P(UnixSocketPairTest, BasicFDPassNoSpaceMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - std::vector control(CMSG_SPACE(0)); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); -} - -// BasicFDPassNullControlMsgCtrunc sends an FD and sets contradictory values for -// msg_controllen and msg_control. msg_controllen is set to the correct size to -// accomidate the FD, but msg_control is set to NULL. In this case, msg_control -// should override msg_controllen. -TEST_P(UnixSocketPairTest, BasicFDPassNullControlMsgCtrunc) { - // FIXME(gvisor.dev/issue/207): Fix handling of NULL msg_control. - SKIP_IF(IsRunningOnGvisor()); - - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - msg.msg_controllen = CMSG_SPACE(1); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); -} - -// BasicFDPassNotEnoughSpaceMsgCtrunc sends an FD, but does not provide enough -// space to receive it. It then verifies that the MSG_CTRUNC flag is set in the -// msghdr. -TEST_P(UnixSocketPairTest, BasicFDPassNotEnoughSpaceMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - std::vector control(CMSG_SPACE(0) + 1); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_controllen, 0); - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); -} - -// BasicThreeFDPassTruncationMsgCtrunc sends three FDs, but only provides enough -// space to receive two of them. It then verifies that the MSG_CTRUNC flag is -// set in the msghdr. -TEST_P(UnixSocketPairTest, BasicThreeFDPassTruncationMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - auto pair3 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - std::vector control(CMSG_SPACE(2 * sizeof(int))); - msg.msg_control = &control[0]; - msg.msg_controllen = control.size(); - - char received_data[sizeof(sent_data)]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(2 * sizeof(int))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); -} - -// BasicFDPassUnalignedRecv starts off by sending a single FD just like -// BasicFDPass. The difference is that when calling recvmsg, the length of the -// receive data is only aligned on a 4 byte boundry instead of the normal 8. -TEST_P(UnixSocketPairTest, BasicFDPassUnalignedRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFDUnaligned( - sockets->second_fd(), &fd, received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -// BasicFDPassUnalignedRecvNoMsgTrunc sends one FD and only provides enough -// space to receive just it. (Normally the minimum amount of space one would -// provide would be enough space for two FDs.) It then verifies that the -// MSG_CTRUNC flag is not set in the msghdr. -TEST_P(UnixSocketPairTest, BasicFDPassUnalignedRecvNoMsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_flags, 0); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); -} - -// BasicTwoFDPassUnalignedRecvTruncationMsgTrunc sends two FDs, but only -// provides enough space to receive one of them. It then verifies that the -// MSG_CTRUNC flag is set in the msghdr. -TEST_P(UnixSocketPairTest, BasicTwoFDPassUnalignedRecvTruncationMsgTrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - int sent_fds[] = {pair->first_fd(), pair->second_fd()}; - - ASSERT_NO_FATAL_FAILURE( - SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - // CMSG_SPACE rounds up to two FDs, we only want one. - char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); -} - -TEST_P(UnixSocketPairTest, ConcurrentBasicFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - int sockfd1 = sockets->first_fd(); - auto recv_func = [sockfd1, sent_data]() { - char received_data[20]; - int fd = -1; - RecvSingleFD(sockfd1, &fd, received_data, sizeof(received_data)); - ASSERT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - char buf[20]; - ASSERT_THAT(ReadFd(fd, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(WriteFd(fd, buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - }; - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->second_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - ScopedThread t(recv_func); - - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(WriteFd(pair->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[20]; - ASSERT_THAT(ReadFd(pair->first_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - t.Join(); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -// FDPassNoRecv checks that the control message can be safely ignored by using -// read(2) instead of recvmsg(2). -TEST_P(UnixSocketPairTest, FDPassNoRecv) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - // Read while ignoring the passed FD. - char received_data[20]; - ASSERT_THAT( - ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // Check that the socket still works for reads and writes. - ASSERT_NO_FATAL_FAILURE( - TransferTest(sockets->first_fd(), sockets->second_fd())); -} - -// FDPassInterspersed1 checks that sent control messages cannot be read before -// their associated data has been read. -TEST_P(UnixSocketPairTest, FDPassInterspersed1) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char written_data[20]; - RandomizeBuffer(written_data, sizeof(written_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), - SyscallSucceedsWithValue(sizeof(written_data))); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - // Check that we don't get a control message, but do get the data. - char received_data[20]; - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)); - EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); -} - -// FDPassInterspersed2 checks that sent control messages cannot be read after -// their assocated data has been read while ignoring the control message by -// using read(2) instead of recvmsg(2). -TEST_P(UnixSocketPairTest, FDPassInterspersed2) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char written_data[20]; - RandomizeBuffer(written_data, sizeof(written_data)); - ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), - SyscallSucceedsWithValue(sizeof(written_data))); - - char received_data[20]; - ASSERT_THAT( - ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); -} - -TEST_P(UnixSocketPairTest, FDPassNotCoalesced) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data1[20]; - RandomizeBuffer(sent_data1, sizeof(sent_data1)); - - auto pair1 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(), - sent_data1, sizeof(sent_data1))); - - char sent_data2[20]; - RandomizeBuffer(sent_data2, sizeof(sent_data2)); - - auto pair2 = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(), - sent_data2, sizeof(sent_data2))); - - char received_data1[sizeof(sent_data1) + sizeof(sent_data2)]; - int received_fd1 = -1; - - RecvSingleFD(sockets->second_fd(), &received_fd1, received_data1, - sizeof(received_data1), sizeof(sent_data1)); - - EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1))); - TransferTest(pair1->first_fd(), pair1->second_fd()); - - char received_data2[sizeof(sent_data1) + sizeof(sent_data2)]; - int received_fd2 = -1; - - RecvSingleFD(sockets->second_fd(), &received_fd2, received_data2, - sizeof(received_data2), sizeof(sent_data2)); - - EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2))); - TransferTest(pair2->first_fd(), pair2->second_fd()); -} - -TEST_P(UnixSocketPairTest, FDPassPeek) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char peek_data[20]; - int peek_fd = -1; - PeekSingleFD(sockets->second_fd(), &peek_fd, peek_data, sizeof(peek_data)); - EXPECT_EQ(0, memcmp(sent_data, peek_data, sizeof(sent_data))); - TransferTest(peek_fd, pair->first_fd()); - EXPECT_THAT(close(peek_fd), SyscallSucceeds()); - - char received_data[20]; - int received_fd = -1; - RecvSingleFD(sockets->second_fd(), &received_fd, received_data, - sizeof(received_data)); - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - TransferTest(received_fd, pair->first_fd()); - EXPECT_THAT(close(received_fd), SyscallSucceeds()); -} - -TEST_P(UnixSocketPairTest, BasicCredPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - EXPECT_EQ(sent_creds.pid, received_creds.pid); - EXPECT_EQ(sent_creds.uid, received_creds.uid); - EXPECT_EQ(sent_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, SendNullCredsBeforeSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, SendNullCredsAfterSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, SendNullCredsBeforeSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->first_fd()); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, SendNullCredsAfterSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->first_fd()); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, SendNullCredsBeforeSoPassCredRecvEndAfterSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->first_fd()); - - ASSERT_NO_FATAL_FAILURE( - SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, WriteBeforeSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, WriteAfterSoPassCredRecvEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->second_fd()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[20]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, WriteBeforeSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - SetSoPassCred(sockets->first_fd()); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, WriteAfterSoPassCredSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - SetSoPassCred(sockets->first_fd()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - char received_data[20]; - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, WriteBeforeSoPassCredRecvEndAfterSendEnd) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - SetSoPassCred(sockets->first_fd()); - - ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), - SyscallSucceedsWithValue(sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - - struct ucred received_creds; - ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, - received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); -} - -TEST_P(UnixSocketPairTest, CredPassTruncated) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(0) + sizeof(pid_t)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); - - pid_t pid = 0; - memcpy(&pid, CMSG_DATA(cmsg), sizeof(pid)); - EXPECT_EQ(pid, sent_creds.pid); -} - -// CredPassNoMsgCtrunc passes a full set of credentials. It then verifies that -// receiving the full set does not result in MSG_CTRUNC being set in the msghdr. -TEST_P(UnixSocketPairTest, CredPassNoMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(struct ucred))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // The control message should not be truncated. - EXPECT_EQ(msg.msg_flags, 0); - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct ucred))); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -// CredPassNoSpaceMsgCtrunc passes a full set of credentials. It then receives -// the data without providing space for any credentials and verifies that -// MSG_CTRUNC is set in the msghdr. -TEST_P(UnixSocketPairTest, CredPassNoSpaceMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(0)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // The control message should be truncated. - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -// CredPassTruncatedMsgCtrunc passes a full set of credentials. It then receives -// the data while providing enough space for only the first field of the -// credentials and verifies that MSG_CTRUNC is set in the msghdr. -TEST_P(UnixSocketPairTest, CredPassTruncatedMsgCtrunc) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - struct msghdr msg = {}; - char control[CMSG_SPACE(0) + sizeof(pid_t)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[sizeof(sent_data)] = {}; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - // The control message should be truncated. - EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -TEST_P(UnixSocketPairTest, SoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int opt; - socklen_t optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - SetSoPassCred(sockets->first_fd()); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_TRUE(opt); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - int zero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &zero, - sizeof(zero)), - SyscallSucceeds()); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); - - optLen = sizeof(opt); - EXPECT_THAT( - getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), - SyscallSucceeds()); - EXPECT_FALSE(opt); -} - -TEST_P(UnixSocketPairTest, NoDataCredPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct msghdr msg = {}; - - struct iovec iov; - iov.iov_base = sent_data; - iov.iov_len = sizeof(sent_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - char control[CMSG_SPACE(0)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_CREDENTIALS; - cmsg->cmsg_len = CMSG_LEN(0); - - ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UnixSocketPairTest, NoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - ASSERT_NO_FATAL_FAILURE( - SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); - - char received_data[20]; - - ASSERT_NO_FATAL_FAILURE( - RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); -} - -TEST_P(UnixSocketPairTest, CredAndFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - struct ucred sent_creds; - - ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendCredsAndFD(sockets->first_fd(), sent_creds, - pair->second_fd(), sent_data, - sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, - &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(sent_creds.pid, received_creds.pid); - EXPECT_EQ(sent_creds.uid, received_creds.uid); - EXPECT_EQ(sent_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairTest, FDPassBeforeSoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - SetSoPassCred(sockets->second_fd()); - - char received_data[20]; - struct ucred received_creds; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, - &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds { - 0, 65534, 65534 - }; - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairTest, FDPassAfterSoPassCred) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - struct ucred received_creds; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, - &fd, received_data, - sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - struct ucred want_creds; - ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); - ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); - - EXPECT_EQ(want_creds.pid, received_creds.pid); - EXPECT_EQ(want_creds.uid, received_creds.uid); - EXPECT_EQ(want_creds.gid, received_creds.gid); - - ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); -} - -TEST_P(UnixSocketPairTest, CloexecDroppedWhenFDPassed) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = ASSERT_NO_ERRNO_AND_VALUE( - UnixDomainSocketPair(SOCK_SEQPACKET | SOCK_CLOEXEC).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - char received_data[20]; - int fd = -1; - ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, - sizeof(received_data))); - - EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(0)); -} - -TEST_P(UnixSocketPairTest, CloexecRecvFDPass) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int))]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - char received_data[20]; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CMSG_CLOEXEC), - SyscallSucceedsWithValue(sizeof(received_data))); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); - - int fd = -1; - memcpy(&fd, CMSG_DATA(cmsg), sizeof(int)); - - EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); -} - -TEST_P(UnixSocketPairTest, FDPassAfterSoPassCredWithoutCredSpace) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_LEN(0)]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[20]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); - - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - - EXPECT_EQ(msg.msg_controllen, sizeof(control)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); - EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); - EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); -} - -// This test will validate that MSG_CTRUNC as an input flag to recvmsg will -// not appear as an output flag on the control message when truncation doesn't -// happen. -TEST_P(UnixSocketPairTest, MsgCtruncInputIsNoop) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_SPACE(sizeof(int)) /* we're passing a single fd */]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - struct iovec iov; - char received_data[20]; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CTRUNC), - SyscallSucceedsWithValue(sizeof(received_data))); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); - - // Now we should verify that MSG_CTRUNC wasn't set as an output flag. - EXPECT_EQ(msg.msg_flags & MSG_CTRUNC, 0); -} - -TEST_P(UnixSocketPairTest, FDPassAfterSoPassCredWithoutCredHeaderSpace) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - char sent_data[20]; - RandomizeBuffer(sent_data, sizeof(sent_data)); - - auto pair = - ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); - - SetSoPassCred(sockets->second_fd()); - - ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), - sent_data, sizeof(sent_data))); - - struct msghdr msg = {}; - char control[CMSG_LEN(0) / 2]; - msg.msg_control = control; - msg.msg_controllen = sizeof(control); - - char received_data[20]; - struct iovec iov; - iov.iov_base = received_data; - iov.iov_len = sizeof(received_data); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; +// This file contains tests specific to Unix domain sockets. It does not contain +// tests for UDS control messages. Those belong in socket_unix_cmsg.cc. +// +// This file is a generic socket test file. It must be built with another file +// that provides the test types. - ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), - SyscallSucceedsWithValue(sizeof(received_data))); +namespace gvisor { +namespace testing { - EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); - EXPECT_EQ(msg.msg_controllen, 0); -} +namespace { TEST_P(UnixSocketPairTest, InvalidGetSockOpt) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -1519,6 +98,14 @@ TEST_P(UnixSocketPairTest, RecvmmsgTimeoutAfterRecv) { TEST_P(UnixSocketPairTest, TIOCINQSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/273): Inherited host UDS don't support TIOCINQ. + // Skip the test. + int size = -1; + int ret = ioctl(sockets->first_fd(), TIOCINQ, &size); + SKIP_IF(ret == -1 && errno == ENOTTY); + } + int size = -1; EXPECT_THAT(ioctl(sockets->first_fd(), TIOCINQ, &size), SyscallSucceeds()); EXPECT_EQ(size, 0); @@ -1544,6 +131,14 @@ TEST_P(UnixSocketPairTest, TIOCINQSucceeds) { TEST_P(UnixSocketPairTest, TIOCOUTQSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/273): Inherited host UDS don't support TIOCOUTQ. + // Skip the test. + int size = -1; + int ret = ioctl(sockets->second_fd(), TIOCOUTQ, &size); + SKIP_IF(ret == -1 && errno == ENOTTY); + } + int size = -1; EXPECT_THAT(ioctl(sockets->second_fd(), TIOCOUTQ, &size), SyscallSucceeds()); EXPECT_EQ(size, 0); @@ -1580,19 +175,70 @@ TEST_P(UnixSocketPairTest, NetdeviceIoctlsSucceed) { } } -TEST_P(UnixSocketPairTest, SocketShutdown) { +TEST_P(UnixSocketPairTest, Shutdown) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - char buf[20]; + const std::string data = "abc"; - ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), 3), - SyscallSucceedsWithValue(3)); + ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); // Shutting down a socket does not clear the buffer. - ASSERT_THAT(ReadFd(sockets->second_fd(), buf, 3), - SyscallSucceedsWithValue(3)); - EXPECT_EQ(data, absl::string_view(buf, 3)); + char buf[3]; + ASSERT_THAT(ReadFd(sockets->second_fd(), buf, data.size()), + SyscallSucceedsWithValue(data.size())); + EXPECT_EQ(data, absl::string_view(buf, data.size())); +} + +TEST_P(UnixSocketPairTest, ShutdownRead) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RD), SyscallSucceeds()); + + // When the socket is shutdown for read, read behavior varies between + // different socket types. This is covered by the various ReadOneSideClosed + // test cases. + + // ... and the peer cannot write. + const std::string data = "abc"; + EXPECT_THAT(WriteFd(sockets->second_fd(), data.c_str(), data.size()), + SyscallFailsWithErrno(EPIPE)); + + // ... but the socket can still write. + ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + + // ... and the peer can still read. + char buf[3]; + EXPECT_THAT(ReadFd(sockets->second_fd(), buf, data.size()), + SyscallSucceedsWithValue(data.size())); + EXPECT_EQ(data, absl::string_view(buf, data.size())); +} + +TEST_P(UnixSocketPairTest, ShutdownWrite) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_WR), SyscallSucceeds()); + + // When the socket is shutdown for write, it cannot write. + const std::string data = "abc"; + EXPECT_THAT(WriteFd(sockets->first_fd(), data.c_str(), data.size()), + SyscallFailsWithErrno(EPIPE)); + + // ... and the peer read behavior varies between different socket types. This + // is covered by the various ReadOneSideClosed test cases. + + // ... but the peer can still write. + char buf[3]; + ASSERT_THAT(WriteFd(sockets->second_fd(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + + // ... and the socket can still read. + EXPECT_THAT(ReadFd(sockets->first_fd(), buf, data.size()), + SyscallSucceedsWithValue(data.size())); + EXPECT_EQ(data, absl::string_view(buf, data.size())); } TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) { diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc new file mode 100644 index 000000000..b0ab26847 --- /dev/null +++ b/test/syscalls/linux/socket_unix_cmsg.cc @@ -0,0 +1,1473 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_unix_cmsg.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "absl/strings/string_view.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +// This file contains tests for control message in Unix domain sockets. +// +// This file is a generic socket test file. It must be built with another file +// that provides the test types. + +namespace gvisor { +namespace testing { + +namespace { + +TEST_P(UnixSocketPairCmsgTest, BasicFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, + sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, BasicTwoFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair1 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair2 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + int sent_fds[] = {pair1->second_fd(), pair2->second_fd()}; + + ASSERT_NO_FATAL_FAILURE( + SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); + + char received_data[20]; + int received_fds[] = {-1, -1}; + + ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 2, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair1 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair2 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair3 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; + + ASSERT_NO_FATAL_FAILURE( + SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); + + char received_data[20]; + int received_fds[] = {-1, -1, -1}; + + ASSERT_NO_FATAL_FAILURE(RecvFDs(sockets->second_fd(), received_fds, 3, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[0], pair1->first_fd())); + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[1], pair2->first_fd())); + ASSERT_NO_FATAL_FAILURE(TransferTest(received_fds[2], pair3->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, BadFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + int sent_fd = -1; + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(sent_fd))]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = CMSG_LEN(sizeof(sent_fd)); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd)); + + struct iovec iov; + iov.iov_base = sent_data; + iov.iov_len = sizeof(sent_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), + SyscallFailsWithErrno(EBADF)); +} + +// BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass. +// The difference is that when calling recvmsg, no space for FDs is provided, +// only space for the cmsg header. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassNoSpace) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + + struct msghdr msg = {}; + std::vector control(CMSG_SPACE(0)); + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_controllen, 0); + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +// BasicFDPassNoSpaceMsgCtrunc sends an FD, but does not provide any space to +// receive it. It then verifies that the MSG_CTRUNC flag is set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassNoSpaceMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + std::vector control(CMSG_SPACE(0)); + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + char received_data[sizeof(sent_data)]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_controllen, 0); + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); +} + +// BasicFDPassNullControlMsgCtrunc sends an FD and sets contradictory values for +// msg_controllen and msg_control. msg_controllen is set to the correct size to +// accomidate the FD, but msg_control is set to NULL. In this case, msg_control +// should override msg_controllen. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassNullControlMsgCtrunc) { + // FIXME(gvisor.dev/issue/207): Fix handling of NULL msg_control. + SKIP_IF(IsRunningOnGvisor()); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + msg.msg_controllen = CMSG_SPACE(1); + + char received_data[sizeof(sent_data)]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_controllen, 0); + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); +} + +// BasicFDPassNotEnoughSpaceMsgCtrunc sends an FD, but does not provide enough +// space to receive it. It then verifies that the MSG_CTRUNC flag is set in the +// msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassNotEnoughSpaceMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + std::vector control(CMSG_SPACE(0) + 1); + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + char received_data[sizeof(sent_data)]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_controllen, 0); + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); +} + +// BasicThreeFDPassTruncationMsgCtrunc sends three FDs, but only provides enough +// space to receive two of them. It then verifies that the MSG_CTRUNC flag is +// set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicThreeFDPassTruncationMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair1 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair2 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + auto pair3 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + int sent_fds[] = {pair1->second_fd(), pair2->second_fd(), pair3->second_fd()}; + + ASSERT_NO_FATAL_FAILURE( + SendFDs(sockets->first_fd(), sent_fds, 3, sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + std::vector control(CMSG_SPACE(2 * sizeof(int))); + msg.msg_control = &control[0]; + msg.msg_controllen = control.size(); + + char received_data[sizeof(sent_data)]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(2 * sizeof(int))); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); +} + +// BasicFDPassUnalignedRecv starts off by sending a single FD just like +// BasicFDPass. The difference is that when calling recvmsg, the length of the +// receive data is only aligned on a 4 byte boundry instead of the normal 8. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecv) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvSingleFDUnaligned( + sockets->second_fd(), &fd, received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +// BasicFDPassUnalignedRecvNoMsgTrunc sends one FD and only provides enough +// space to receive just it. (Normally the minimum amount of space one would +// provide would be enough space for two FDs.) It then verifies that the +// MSG_CTRUNC flag is not set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicFDPassUnalignedRecvNoMsgTrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_flags, 0); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); +} + +// BasicTwoFDPassUnalignedRecvTruncationMsgTrunc sends two FDs, but only +// provides enough space to receive one of them. It then verifies that the +// MSG_CTRUNC flag is set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, BasicTwoFDPassUnalignedRecvTruncationMsgTrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + int sent_fds[] = {pair->first_fd(), pair->second_fd()}; + + ASSERT_NO_FATAL_FAILURE( + SendFDs(sockets->first_fd(), sent_fds, 2, sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + // CMSG_SPACE rounds up to two FDs, we only want one. + char control[CMSG_SPACE(sizeof(int)) - sizeof(int)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_RIGHTS); +} + +TEST_P(UnixSocketPairCmsgTest, ConcurrentBasicFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + int sockfd1 = sockets->first_fd(); + auto recv_func = [sockfd1, sent_data]() { + char received_data[20]; + int fd = -1; + RecvSingleFD(sockfd1, &fd, received_data, sizeof(received_data)); + ASSERT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + char buf[20]; + ASSERT_THAT(ReadFd(fd, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(WriteFd(fd, buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + }; + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->second_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + ScopedThread t(recv_func); + + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT(WriteFd(pair->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[20]; + ASSERT_THAT(ReadFd(pair->first_fd(), received_data, sizeof(received_data)), + SyscallSucceedsWithValue(sizeof(received_data))); + + t.Join(); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +// FDPassNoRecv checks that the control message can be safely ignored by using +// read(2) instead of recvmsg(2). +TEST_P(UnixSocketPairCmsgTest, FDPassNoRecv) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + // Read while ignoring the passed FD. + char received_data[20]; + ASSERT_THAT( + ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + // Check that the socket still works for reads and writes. + ASSERT_NO_FATAL_FAILURE( + TransferTest(sockets->first_fd(), sockets->second_fd())); +} + +// FDPassInterspersed1 checks that sent control messages cannot be read before +// their associated data has been read. +TEST_P(UnixSocketPairCmsgTest, FDPassInterspersed1) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char written_data[20]; + RandomizeBuffer(written_data, sizeof(written_data)); + + ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), + SyscallSucceedsWithValue(sizeof(written_data))); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + // Check that we don't get a control message, but do get the data. + char received_data[20]; + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data)); + EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); +} + +// FDPassInterspersed2 checks that sent control messages cannot be read after +// their assocated data has been read while ignoring the control message by +// using read(2) instead of recvmsg(2). +TEST_P(UnixSocketPairCmsgTest, FDPassInterspersed2) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char written_data[20]; + RandomizeBuffer(written_data, sizeof(written_data)); + ASSERT_THAT(WriteFd(sockets->first_fd(), written_data, sizeof(written_data)), + SyscallSucceedsWithValue(sizeof(written_data))); + + char received_data[20]; + ASSERT_THAT( + ReadFd(sockets->second_fd(), received_data, sizeof(received_data)), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + EXPECT_EQ(0, memcmp(written_data, received_data, sizeof(written_data))); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassNotCoalesced) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data1[20]; + RandomizeBuffer(sent_data1, sizeof(sent_data1)); + + auto pair1 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair1->second_fd(), + sent_data1, sizeof(sent_data1))); + + char sent_data2[20]; + RandomizeBuffer(sent_data2, sizeof(sent_data2)); + + auto pair2 = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair2->second_fd(), + sent_data2, sizeof(sent_data2))); + + char received_data1[sizeof(sent_data1) + sizeof(sent_data2)]; + int received_fd1 = -1; + + RecvSingleFD(sockets->second_fd(), &received_fd1, received_data1, + sizeof(received_data1), sizeof(sent_data1)); + + EXPECT_EQ(0, memcmp(sent_data1, received_data1, sizeof(sent_data1))); + TransferTest(pair1->first_fd(), pair1->second_fd()); + + char received_data2[sizeof(sent_data1) + sizeof(sent_data2)]; + int received_fd2 = -1; + + RecvSingleFD(sockets->second_fd(), &received_fd2, received_data2, + sizeof(received_data2), sizeof(sent_data2)); + + EXPECT_EQ(0, memcmp(sent_data2, received_data2, sizeof(sent_data2))); + TransferTest(pair2->first_fd(), pair2->second_fd()); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassPeek) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char peek_data[20]; + int peek_fd = -1; + PeekSingleFD(sockets->second_fd(), &peek_fd, peek_data, sizeof(peek_data)); + EXPECT_EQ(0, memcmp(sent_data, peek_data, sizeof(sent_data))); + TransferTest(peek_fd, pair->first_fd()); + EXPECT_THAT(close(peek_fd), SyscallSucceeds()); + + char received_data[20]; + int received_fd = -1; + RecvSingleFD(sockets->second_fd(), &received_fd, received_data, + sizeof(received_data)); + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + TransferTest(received_fd, pair->first_fd()); + EXPECT_THAT(close(received_fd), SyscallSucceeds()); +} + +TEST_P(UnixSocketPairCmsgTest, BasicCredPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + EXPECT_EQ(sent_creds.pid, received_creds.pid); + EXPECT_EQ(sent_creds.uid, received_creds.uid); + EXPECT_EQ(sent_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, SendNullCredsBeforeSoPassCredRecvEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds { + 0, 65534, 65534 + }; + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, SendNullCredsAfterSoPassCredRecvEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + SetSoPassCred(sockets->second_fd()); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + char received_data[20]; + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, SendNullCredsBeforeSoPassCredSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->first_fd()); + + char received_data[20]; + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, SendNullCredsAfterSoPassCredSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + SetSoPassCred(sockets->first_fd()); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + char received_data[20]; + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, + SendNullCredsBeforeSoPassCredRecvEndAfterSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + SetSoPassCred(sockets->first_fd()); + + ASSERT_NO_FATAL_FAILURE( + SendNullCmsg(sockets->first_fd(), sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredRecvEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds { + 0, 65534, 65534 + }; + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, WriteAfterSoPassCredRecvEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + SetSoPassCred(sockets->second_fd()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[20]; + + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + SetSoPassCred(sockets->first_fd()); + + char received_data[20]; + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, WriteAfterSoPassCredSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + SetSoPassCred(sockets->first_fd()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + char received_data[20]; + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, WriteBeforeSoPassCredRecvEndAfterSendEnd) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + SetSoPassCred(sockets->first_fd()); + + ASSERT_THAT(WriteFd(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + + struct ucred received_creds; + ASSERT_NO_FATAL_FAILURE(RecvCreds(sockets->second_fd(), &received_creds, + received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); +} + +TEST_P(UnixSocketPairCmsgTest, CredPassTruncated) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + struct msghdr msg = {}; + char control[CMSG_SPACE(0) + sizeof(pid_t)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); + + pid_t pid = 0; + memcpy(&pid, CMSG_DATA(cmsg), sizeof(pid)); + EXPECT_EQ(pid, sent_creds.pid); +} + +// CredPassNoMsgCtrunc passes a full set of credentials. It then verifies that +// receiving the full set does not result in MSG_CTRUNC being set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, CredPassNoMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(struct ucred))]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + // The control message should not be truncated. + EXPECT_EQ(msg.msg_flags, 0); + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct ucred))); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); +} + +// CredPassNoSpaceMsgCtrunc passes a full set of credentials. It then receives +// the data without providing space for any credentials and verifies that +// MSG_CTRUNC is set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, CredPassNoSpaceMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + struct msghdr msg = {}; + char control[CMSG_SPACE(0)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + // The control message should be truncated. + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); +} + +// CredPassTruncatedMsgCtrunc passes a full set of credentials. It then receives +// the data while providing enough space for only the first field of the +// credentials and verifies that MSG_CTRUNC is set in the msghdr. +TEST_P(UnixSocketPairCmsgTest, CredPassTruncatedMsgCtrunc) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + struct msghdr msg = {}; + char control[CMSG_SPACE(0) + sizeof(pid_t)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[sizeof(sent_data)] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + // The control message should be truncated. + EXPECT_EQ(msg.msg_flags, MSG_CTRUNC); + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); +} + +TEST_P(UnixSocketPairCmsgTest, SoPassCred) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int opt; + socklen_t optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); + + SetSoPassCred(sockets->first_fd()); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_TRUE(opt); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); + + int zero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &zero, + sizeof(zero)), + SyscallSucceeds()); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); + + optLen = sizeof(opt); + EXPECT_THAT( + getsockopt(sockets->second_fd(), SOL_SOCKET, SO_PASSCRED, &opt, &optLen), + SyscallSucceeds()); + EXPECT_FALSE(opt); +} + +TEST_P(UnixSocketPairCmsgTest, NoDataCredPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct msghdr msg = {}; + + struct iovec iov; + iov.iov_base = sent_data; + iov.iov_len = sizeof(sent_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + char control[CMSG_SPACE(0)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_CREDENTIALS; + cmsg->cmsg_len = CMSG_LEN(0); + + ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(UnixSocketPairCmsgTest, NoPassCred) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + ASSERT_NO_FATAL_FAILURE( + SendCreds(sockets->first_fd(), sent_creds, sent_data, sizeof(sent_data))); + + char received_data[20]; + + ASSERT_NO_FATAL_FAILURE( + RecvNoCmsg(sockets->second_fd(), received_data, sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); +} + +TEST_P(UnixSocketPairCmsgTest, CredAndFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + struct ucred sent_creds; + + ASSERT_THAT(sent_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(sent_creds.gid = getgid(), SyscallSucceeds()); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendCredsAndFD(sockets->first_fd(), sent_creds, + pair->second_fd(), sent_data, + sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, + &fd, received_data, + sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + EXPECT_EQ(sent_creds.pid, received_creds.pid); + EXPECT_EQ(sent_creds.uid, received_creds.uid); + EXPECT_EQ(sent_creds.gid, received_creds.gid); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassBeforeSoPassCred) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + SetSoPassCred(sockets->second_fd()); + + char received_data[20]; + struct ucred received_creds; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, + &fd, received_data, + sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds { + 0, 65534, 65534 + }; + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCred) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + SetSoPassCred(sockets->second_fd()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + struct ucred received_creds; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvCredsAndFD(sockets->second_fd(), &received_creds, + &fd, received_data, + sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + struct ucred want_creds; + ASSERT_THAT(want_creds.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.uid = getuid(), SyscallSucceeds()); + ASSERT_THAT(want_creds.gid = getgid(), SyscallSucceeds()); + + EXPECT_EQ(want_creds.pid, received_creds.pid); + EXPECT_EQ(want_creds.uid, received_creds.uid); + EXPECT_EQ(want_creds.gid, received_creds.gid); + + ASSERT_NO_FATAL_FAILURE(TransferTest(fd, pair->first_fd())); +} + +TEST_P(UnixSocketPairCmsgTest, CloexecDroppedWhenFDPassed) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = ASSERT_NO_ERRNO_AND_VALUE( + UnixDomainSocketPair(SOCK_SEQPACKET | SOCK_CLOEXEC).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + char received_data[20]; + int fd = -1; + ASSERT_NO_FATAL_FAILURE(RecvSingleFD(sockets->second_fd(), &fd, received_data, + sizeof(received_data))); + + EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(0)); +} + +TEST_P(UnixSocketPairCmsgTest, CloexecRecvFDPass) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(int))]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct iovec iov; + char received_data[20]; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CMSG_CLOEXEC), + SyscallSucceedsWithValue(sizeof(received_data))); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); + + int fd = -1; + memcpy(&fd, CMSG_DATA(cmsg), sizeof(int)); + + EXPECT_THAT(fcntl(fd, F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCredWithoutCredSpace) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + SetSoPassCred(sockets->second_fd()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_LEN(0)]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[20]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + + EXPECT_EQ(msg.msg_controllen, sizeof(control)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, sizeof(control)); + EXPECT_EQ(cmsg->cmsg_level, SOL_SOCKET); + EXPECT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); +} + +// This test will validate that MSG_CTRUNC as an input flag to recvmsg will +// not appear as an output flag on the control message when truncation doesn't +// happen. +TEST_P(UnixSocketPairCmsgTest, MsgCtruncInputIsNoop) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_SPACE(sizeof(int)) /* we're passing a single fd */]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + struct iovec iov; + char received_data[20]; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_CTRUNC), + SyscallSucceedsWithValue(sizeof(received_data))); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(int))); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SCM_RIGHTS); + + // Now we should verify that MSG_CTRUNC wasn't set as an output flag. + EXPECT_EQ(msg.msg_flags & MSG_CTRUNC, 0); +} + +TEST_P(UnixSocketPairCmsgTest, FDPassAfterSoPassCredWithoutCredHeaderSpace) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char sent_data[20]; + RandomizeBuffer(sent_data, sizeof(sent_data)); + + auto pair = + ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create()); + + SetSoPassCred(sockets->second_fd()); + + ASSERT_NO_FATAL_FAILURE(SendSingleFD(sockets->first_fd(), pair->second_fd(), + sent_data, sizeof(sent_data))); + + struct msghdr msg = {}; + char control[CMSG_LEN(0) / 2]; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + char received_data[20]; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, 0), + SyscallSucceedsWithValue(sizeof(received_data))); + + EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data))); + EXPECT_EQ(msg.msg_controllen, 0); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_cmsg.h b/test/syscalls/linux/socket_unix_cmsg.h new file mode 100644 index 000000000..431606903 --- /dev/null +++ b/test/syscalls/linux/socket_unix_cmsg.h @@ -0,0 +1,30 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_ + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to pairs of connected unix sockets about +// control messages. +using UnixSocketPairCmsgTest = SocketPairTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_CMSG_H_ diff --git a/test/syscalls/linux/socket_unix_pair.cc b/test/syscalls/linux/socket_unix_pair.cc index bacfc11e4..411fb4518 100644 --- a/test/syscalls/linux/socket_unix_pair.cc +++ b/test/syscalls/linux/socket_unix_pair.cc @@ -16,6 +16,7 @@ #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/socket_unix.h" +#include "test/syscalls/linux/socket_unix_cmsg.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" @@ -33,5 +34,9 @@ INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, UnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +INSTANTIATE_TEST_SUITE_P( + AllUnixDomainSockets, UnixSocketPairCmsgTest, + ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 79f7cb6c1c4c16e3aca44d7fdc8e9f2487a605cf Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Wed, 5 Jun 2019 22:50:48 -0700 Subject: netstack/sniffer: log GSO attributes PiperOrigin-RevId: 251788534 --- pkg/tcpip/link/sniffer/sniffer.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index fccabd554..98581e50e 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -118,7 +118,7 @@ func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcp // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, vv.First()) + logPacket("recv", protocol, vv.First(), nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { vs := vv.Views() @@ -198,7 +198,7 @@ func (e *endpoint) GSOMaxSize() uint32 { // the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, hdr.View()) + logPacket("send", protocol, hdr.View(), gso) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { hdrBuf := hdr.View() @@ -240,7 +240,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return e.lower.WritePacket(r, gso, hdr, payload, protocol) } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -404,5 +404,9 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie return } + if gso != nil { + details += fmt.Sprintf(" gso: %+v", gso) + } + log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) } -- cgit v1.2.3 From 85be01b42d4ac48698d1e8f50a4cf2607a4fc50b Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 6 Jun 2019 08:05:46 -0700 Subject: Add multi-fd support to fdbased endpoint. This allows an fdbased endpoint to have multiple underlying fd's from which packets can be read and dispatched/written to. This should allow for higher throughput as well as better scalability of the network stack as number of connections increases. Updates #231 PiperOrigin-RevId: 251852825 --- pkg/tcpip/link/fdbased/endpoint.go | 168 ++++++++++++++++++++++--------- pkg/tcpip/link/fdbased/endpoint_test.go | 2 +- pkg/tcpip/sample/tun_tcp_connect/main.go | 2 +- pkg/tcpip/sample/tun_tcp_echo/main.go | 2 +- pkg/urpc/urpc.go | 2 +- runsc/boot/config.go | 6 ++ runsc/boot/network.go | 39 ++++--- runsc/main.go | 25 +++-- runsc/sandbox/network.go | 119 +++++++++++++--------- runsc/test/testutil/testutil.go | 1 + 10 files changed, 249 insertions(+), 117 deletions(-) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 1f889c2a0..b88e2e7bf 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -21,12 +21,29 @@ // FD based endpoints can be used in the networking stack by calling New() to // create a new endpoint, and then passing it as an argument to // Stack.CreateNIC(). +// +// FD based endpoints can use more than one file descriptor to read incoming +// packets. If there are more than one FDs specified and the underlying FD is an +// AF_PACKET then the endpoint will enable FANOUT mode on the socket so that the +// host kernel will consistently hash the packets to the sockets. This ensures +// that packets for the same TCP streams are not reordered. +// +// Similarly if more than one FD's are specified where the underlying FD is not +// AF_PACKET then it's the caller's responsibility to ensure that all inbound +// packets on the descriptors are consistently 5 tuple hashed to one of the +// descriptors to prevent TCP reordering. +// +// Since netstack today does not compute 5 tuple hashes for outgoing packets we +// only use the first FD to write outbound packets. Once 5 tuple hashes for +// all outbound packets are available we will make use of all underlying FD's to +// write outbound packets. package fdbased import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" @@ -65,8 +82,10 @@ const ( ) type endpoint struct { - // fd is the file descriptor used to send and receive packets. - fd int + // fds is the set of file descriptors each identifying one inbound/outbound + // channel. The endpoint will dispatch from all inbound channels as well as + // hash outbound packets to specific channels based on the packet hash. + fds []int // mtu (maximum transmission unit) is the maximum size of a packet. mtu uint32 @@ -85,8 +104,8 @@ type endpoint struct { // its end of the communication pipe. closed func(*tcpip.Error) - inboundDispatcher linkDispatcher - dispatcher stack.NetworkDispatcher + inboundDispatchers []linkDispatcher + dispatcher stack.NetworkDispatcher // packetDispatchMode controls the packet dispatcher used by this // endpoint. @@ -99,17 +118,47 @@ type endpoint struct { // Options specify the details about the fd-based endpoint to be created. type Options struct { - FD int - MTU uint32 - EthernetHeader bool - ClosedFunc func(*tcpip.Error) - Address tcpip.LinkAddress - SaveRestore bool - DisconnectOk bool - GSOMaxSize uint32 + // FDs is a set of FDs used to read/write packets. + FDs []int + + // MTU is the mtu to use for this endpoint. + MTU uint32 + + // EthernetHeader if true, indicates that the endpoint should read/write + // ethernet frames instead of IP packets. + EthernetHeader bool + + // ClosedFunc is a function to be called when an endpoint's peer (if + // any) closes its end of the communication pipe. + ClosedFunc func(*tcpip.Error) + + // Address is the link address for this endpoint. Only used if + // EthernetHeader is true. + Address tcpip.LinkAddress + + // SaveRestore if true, indicates that this NIC capability set should + // include CapabilitySaveRestore + SaveRestore bool + + // DisconnectOk if true, indicates that this NIC capability set should + // include CapabilityDisconnectOk. + DisconnectOk bool + + // GSOMaxSize is the maximum GSO packet size. It is zero if GSO is + // disabled. + GSOMaxSize uint32 + + // PacketDispatchMode specifies the type of inbound dispatcher to be + // used for this endpoint. PacketDispatchMode PacketDispatchMode - TXChecksumOffload bool - RXChecksumOffload bool + + // TXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityTXChecksumOffload. + TXChecksumOffload bool + + // RXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityRXChecksumOffload. + RXChecksumOffload bool } // New creates a new fd-based endpoint. @@ -117,10 +166,6 @@ type Options struct { // Makes fd non-blocking, but does not take ownership of fd, which must remain // open for the lifetime of the returned endpoint. func New(opts *Options) (tcpip.LinkEndpointID, error) { - if err := syscall.SetNonblock(opts.FD, true); err != nil { - return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", opts.FD, err) - } - caps := stack.LinkEndpointCapabilities(0) if opts.RXChecksumOffload { caps |= stack.CapabilityRXChecksumOffload @@ -144,8 +189,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { caps |= stack.CapabilityDisconnectOk } + if len(opts.FDs) == 0 { + return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + } + e := &endpoint{ - fd: opts.FD, + fds: opts.FDs, mtu: opts.MTU, caps: caps, closed: opts.ClosedFunc, @@ -154,46 +203,71 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { packetDispatchMode: opts.PacketDispatchMode, } - isSocket, err := isSocketFD(e.fd) - if err != nil { - return 0, err - } - if isSocket { - if opts.GSOMaxSize != 0 { - e.caps |= stack.CapabilityGSO - e.gsoMaxSize = opts.GSOMaxSize + // Create per channel dispatchers. + for i := 0; i < len(e.fds); i++ { + fd := e.fds[i] + if err := syscall.SetNonblock(fd, true); err != nil { + return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) } - } - e.inboundDispatcher, err = createInboundDispatcher(e, isSocket) - if err != nil { - return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + + isSocket, err := isSocketFD(fd) + if err != nil { + return 0, err + } + if isSocket { + if opts.GSOMaxSize != 0 { + e.caps |= stack.CapabilityGSO + e.gsoMaxSize = opts.GSOMaxSize + } + } + inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket) + if err != nil { + return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + } + e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) } return stack.RegisterLinkEndpoint(e), nil } -func createInboundDispatcher(e *endpoint, isSocket bool) (linkDispatcher, error) { +func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) { // By default use the readv() dispatcher as it works with all kinds of // FDs (tap/tun/unix domain sockets and af_packet). - inboundDispatcher, err := newReadVDispatcher(e.fd, e) + inboundDispatcher, err := newReadVDispatcher(fd, e) if err != nil { - return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", e.fd, e, err) + return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", fd, e, err) } if isSocket { + sa, err := unix.Getsockname(fd) + if err != nil { + return nil, fmt.Errorf("unix.Getsockname(%d) = %v", fd, err) + } + switch sa.(type) { + case *unix.SockaddrLinklayer: + // enable PACKET_FANOUT mode is the underlying socket is + // of type AF_PACKET. + const fanoutID = 1 + const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG + fanoutArg := fanoutID | fanoutType<<16 + if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { + return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) + } + } + switch e.packetDispatchMode { case PacketMMap: - inboundDispatcher, err = newPacketMMapDispatcher(e.fd, e) + inboundDispatcher, err = newPacketMMapDispatcher(fd, e) if err != nil { - return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", e.fd, e, err) + return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", fd, e, err) } case RecvMMsg: // If the provided FD is a socket then we optimize // packet reads by using recvmmsg() instead of read() to // read packets in a batch. - inboundDispatcher, err = newRecvMMsgDispatcher(e.fd, e) + inboundDispatcher, err = newRecvMMsgDispatcher(fd, e) if err != nil { - return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", e.fd, e, err) + return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", fd, e, err) } } } @@ -215,7 +289,9 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { // Link endpoints are not savable. When transportation endpoints are // saved, they stop sending outgoing packets and all incoming packets // are rejected. - go e.dispatchLoop() // S/R-SAFE: See above. + for i := range e.inboundDispatchers { + go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above. + } } // IsAttached implements stack.LinkEndpoint.IsAttached. @@ -305,26 +381,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } } - return rawfile.NonBlockingWrite3(e.fd, vnetHdrBuf, hdr.View(), payload.ToView()) + return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView()) } if payload.Size() == 0 { - return rawfile.NonBlockingWrite(e.fd, hdr.View()) + return rawfile.NonBlockingWrite(e.fds[0], hdr.View()) } - return rawfile.NonBlockingWrite3(e.fd, hdr.View(), payload.ToView(), nil) + return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil) } // WriteRawPacket writes a raw packet directly to the file descriptor. func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { - return rawfile.NonBlockingWrite(e.fd, packet) + return rawfile.NonBlockingWrite(e.fds[0], packet) } // dispatchLoop reads packets from the file descriptor in a loop and dispatches // them to the network stack. -func (e *endpoint) dispatchLoop() *tcpip.Error { +func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error { for { - cont, err := e.inboundDispatcher.dispatch() + cont, err := inboundDispatcher.dispatch() if err != nil || !cont { if e.closed != nil { e.closed(err) @@ -363,7 +439,7 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti syscall.SetNonblock(fd, true) e := &InjectableEndpoint{endpoint: endpoint{ - fd: fd, + fds: []int{fd}, mtu: mtu, caps: capabilities, }} diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index fd1722074..ba3e09192 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -67,7 +67,7 @@ func newContext(t *testing.T, opt *Options) *context { done <- struct{}{} } - opt.FD = fds[1] + opt.FDs = []int{fds[1]} epID, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 1681de56e..1fa899e7e 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -137,7 +137,7 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{FD: fd, MTU: mtu}) + linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) if err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 642607f83..d47085581 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -129,7 +129,7 @@ func main() { } linkID, err := fdbased.New(&fdbased.Options{ - FD: fd, + FDs: []int{fd}, MTU: mtu, EthernetHeader: *tap, Address: tcpip.LinkAddress(maddr), diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go index 0f155ec74..4ea684659 100644 --- a/pkg/urpc/urpc.go +++ b/pkg/urpc/urpc.go @@ -35,7 +35,7 @@ import ( ) // maxFiles determines the maximum file payload. -const maxFiles = 16 +const maxFiles = 32 // ErrTooManyFiles is returned when too many file descriptors are mapped. var ErrTooManyFiles = errors.New("too many files") diff --git a/runsc/boot/config.go b/runsc/boot/config.go index 15f624f9b..8564c502d 100644 --- a/runsc/boot/config.go +++ b/runsc/boot/config.go @@ -221,6 +221,11 @@ type Config struct { // user, and without chrooting the sandbox process. This can be // necessary in test environments that have limited capabilities. TestOnlyAllowRunAsCurrentUserWithoutChroot bool + + // NumNetworkChannels controls the number of AF_PACKET sockets that map + // to the same underlying network device. This allows netstack to better + // scale for high throughput use cases. + NumNetworkChannels int } // ToFlags returns a slice of flags that correspond to the given Config. @@ -244,6 +249,7 @@ func (c *Config) ToFlags() []string { "--panic-signal=" + strconv.Itoa(c.PanicSignal), "--profile=" + strconv.FormatBool(c.ProfileEnable), "--net-raw=" + strconv.FormatBool(c.EnableRaw), + "--num-network-channels=" + strconv.Itoa(c.NumNetworkChannels), } if c.TestOnlyAllowRunAsCurrentUserWithoutChroot { // Only include if set since it is never to be used by users. diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 0a154d90b..82c259f47 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -57,6 +57,10 @@ type FDBasedLink struct { Routes []Route GSOMaxSize uint32 LinkAddress []byte + + // NumChannels controls how many underlying FD's are to be used to + // create this endpoint. + NumChannels int } // LoopbackLink configures a loopback li nk. @@ -68,8 +72,9 @@ type LoopbackLink struct { // CreateLinksAndRoutesArgs are arguments to CreateLinkAndRoutes. type CreateLinksAndRoutesArgs struct { - // FilePayload contains the fds associated with the FDBasedLinks. The - // two slices must have the same length. + // FilePayload contains the fds associated with the FDBasedLinks. The + // number of fd's should match the sum of the NumChannels field of the + // FDBasedLink entries below. urpc.FilePayload LoopbackLinks []LoopbackLink @@ -95,8 +100,12 @@ func (r *Route) toTcpipRoute(id tcpip.NICID) tcpip.Route { // CreateLinksAndRoutes creates links and routes in a network stack. It should // only be called once. func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct{}) error { - if len(args.FilePayload.Files) != len(args.FDBasedLinks) { - return fmt.Errorf("FilePayload must be same length at FDBasedLinks") + wantFDs := 0 + for _, l := range args.FDBasedLinks { + wantFDs += l.NumChannels + } + if got := len(args.FilePayload.Files); got != wantFDs { + return fmt.Errorf("args.FilePayload.Files has %d FD's but we need %d entries based on FDBasedLinks", got, wantFDs) } var nicID tcpip.NICID @@ -123,20 +132,26 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } } - for i, link := range args.FDBasedLinks { + fdOffset := 0 + for _, link := range args.FDBasedLinks { nicID++ nicids[link.Name] = nicID - // Copy the underlying FD. - oldFD := args.FilePayload.Files[i].Fd() - newFD, err := syscall.Dup(int(oldFD)) - if err != nil { - return fmt.Errorf("failed to dup FD %v: %v", oldFD, err) + FDs := []int{} + for j := 0; j < link.NumChannels; j++ { + // Copy the underlying FD. + oldFD := args.FilePayload.Files[fdOffset].Fd() + newFD, err := syscall.Dup(int(oldFD)) + if err != nil { + return fmt.Errorf("failed to dup FD %v: %v", oldFD, err) + } + FDs = append(FDs, newFD) + fdOffset++ } mac := tcpip.LinkAddress(link.LinkAddress) linkEP, err := fdbased.New(&fdbased.Options{ - FD: newFD, + FDs: FDs, MTU: uint32(link.MTU), EthernetHeader: true, Address: mac, @@ -148,7 +163,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct return err } - log.Infof("Enabling interface %q with id %d on addresses %+v (%v)", link.Name, nicID, link.Addresses, mac) + log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil { return err } diff --git a/runsc/main.go b/runsc/main.go index 11bc73f75..44ad23cba 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -60,16 +60,16 @@ var ( straceLogSize = flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs") // Flags that control sandbox runtime behavior. - platform = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm") - network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") - gso = flag.Bool("gso", true, "enable generic segmenation offload") - fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") - overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") - watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.") - panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.") - profile = flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).") - netRaw = flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.") - + platform = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm") + network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") + gso = flag.Bool("gso", true, "enable generic segmenation offload") + fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") + overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") + watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.") + panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.") + profile = flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).") + netRaw = flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.") + numNetworkChannels = flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.") testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") ) @@ -141,6 +141,10 @@ func main() { cmd.Fatalf("%v", err) } + if *numNetworkChannels <= 0 { + cmd.Fatalf("num_network_channels must be > 0, got: %d", *numNetworkChannels) + } + // Create a new Config from the flags. conf := &boot.Config{ RootDir: *rootDir, @@ -162,6 +166,7 @@ func main() { ProfileEnable: *profile, EnableRaw: *netRaw, TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, + NumNetworkChannels: *numNetworkChannels, } if len(*straceSyscalls) != 0 { conf.StraceSyscalls = strings.Split(*straceSyscalls, ",") diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 0460d5f1a..1fd091514 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -68,7 +68,7 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi // Build the path to the net namespace of the sandbox process. // This is what we will copy. nsPath := filepath.Join("/proc", strconv.Itoa(pid), "ns/net") - if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.GSO); err != nil { + if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.GSO, conf.NumNetworkChannels); err != nil { return fmt.Errorf("creating interfaces from net namespace %q: %v", nsPath, err) } case boot.NetworkHost: @@ -138,7 +138,7 @@ func isRootNS() (bool, error) { // createInterfacesAndRoutesFromNS scrapes the interface and routes from the // net namespace with the given path, creates them in the sandbox, and removes // them from the host. -func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO bool) error { +func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO bool, numNetworkChannels int) error { // Join the network namespace that we will be copying. restore, err := joinNetNS(nsPath) if err != nil { @@ -202,25 +202,6 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO continue } - // Create the socket. - const protocol = 0x0300 // htons(ETH_P_ALL) - fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) - if err != nil { - return fmt.Errorf("unable to create raw socket: %v", err) - } - deviceFile := os.NewFile(uintptr(fd), "raw-device-fd") - - // Bind to the appropriate device. - ll := syscall.SockaddrLinklayer{ - Protocol: protocol, - Ifindex: iface.Index, - Hatype: 0, // No ARP type. - Pkttype: syscall.PACKET_OTHERHOST, - } - if err := syscall.Bind(fd, &ll); err != nil { - return fmt.Errorf("unable to bind to %q: %v", iface.Name, err) - } - // Scrape the routes before removing the address, since that // will remove the routes as well. routes, def, err := routesForIface(iface) @@ -236,9 +217,10 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO } link := boot.FDBasedLink{ - Name: iface.Name, - MTU: iface.MTU, - Routes: routes, + Name: iface.Name, + MTU: iface.MTU, + Routes: routes, + NumChannels: numNetworkChannels, } // Get the link for the interface. @@ -248,30 +230,23 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO } link.LinkAddress = []byte(ifaceLink.Attrs().HardwareAddr) - if enableGSO { - gso, err := isGSOEnabled(fd, iface.Name) + log.Debugf("Setting up network channels") + // Create the socket for the device. + for i := 0; i < link.NumChannels; i++ { + log.Debugf("Creating Channel %d", i) + socketEntry, err := createSocket(iface, ifaceLink, enableGSO) if err != nil { - return fmt.Errorf("getting GSO for interface %q: %v", iface.Name, err) + return fmt.Errorf("failed to createSocket for %s : %v", iface.Name, err) } - if gso { - if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { - return fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) - } - link.GSOMaxSize = ifaceLink.Attrs().GSOMaxSize + if i == 0 { + link.GSOMaxSize = socketEntry.gsoMaxSize } else { - log.Infof("GSO not available in host.") + if link.GSOMaxSize != socketEntry.gsoMaxSize { + return fmt.Errorf("inconsistent gsoMaxSize %d and %d when creating multiple channels for same interface: %s", + link.GSOMaxSize, socketEntry.gsoMaxSize, iface.Name) + } } - } - - // Use SO_RCVBUFFORCE because on linux the receive buffer for an - // AF_PACKET socket is capped by "net.core.rmem_max". rmem_max - // defaults to a unusually low value of 208KB. This is too low - // for gVisor to be able to receive packets at high throughputs - // without incurring packet drops. - const rcvBufSize = 4 << 20 // 4MB. - - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, rcvBufSize); err != nil { - return fmt.Errorf("failed to increase socket rcv buffer to %d: %v", rcvBufSize, err) + args.FilePayload.Files = append(args.FilePayload.Files, socketEntry.deviceFile) } // Collect the addresses for the interface, enable forwarding, @@ -285,7 +260,6 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO } } - args.FilePayload.Files = append(args.FilePayload.Files, deviceFile) args.FDBasedLinks = append(args.FDBasedLinks, link) } @@ -296,6 +270,61 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO return nil } +type socketEntry struct { + deviceFile *os.File + gsoMaxSize uint32 +} + +// createSocket creates an underlying AF_PACKET socket and configures it for use by +// the sentry and returns an *os.File that wraps the underlying socket fd. +func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (*socketEntry, error) { + // Create the socket. + const protocol = 0x0300 // htons(ETH_P_ALL) + fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) + if err != nil { + return nil, fmt.Errorf("unable to create raw socket: %v", err) + } + deviceFile := os.NewFile(uintptr(fd), "raw-device-fd") + // Bind to the appropriate device. + ll := syscall.SockaddrLinklayer{ + Protocol: protocol, + Ifindex: iface.Index, + Hatype: 0, // No ARP type. + Pkttype: syscall.PACKET_OTHERHOST, + } + if err := syscall.Bind(fd, &ll); err != nil { + return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) + } + + gsoMaxSize := uint32(0) + if enableGSO { + gso, err := isGSOEnabled(fd, iface.Name) + if err != nil { + return nil, fmt.Errorf("getting GSO for interface %q: %v", iface.Name, err) + } + if gso { + if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { + return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) + } + gsoMaxSize = ifaceLink.Attrs().GSOMaxSize + } else { + log.Infof("GSO not available in host.") + } + } + + // Use SO_RCVBUFFORCE because on linux the receive buffer for an + // AF_PACKET socket is capped by "net.core.rmem_max". rmem_max + // defaults to a unusually low value of 208KB. This is too low + // for gVisor to be able to receive packets at high throughputs + // without incurring packet drops. + const rcvBufSize = 4 << 20 // 4MB. + + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, rcvBufSize); err != nil { + return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", rcvBufSize, err) + } + return &socketEntry{deviceFile, gsoMaxSize}, nil +} + // loopbackLinks collects the links for a loopback interface. func loopbackLinks(iface net.Interface, addrs []net.Addr) ([]boot.LoopbackLink, error) { var links []boot.LoopbackLink diff --git a/runsc/test/testutil/testutil.go b/runsc/test/testutil/testutil.go index 9efb1ba8e..727b648a6 100644 --- a/runsc/test/testutil/testutil.go +++ b/runsc/test/testutil/testutil.go @@ -136,6 +136,7 @@ func TestConfig() *boot.Config { Strace: true, FileAccess: boot.FileAccessExclusive, TestOnlyAllowRunAsCurrentUserWithoutChroot: true, + NumNetworkChannels: 1, } } -- cgit v1.2.3 From 720ec3590d9bbf6dc2f9533ed5ef2cbc0b01627a Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Thu, 6 Jun 2019 10:48:19 -0700 Subject: Send error message to docker/kubectl exec on failure Containerd uses the last error message sent to the log to print as failure cause for create/exec. This required a few changes in the logging logic for runsc: - cmd.Errorf/Fatalf: now writes a message with 'error' level to containerd log, in addition to stderr and debug logs, like before. - log.Infof/Warningf/Fatalf: are not sent to containerd log anymore. They are mostly used for debugging and not useful to containerd. In most cases, --debug-log is enabled and this avoids the logs messages from being duplicated. - stderr is not used as default log destination anymore. Some commands assume stdio is for the container/process running inside the sandbox and it's better to never use it for logging. By default, logs are supressed now. PiperOrigin-RevId: 251881815 --- runsc/cmd/BUILD | 1 + runsc/cmd/cmd.go | 19 ---------- runsc/cmd/create.go | 1 - runsc/cmd/error.go | 72 +++++++++++++++++++++++++++++++++++++ runsc/cmd/exec.go | 28 +++++++++------ runsc/cmd/start.go | 1 - runsc/main.go | 60 +++++++++++++++---------------- runsc/test/integration/exec_test.go | 23 ++++++++++++ 8 files changed, 142 insertions(+), 63 deletions(-) create mode 100644 runsc/cmd/error.go diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index b7551a5ab..173b7671e 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -14,6 +14,7 @@ go_library( "debug.go", "delete.go", "do.go", + "error.go", "events.go", "exec.go", "gofer.go", diff --git a/runsc/cmd/cmd.go b/runsc/cmd/cmd.go index a2fc377d1..5b4cc4a39 100644 --- a/runsc/cmd/cmd.go +++ b/runsc/cmd/cmd.go @@ -17,34 +17,15 @@ package cmd import ( "fmt" - "os" "runtime" "strconv" "syscall" - "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/runsc/specutils" ) -// Errorf logs to stderr and returns subcommands.ExitFailure. -func Errorf(s string, args ...interface{}) subcommands.ExitStatus { - // If runsc is being invoked by docker or cri-o, then we might not have - // access to stderr, so we log a serious-looking warning in addition to - // writing to stderr. - log.Warningf("FATAL ERROR: "+s, args...) - fmt.Fprintf(os.Stderr, s+"\n", args...) - // Return an error that is unlikely to be used by the application. - return subcommands.ExitFailure -} - -// Fatalf logs to stderr and exits with a failure status code. -func Fatalf(s string, args ...interface{}) { - Errorf(s, args...) - os.Exit(128) -} - // intFlags can be used with int flags that appear multiple times. type intFlags []int diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go index 629c198fd..8bf9b7dcf 100644 --- a/runsc/cmd/create.go +++ b/runsc/cmd/create.go @@ -16,7 +16,6 @@ package cmd import ( "context" - "flag" "github.com/google/subcommands" "gvisor.googlesource.com/gvisor/runsc/boot" diff --git a/runsc/cmd/error.go b/runsc/cmd/error.go new file mode 100644 index 000000000..700b19f14 --- /dev/null +++ b/runsc/cmd/error.go @@ -0,0 +1,72 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "encoding/json" + "fmt" + "io" + "os" + "time" + + "github.com/google/subcommands" + "gvisor.googlesource.com/gvisor/pkg/log" +) + +// ErrorLogger is where error messages should be written to. These messages are +// consumed by containerd and show up to users of command line tools, +// like docker/kubectl. +var ErrorLogger io.Writer + +type jsonError struct { + Msg string `json:"msg"` + Level string `json:"level"` + Time time.Time `json:"time"` +} + +// Errorf logs error to containerd log (--log), to stderr, and debug logs. It +// returns subcommands.ExitFailure for convenience with subcommand.Execute() +// methods: +// return Errorf("Danger! Danger!") +// +func Errorf(format string, args ...interface{}) subcommands.ExitStatus { + // If runsc is being invoked by docker or cri-o, then we might not have + // access to stderr, so we log a serious-looking warning in addition to + // writing to stderr. + log.Warningf("FATAL ERROR: "+format, args...) + fmt.Fprintf(os.Stderr, format+"\n", args...) + + j := jsonError{ + Msg: fmt.Sprintf(format, args...), + Level: "error", + Time: time.Now(), + } + b, err := json.Marshal(j) + if err != nil { + panic(err) + } + if ErrorLogger != nil { + ErrorLogger.Write(b) + } + + return subcommands.ExitFailure +} + +// Fatalf logs the same way as Errorf() does, plus *exits* the process. +func Fatalf(format string, args ...interface{}) { + Errorf(format, args...) + // Return an error that is unlikely to be used by the application. + os.Exit(128) +} diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go index 8cd070e61..0eeaaadba 100644 --- a/runsc/cmd/exec.go +++ b/runsc/cmd/exec.go @@ -143,13 +143,16 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // write the child's PID to the pid file. So when the container returns, the // child process will also return and signal containerd. if ex.detach { - return ex.execAndWait(waitStatus) + return ex.execChildAndWait(waitStatus) } + return ex.exec(c, e, waitStatus) +} +func (ex *Exec) exec(c *container.Container, e *control.ExecArgs, waitStatus *syscall.WaitStatus) subcommands.ExitStatus { // Start the new process and get it pid. pid, err := c.Execute(e) if err != nil { - Fatalf("executing processes for container: %v", err) + return Errorf("executing processes for container: %v", err) } if e.StdioIsPty { @@ -163,29 +166,29 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) if ex.internalPidFile != "" { pidStr := []byte(strconv.Itoa(int(pid))) if err := ioutil.WriteFile(ex.internalPidFile, pidStr, 0644); err != nil { - Fatalf("writing internal pid file %q: %v", ex.internalPidFile, err) + return Errorf("writing internal pid file %q: %v", ex.internalPidFile, err) } } - // Generate the pid file after the internal pid file is generated, so that users - // can safely assume that the internal pid file is ready after `runsc exec -d` - // returns. + // Generate the pid file after the internal pid file is generated, so that + // users can safely assume that the internal pid file is ready after + // `runsc exec -d` returns. if ex.pidFile != "" { if err := ioutil.WriteFile(ex.pidFile, []byte(strconv.Itoa(os.Getpid())), 0644); err != nil { - Fatalf("writing pid file: %v", err) + return Errorf("writing pid file: %v", err) } } // Wait for the process to exit. ws, err := c.WaitPID(pid) if err != nil { - Fatalf("waiting on pid %d: %v", pid, err) + return Errorf("waiting on pid %d: %v", pid, err) } *waitStatus = ws return subcommands.ExitSuccess } -func (ex *Exec) execAndWait(waitStatus *syscall.WaitStatus) subcommands.ExitStatus { +func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.ExitStatus { var args []string for _, a := range os.Args[1:] { if !strings.Contains(a, "detach") { @@ -193,7 +196,7 @@ func (ex *Exec) execAndWait(waitStatus *syscall.WaitStatus) subcommands.ExitStat } } - // The command needs to write a pid file so that execAndWait can tell + // The command needs to write a pid file so that execChildAndWait can tell // when it has started. If no pid-file was provided, we should use a // filename in a temp directory. pidFile := ex.pidFile @@ -262,7 +265,10 @@ func (ex *Exec) execAndWait(waitStatus *syscall.WaitStatus) subcommands.ExitStat return false, nil } if err := specutils.WaitForReady(cmd.Process.Pid, 10*time.Second, ready); err != nil { - Fatalf("unexpected error waiting for PID file, err: %v", err) + // Don't log fatal error here, otherwise it will override the error logged + // by the child process that has failed to start. + log.Warningf("Unexpected error waiting for PID file, err: %v", err) + return subcommands.ExitFailure } *waitStatus = 0 diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index 657726251..31e8f42bb 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -16,7 +16,6 @@ package cmd import ( "context" - "flag" "github.com/google/subcommands" "gvisor.googlesource.com/gvisor/runsc/boot" diff --git a/runsc/main.go b/runsc/main.go index 44ad23cba..39c43507c 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -117,6 +117,22 @@ func main() { os.Exit(0) } + var errorLogger io.Writer + if *logFD > -1 { + errorLogger = os.NewFile(uintptr(*logFD), "error log file") + + } else if *logFilename != "" { + // We must set O_APPEND and not O_TRUNC because Docker passes + // the same log file for all commands (and also parses these + // log files), so we can't destroy them on each command. + var err error + errorLogger, err = os.OpenFile(*logFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + cmd.Fatalf("error opening log file %q: %v", *logFilename, err) + } + } + cmd.ErrorLogger = errorLogger + platformType, err := boot.MakePlatformType(*platform) if err != nil { cmd.Fatalf("%v", err) @@ -179,24 +195,7 @@ func main() { subcommand := flag.CommandLine.Arg(0) - var logFile io.Writer = os.Stderr - if *logFD > -1 { - logFile = os.NewFile(uintptr(*logFD), "log file") - } else if *logFilename != "" { - // We must set O_APPEND and not O_TRUNC because Docker passes - // the same log file for all commands (and also parses these - // log files), so we can't destroy them on each command. - f, err := os.OpenFile(*logFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) - if err != nil { - cmd.Fatalf("error opening log file %q: %v", *logFilename, err) - } - logFile = f - } else if subcommand == "do" { - logFile = ioutil.Discard - } - - e := newEmitter(*logFormat, logFile) - + var e log.Emitter if *debugLogFD > -1 { f := os.NewFile(uintptr(*debugLogFD), "debug log file") @@ -206,28 +205,27 @@ func main() { cmd.Fatalf("flag --debug-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand) } - // If we are the boot process, then we own our stdio FDs and - // can do what we want with them. Since Docker and Containerd - // both eat boot's stderr, we dup our stderr to the provided - // log FD so that panics will appear in the logs, rather than - // just disappear. + // If we are the boot process, then we own our stdio FDs and can do what we + // want with them. Since Docker and Containerd both eat boot's stderr, we + // dup our stderr to the provided log FD so that panics will appear in the + // logs, rather than just disappear. if err := syscall.Dup2(int(f.Fd()), int(os.Stderr.Fd())); err != nil { cmd.Fatalf("error dup'ing fd %d to stderr: %v", f.Fd(), err) } - if logFile == os.Stderr { - // Suppress logging to stderr when debug log is enabled. Otherwise all - // messages will be duplicated in the debug log (see Dup2() call above). - e = newEmitter(*debugLogFormat, f) - } else { - e = log.MultiEmitter{e, newEmitter(*debugLogFormat, f)} - } + e = newEmitter(*debugLogFormat, f) + } else if *debugLog != "" { f, err := specutils.DebugLogFile(*debugLog, subcommand) if err != nil { cmd.Fatalf("error opening debug log file in %q: %v", *debugLog, err) } - e = log.MultiEmitter{e, newEmitter(*debugLogFormat, f)} + e = newEmitter(*debugLogFormat, f) + + } else { + // Stderr is reserved for the application, just discard the logs if no debug + // log is specified. + e = newEmitter("text", ioutil.Discard) } log.SetTarget(e) diff --git a/runsc/test/integration/exec_test.go b/runsc/test/integration/exec_test.go index 7af064d79..7c0e61ac3 100644 --- a/runsc/test/integration/exec_test.go +++ b/runsc/test/integration/exec_test.go @@ -29,6 +29,7 @@ package integration import ( "fmt" "strconv" + "strings" "syscall" "testing" "time" @@ -136,3 +137,25 @@ func TestExecJobControl(t *testing.T) { t.Errorf("ws.ExitedStatus got %d, want %d", got, want) } } + +// Test that failure to exec returns proper error message. +func TestExecError(t *testing.T) { + if err := testutil.Pull("alpine"); err != nil { + t.Fatalf("docker pull failed: %v", err) + } + d := testutil.MakeDocker("exec-error-test") + + // Start the container. + if err := d.Run("alpine", "sleep", "1000"); err != nil { + t.Fatalf("docker run failed: %v", err) + } + defer d.CleanUp() + + _, err := d.Exec("no_can_find") + if err == nil { + t.Fatalf("docker exec didn't fail") + } + if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(err.Error(), want) { + t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want) + } +} -- cgit v1.2.3 From 81eafb2c5e6242251618456f7e2a5657133a103c Mon Sep 17 00:00:00 2001 From: Googler Date: Thu, 6 Jun 2019 12:26:01 -0700 Subject: Internal change. PiperOrigin-RevId: 251902567 --- test/syscalls/linux/pipe.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index bce351e08..67b93ecf5 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -55,7 +55,7 @@ class PipeTest : public ::testing::TestWithParam { FileDescriptor wfd; public: - static void SetUpTestCase() { + static void SetUpTestSuite() { // Tests intentionally generate SIGPIPE. TEST_PCHECK(signal(SIGPIPE, SIG_IGN) != SIG_ERR); } @@ -82,7 +82,7 @@ class PipeTest : public ::testing::TestWithParam { return s1; } - static void TearDownTestCase() { + static void TearDownTestSuite() { TEST_PCHECK(signal(SIGPIPE, SIG_DFL) != SIG_ERR); } -- cgit v1.2.3 From 8b8bd8d5b28a8e41f59fc3465c38964986bfb084 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 6 Jun 2019 14:30:50 -0700 Subject: Try increase listen backlog. PiperOrigin-RevId: 251928000 --- test/syscalls/linux/socket_inet_loopback.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index b216d14cb..9b3b70b01 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -198,7 +198,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { ASSERT_THAT( bind(fd, reinterpret_cast(&listen_addr), listener.addr_len), SyscallSucceeds()); - ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); + ASSERT_THAT(listen(fd, 512), SyscallSucceeds()); // On the first bind we need to determine which port was bound. if (i != 0) { -- cgit v1.2.3 From bf0b1b9d767736e632fa56b90d904fee968d8d3d Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Thu, 6 Jun 2019 14:37:12 -0700 Subject: Add overlay dimension to FS related syscall tests PiperOrigin-RevId: 251929314 --- test/syscalls/BUILD | 195 ++++++++++++++++++++++++++++------- test/syscalls/build_defs.bzl | 19 +++- test/syscalls/syscall_test_runner.go | 8 +- 3 files changed, 183 insertions(+), 39 deletions(-) diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 4ea4cee30..2985275bb 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -13,11 +13,17 @@ syscall_test( test = "//test/syscalls/linux:accept_bind_test", ) -syscall_test(test = "//test/syscalls/linux:access_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:access_test", +) syscall_test(test = "//test/syscalls/linux:affinity_test") -syscall_test(test = "//test/syscalls/linux:aio_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:aio_test", +) syscall_test( size = "medium", @@ -30,6 +36,7 @@ syscall_test(test = "//test/syscalls/linux:bad_test") syscall_test( size = "large", + add_overlay = True, test = "//test/syscalls/linux:bind_test", ) @@ -37,17 +44,27 @@ syscall_test(test = "//test/syscalls/linux:brk_test") syscall_test(test = "//test/syscalls/linux:socket_test") -syscall_test(test = "//test/syscalls/linux:chdir_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:chdir_test", +) -syscall_test(test = "//test/syscalls/linux:chmod_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:chmod_test", +) syscall_test( size = "medium", + add_overlay = True, test = "//test/syscalls/linux:chown_test", use_tmpfs = True, # chwon tests require gofer to be running as root. ) -syscall_test(test = "//test/syscalls/linux:chroot_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:chroot_test", +) syscall_test(test = "//test/syscalls/linux:clock_getres_test") @@ -60,11 +77,17 @@ syscall_test(test = "//test/syscalls/linux:clock_nanosleep_test") syscall_test(test = "//test/syscalls/linux:concurrency_test") -syscall_test(test = "//test/syscalls/linux:creat_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:creat_test", +) syscall_test(test = "//test/syscalls/linux:dev_test") -syscall_test(test = "//test/syscalls/linux:dup_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:dup_test", +) syscall_test(test = "//test/syscalls/linux:epoll_test") @@ -74,23 +97,34 @@ syscall_test(test = "//test/syscalls/linux:exceptions_test") syscall_test( size = "medium", + add_overlay = True, test = "//test/syscalls/linux:exec_test", ) syscall_test( size = "medium", + add_overlay = True, test = "//test/syscalls/linux:exec_binary_test", ) syscall_test(test = "//test/syscalls/linux:exit_test") -syscall_test(test = "//test/syscalls/linux:fadvise64_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:fadvise64_test", +) -syscall_test(test = "//test/syscalls/linux:fallocate_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:fallocate_test", +) syscall_test(test = "//test/syscalls/linux:fault_test") -syscall_test(test = "//test/syscalls/linux:fchdir_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:fchdir_test", +) syscall_test( size = "medium", @@ -99,6 +133,7 @@ syscall_test( syscall_test( size = "medium", + add_overlay = True, test = "//test/syscalls/linux:flock_test", ) @@ -108,7 +143,10 @@ syscall_test(test = "//test/syscalls/linux:fpsig_fork_test") syscall_test(test = "//test/syscalls/linux:fpsig_nested_test") -syscall_test(test = "//test/syscalls/linux:fsync_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:fsync_test", +) syscall_test( size = "medium", @@ -120,7 +158,10 @@ syscall_test(test = "//test/syscalls/linux:getcpu_host_test") syscall_test(test = "//test/syscalls/linux:getcpu_test") -syscall_test(test = "//test/syscalls/linux:getdents_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:getdents_test", +) syscall_test(test = "//test/syscalls/linux:getrandom_test") @@ -128,11 +169,13 @@ syscall_test(test = "//test/syscalls/linux:getrusage_test") syscall_test( size = "medium", + add_overlay = False, # TODO(gvisor.dev/issue/317): enable when fixed. test = "//test/syscalls/linux:inotify_test", ) syscall_test( size = "medium", + add_overlay = True, test = "//test/syscalls/linux:ioctl_test", ) @@ -144,11 +187,15 @@ syscall_test( syscall_test(test = "//test/syscalls/linux:kill_test") syscall_test( + add_overlay = True, test = "//test/syscalls/linux:link_test", use_tmpfs = True, # gofer needs CAP_DAC_READ_SEARCH to use AT_EMPTY_PATH with linkat(2) ) -syscall_test(test = "//test/syscalls/linux:lseek_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:lseek_test", +) syscall_test(test = "//test/syscalls/linux:madvise_test") @@ -158,9 +205,13 @@ syscall_test(test = "//test/syscalls/linux:mempolicy_test") syscall_test(test = "//test/syscalls/linux:mincore_test") -syscall_test(test = "//test/syscalls/linux:mkdir_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:mkdir_test", +) syscall_test( + add_overlay = True, test = "//test/syscalls/linux:mknod_test", use_tmpfs = True, # mknod is not supported over gofer. ) @@ -171,7 +222,10 @@ syscall_test( test = "//test/syscalls/linux:mmap_test", ) -syscall_test(test = "//test/syscalls/linux:mount_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:mount_test", +) syscall_test( size = "medium", @@ -185,9 +239,15 @@ syscall_test( syscall_test(test = "//test/syscalls/linux:munmap_test") -syscall_test(test = "//test/syscalls/linux:open_create_test") +syscall_test( + add_overlay = False, # TODO(gvisor.dev/issue/316): enable when fixed. + test = "//test/syscalls/linux:open_create_test", +) -syscall_test(test = "//test/syscalls/linux:open_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:open_test", +) syscall_test(test = "//test/syscalls/linux:partial_bad_buffer_test") @@ -195,6 +255,7 @@ syscall_test(test = "//test/syscalls/linux:pause_test") syscall_test( size = "large", + add_overlay = False, # TODO(gvisor.dev/issue/318): enable when fixed. shard_count = 5, test = "//test/syscalls/linux:pipe_test", ) @@ -210,11 +271,20 @@ syscall_test(test = "//test/syscalls/linux:prctl_setuid_test") syscall_test(test = "//test/syscalls/linux:prctl_test") -syscall_test(test = "//test/syscalls/linux:pread64_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:pread64_test", +) -syscall_test(test = "//test/syscalls/linux:preadv_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:preadv_test", +) -syscall_test(test = "//test/syscalls/linux:preadv2_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:preadv2_test", +) syscall_test(test = "//test/syscalls/linux:priority_test") @@ -239,13 +309,22 @@ syscall_test( test = "//test/syscalls/linux:pty_test", ) -syscall_test(test = "//test/syscalls/linux:pwritev2_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:pwritev2_test", +) -syscall_test(test = "//test/syscalls/linux:pwrite64_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:pwrite64_test", +) syscall_test(test = "//test/syscalls/linux:raw_socket_ipv4_test") -syscall_test(test = "//test/syscalls/linux:read_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:read_test", +) syscall_test( size = "medium", @@ -254,11 +333,13 @@ syscall_test( syscall_test( size = "medium", + add_overlay = True, test = "//test/syscalls/linux:readv_test", ) syscall_test( size = "medium", + add_overlay = True, test = "//test/syscalls/linux:rename_test", ) @@ -279,11 +360,20 @@ syscall_test( test = "//test/syscalls/linux:semaphore_test", ) -syscall_test(test = "//test/syscalls/linux:sendfile_socket_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:sendfile_socket_test", +) -syscall_test(test = "//test/syscalls/linux:sendfile_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:sendfile_test", +) -syscall_test(test = "//test/syscalls/linux:splice_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:splice_test", +) syscall_test(test = "//test/syscalls/linux:sigaction_test") @@ -330,11 +420,13 @@ syscall_test( syscall_test( size = "medium", + add_overlay = True, test = "//test/syscalls/linux:socket_filesystem_non_blocking_test", ) syscall_test( size = "large", + add_overlay = True, shard_count = 10, test = "//test/syscalls/linux:socket_filesystem_test", ) @@ -430,6 +522,7 @@ syscall_test( syscall_test( size = "large", + add_overlay = True, shard_count = 10, test = "//test/syscalls/linux:socket_unix_pair_test", ) @@ -472,19 +565,40 @@ syscall_test( test = "//test/syscalls/linux:socket_unix_unbound_stream_test", ) -syscall_test(test = "//test/syscalls/linux:statfs_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:statfs_test", +) -syscall_test(test = "//test/syscalls/linux:stat_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:stat_test", +) -syscall_test(test = "//test/syscalls/linux:stat_times_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:stat_times_test", +) -syscall_test(test = "//test/syscalls/linux:sticky_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:sticky_test", +) -syscall_test(test = "//test/syscalls/linux:symlink_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:symlink_test", +) -syscall_test(test = "//test/syscalls/linux:sync_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:sync_test", +) -syscall_test(test = "//test/syscalls/linux:sync_file_range_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:sync_file_range_test", +) syscall_test(test = "//test/syscalls/linux:sysinfo_test") @@ -508,7 +622,10 @@ syscall_test(test = "//test/syscalls/linux:time_test") syscall_test(test = "//test/syscalls/linux:tkill_test") -syscall_test(test = "//test/syscalls/linux:truncate_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:truncate_test", +) syscall_test(test = "//test/syscalls/linux:udp_bind_test") @@ -522,7 +639,10 @@ syscall_test(test = "//test/syscalls/linux:uidgid_test") syscall_test(test = "//test/syscalls/linux:uname_test") -syscall_test(test = "//test/syscalls/linux:unlink_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:unlink_test", +) syscall_test(test = "//test/syscalls/linux:unshare_test") @@ -544,7 +664,10 @@ syscall_test( test = "//test/syscalls/linux:wait_test", ) -syscall_test(test = "//test/syscalls/linux:write_test") +syscall_test( + add_overlay = True, + test = "//test/syscalls/linux:write_test", +) syscall_test( test = "//test/syscalls/linux:proc_net_unix_test", diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index cd74a769d..9f2fc9109 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -7,6 +7,7 @@ def syscall_test( shard_count = 1, size = "small", use_tmpfs = False, + add_overlay = False, tags = None, parallel = True): _syscall_test( @@ -39,6 +40,18 @@ def syscall_test( parallel = parallel, ) + if add_overlay: + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "ptrace", + use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. + tags = tags, + parallel = parallel, + overlay = True, + ) + if not use_tmpfs: # Also test shared gofer access. _syscall_test( @@ -60,7 +73,8 @@ def _syscall_test( use_tmpfs, tags, parallel, - file_access = "exclusive"): + file_access = "exclusive", + overlay = False): test_name = test.split(":")[1] # Prepend "runsc" to non-native platform names. @@ -69,6 +83,8 @@ def _syscall_test( name = test_name + "_" + full_platform if file_access == "shared": name += "_shared" + if overlay: + name += "_overlay" if tags == None: tags = [] @@ -92,6 +108,7 @@ def _syscall_test( "--platform=" + platform, "--use-tmpfs=" + str(use_tmpfs), "--file-access=" + file_access, + "--overlay=" + str(overlay), ] if parallel: diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go index 9a8e0600b..eb04a4fab 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/syscalls/syscall_test_runner.go @@ -47,6 +47,7 @@ var ( platform = flag.String("platform", "ptrace", "platform to run on") useTmpfs = flag.Bool("use-tmpfs", false, "mounts tmpfs for /tmp") fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode") + overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") parallel = flag.Bool("parallel", false, "run tests in parallel") runscPath = flag.String("runsc", "", "path to runsc binary") ) @@ -184,10 +185,13 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { "-platform", *platform, "-root", rootDir, "-file-access", *fileAccess, - "--network=none", + "-network=none", "-log-format=text", "-TESTONLY-unsafe-nonroot=true", - "--net-raw=true", + "-net-raw=true", + } + if *overlay { + args = append(args, "-overlay") } if *debug { args = append(args, "-debug", "-log-packets=true") -- cgit v1.2.3 From 2d2831e3541c8ae3c84f17cfd1bf0a26f2027044 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 6 Jun 2019 15:03:44 -0700 Subject: Track and export socket state. This is necessary for implementing network diagnostic interfaces like /proc/net/{tcp,udp,unix} and sock_diag(7). For pass-through endpoints such as hostinet, we obtain the socket state from the backend. For netstack, we add explicit tracking of TCP states. PiperOrigin-RevId: 251934850 --- pkg/abi/linux/socket.go | 16 ++ pkg/sentry/fs/proc/net.go | 20 +-- pkg/sentry/socket/epsocket/epsocket.go | 44 +++++ pkg/sentry/socket/hostinet/socket.go | 24 +++ pkg/sentry/socket/netlink/socket.go | 5 + pkg/sentry/socket/rpcinet/socket.go | 6 + pkg/sentry/socket/socket.go | 4 + pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/connectioned.go | 9 ++ pkg/sentry/socket/unix/transport/connectionless.go | 16 ++ pkg/sentry/socket/unix/transport/unix.go | 4 + pkg/sentry/socket/unix/unix.go | 5 + pkg/tcpip/stack/transport_test.go | 4 + pkg/tcpip/tcpip.go | 4 + pkg/tcpip/transport/icmp/BUILD | 1 + pkg/tcpip/transport/icmp/endpoint.go | 6 + pkg/tcpip/transport/raw/endpoint.go | 5 + pkg/tcpip/transport/tcp/accept.go | 12 +- pkg/tcpip/transport/tcp/connect.go | 26 ++- pkg/tcpip/transport/tcp/endpoint.go | 174 ++++++++++++++------ pkg/tcpip/transport/tcp/endpoint_state.go | 42 ++--- pkg/tcpip/transport/tcp/rcv.go | 37 +++++ pkg/tcpip/transport/tcp/snd.go | 4 + pkg/tcpip/transport/tcp/tcp_test.go | 131 ++++++++++++--- pkg/tcpip/transport/tcp/testing/context/context.go | 39 ++++- pkg/tcpip/transport/udp/endpoint.go | 6 + test/syscalls/linux/proc_net_unix.cc | 178 +++++++++++++++++++++ 27 files changed, 696 insertions(+), 127 deletions(-) diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 417840731..44bd69df6 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -200,6 +200,22 @@ const ( SS_DISCONNECTING = 4 // In process of disconnecting. ) +// TCP protocol states, from include/net/tcp_states.h. +const ( + TCP_ESTABLISHED uint32 = iota + 1 + TCP_SYN_SENT + TCP_SYN_RECV + TCP_FIN_WAIT1 + TCP_FIN_WAIT2 + TCP_TIME_WAIT + TCP_CLOSE + TCP_CLOSE_WAIT + TCP_LAST_ACK + TCP_LISTEN + TCP_CLOSING + TCP_NEW_SYN_RECV +) + // SockAddrMax is the maximum size of a struct sockaddr, from // uapi/linux/socket.h. const SockAddrMax = 128 diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 4a107c739..3daaa962c 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -240,24 +240,6 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s } } - var sockState int - switch sops.Endpoint().Type() { - case linux.SOCK_DGRAM: - sockState = linux.SS_CONNECTING - // Unlike Linux, we don't have unbound connection-less sockets, - // so no SS_DISCONNECTING. - - case linux.SOCK_SEQPACKET: - fallthrough - case linux.SOCK_STREAM: - // Connectioned. - if sops.Endpoint().(transport.ConnectingEndpoint).Connected() { - sockState = linux.SS_CONNECTED - } else { - sockState = linux.SS_UNCONNECTED - } - } - // In the socket entry below, the value for the 'Num' field requires // some consideration. Linux prints the address to the struct // unix_sock representing a socket in the kernel, but may redact the @@ -282,7 +264,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s 0, // Protocol, always 0 for UDS. sockFlags, // Flags. sops.Endpoint().Type(), // Type. - sockState, // State. + sops.State(), // State. sfile.InodeID(), // Inode. ) diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index de4b963da..f91c5127a 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -52,6 +52,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -2281,3 +2282,46 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { } return rv } + +// State implements socket.Socket.State. State translates the internal state +// returned by netstack to values defined by Linux. +func (s *SocketOperations) State() uint32 { + if s.family != linux.AF_INET && s.family != linux.AF_INET6 { + // States not implemented for this socket's family. + return 0 + } + + if !s.isPacketBased() { + // TCP socket. + switch tcp.EndpointState(s.Endpoint.State()) { + case tcp.StateEstablished: + return linux.TCP_ESTABLISHED + case tcp.StateSynSent: + return linux.TCP_SYN_SENT + case tcp.StateSynRecv: + return linux.TCP_SYN_RECV + case tcp.StateFinWait1: + return linux.TCP_FIN_WAIT1 + case tcp.StateFinWait2: + return linux.TCP_FIN_WAIT2 + case tcp.StateTimeWait: + return linux.TCP_TIME_WAIT + case tcp.StateClose, tcp.StateInitial, tcp.StateBound, tcp.StateConnecting, tcp.StateError: + return linux.TCP_CLOSE + case tcp.StateCloseWait: + return linux.TCP_CLOSE_WAIT + case tcp.StateLastAck: + return linux.TCP_LAST_ACK + case tcp.StateListen: + return linux.TCP_LISTEN + case tcp.StateClosing: + return linux.TCP_CLOSING + default: + // Internal or unknown state. + return 0 + } + } + + // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. + return 0 +} diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 41f9693bb..0d75580a3 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -19,7 +19,9 @@ import ( "syscall" "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/binary" "gvisor.googlesource.com/gvisor/pkg/fdnotifier" + "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" @@ -519,6 +521,28 @@ func translateIOSyscallError(err error) error { return err } +// State implements socket.Socket.State. +func (s *socketOperations) State() uint32 { + info := linux.TCPInfo{} + buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo) + if err != nil { + if err != syscall.ENOPROTOOPT { + log.Warningf("Failed to get TCP socket info from %+v: %v", s, err) + } + // For non-TCP sockets, silently ignore the failure. + return 0 + } + if len(buf) != linux.SizeOfTCPInfo { + // Unmarshal below will panic if getsockopt returns a buffer of + // unexpected size. + log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo) + return 0 + } + + binary.Unmarshal(buf, usermem.ByteOrder, &info) + return uint32(info.State) +} + type socketProvider struct { family int } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index afd06ca33..16c79aa33 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -616,3 +616,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) return int64(n), err.ToError() } + +// State implements socket.Socket.State. +func (s *Socket) State() uint32 { + return s.ep.State() +} diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 55e0b6665..bf42bdf69 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -830,6 +830,12 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } } +// State implements socket.Socket.State. +func (s *socketOperations) State() uint32 { + // TODO(b/127845868): Define a new rpc to query the socket state. + return 0 +} + type socketProvider struct { family int } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 9393acd28..a99423365 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -116,6 +116,10 @@ type Socket interface { // SendTimeout gets the current timeout (in ns) for send operations. Zero // means no timeout, and negative means DONTWAIT. SendTimeout() int64 + + // State returns the current state of the socket, as represented by Linux in + // procfs. The returned state value is protocol-specific. + State() uint32 } // Provider is the interface implemented by providers of sockets for specific diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 5a2de0c4c..52f324eed 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -28,6 +28,7 @@ go_library( importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport", visibility = ["//:sandbox"], deps = [ + "//pkg/abi/linux", "//pkg/ilist", "//pkg/refs", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 18e492862..9c8ec0365 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -17,6 +17,7 @@ package transport import ( "sync" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/waiter" @@ -458,3 +459,11 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return ready } + +// State implements socket.Socket.State. +func (e *connectionedEndpoint) State() uint32 { + if e.Connected() { + return linux.SS_CONNECTED + } + return linux.SS_UNCONNECTED +} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 43ff875e4..c034cf984 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -15,6 +15,7 @@ package transport import ( + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/waiter" @@ -194,3 +195,18 @@ func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMa return ready } + +// State implements socket.Socket.State. +func (e *connectionlessEndpoint) State() uint32 { + e.Lock() + defer e.Unlock() + + switch { + case e.isBound(): + return linux.SS_UNCONNECTED + case e.Connected(): + return linux.SS_CONNECTING + default: + return linux.SS_DISCONNECTING + } +} diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 37d82bb6b..5fc09af55 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -191,6 +191,10 @@ type Endpoint interface { // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + + // State returns the current state of the socket, as represented by Linux in + // procfs. + State() uint32 } // A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 388cc0d8b..375542350 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -596,6 +596,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } +// State implements socket.Socket.State. +func (s *SocketOperations) State() uint32 { + return s.ep.State() +} + // provider is a unix domain socket provider. type provider struct{} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8d74f1543..e8a9392b5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -188,6 +188,10 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s f.proto.controlCount++ } +func (f *fakeTransportEndpoint) State() uint32 { + return 0 +} + type fakeTransportGoodOption bool type fakeTransportBadOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f9886c6e4..85ef014d0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -377,6 +377,10 @@ type Endpoint interface { // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error + + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 } // WriteOptions contains options for Endpoint.Write. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9aa6f3978..84a2b53b7 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", + "//pkg/tcpip/transport/tcp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e2b90ef10..b8005093a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -708,3 +708,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't +// expose internal socket state. +func (e *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 1daf5823f..e4ff50c91 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -519,3 +519,8 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b ep.waiterQueue.Notify(waiter.EventIn) } } + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 31e365ae5..a32e20b06 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -226,7 +226,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n.isRegistered = true - n.state = stateConnecting // Create sender and receiver. // @@ -258,8 +257,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.Close() return nil, err } - - ep.state = stateConnected + ep.mu.Lock() + ep.state = StateEstablished + ep.mu.Unlock() // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -276,7 +276,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.RLock() state := e.state e.mu.RUnlock() - if state == stateListen { + if state == StateListen { e.acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { @@ -406,7 +406,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. - n.state = stateConnected + n.state = StateEstablished // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -429,7 +429,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = stateClosed + e.state = StateClose // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 371d2ed29..0ad7bfb38 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -151,6 +151,9 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.listenEP = listenEP + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() } // checkAck checks if the ACK number, if present, of a segment received during @@ -219,6 +222,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // but resend our own SYN and wait for it to be acknowledged in the // SYN-RCVD state. h.state = handshakeSynRcvd + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: rcvSynOpts.TS, @@ -668,7 +674,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -719,8 +725,7 @@ func (e *endpoint) handleClose() *tcpip.Error { // protocol goroutine. func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) - - e.state = stateError + e.state = StateError e.hardError = err } @@ -876,14 +881,19 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // handshake, and then inform potential waiters about its // completion. h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + e.mu.Lock() + h.ep.state = StateSynSent + e.mu.Unlock() + if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() e.mu.Lock() - e.state = stateError + e.state = StateError e.hardError = err + // Lock released below. epilogue() @@ -905,7 +915,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = stateConnected + e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -1005,7 +1015,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } } - if e.state != stateError { + if e.state != StateError { close(e.drainDone) <-e.undrain } @@ -1061,8 +1071,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() - if e.state != stateError { - e.state = stateClosed + if e.state != StateError { + e.state = StateClose } // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index fd697402e..23422ca5e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -32,18 +32,81 @@ import ( "gvisor.googlesource.com/gvisor/pkg/waiter" ) -type endpointState int +// EndpointState represents the state of a TCP endpoint. +type EndpointState uint32 +// Endpoint states. Note that are represented in a netstack-specific manner and +// may not be meaningful externally. Specifically, they need to be translated to +// Linux's representation for these states if presented to userspace. const ( - stateInitial endpointState = iota - stateBound - stateListen - stateConnecting - stateConnected - stateClosed - stateError + // Endpoint states internal to netstack. These map to the TCP state CLOSED. + StateInitial EndpointState = iota + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError + + // TCP protocol states. + StateEstablished + StateSynSent + StateSynRecv + StateFinWait1 + StateFinWait2 + StateTimeWait + StateClose + StateCloseWait + StateLastAck + StateListen + StateClosing ) +// connected is the set of states where an endpoint is connected to a peer. +func (s EndpointState) connected() bool { + switch s { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + return true + default: + return false + } +} + +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnecting: + return "CONNECTING" + case StateError: + return "ERROR" + case StateEstablished: + return "ESTABLISHED" + case StateSynSent: + return "SYN-SENT" + case StateSynRecv: + return "SYN-RCVD" + case StateFinWait1: + return "FIN-WAIT1" + case StateFinWait2: + return "FIN-WAIT2" + case StateTimeWait: + return "TIME-WAIT" + case StateClose: + return "CLOSED" + case StateCloseWait: + return "CLOSE-WAIT" + case StateLastAck: + return "LAST-ACK" + case StateListen: + return "LISTEN" + case StateClosing: + return "CLOSING" + default: + panic("unreachable") + } +} + // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota @@ -108,10 +171,14 @@ type endpoint struct { rcvBufUsed int // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID - state endpointState `state:".(endpointState)"` - isPortReserved bool `state:"manual"` + mu sync.RWMutex `state:"nosave"` + id stack.TransportEndpointID + + // state endpointState `state:".(endpointState)"` + // pState ProtocolState + state EndpointState `state:".(EndpointState)"` + + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` @@ -304,6 +371,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite stack: stack, netProto: netProto, waiterQueue: waiterQueue, + state: StateInitial, rcvBufSize: DefaultBufferSize, sndBufSize: DefaultBufferSize, sndMTU: int(math.MaxInt32), @@ -351,14 +419,14 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { defer e.mu.RUnlock() switch e.state { - case stateInitial, stateBound, stateConnecting: + case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case stateClosed, stateError: + case StateClose, StateError: // Ready for anything. result = mask - case stateListen: + case StateListen: // Check if there's anything in the accepted channel. if (mask & waiter.EventIn) != 0 { if len(e.acceptedChan) > 0 { @@ -366,7 +434,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -427,7 +495,7 @@ func (e *endpoint) Close() { // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == stateListen && e.isPortReserved { + if e.state == StateListen && e.isPortReserved { if e.isRegistered { e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.isRegistered = false @@ -487,15 +555,15 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RLock() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received - // would cause the state to become stateError so we should allow the + // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 { + if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.hardError e.mu.RUnlock() - if s == stateError { + if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -511,7 +579,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -547,9 +615,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c defer e.mu.RUnlock() // The endpoint cannot be written to if it's not connected. - if e.state != stateConnected { + if !e.state.connected() { switch e.state { - case stateError: + case StateError: return 0, nil, e.hardError default: return 0, nil, tcpip.ErrClosedForSend @@ -612,8 +680,8 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; s != stateConnected && s != stateClosed { - if s == stateError { + if s := e.state; !s.connected() && s != StateClose { + if s == StateError { return 0, tcpip.ControlMessages{}, e.hardError } return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -623,7 +691,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -789,7 +857,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -841,7 +909,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == stateListen { + if e.state == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1057,7 +1125,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid := addr.NIC switch e.state { - case stateBound: + case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. if e.boundNICID == 0 { @@ -1070,16 +1138,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid = e.boundNICID - case stateInitial: - // Nothing to do. We'll eventually fill-in the gaps in the ID - // (if any) when we find a route. + case StateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) + // when we find a route. - case stateConnecting: - // A connection request has already been issued but hasn't - // completed yet. + case StateConnecting, StateSynSent, StateSynRecv: + // A connection request has already been issued but hasn't completed + // yet. return tcpip.ErrAlreadyConnecting - case stateConnected: + case StateEstablished: // The endpoint is already connected. If caller hasn't been notified yet, return success. if !e.isConnectNotified { e.isConnectNotified = true @@ -1088,7 +1156,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // Otherwise return that it's already connected. return tcpip.ErrAlreadyConnected - case stateError: + case StateError: return e.hardError default: @@ -1154,7 +1222,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.isRegistered = true - e.state = stateConnecting + e.state = StateConnecting e.route = r.Clone() e.boundNICID = nicid e.effectiveNetProtos = netProtos @@ -1175,7 +1243,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = stateConnected + e.state = StateEstablished } if run { @@ -1199,8 +1267,8 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { defer e.mu.Unlock() e.shutdownFlags |= flags - switch e.state { - case stateConnected: + switch { + case e.state.connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1241,7 +1309,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.sndCloseWaker.Assert() } - case stateListen: + case e.state == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1269,7 +1337,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == stateListen && !e.workerCleanup { + if e.state == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1288,7 +1356,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Endpoint must be bound before it can transition to listen mode. - if e.state != stateBound { + if e.state != StateBound { return tcpip.ErrInvalidEndpointState } @@ -1298,7 +1366,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } e.isRegistered = true - e.state = stateListen + e.state = StateListen if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } @@ -1325,7 +1393,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != stateListen { + if e.state != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -1353,7 +1421,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrAlreadyBound } @@ -1408,7 +1476,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = stateBound + e.state = StateBound return nil } @@ -1430,7 +1498,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { + if !e.state.connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -1739,3 +1807,11 @@ func (e *endpoint) initGSO() { gso.MaxSize = e.route.GSOMaxSize() e.gso = gso } + +// State implements tcpip.Endpoint.State. It exports the endpoint's protocol +// state for diagnostics. +func (e *endpoint) State() uint32 { + e.mu.Lock() + defer e.mu.Unlock() + return uint32(e.state) +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index e8aed2875..5f30c2374 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -49,8 +49,8 @@ func (e *endpoint) beforeSave() { defer e.mu.Unlock() switch e.state { - case stateInitial, stateBound: - case stateConnected: + case StateInitial, StateBound: + case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) @@ -66,17 +66,17 @@ func (e *endpoint) beforeSave() { break } fallthrough - case stateListen, stateConnecting: + case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != stateClosed && e.state != stateError { + if e.state != StateClose && e.state != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } fallthrough - case stateError, stateClosed: - for e.state == stateError && e.workerRunning { + case StateError, StateClose: + for e.state == StateError && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -92,7 +92,7 @@ func (e *endpoint) beforeSave() { panic("endpoint still has waiters upon save") } - if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) { + if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -132,7 +132,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { } // saveState is invoked by stateify. -func (e *endpoint) saveState() endpointState { +func (e *endpoint) saveState() EndpointState { return e.state } @@ -146,15 +146,15 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state endpointState) { +func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: connectedLoading.Add(1) - case stateListen: + case StateListen: listenLoading.Add(1) - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } e.state = state @@ -168,7 +168,7 @@ func (e *endpoint) afterLoad() { state := e.state switch state { - case stateInitial, stateBound, stateListen, stateConnecting, stateConnected: + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { @@ -181,7 +181,7 @@ func (e *endpoint) afterLoad() { } bind := func() { - e.state = stateInitial + e.state = StateInitial if len(e.bindAddress) == 0 { e.bindAddress = e.id.LocalAddress } @@ -191,7 +191,7 @@ func (e *endpoint) afterLoad() { } switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { // This endpoint is accepted by netstack but not yet by @@ -211,7 +211,7 @@ func (e *endpoint) afterLoad() { panic("endpoint connecting failed: " + err.String()) } connectedLoading.Done() - case stateListen: + case StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -223,7 +223,7 @@ func (e *endpoint) afterLoad() { listenLoading.Done() tcpip.AsyncLoading.Done() }() - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -235,7 +235,7 @@ func (e *endpoint) afterLoad() { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case stateBound: + case StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -244,7 +244,7 @@ func (e *endpoint) afterLoad() { bind() tcpip.AsyncLoading.Done() }() - case stateClosed: + case StateClose: if e.isPortReserved { tcpip.AsyncLoading.Add(1) go func() { @@ -252,12 +252,12 @@ func (e *endpoint) afterLoad() { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = stateClosed + e.state = StateClose tcpip.AsyncLoading.Done() }() } fallthrough - case stateError: + case StateError: tcpip.DeleteDanglingEndpoint(e) } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index b08a0e356..f02fa6105 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -134,6 +134,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // sequence numbers that have been consumed. TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { r.rcvNxt++ @@ -144,6 +145,25 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.closed = true r.ep.readyToRead(nil) + // We just received a FIN, our next state depends on whether we sent a + // FIN already or not. + r.ep.mu.Lock() + switch r.ep.state { + case StateEstablished: + r.ep.state = StateCloseWait + case StateFinWait1: + if s.flagIsSet(header.TCPFlagAck) { + // FIN-ACK, transition to TIME-WAIT. + r.ep.state = StateTimeWait + } else { + // Simultaneous close, expecting a final ACK. + r.ep.state = StateClosing + } + case StateFinWait2: + r.ep.state = StateTimeWait + } + r.ep.mu.Unlock() + // Flush out any pending segments, except the very first one if // it happens to be the one we're handling now because the // caller is using it. @@ -156,6 +176,23 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.pendingRcvdSegments[i].decRef() } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] + + return true + } + + // Handle ACK (not FIN-ACK, which we handled above) during one of the + // shutdown states. + if s.flagIsSet(header.TCPFlagAck) { + r.ep.mu.Lock() + switch r.ep.state { + case StateFinWait1: + r.ep.state = StateFinWait2 + case StateClosing: + r.ep.state = StateTimeWait + case StateLastAck: + r.ep.state = StateClose + } + r.ep.mu.Unlock() } return true diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 3464e4be7..b236d7af2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -632,6 +632,10 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) + // Transition to FIN-WAIT1 state since we're initiating an active close. + s.ep.mu.Lock() + s.ep.state = StateFinWait1 + s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index b8f0ccaf1..56b490aaa 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -168,8 +168,8 @@ func TestTCPResetsSentIncrement(t *testing.T) { // Receive the SYN-ACK reply. b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -269,8 +269,8 @@ func TestConnectResetAfterClose(t *testing.T) { time.Sleep(3 * time.Second) for { b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin { + tcpHdr := header.TCP(header.IPv4(b).Payload()) + if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { // This is a retransmit of the FIN, ignore it. continue } @@ -553,9 +553,13 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } - // This final should be ignored because an ACK on a reset doesn't - // mean anything. + // This final ACK should be ignored because an ACK on a reset doesn't mean + // anything. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -618,6 +622,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.SeqNum(uint32(c.IRS)+1), )) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Cause a RST to be generated by closing the read end now since we have // unread data. c.EP.Shutdown(tcpip.ShutdownRead) @@ -630,6 +638,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } // The ACK to the FIN should now be rejected since the connection has been // closed by a RST. @@ -1510,8 +1522,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { for bytesReceived != dataLen { b := c.GetPacket() numPackets++ - tcp := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcp.Payload()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + payloadLen := len(tcpHdr.Payload()) checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), @@ -1522,7 +1534,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { ) pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcp.Payload(); !bytes.Equal(pdata, p) { + if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { t.Fatalf("got data = %v, want = %v", p, pdata) } bytesReceived += payloadLen @@ -1530,7 +1542,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { if c.TimeStampEnabled { // If timestamp option is enabled, echo back the timestamp and increment // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcp.ParsedOptions() + parsedOpts := tcpHdr.ParsedOptions() tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) options = tsOpt[:] @@ -1757,8 +1769,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) // Wait for retransmit. time.Sleep(1 * time.Second) @@ -1766,8 +1778,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcp.SourcePort()), - checker.SeqNum(tcp.SequenceNumber()), + checker.SrcPort(tcpHdr.SourcePort()), + checker.SeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -1775,8 +1787,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Send SYN-ACK. iss := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -2523,8 +2535,8 @@ func TestReceivedSegmentQueuing(t *testing.T) { checker.TCPFlags(header.TCPFlagAck), ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcp.AckNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + ack := seqnum.Value(tcpHdr.AckNumber()) if ack == last { break } @@ -2568,6 +2580,10 @@ func TestReadAfterClosedState(t *testing.T) { ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Send some data and acknowledge the FIN. data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -2589,9 +2605,15 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. + // Give the stack the chance to transition to closed state. Note that since + // both the sender and receiver are now closed, we effectively skip the + // TIME-WAIT state. time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Wait for receive to be notified. select { case <-ch: @@ -3680,9 +3702,15 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := c.EP.Listen(1); err != nil { t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } stats := c.Stack().Stats() want := stats.TCP.PassiveConnectionOpenings.Value() + 1 @@ -3826,3 +3854,68 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { } } } + +func TestEndpointBindListenAcceptState(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + aep, _, err := ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + aep, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Listening endpoint remains in listen state. + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + ep.Close() + // Give worker goroutines time to receive the close notification. + time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Accepted endpoint remains open when the listen endpoint is closed. + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6e12413c6..69a43b6f4 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -532,6 +532,9 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if epRcvBuf != nil { if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { @@ -557,13 +560,16 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. checker.TCPFlags(header.TCPFlagSyn), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -591,8 +597,11 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - c.Port = tcp.SourcePort() + c.Port = tcpHdr.SourcePort() } // RawEndpoint is just a small wrapper around a TCP endpoint's state to make @@ -690,6 +699,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * if err != nil { c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -719,6 +731,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * }), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpSeg := header.TCP(header.IPv4(b).Payload()) synOptions := header.ParseSynOptions(tcpSeg.Options(), false) @@ -782,6 +798,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Store the source port in use by the endpoint. c.Port = tcpSeg.SourcePort() @@ -821,10 +840,16 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := ep.Listen(10); err != nil { c.t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) @@ -847,6 +872,10 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption c.t.Fatalf("Timed out waiting for accept") } } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + return rep } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 3d52a4f31..fa7278286 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1000,3 +1000,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements socket.Socket.State. +func (e *endpoint) State() uint32 { + // TODO(b/112063468): Translate internal state to values returned by Linux. + return 0 +} diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 6d745f728..82d325c17 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -34,6 +34,16 @@ using absl::StrFormat; constexpr char kProcNetUnixHeader[] = "Num RefCount Protocol Flags Type St Inode Path"; +// Possible values of the "st" field in a /proc/net/unix entry. Source: Linux +// kernel, include/uapi/linux/net.h. +enum { + SS_FREE = 0, // Not allocated + SS_UNCONNECTED, // Unconnected to any socket + SS_CONNECTING, // In process of connecting + SS_CONNECTED, // Connected to socket + SS_DISCONNECTING // In process of disconnecting +}; + // UnixEntry represents a single entry from /proc/net/unix. struct UnixEntry { uintptr_t addr; @@ -71,7 +81,12 @@ PosixErrorOr> ProcNetUnixEntries() { bool skipped_header = false; std::vector entries; std::vector lines = absl::StrSplit(content, absl::ByAnyChar("\n")); + std::cerr << "" << std::endl; for (std::string line : lines) { + // Emit the proc entry to the test output to provide context for the test + // results. + std::cerr << line << std::endl; + if (!skipped_header) { EXPECT_EQ(line, kProcNetUnixHeader); skipped_header = true; @@ -139,6 +154,7 @@ PosixErrorOr> ProcNetUnixEntries() { entries.push_back(entry); } + std::cerr << "" << std::endl; return entries; } @@ -241,6 +257,168 @@ TEST(ProcNetUnix, SocketPair) { EXPECT_EQ(entries.size(), 2); } +TEST(ProcNetUnix, StreamSocketStateUnconnectedOnBind) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, StreamSocketStateStateUnconnectedOnListen) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); + + ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry listen_entry; + ASSERT_TRUE( + FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); + EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); + // The bind and listen entries should refer to the same socket. + EXPECT_EQ(listen_entry.inode, bind_entry.inode); +} + +TEST(ProcNetUnix, StreamSocketStateStateConnectedOnAccept) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + const std::string address = ExtractPath(sockets->first_addr()); + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry listen_entry; + ASSERT_TRUE( + FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + int clientfd; + ASSERT_THAT(clientfd = accept(sockets->first_fd(), nullptr, nullptr), + SyscallSucceeds()); + + // Find the entry for the accepted socket. UDS proc entries don't have a + // remote address, so we distinguish the accepted socket from the listen + // socket by checking for a different inode. + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry accept_entry; + ASSERT_TRUE(FindBy( + entries, &accept_entry, [address, listen_entry](const UnixEntry& e) { + return e.path == address && e.inode != listen_entry.inode; + })); + EXPECT_EQ(accept_entry.state, SS_CONNECTED); + // Listen entry should still be in SS_UNCONNECTED state. + ASSERT_TRUE(FindBy(entries, &listen_entry, + [&sockets, listen_entry](const UnixEntry& e) { + return e.path == ExtractPath(sockets->first_addr()) && + e.inode == listen_entry.inode; + })); + EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // On gVisor, the only two UDS on the system are the ones we just created and + // we rely on this to locate the test socket entries in the remainder of the + // test. On a generic Linux system, we have no easy way to locate the + // corresponding entries, as they don't have an address yet. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + for (auto e : entries) { + ASSERT_EQ(e.state, SS_DISCONNECTING); + } + } + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // On gVisor, the only two UDS on the system are the ones we just created and + // we rely on this to locate the test socket entries in the remainder of the + // test. On a generic Linux system, we have no easy way to locate the + // corresponding entries, as they don't have an address yet. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + for (auto e : entries) { + ASSERT_EQ(e.state, SS_DISCONNECTING); + } + } + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // Once again, we have no easy way to identify the connecting socket as it has + // no listed address. We can only identify the entry as the "non-bind socket + // entry" on gVisor, where we're guaranteed to have only the two entries we + // create during this test. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + UnixEntry connect_entry; + ASSERT_TRUE( + FindBy(entries, &connect_entry, [bind_entry](const UnixEntry& e) { + return e.inode != bind_entry.inode; + })); + EXPECT_EQ(connect_entry.state, SS_CONNECTING); + } +} + } // namespace } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 93aa7d11673392ca51ba69122ff5fe1aad7331b9 Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Thu, 6 Jun 2019 15:54:54 -0700 Subject: Remove tmpfs restriction from test runsc supports UDS over gofer mounts and tmpfs is not needed for this test. PiperOrigin-RevId: 251944870 --- test/syscalls/BUILD | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 2985275bb..731e2aa85 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -669,12 +669,7 @@ syscall_test( test = "//test/syscalls/linux:write_test", ) -syscall_test( - test = "//test/syscalls/linux:proc_net_unix_test", - # Unix domain socket creation isn't supported on all file systems. The - # sentry-internal tmpfs is known to support it. - use_tmpfs = True, -) +syscall_test(test = "//test/syscalls/linux:proc_net_unix_test") go_binary( name = "syscall_test_runner", -- cgit v1.2.3 From a26043ee53a2f38b81c9eaa098d115025e87f4c3 Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Thu, 6 Jun 2019 16:26:00 -0700 Subject: Implement reclaim-driven MemoryFile eviction. PiperOrigin-RevId: 251950660 --- pkg/sentry/hostmm/BUILD | 18 ++++++ pkg/sentry/hostmm/cgroup.go | 111 ++++++++++++++++++++++++++++++++++++ pkg/sentry/hostmm/hostmm.go | 130 ++++++++++++++++++++++++++++++++++++++++++ pkg/sentry/pgalloc/BUILD | 1 + pkg/sentry/pgalloc/pgalloc.go | 63 +++++++++++++++++--- runsc/boot/loader.go | 3 + 6 files changed, 317 insertions(+), 9 deletions(-) create mode 100644 pkg/sentry/hostmm/BUILD create mode 100644 pkg/sentry/hostmm/cgroup.go create mode 100644 pkg/sentry/hostmm/hostmm.go diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD new file mode 100644 index 000000000..1a4632a54 --- /dev/null +++ b/pkg/sentry/hostmm/BUILD @@ -0,0 +1,18 @@ +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "hostmm", + srcs = [ + "cgroup.go", + "hostmm.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/hostmm", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/fd", + "//pkg/log", + "//pkg/sentry/usermem", + ], +) diff --git a/pkg/sentry/hostmm/cgroup.go b/pkg/sentry/hostmm/cgroup.go new file mode 100644 index 000000000..e5cc26ab2 --- /dev/null +++ b/pkg/sentry/hostmm/cgroup.go @@ -0,0 +1,111 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostmm + +import ( + "bufio" + "fmt" + "os" + "path" + "strings" +) + +// currentCgroupDirectory returns the directory for the cgroup for the given +// controller in which the calling process resides. +func currentCgroupDirectory(ctrl string) (string, error) { + root, err := cgroupRootDirectory(ctrl) + if err != nil { + return "", err + } + cg, err := currentCgroup(ctrl) + if err != nil { + return "", err + } + return path.Join(root, cg), nil +} + +// cgroupRootDirectory returns the root directory for the cgroup hierarchy in +// which the given cgroup controller is mounted in the calling process' mount +// namespace. +func cgroupRootDirectory(ctrl string) (string, error) { + const path = "/proc/self/mounts" + file, err := os.Open(path) + if err != nil { + return "", err + } + defer file.Close() + + // Per proc(5) -> fstab(5): + // Each line of /proc/self/mounts describes a mount. + scanner := bufio.NewScanner(file) + for scanner.Scan() { + // Each line consists of 6 space-separated fields. Find the line for + // which the third field (fs_vfstype) is cgroup, and the fourth field + // (fs_mntops, a comma-separated list of mount options) contains + // ctrl. + var spec, file, vfstype, mntopts, freq, passno string + const nrfields = 6 + line := scanner.Text() + n, err := fmt.Sscan(line, &spec, &file, &vfstype, &mntopts, &freq, &passno) + if err != nil { + return "", fmt.Errorf("failed to parse %s: %v", path, err) + } + if n != nrfields { + return "", fmt.Errorf("failed to parse %s: line %q: got %d fields, wanted %d", path, line, n, nrfields) + } + if vfstype != "cgroup" { + continue + } + for _, mntopt := range strings.Split(mntopts, ",") { + if mntopt == ctrl { + return file, nil + } + } + } + return "", fmt.Errorf("no cgroup hierarchy mounted for controller %s", ctrl) +} + +// currentCgroup returns the cgroup for the given controller in which the +// calling process resides. The returned string is a path that should be +// interpreted as relative to cgroupRootDirectory(ctrl). +func currentCgroup(ctrl string) (string, error) { + const path = "/proc/self/cgroup" + file, err := os.Open(path) + if err != nil { + return "", err + } + defer file.Close() + + // Per proc(5) -> cgroups(7): + // Each line of /proc/self/cgroups describes a cgroup hierarchy. + scanner := bufio.NewScanner(file) + for scanner.Scan() { + // Each line consists of 3 colon-separated fields. Find the line for + // which the second field (controller-list, a comma-separated list of + // cgroup controllers) contains ctrl. + line := scanner.Text() + const nrfields = 3 + fields := strings.Split(line, ":") + if len(fields) != nrfields { + return "", fmt.Errorf("failed to parse %s: line %q: got %d fields, wanted %d", path, line, len(fields), nrfields) + } + for _, controller := range strings.Split(fields[1], ",") { + if controller == ctrl { + return fields[2], nil + } + } + } + return "", fmt.Errorf("not a member of a cgroup hierarchy for controller %s", ctrl) +} diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go new file mode 100644 index 000000000..5432cada9 --- /dev/null +++ b/pkg/sentry/hostmm/hostmm.go @@ -0,0 +1,130 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package hostmm provides tools for interacting with the host Linux kernel's +// virtual memory management subsystem. +package hostmm + +import ( + "fmt" + "os" + "path" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/fd" + "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" +) + +// NotifyCurrentMemcgPressureCallback requests that f is called whenever the +// calling process' memory cgroup indicates memory pressure of the given level, +// as specified by Linux's Documentation/cgroup-v1/memory.txt. +// +// If NotifyCurrentMemcgPressureCallback succeeds, it returns a function that +// terminates the requested memory pressure notifications. This function may be +// called at most once. +func NotifyCurrentMemcgPressureCallback(f func(), level string) (func(), error) { + cgdir, err := currentCgroupDirectory("memory") + if err != nil { + return nil, err + } + + pressurePath := path.Join(cgdir, "memory.pressure_level") + pressureFile, err := os.Open(pressurePath) + if err != nil { + return nil, err + } + defer pressureFile.Close() + + eventControlPath := path.Join(cgdir, "cgroup.event_control") + eventControlFile, err := os.OpenFile(eventControlPath, os.O_WRONLY, 0) + if err != nil { + return nil, err + } + defer eventControlFile.Close() + + eventFD, err := newEventFD() + if err != nil { + return nil, err + } + + // Don't use fmt.Fprintf since the whole string needs to be written in a + // single syscall. + eventControlStr := fmt.Sprintf("%d %d %s", eventFD.FD(), pressureFile.Fd(), level) + if n, err := eventControlFile.Write([]byte(eventControlStr)); n != len(eventControlStr) || err != nil { + eventFD.Close() + return nil, fmt.Errorf("error writing %q to %s: got (%d, %v), wanted (%d, nil)", eventControlStr, eventControlPath, n, err, len(eventControlStr)) + } + + log.Debugf("Receiving memory pressure level notifications from %s at level %q", pressurePath, level) + const sizeofUint64 = 8 + // The most significant bit of the eventfd value is set by the stop + // function, which is practically unambiguous since it's not plausible for + // 2**63 pressure events to occur between eventfd reads. + const stopVal = 1 << 63 + stopCh := make(chan struct{}) + go func() { // S/R-SAFE: f provides synchronization if necessary + rw := fd.NewReadWriter(eventFD.FD()) + var buf [sizeofUint64]byte + for { + n, err := rw.Read(buf[:]) + if err != nil { + if err == syscall.EINTR { + continue + } + panic(fmt.Sprintf("failed to read from memory pressure level eventfd: %v", err)) + } + if n != sizeofUint64 { + panic(fmt.Sprintf("short read from memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64)) + } + val := usermem.ByteOrder.Uint64(buf[:]) + if val >= stopVal { + // Assume this was due to the notifier's "destructor" (the + // function returned by NotifyCurrentMemcgPressureCallback + // below) being called. + eventFD.Close() + close(stopCh) + return + } + f() + } + }() + return func() { + rw := fd.NewReadWriter(eventFD.FD()) + var buf [sizeofUint64]byte + usermem.ByteOrder.PutUint64(buf[:], stopVal) + for { + n, err := rw.Write(buf[:]) + if err != nil { + if err == syscall.EINTR { + continue + } + panic(fmt.Sprintf("failed to write to memory pressure level eventfd: %v", err)) + } + if n != sizeofUint64 { + panic(fmt.Sprintf("short write to memory pressure level eventfd: got %d bytes, wanted %d", n, sizeofUint64)) + } + break + } + <-stopCh + }, nil +} + +func newEventFD() (*fd.FD, error) { + f, _, e := syscall.Syscall(syscall.SYS_EVENTFD2, 0, 0, 0) + if e != 0 { + return nil, fmt.Errorf("failed to create eventfd: %v", e) + } + return fd.New(int(f)), nil +} diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 8a8a0e4e4..bbdb1f922 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -65,6 +65,7 @@ go_library( "//pkg/log", "//pkg/sentry/arch", "//pkg/sentry/context", + "//pkg/sentry/hostmm", "//pkg/sentry/memutil", "//pkg/sentry/platform", "//pkg/sentry/safemem", diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 2b9924ad7..6d91f1a7b 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -32,6 +32,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/sentry/context" + "gvisor.googlesource.com/gvisor/pkg/sentry/hostmm" "gvisor.googlesource.com/gvisor/pkg/sentry/platform" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/usage" @@ -162,6 +163,11 @@ type MemoryFile struct { // evictionWG counts the number of goroutines currently performing evictions. evictionWG sync.WaitGroup + + // stopNotifyPressure stops memory cgroup pressure level + // notifications used to drive eviction. stopNotifyPressure is + // immutable. + stopNotifyPressure func() } // MemoryFileOpts provides options to NewMemoryFile. @@ -169,6 +175,11 @@ type MemoryFileOpts struct { // DelayedEviction controls the extent to which the MemoryFile may delay // eviction of evictable allocations. DelayedEviction DelayedEvictionType + + // If UseHostMemcgPressure is true, use host memory cgroup pressure level + // notifications to determine when eviction is necessary. This option has + // no effect unless DelayedEviction is DelayedEvictionEnabled. + UseHostMemcgPressure bool } // DelayedEvictionType is the type of MemoryFileOpts.DelayedEviction. @@ -186,9 +197,14 @@ const ( // evictable allocations until doing so is considered necessary to avoid // performance degradation due to host memory pressure, or OOM kills. // - // As of this writing, DelayedEvictionEnabled delays evictions until the - // reclaimer goroutine is out of work (pages to reclaim), then evicts all - // pending evictable allocations immediately. + // As of this writing, the behavior of DelayedEvictionEnabled depends on + // whether or not MemoryFileOpts.UseHostMemcgPressure is enabled: + // + // - If UseHostMemcgPressure is true, evictions are delayed until memory + // pressure is indicated. + // + // - Otherwise, evictions are only delayed until the reclaimer goroutine + // is out of work (pages to reclaim). DelayedEvictionEnabled // DelayedEvictionManual requires that evictable allocations are only @@ -292,6 +308,22 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) { } f.mappings.Store(make([]uintptr, initialSize/chunkSize)) f.reclaimCond.L = &f.mu + + if f.opts.DelayedEviction == DelayedEvictionEnabled && f.opts.UseHostMemcgPressure { + stop, err := hostmm.NotifyCurrentMemcgPressureCallback(func() { + f.mu.Lock() + startedAny := f.startEvictionsLocked() + f.mu.Unlock() + if startedAny { + log.Debugf("pgalloc.MemoryFile performing evictions due to memcg pressure") + } + }, "low") + if err != nil { + return nil, fmt.Errorf("failed to configure memcg pressure level notifications: %v", err) + } + f.stopNotifyPressure = stop + } + go f.runReclaim() // S/R-SAFE: f.mu // The Linux kernel contains an optional feature called "Integrity @@ -692,9 +724,11 @@ func (f *MemoryFile) MarkEvictable(user EvictableMemoryUser, er EvictableRange) // Kick off eviction immediately. f.startEvictionGoroutineLocked(user, info) case DelayedEvictionEnabled: - // Ensure that the reclaimer goroutine is running, so that it can - // start eviction when necessary. - f.reclaimCond.Signal() + if !f.opts.UseHostMemcgPressure { + // Ensure that the reclaimer goroutine is running, so that it + // can start eviction when necessary. + f.reclaimCond.Signal() + } } } } @@ -992,11 +1026,12 @@ func (f *MemoryFile) runReclaim() { } f.markReclaimed(fr) } + // We only get here if findReclaimable finds f.destroyed set and returns // false. f.mu.Lock() - defer f.mu.Unlock() if !f.destroyed { + f.mu.Unlock() panic("findReclaimable broke out of reclaim loop, but destroyed is no longer set") } f.file.Close() @@ -1016,6 +1051,13 @@ func (f *MemoryFile) runReclaim() { } // Similarly, invalidate f.mappings. (atomic.Value.Store(nil) panics.) f.mappings.Store([]uintptr{}) + f.mu.Unlock() + + // This must be called without holding f.mu to avoid circular lock + // ordering. + if f.stopNotifyPressure != nil { + f.stopNotifyPressure() + } } func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { @@ -1029,7 +1071,7 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { if f.reclaimable { break } - if f.opts.DelayedEviction == DelayedEvictionEnabled { + if f.opts.DelayedEviction == DelayedEvictionEnabled && !f.opts.UseHostMemcgPressure { // No work to do. Evict any pending evictable allocations to // get more reclaimable pages before going to sleep. f.startEvictionsLocked() @@ -1089,14 +1131,17 @@ func (f *MemoryFile) StartEvictions() { } // Preconditions: f.mu must be locked. -func (f *MemoryFile) startEvictionsLocked() { +func (f *MemoryFile) startEvictionsLocked() bool { + startedAny := false for user, info := range f.evictable { // Don't start multiple goroutines to evict the same user's // allocations. if !info.evicting { f.startEvictionGoroutineLocked(user, info) + startedAny = true } } + return startedAny } // Preconditions: info == f.evictable[user]. !info.evicting. f.mu must be diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index a997776f8..ef4ccd0bd 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -424,6 +424,9 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) { return nil, fmt.Errorf("error creating memfd: %v", err) } memfile := os.NewFile(uintptr(memfd), memfileName) + // We can't enable pgalloc.MemoryFileOpts.UseHostMemcgPressure even if + // there are memory cgroups specified, because at this point we're already + // in a mount namespace in which the relevant cgroupfs is not visible. mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{}) if err != nil { memfile.Close() -- cgit v1.2.3 From b3f104507d7a04c0ca058cbcacc5ff78d853f4ba Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Thu, 6 Jun 2019 16:27:09 -0700 Subject: "Implement" mbind(2). We still only advertise a single NUMA node, and ignore mempolicy accordingly, but mbind() at least now succeeds and has effects reflected by get_mempolicy(). Also fix handling of nodemasks: round sizes to unsigned long (as documented and done by Linux), and zero trailing bits when copying them out. PiperOrigin-RevId: 251950859 --- pkg/abi/linux/mm.go | 9 + pkg/sentry/kernel/task.go | 7 +- pkg/sentry/kernel/task_sched.go | 4 +- pkg/sentry/mm/mm.go | 6 + pkg/sentry/mm/syscalls.go | 53 +++++ pkg/sentry/mm/vma.go | 3 + pkg/sentry/syscalls/linux/BUILD | 1 + pkg/sentry/syscalls/linux/linux64.go | 3 +- pkg/sentry/syscalls/linux/sys_mempolicy.go | 312 +++++++++++++++++++++++++++++ pkg/sentry/syscalls/linux/sys_mmap.go | 145 -------------- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/mempolicy.cc | 37 +++- 12 files changed, 426 insertions(+), 155 deletions(-) create mode 100644 pkg/sentry/syscalls/linux/sys_mempolicy.go diff --git a/pkg/abi/linux/mm.go b/pkg/abi/linux/mm.go index 0b02f938a..cd043dac3 100644 --- a/pkg/abi/linux/mm.go +++ b/pkg/abi/linux/mm.go @@ -114,3 +114,12 @@ const ( MPOL_MODE_FLAGS = (MPOL_F_STATIC_NODES | MPOL_F_RELATIVE_NODES) ) + +// Flags for mbind(2). +const ( + MPOL_MF_STRICT = 1 << 0 + MPOL_MF_MOVE = 1 << 1 + MPOL_MF_MOVE_ALL = 1 << 2 + + MPOL_MF_VALID = MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL +) diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index f9378c2de..4d889422f 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -455,12 +455,13 @@ type Task struct { // single numa node, all policies are no-ops. We only track this information // so that we can return reasonable values if the application calls // get_mempolicy(2) after setting a non-default policy. Note that in the - // real syscall, nodemask can be longer than 4 bytes, but we always report a - // single node so never need to save more than a single bit. + // real syscall, nodemask can be longer than a single unsigned long, but we + // always report a single node so never need to save more than a single + // bit. // // numaPolicy and numaNodeMask are protected by mu. numaPolicy int32 - numaNodeMask uint32 + numaNodeMask uint64 // If netns is true, the task is in a non-root network namespace. Network // namespaces aren't currently implemented in full; being in a network diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index 5455f6ea9..1c94ab11b 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -622,14 +622,14 @@ func (t *Task) SetNiceness(n int) { } // NumaPolicy returns t's current numa policy. -func (t *Task) NumaPolicy() (policy int32, nodeMask uint32) { +func (t *Task) NumaPolicy() (policy int32, nodeMask uint64) { t.mu.Lock() defer t.mu.Unlock() return t.numaPolicy, t.numaNodeMask } // SetNumaPolicy sets t's numa policy. -func (t *Task) SetNumaPolicy(policy int32, nodeMask uint32) { +func (t *Task) SetNumaPolicy(policy int32, nodeMask uint64) { t.mu.Lock() defer t.mu.Unlock() t.numaPolicy = policy diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 0a026ff8c..604866d04 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -276,6 +276,12 @@ type vma struct { mlockMode memmap.MLockMode + // numaPolicy is the NUMA policy for this vma set by mbind(). + numaPolicy int32 + + // numaNodemask is the NUMA nodemask for this vma set by mbind(). + numaNodemask uint64 + // If id is not nil, it controls the lifecycle of mappable and provides vma // metadata shown in /proc/[pid]/maps, and the vma holds a reference. id memmap.MappingIdentity diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index af1e53f5d..9cf136532 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -973,6 +973,59 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error return nil } +// NumaPolicy implements the semantics of Linux's get_mempolicy(MPOL_F_ADDR). +func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (int32, uint64, error) { + mm.mappingMu.RLock() + defer mm.mappingMu.RUnlock() + vseg := mm.vmas.FindSegment(addr) + if !vseg.Ok() { + return 0, 0, syserror.EFAULT + } + vma := vseg.ValuePtr() + return vma.numaPolicy, vma.numaNodemask, nil +} + +// SetNumaPolicy implements the semantics of Linux's mbind(). +func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy int32, nodemask uint64) error { + if !addr.IsPageAligned() { + return syserror.EINVAL + } + // Linux allows this to overflow. + la, _ := usermem.Addr(length).RoundUp() + ar, ok := addr.ToRange(uint64(la)) + if !ok { + return syserror.EINVAL + } + if ar.Length() == 0 { + return nil + } + + mm.mappingMu.Lock() + defer mm.mappingMu.Unlock() + defer func() { + mm.vmas.MergeRange(ar) + mm.vmas.MergeAdjacent(ar) + }() + vseg := mm.vmas.LowerBoundSegment(ar.Start) + lastEnd := ar.Start + for { + if !vseg.Ok() || lastEnd < vseg.Start() { + // "EFAULT: ... there was an unmapped hole in the specified memory + // range specified [sic] by addr and len." - mbind(2) + return syserror.EFAULT + } + vseg = mm.vmas.Isolate(vseg, ar) + vma := vseg.ValuePtr() + vma.numaPolicy = policy + vma.numaNodemask = nodemask + lastEnd = vseg.End() + if ar.End <= lastEnd { + return nil + } + vseg, _ = vseg.NextNonEmpty() + } +} + // Decommit implements the semantics of Linux's madvise(MADV_DONTNEED). func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error { ar, ok := addr.ToRange(length) diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 02203f79f..0af8de5b0 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -107,6 +107,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp private: opts.Private, growsDown: opts.GrowsDown, mlockMode: opts.MLockMode, + numaPolicy: linux.MPOL_DEFAULT, id: opts.MappingIdentity, hint: opts.Hint, } @@ -436,6 +437,8 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa vma1.private != vma2.private || vma1.growsDown != vma2.growsDown || vma1.mlockMode != vma2.mlockMode || + vma1.numaPolicy != vma2.numaPolicy || + vma1.numaNodemask != vma2.numaNodemask || vma1.id != vma2.id || vma1.hint != vma2.hint { return vma{}, false diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index f76989ae2..1c057526b 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -19,6 +19,7 @@ go_library( "sys_identity.go", "sys_inotify.go", "sys_lseek.go", + "sys_mempolicy.go", "sys_mmap.go", "sys_mount.go", "sys_pipe.go", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 3e4d312af..ad88b1391 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -360,8 +360,7 @@ var AMD64 = &kernel.SyscallTable{ 235: Utimes, // @Syscall(Vserver, note:Not implemented by Linux) 236: syscalls.Error(syscall.ENOSYS), // Vserver, not implemented by Linux - // @Syscall(Mbind, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise), TODO(b/117792295) - 237: syscalls.CapError(linux.CAP_SYS_NICE), // may require cap_sys_nice + 237: Mbind, 238: SetMempolicy, 239: GetMempolicy, // 240: @Syscall(MqOpen), TODO(b/29354921) diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go new file mode 100644 index 000000000..652b2c206 --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go @@ -0,0 +1,312 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "fmt" + + "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" + "gvisor.googlesource.com/gvisor/pkg/syserror" +) + +// We unconditionally report a single NUMA node. This also means that our +// "nodemask_t" is a single unsigned long (uint64). +const ( + maxNodes = 1 + allowedNodemask = (1 << maxNodes) - 1 +) + +func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64, error) { + // "nodemask points to a bit mask of node IDs that contains up to maxnode + // bits. The bit mask size is rounded to the next multiple of + // sizeof(unsigned long), but the kernel will use bits only up to maxnode. + // A NULL value of nodemask or a maxnode value of zero specifies the empty + // set of nodes. If the value of maxnode is zero, the nodemask argument is + // ignored." - set_mempolicy(2). Unfortunately, most of this is inaccurate + // because of what appears to be a bug: mm/mempolicy.c:get_nodes() uses + // maxnode-1, not maxnode, as the number of bits. + bits := maxnode - 1 + if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0 + return 0, syserror.EINVAL + } + if bits == 0 { + return 0, nil + } + // Copy in the whole nodemask. + numUint64 := (bits + 63) / 64 + buf := t.CopyScratchBuffer(int(numUint64) * 8) + if _, err := t.CopyInBytes(addr, buf); err != nil { + return 0, err + } + val := usermem.ByteOrder.Uint64(buf) + // Check that only allowed bits in the first unsigned long in the nodemask + // are set. + if val&^allowedNodemask != 0 { + return 0, syserror.EINVAL + } + // Check that all remaining bits in the nodemask are 0. + for i := 8; i < len(buf); i++ { + if buf[i] != 0 { + return 0, syserror.EINVAL + } + } + return val, nil +} + +func copyOutNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32, val uint64) error { + // mm/mempolicy.c:copy_nodes_to_user() also uses maxnode-1 as the number of + // bits. + bits := maxnode - 1 + if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0 + return syserror.EINVAL + } + if bits == 0 { + return nil + } + // Copy out the first unsigned long in the nodemask. + buf := t.CopyScratchBuffer(8) + usermem.ByteOrder.PutUint64(buf, val) + if _, err := t.CopyOutBytes(addr, buf); err != nil { + return err + } + // Zero out remaining unsigned longs in the nodemask. + if bits > 64 { + remAddr, ok := addr.AddLength(8) + if !ok { + return syserror.EFAULT + } + remUint64 := (bits - 1) / 64 + if _, err := t.MemoryManager().ZeroOut(t, remAddr, int64(remUint64)*8, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return err + } + } + return nil +} + +// GetMempolicy implements the syscall get_mempolicy(2). +func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + mode := args[0].Pointer() + nodemask := args[1].Pointer() + maxnode := args[2].Uint() + addr := args[3].Pointer() + flags := args[4].Uint() + + if flags&^(linux.MPOL_F_NODE|linux.MPOL_F_ADDR|linux.MPOL_F_MEMS_ALLOWED) != 0 { + return 0, nil, syserror.EINVAL + } + nodeFlag := flags&linux.MPOL_F_NODE != 0 + addrFlag := flags&linux.MPOL_F_ADDR != 0 + memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0 + + // "EINVAL: The value specified by maxnode is less than the number of node + // IDs supported by the system." - get_mempolicy(2) + if nodemask != 0 && maxnode < maxNodes { + return 0, nil, syserror.EINVAL + } + + // "If flags specifies MPOL_F_MEMS_ALLOWED [...], the mode argument is + // ignored and the set of nodes (memories) that the thread is allowed to + // specify in subsequent calls to mbind(2) or set_mempolicy(2) (in the + // absence of any mode flags) is returned in nodemask." + if memsAllowed { + // "It is not permitted to combine MPOL_F_MEMS_ALLOWED with either + // MPOL_F_ADDR or MPOL_F_NODE." + if nodeFlag || addrFlag { + return 0, nil, syserror.EINVAL + } + if err := copyOutNodemask(t, nodemask, maxnode, allowedNodemask); err != nil { + return 0, nil, err + } + return 0, nil, nil + } + + // "If flags specifies MPOL_F_ADDR, then information is returned about the + // policy governing the memory address given in addr. ... If the mode + // argument is not NULL, then get_mempolicy() will store the policy mode + // and any optional mode flags of the requested NUMA policy in the location + // pointed to by this argument. If nodemask is not NULL, then the nodemask + // associated with the policy will be stored in the location pointed to by + // this argument." + if addrFlag { + policy, nodemaskVal, err := t.MemoryManager().NumaPolicy(addr) + if err != nil { + return 0, nil, err + } + if nodeFlag { + // "If flags specifies both MPOL_F_NODE and MPOL_F_ADDR, + // get_mempolicy() will return the node ID of the node on which the + // address addr is allocated into the location pointed to by mode. + // If no page has yet been allocated for the specified address, + // get_mempolicy() will allocate a page as if the thread had + // performed a read (load) access to that address, and return the + // ID of the node where that page was allocated." + buf := t.CopyScratchBuffer(1) + _, err := t.CopyInBytes(addr, buf) + if err != nil { + return 0, nil, err + } + policy = 0 // maxNodes == 1 + } + if mode != 0 { + if _, err := t.CopyOut(mode, policy); err != nil { + return 0, nil, err + } + } + if nodemask != 0 { + if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil { + return 0, nil, err + } + } + return 0, nil, nil + } + + // "EINVAL: ... flags specified MPOL_F_ADDR and addr is NULL, or flags did + // not specify MPOL_F_ADDR and addr is not NULL." This is partially + // inaccurate: if flags specifies MPOL_F_ADDR, + // mm/mempolicy.c:do_get_mempolicy() doesn't special-case NULL; it will + // just (usually) fail to find a VMA at address 0 and return EFAULT. + if addr != 0 { + return 0, nil, syserror.EINVAL + } + + // "If flags is specified as 0, then information about the calling thread's + // default policy (as set by set_mempolicy(2)) is returned, in the buffers + // pointed to by mode and nodemask. ... If flags specifies MPOL_F_NODE, but + // not MPOL_F_ADDR, and the thread's current policy is MPOL_INTERLEAVE, + // then get_mempolicy() will return in the location pointed to by a + // non-NULL mode argument, the node ID of the next node that will be used + // for interleaving of internal kernel pages allocated on behalf of the + // thread." + policy, nodemaskVal := t.NumaPolicy() + if nodeFlag { + if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE { + return 0, nil, syserror.EINVAL + } + policy = 0 // maxNodes == 1 + } + if mode != 0 { + if _, err := t.CopyOut(mode, policy); err != nil { + return 0, nil, err + } + } + if nodemask != 0 { + if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil { + return 0, nil, err + } + } + return 0, nil, nil +} + +// SetMempolicy implements the syscall set_mempolicy(2). +func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + modeWithFlags := args[0].Int() + nodemask := args[1].Pointer() + maxnode := args[2].Uint() + + modeWithFlags, nodemaskVal, err := copyInMempolicyNodemask(t, modeWithFlags, nodemask, maxnode) + if err != nil { + return 0, nil, err + } + + t.SetNumaPolicy(modeWithFlags, nodemaskVal) + return 0, nil, nil +} + +// Mbind implements the syscall mbind(2). +func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + length := args[1].Uint64() + mode := args[2].Int() + nodemask := args[3].Pointer() + maxnode := args[4].Uint() + flags := args[5].Uint() + + if flags&^linux.MPOL_MF_VALID != 0 { + return 0, nil, syserror.EINVAL + } + // "If MPOL_MF_MOVE_ALL is passed in flags ... [the] calling thread must be + // privileged (CAP_SYS_NICE) to use this flag." - mbind(2) + if flags&linux.MPOL_MF_MOVE_ALL != 0 && !t.HasCapability(linux.CAP_SYS_NICE) { + return 0, nil, syserror.EPERM + } + + mode, nodemaskVal, err := copyInMempolicyNodemask(t, mode, nodemask, maxnode) + if err != nil { + return 0, nil, err + } + + // Since we claim to have only a single node, all flags can be ignored + // (since all pages must already be on that single node). + err = t.MemoryManager().SetNumaPolicy(addr, length, mode, nodemaskVal) + return 0, nil, err +} + +func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags int32, nodemask usermem.Addr, maxnode uint32) (int32, uint64, error) { + flags := modeWithFlags & linux.MPOL_MODE_FLAGS + mode := modeWithFlags &^ linux.MPOL_MODE_FLAGS + if flags == linux.MPOL_MODE_FLAGS { + // Can't specify both mode flags simultaneously. + return 0, 0, syserror.EINVAL + } + if mode < 0 || mode >= linux.MPOL_MAX { + // Must specify a valid mode. + return 0, 0, syserror.EINVAL + } + + var nodemaskVal uint64 + if nodemask != 0 { + var err error + nodemaskVal, err = copyInNodemask(t, nodemask, maxnode) + if err != nil { + return 0, 0, err + } + } + + switch mode { + case linux.MPOL_DEFAULT: + // "nodemask must be specified as NULL." - set_mempolicy(2). This is inaccurate; + // Linux allows a nodemask to be specified, as long as it is empty. + if nodemaskVal != 0 { + return 0, 0, syserror.EINVAL + } + case linux.MPOL_BIND, linux.MPOL_INTERLEAVE: + // These require a non-empty nodemask. + if nodemaskVal == 0 { + return 0, 0, syserror.EINVAL + } + case linux.MPOL_PREFERRED: + // This permits an empty nodemask, as long as no flags are set. + if nodemaskVal == 0 && flags != 0 { + return 0, 0, syserror.EINVAL + } + case linux.MPOL_LOCAL: + // This requires an empty nodemask and no flags set ... + if nodemaskVal != 0 || flags != 0 { + return 0, 0, syserror.EINVAL + } + // ... and is implemented as MPOL_PREFERRED. + mode = linux.MPOL_PREFERRED + default: + // Unknown mode, which we should have rejected above. + panic(fmt.Sprintf("unknown mode: %v", mode)) + } + + return mode | flags, nodemaskVal, nil +} diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 64a6e639c..9926f0ac5 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -204,151 +204,6 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } } -func copyOutIfNotNull(t *kernel.Task, ptr usermem.Addr, val interface{}) (int, error) { - if ptr != 0 { - return t.CopyOut(ptr, val) - } - return 0, nil -} - -// GetMempolicy implements the syscall get_mempolicy(2). -func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - mode := args[0].Pointer() - nodemask := args[1].Pointer() - maxnode := args[2].Uint() - addr := args[3].Pointer() - flags := args[4].Uint() - - memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0 - nodeFlag := flags&linux.MPOL_F_NODE != 0 - addrFlag := flags&linux.MPOL_F_ADDR != 0 - - // TODO(rahat): Once sysfs is implemented, report a single numa node in - // /sys/devices/system/node. - if nodemask != 0 && maxnode < 1 { - return 0, nil, syserror.EINVAL - } - - // 'addr' provided iff 'addrFlag' set. - if addrFlag == (addr == 0) { - return 0, nil, syserror.EINVAL - } - - // Default policy for the thread. - if flags == 0 { - policy, nodemaskVal := t.NumaPolicy() - if _, err := copyOutIfNotNull(t, mode, policy); err != nil { - return 0, nil, syserror.EFAULT - } - if _, err := copyOutIfNotNull(t, nodemask, nodemaskVal); err != nil { - return 0, nil, syserror.EFAULT - } - return 0, nil, nil - } - - // Report all nodes available to caller. - if memsAllowed { - // MPOL_F_NODE and MPOL_F_ADDR not allowed with MPOL_F_MEMS_ALLOWED. - if nodeFlag || addrFlag { - return 0, nil, syserror.EINVAL - } - - // Report a single numa node. - if _, err := copyOutIfNotNull(t, nodemask, uint32(0x1)); err != nil { - return 0, nil, syserror.EFAULT - } - return 0, nil, nil - } - - if addrFlag { - if nodeFlag { - // Return the id for the node where 'addr' resides, via 'mode'. - // - // The real get_mempolicy(2) allocates the page referenced by 'addr' - // by simulating a read, if it is unallocated before the call. It - // then returns the node the page is allocated on through the mode - // pointer. - b := t.CopyScratchBuffer(1) - _, err := t.CopyInBytes(addr, b) - if err != nil { - return 0, nil, syserror.EFAULT - } - if _, err := copyOutIfNotNull(t, mode, int32(0)); err != nil { - return 0, nil, syserror.EFAULT - } - } else { - storedPolicy, _ := t.NumaPolicy() - // Return the policy governing the memory referenced by 'addr'. - if _, err := copyOutIfNotNull(t, mode, int32(storedPolicy)); err != nil { - return 0, nil, syserror.EFAULT - } - } - return 0, nil, nil - } - - storedPolicy, _ := t.NumaPolicy() - if nodeFlag && (storedPolicy&^linux.MPOL_MODE_FLAGS == linux.MPOL_INTERLEAVE) { - // Policy for current thread is to interleave memory between - // nodes. Return the next node we'll allocate on. Since we only have a - // single node, this is always node 0. - if _, err := copyOutIfNotNull(t, mode, int32(0)); err != nil { - return 0, nil, syserror.EFAULT - } - return 0, nil, nil - } - - return 0, nil, syserror.EINVAL -} - -func allowedNodesMask() uint32 { - const maxNodes = 1 - return ^uint32((1 << maxNodes) - 1) -} - -// SetMempolicy implements the syscall set_mempolicy(2). -func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - modeWithFlags := args[0].Int() - nodemask := args[1].Pointer() - maxnode := args[2].Uint() - - if nodemask != 0 && maxnode < 1 { - return 0, nil, syserror.EINVAL - } - - if modeWithFlags&linux.MPOL_MODE_FLAGS == linux.MPOL_MODE_FLAGS { - // Can't specify multiple modes simultaneously. - return 0, nil, syserror.EINVAL - } - - mode := modeWithFlags &^ linux.MPOL_MODE_FLAGS - if mode < 0 || mode >= linux.MPOL_MAX { - // Must specify a valid mode. - return 0, nil, syserror.EINVAL - } - - var nodemaskVal uint32 - // Nodemask may be empty for some policy modes. - if nodemask != 0 && maxnode > 0 { - if _, err := t.CopyIn(nodemask, &nodemaskVal); err != nil { - return 0, nil, syserror.EFAULT - } - } - - if (mode == linux.MPOL_INTERLEAVE || mode == linux.MPOL_BIND) && nodemaskVal == 0 { - // Mode requires a non-empty nodemask, but got an empty nodemask. - return 0, nil, syserror.EINVAL - } - - if nodemaskVal&allowedNodesMask() != 0 { - // Invalid node specified. - return 0, nil, syserror.EINVAL - } - - t.SetNumaPolicy(int32(modeWithFlags), nodemaskVal) - - return 0, nil, nil -} - // Mincore implements the syscall mincore(2). func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 0cb7b47b6..9bafc6e4f 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -999,6 +999,7 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:cleanup", + "//test/util:memory_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", diff --git a/test/syscalls/linux/mempolicy.cc b/test/syscalls/linux/mempolicy.cc index 4ac4cb88f..9d5f47651 100644 --- a/test/syscalls/linux/mempolicy.cc +++ b/test/syscalls/linux/mempolicy.cc @@ -18,6 +18,7 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "test/util/cleanup.h" +#include "test/util/memory_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -34,7 +35,7 @@ namespace { #define MPOL_PREFERRED 1 #define MPOL_BIND 2 #define MPOL_INTERLEAVE 3 -#define MPOL_MAX MPOL_INTERLEAVE +#define MPOL_LOCAL 4 #define MPOL_F_NODE (1 << 0) #define MPOL_F_ADDR (1 << 1) #define MPOL_F_MEMS_ALLOWED (1 << 2) @@ -44,11 +45,17 @@ namespace { int get_mempolicy(int *policy, uint64_t *nmask, uint64_t maxnode, void *addr, int flags) { - return syscall(__NR_get_mempolicy, policy, nmask, maxnode, addr, flags); + return syscall(SYS_get_mempolicy, policy, nmask, maxnode, addr, flags); } int set_mempolicy(int mode, uint64_t *nmask, uint64_t maxnode) { - return syscall(__NR_set_mempolicy, mode, nmask, maxnode); + return syscall(SYS_set_mempolicy, mode, nmask, maxnode); +} + +int mbind(void *addr, unsigned long len, int mode, + const unsigned long *nodemask, unsigned long maxnode, + unsigned flags) { + return syscall(SYS_mbind, addr, len, mode, nodemask, maxnode, flags); } // Creates a cleanup object that resets the calling thread's mempolicy to the @@ -252,6 +259,30 @@ TEST(MempolicyTest, GetMempolicyNextInterleaveNode) { EXPECT_EQ(0, mode); } +TEST(MempolicyTest, Mbind) { + // Temporarily set the thread policy to MPOL_PREFERRED. + const auto cleanup_thread_policy = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSetMempolicy(MPOL_PREFERRED, nullptr, 0)); + + const auto mapping = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS)); + + // vmas default to MPOL_DEFAULT irrespective of the thread policy (currently + // MPOL_PREFERRED). + int mode; + ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, mapping.ptr(), MPOL_F_ADDR), + SyscallSucceeds()); + EXPECT_EQ(mode, MPOL_DEFAULT); + + // Set MPOL_PREFERRED for the vma and read it back. + ASSERT_THAT( + mbind(mapping.ptr(), mapping.len(), MPOL_PREFERRED, nullptr, 0, 0), + SyscallSucceeds()); + ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, mapping.ptr(), MPOL_F_ADDR), + SyscallSucceeds()); + EXPECT_EQ(mode, MPOL_PREFERRED); +} + } // namespace } // namespace testing -- cgit v1.2.3 From 02ab1f187cd24c67b754b004229421d189cee264 Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Thu, 6 Jun 2019 16:44:40 -0700 Subject: Copy up parent when binding UDS on overlayfs Overlayfs was expecting the parent to exist when bind(2) was called, which may not be the case. The fix is to copy the parent directory to the upper layer before binding the UDS. There is not good place to add tests for it. Syscall tests would be ideal, but it's hard to guarantee that the directory where the socket is created hasn't been touched before (and thus copied the parent to the upper layer). Added it to runsc integration tests for now. If it turns out we have lots of these kind of tests, we can consider moving them somewhere more appropriate. PiperOrigin-RevId: 251954156 --- pkg/sentry/fs/dirent.go | 2 +- pkg/sentry/fs/inode.go | 4 +-- pkg/sentry/fs/inode_overlay.go | 12 ++++----- runsc/test/integration/BUILD | 1 + runsc/test/integration/regression_test.go | 45 +++++++++++++++++++++++++++++++ 5 files changed, 55 insertions(+), 9 deletions(-) create mode 100644 runsc/test/integration/regression_test.go diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index c0bc261a2..a0a35c242 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -805,7 +805,7 @@ func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data trans var childDir *Dirent err := d.genericCreate(ctx, root, name, func() error { var e error - childDir, e = d.Inode.Bind(ctx, name, data, perms) + childDir, e = d.Inode.Bind(ctx, d, name, data, perms) if e != nil { return e } diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index aef1a1cb9..0b54c2e77 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -220,9 +220,9 @@ func (i *Inode) Rename(ctx context.Context, oldParent *Dirent, renamed *Dirent, } // Bind calls i.InodeOperations.Bind with i as the directory. -func (i *Inode) Bind(ctx context.Context, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { +func (i *Inode) Bind(ctx context.Context, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { if i.overlay != nil { - return overlayBind(ctx, i.overlay, name, data, perm) + return overlayBind(ctx, i.overlay, parent, name, data, perm) } return i.InodeOperations.Bind(ctx, i, name, data, perm) } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index cdffe173b..06506fb20 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -398,14 +398,14 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena return nil } -func overlayBind(ctx context.Context, o *overlayEntry, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { +func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) { + if err := copyUp(ctx, parent); err != nil { + return nil, err + } + o.copyMu.RLock() defer o.copyMu.RUnlock() - // We do not support doing anything exciting with sockets unless there - // is already a directory in the upper filesystem. - if o.upper == nil { - return nil, syserror.EOPNOTSUPP - } + d, err := o.upper.InodeOperations.Bind(ctx, o.upper, name, data, perm) if err != nil { return nil, err diff --git a/runsc/test/integration/BUILD b/runsc/test/integration/BUILD index 0c4e4fa80..04ed885c6 100644 --- a/runsc/test/integration/BUILD +++ b/runsc/test/integration/BUILD @@ -8,6 +8,7 @@ go_test( srcs = [ "exec_test.go", "integration_test.go", + "regression_test.go", ], embed = [":integration"], tags = [ diff --git a/runsc/test/integration/regression_test.go b/runsc/test/integration/regression_test.go new file mode 100644 index 000000000..80bae9970 --- /dev/null +++ b/runsc/test/integration/regression_test.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "strings" + "testing" + + "gvisor.googlesource.com/gvisor/runsc/test/testutil" +) + +// Test that UDS can be created using overlay when parent directory is in lower +// layer only (b/134090485). +// +// Prerequisite: the directory where the socket file is created must not have +// been open for write before bind(2) is called. +func TestBindOverlay(t *testing.T) { + if err := testutil.Pull("ubuntu:trusty"); err != nil { + t.Fatal("docker pull failed:", err) + } + d := testutil.MakeDocker("bind-overlay-test") + + cmd := "nc -l -U /var/run/sock& sleep 1 && echo foobar-asdf | nc -U /var/run/sock" + got, err := d.RunFg("ubuntu:trusty", "bash", "-c", cmd) + if err != nil { + t.Fatal("docker run failed:", err) + } + + if want := "foobar-asdf"; !strings.Contains(got, want) { + t.Fatalf("docker run output is missing %q: %s", want, got) + } + defer d.CleanUp() +} -- cgit v1.2.3 From 6a4c0065642922c157511fa2cd3feea85cb7c44b Mon Sep 17 00:00:00 2001 From: Ian Lewis Date: Thu, 6 Jun 2019 16:57:18 -0700 Subject: Add the gVisor gitter badge to the README Moves the build badge to just below the logo and adds the gitter badge next to it for consistency. PiperOrigin-RevId: 251956383 --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f0252025c..ba01ffc1e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ ![gVisor](g3doc/logo.png) +[![Status](https://storage.googleapis.com/gvisor-build-badges/build.svg)](https://storage.googleapis.com/gvisor-build-badges/build.html) +[![gVisor chat](https://badges.gitter.im/gvisor/community.png)](https://gitter.im/gvisor/community) + ## What is gVisor? **gVisor** is a user-space kernel, written in Go, that implements a substantial @@ -36,8 +39,6 @@ be found at [gvisor.dev][gvisor-dev]. ## Installing from source -[![Status](https://storage.googleapis.com/gvisor-build-badges/build.svg)](https://storage.googleapis.com/gvisor-build-badges/build.html) - gVisor currently requires x86\_64 Linux to build, though support for other architectures may become available in the future. -- cgit v1.2.3 From 315cf9a523d409dc6ddd5ce25f8f0315068ccc67 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 6 Jun 2019 16:59:21 -0700 Subject: Use common definition of SockType. SockType isn't specific to unix domain sockets, and the current definition basically mirrors the linux ABI's definition. PiperOrigin-RevId: 251956740 --- pkg/abi/linux/socket.go | 18 +++++++++++------- pkg/sentry/fs/gofer/socket.go | 11 ++++++----- pkg/sentry/fs/host/socket.go | 11 ++++++----- pkg/sentry/socket/epsocket/BUILD | 1 - pkg/sentry/socket/epsocket/epsocket.go | 11 +++++------ pkg/sentry/socket/epsocket/provider.go | 7 +++---- pkg/sentry/socket/hostinet/BUILD | 1 - pkg/sentry/socket/hostinet/socket.go | 5 ++--- pkg/sentry/socket/netlink/provider.go | 7 +++---- pkg/sentry/socket/rpcinet/BUILD | 1 - pkg/sentry/socket/rpcinet/socket.go | 9 ++++----- pkg/sentry/socket/socket.go | 8 ++++---- pkg/sentry/socket/unix/transport/connectioned.go | 20 ++++++++++---------- pkg/sentry/socket/unix/transport/connectionless.go | 4 ++-- pkg/sentry/socket/unix/transport/unix.go | 22 ++++------------------ pkg/sentry/socket/unix/unix.go | 4 ++-- pkg/sentry/strace/socket.go | 14 +++++++------- pkg/sentry/syscalls/linux/sys_socket.go | 4 ++-- 18 files changed, 71 insertions(+), 87 deletions(-) diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 44bd69df6..a714ac86d 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -102,15 +102,19 @@ const ( SOL_NETLINK = 270 ) +// A SockType is a type (as opposed to family) of sockets. These are enumerated +// below as SOCK_* constants. +type SockType int + // Socket types, from linux/net.h. const ( - SOCK_STREAM = 1 - SOCK_DGRAM = 2 - SOCK_RAW = 3 - SOCK_RDM = 4 - SOCK_SEQPACKET = 5 - SOCK_DCCP = 6 - SOCK_PACKET = 10 + SOCK_STREAM SockType = 1 + SOCK_DGRAM = 2 + SOCK_RAW = 3 + SOCK_RDM = 4 + SOCK_SEQPACKET = 5 + SOCK_DCCP = 6 + SOCK_PACKET = 10 ) // SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 7376fd76f..7ac0a421f 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -15,6 +15,7 @@ package gofer import ( + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/p9" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" @@ -61,13 +62,13 @@ type endpoint struct { path string } -func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) { +func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) { switch t { - case transport.SockStream: + case linux.SOCK_STREAM: return p9.StreamSocket, true - case transport.SockSeqpacket: + case linux.SOCK_SEQPACKET: return p9.SeqpacketSocket, true - case transport.SockDgram: + case linux.SOCK_DGRAM: return p9.DgramSocket, true } return 0, false @@ -75,7 +76,7 @@ func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) { // BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect. func (e *endpoint) BidirectionalConnect(ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error { - cf, ok := unixSockToP9(ce.Type()) + cf, ok := sockTypeToP9(ce.Type()) if !ok { return syserr.ErrConnectionRefused } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index e4ec0f62c..6423ad938 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -19,6 +19,7 @@ import ( "sync" "syscall" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/fd" "gvisor.googlesource.com/gvisor/pkg/fdnotifier" "gvisor.googlesource.com/gvisor/pkg/log" @@ -56,7 +57,7 @@ type ConnectedEndpoint struct { srfd int `state:"wait"` // stype is the type of Unix socket. - stype transport.SockType + stype linux.SockType // sndbuf is the size of the send buffer. // @@ -105,7 +106,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { return syserr.ErrInvalidEndpointState } - c.stype = transport.SockType(stype) + c.stype = linux.SockType(stype) c.sndbuf = sndbuf return nil @@ -163,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype != linux.SOCK_STREAM, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -195,7 +196,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil + return unixsocket.New(ctx, ep, e.stype != linux.SOCK_STREAM), nil } // Send implements transport.ConnectedEndpoint.Send. @@ -209,7 +210,7 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.Contro // Since stream sockets don't preserve message boundaries, we can write // only as much of the message as fits in the send buffer. - truncate := c.stype == transport.SockStream + truncate := c.stype == linux.SOCK_STREAM n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate) if n < totalLen && err == nil { diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD index 44bb97b5b..7e2679ea0 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/epsocket/BUILD @@ -32,7 +32,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f91c5127a..e1e29de35 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -44,7 +44,6 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" @@ -228,7 +227,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint - skType transport.SockType + skType linux.SockType // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` @@ -253,8 +252,8 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { - if skType == transport.SockStream { +func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { + if skType == linux.SOCK_STREAM { if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -638,7 +637,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -664,7 +663,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_TYPE: diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index ec930d8d5..e48a106ea 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -23,7 +23,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" @@ -42,7 +41,7 @@ type provider struct { // getTransportProtocol figures out transport protocol. Currently only TCP, // UDP, and ICMP are supported. -func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { +func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { switch stype { case linux.SOCK_STREAM: if protocol != 0 && protocol != syscall.IPPROTO_TCP { @@ -80,7 +79,7 @@ func getTransportProtocol(ctx context.Context, stype transport.SockType, protoco } // Socket creates a new socket object for the AF_INET or AF_INET6 family. -func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Fail right away if we don't have a stack. stack := t.NetworkContext() if stack == nil { @@ -116,7 +115,7 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int } // Pair just returns nil sockets (not supported). -func (*provider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil } diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a469af7ac..975f47bc3 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -30,7 +30,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 0d75580a3..4517951a0 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -30,7 +30,6 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" @@ -548,7 +547,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the host network stack. stack := t.NetworkContext() if stack == nil { @@ -590,7 +589,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 76cf12fd4..863edc241 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -22,7 +22,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" ) @@ -66,10 +65,10 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Netlink sockets must be specified as datagram or raw, but they // behave the same regardless of type. - if stype != transport.SockDgram && stype != transport.SockRaw { + if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW { return nil, syserr.ErrSocketNotSupported } @@ -94,7 +93,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol } // Pair implements socket.Provider.Pair by returning an error. -func (*socketProvider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { // Netlink sockets never supports creating socket pairs. return nil, nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 4da14a1e0..33ba20de7 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/sentry/socket/hostinet", "//pkg/sentry/socket/rpcinet/conn", "//pkg/sentry/socket/rpcinet/notifier", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index bf42bdf69..2d5b5b58f 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -32,7 +32,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier" pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" @@ -70,7 +69,7 @@ type socketOperations struct { var _ = socket.Socket(&socketOperations{}) // New creates a new RPC socket. -func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, protocol int) (*fs.File, *syserr.Error) { +func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) { id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */) <-c @@ -841,7 +840,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the RPC network stack. stack := t.NetworkContext() if stack == nil { @@ -857,7 +856,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p // // Try to restrict the flags we will accept to minimize backwards // incompatibility with netstack. - stype := int(stypeflags) & linux.SOCK_TYPE_MASK + stype := stypeflags & linux.SOCK_TYPE_MASK switch stype { case syscall.SOCK_STREAM: switch protocol { @@ -881,7 +880,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index a99423365..f1021ec67 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -130,12 +130,12 @@ type Provider interface { // If a nil Socket _and_ a nil error is returned, it means that the // protocol is not supported. A non-nil error should only be returned // if the protocol is supported, but an error occurs during creation. - Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) + Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) // Pair creates a pair of connected sockets. // // See Socket for error information. - Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) + Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) } // families holds a map of all known address families and their providers. @@ -149,7 +149,7 @@ func RegisterProvider(family int, provider Provider) { } // New creates a new socket with the given family, type and protocol. -func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { for _, p := range families[family] { s, err := p.Socket(t, stype, protocol) if err != nil { @@ -166,7 +166,7 @@ func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*f // Pair creates a new connected socket pair with the given family, type and // protocol. -func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { providers, ok := families[family] if !ok { return nil, nil, syserr.ErrAddressFamilyNotSupported diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 9c8ec0365..db79ac904 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -45,7 +45,7 @@ type ConnectingEndpoint interface { // Type returns the socket type, typically either SockStream or // SockSeqpacket. The connection attempt must be aborted if this // value doesn't match the ConnectableEndpoint's type. - Type() SockType + Type() linux.SockType // GetLocalAddress returns the bound path. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) @@ -101,7 +101,7 @@ type connectionedEndpoint struct { // stype is used by connecting sockets to ensure that they are the // same type. The value is typically either tcpip.SockSeqpacket or // tcpip.SockStream. - stype SockType + stype linux.SockType // acceptedChan is per the TCP endpoint implementation. Note that the // sockets in this channel are _already in the connected state_, and @@ -112,7 +112,7 @@ type connectionedEndpoint struct { } // NewConnectioned creates a new unbound connectionedEndpoint. -func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { +func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -122,7 +122,7 @@ func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. -func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { +func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { a := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -139,7 +139,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} - if stype == SockStream { + if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}} } else { @@ -163,7 +163,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. -func NewExternal(stype SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { +func NewExternal(stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), @@ -178,7 +178,7 @@ func (e *connectionedEndpoint) ID() uint64 { } // Type implements ConnectingEndpoint.Type and Endpoint.Type. -func (e *connectionedEndpoint) Type() SockType { +func (e *connectionedEndpoint) Type() linux.SockType { return e.stype } @@ -294,7 +294,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} - if e.stype == SockStream { + if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { ne.receiver = &queueReceiver{readQueue: writeQueue} @@ -309,7 +309,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur writeQueue: writeQueue, } readQueue.IncRef() - if e.stype == SockStream { + if e.stype == linux.SOCK_STREAM { returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) } else { returnConnect(&queueReceiver{readQueue: readQueue}, connected) @@ -429,7 +429,7 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { // Stream sockets do not support specifying the endpoint. Seqpacket // sockets ignore the passed endpoint. - if e.stype == SockStream && to != nil { + if e.stype == linux.SOCK_STREAM && to != nil { return 0, syserr.ErrNotSupported } return e.baseEndpoint.SendMsg(data, c, to) diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index c034cf984..81ebfba10 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -119,8 +119,8 @@ func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to Bo } // Type implements Endpoint.Type. -func (e *connectionlessEndpoint) Type() SockType { - return SockDgram +func (e *connectionlessEndpoint) Type() linux.SockType { + return linux.SOCK_DGRAM } // Connect attempts to connect directly to server. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 5fc09af55..5c55c529e 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -19,6 +19,7 @@ import ( "sync" "sync/atomic" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" @@ -28,21 +29,6 @@ import ( // initialLimit is the starting limit for the socket buffers. const initialLimit = 16 * 1024 -// A SockType is a type (as opposed to family) of sockets. These are enumerated -// in the syscall package as syscall.SOCK_* constants. -type SockType int - -const ( - // SockStream corresponds to syscall.SOCK_STREAM. - SockStream SockType = 1 - // SockDgram corresponds to syscall.SOCK_DGRAM. - SockDgram SockType = 2 - // SockRaw corresponds to syscall.SOCK_RAW. - SockRaw SockType = 3 - // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET. - SockSeqpacket SockType = 5 -) - // A RightsControlMessage is a control message containing FDs. type RightsControlMessage interface { // Clone returns a copy of the RightsControlMessage. @@ -175,7 +161,7 @@ type Endpoint interface { // Type return the socket type, typically either SockStream, SockDgram // or SockSeqpacket. - Type() SockType + Type() linux.SockType // GetLocalAddress returns the address to which the endpoint is bound. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) @@ -629,7 +615,7 @@ type connectedEndpoint struct { GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) // Type implements Endpoint.Type. - Type() SockType + Type() linux.SockType } writeQueue *queue @@ -653,7 +639,7 @@ func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, } truncate := false - if e.endpoint.Type() == SockStream { + if e.endpoint.Type() == linux.SOCK_STREAM { // Since stream sockets don't preserve message boundaries, we // can write only as much of the message as fits in the queue. truncate = true diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 375542350..56ed63e21 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -605,7 +605,7 @@ func (s *SocketOperations) State() uint32 { type provider struct{} // Socket returns a new unix domain socket. -func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check arguments. if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { return nil, syserr.ErrProtocolNotSupported @@ -631,7 +631,7 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) } // Pair creates a new pair of AF_UNIX connected sockets. -func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Check arguments. if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { return nil, nil, syserr.ErrProtocolNotSupported diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index dbe53b9a2..0b5ef84c4 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -76,13 +76,13 @@ var SocketFamily = abi.ValueSet{ // SocketType are the possible socket(2) types. var SocketType = abi.ValueSet{ - linux.SOCK_STREAM: "SOCK_STREAM", - linux.SOCK_DGRAM: "SOCK_DGRAM", - linux.SOCK_RAW: "SOCK_RAW", - linux.SOCK_RDM: "SOCK_RDM", - linux.SOCK_SEQPACKET: "SOCK_SEQPACKET", - linux.SOCK_DCCP: "SOCK_DCCP", - linux.SOCK_PACKET: "SOCK_PACKET", + uint64(linux.SOCK_STREAM): "SOCK_STREAM", + uint64(linux.SOCK_DGRAM): "SOCK_DGRAM", + uint64(linux.SOCK_RAW): "SOCK_RAW", + uint64(linux.SOCK_RDM): "SOCK_RDM", + uint64(linux.SOCK_SEQPACKET): "SOCK_SEQPACKET", + uint64(linux.SOCK_DCCP): "SOCK_DCCP", + uint64(linux.SOCK_PACKET): "SOCK_PACKET", } // SocketFlagSet are the possible socket(2) flags. diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 8f4dbf3bc..31295a6a9 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -188,7 +188,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Create the new socket. - s, e := socket.New(t, domain, transport.SockType(stype&0xf), protocol) + s, e := socket.New(t, domain, linux.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } @@ -227,7 +227,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Create the socket pair. - s1, s2, e := socket.Pair(t, domain, transport.SockType(stype&0xf), protocol) + s1, s2, e := socket.Pair(t, domain, linux.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } -- cgit v1.2.3 From 9ea248489b2144b4b477797ad744f500a9215dbc Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Thu, 6 Jun 2019 17:20:43 -0700 Subject: Cap initial usermem.CopyStringIn buffer size. Almost (?) all uses of CopyStringIn are via linux.copyInPath(), which passes maxlen = linux.PATH_MAX = 4096. Pre-allocating a buffer of this size is measurably inefficient in most cases: most paths will not be this long, 4 KB is a lot of bytes to zero, and as of this writing the Go runtime allocator maps only two 4 KB objects to each 8 KB span, necessitating a call to runtime.mcache.refill() on ~every other call. Limit the initial buffer size to 256 B instead, and geometrically reallocate if necessary. PiperOrigin-RevId: 251960441 --- pkg/sentry/usermem/usermem.go | 36 +++++++++++++++++++++++------------- pkg/sentry/usermem/usermem_test.go | 15 ++++++++++++++- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go index 31e4d6ada..9dde327a2 100644 --- a/pkg/sentry/usermem/usermem.go +++ b/pkg/sentry/usermem/usermem.go @@ -222,9 +222,11 @@ func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts return int(r.Addr - addr), nil } -// copyStringIncrement is the maximum number of bytes that are copied from -// virtual memory at a time by CopyStringIn. -const copyStringIncrement = 64 +// CopyStringIn tuning parameters, defined outside that function for tests. +const ( + copyStringIncrement = 64 + copyStringMaxInitBufLen = 256 +) // CopyStringIn copies a NUL-terminated string of unknown length from the // memory mapped at addr in uio and returns it as a string (not including the @@ -234,31 +236,38 @@ const copyStringIncrement = 64 // // Preconditions: As for IO.CopyFromUser. maxlen >= 0. func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) { - buf := make([]byte, maxlen) + initLen := maxlen + if initLen > copyStringMaxInitBufLen { + initLen = copyStringMaxInitBufLen + } + buf := make([]byte, initLen) var done int for done < maxlen { - start, ok := addr.AddLength(uint64(done)) - if !ok { - // Last page of kernel memory. The application can't use this - // anyway. - return stringFromImmutableBytes(buf[:done]), syserror.EFAULT - } // Read up to copyStringIncrement bytes at a time. readlen := copyStringIncrement if readlen > maxlen-done { readlen = maxlen - done } - end, ok := start.AddLength(uint64(readlen)) + end, ok := addr.AddLength(uint64(readlen)) if !ok { return stringFromImmutableBytes(buf[:done]), syserror.EFAULT } // Shorten the read to avoid crossing page boundaries, since faulting // in a page unnecessarily is expensive. This also ensures that partial // copies up to the end of application-mappable memory succeed. - if start.RoundDown() != end.RoundDown() { + if addr.RoundDown() != end.RoundDown() { end = end.RoundDown() + readlen = int(end - addr) + } + // Ensure that our buffer is large enough to accommodate the read. + if done+readlen > len(buf) { + newBufLen := len(buf) * 2 + if newBufLen > maxlen { + newBufLen = maxlen + } + buf = append(buf, make([]byte, newBufLen-len(buf))...) } - n, err := uio.CopyIn(ctx, start, buf[done:done+int(end-start)], opts) + n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts) // Look for the terminating zero byte, which may have occurred before // hitting err. for i, c := range buf[done : done+n] { @@ -270,6 +279,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt if err != nil { return stringFromImmutableBytes(buf[:done]), err } + addr = end } return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG } diff --git a/pkg/sentry/usermem/usermem_test.go b/pkg/sentry/usermem/usermem_test.go index 4a07118b7..575e5039d 100644 --- a/pkg/sentry/usermem/usermem_test.go +++ b/pkg/sentry/usermem/usermem_test.go @@ -192,6 +192,7 @@ func TestCopyObject(t *testing.T) { } func TestCopyStringInShort(t *testing.T) { + // Tests for string length <= copyStringIncrement. want := strings.Repeat("A", copyStringIncrement-2) mem := want + "\x00" if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { @@ -200,13 +201,25 @@ func TestCopyStringInShort(t *testing.T) { } func TestCopyStringInLong(t *testing.T) { - want := strings.Repeat("A", copyStringIncrement+1) + // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen + // (requiring multiple calls to IO.CopyIn()). + want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4) mem := want + "\x00" if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) } } +func TestCopyStringInVeryLong(t *testing.T) { + // Tests for string length > copyStringMaxInitBufLen (requiring buffer + // reallocation). + want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4) + mem := want + "\x00" + if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil { + t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) + } +} + func TestCopyStringInNoTerminatingZeroByte(t *testing.T) { want := strings.Repeat("A", copyStringIncrement-1) got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{}) -- cgit v1.2.3 From 2e43dcb26b4ccbc4d4f314be61806a82f073a50e Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Thu, 6 Jun 2019 17:48:53 -0700 Subject: Add alsologtostderr option When set sends log messages to the error log: sudo ./runsc --logtostderr do ls I0531 17:59:58.105064 144564 x:0] *************************** I0531 17:59:58.105087 144564 x:0] Args: [runsc --logtostderr do ls] I0531 17:59:58.105112 144564 x:0] PID: 144564 I0531 17:59:58.105125 144564 x:0] UID: 0, GID: 0 [...] PiperOrigin-RevId: 251964377 --- runsc/main.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/runsc/main.go b/runsc/main.go index 39c43507c..6f8e6e378 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -48,11 +48,12 @@ var ( // system that are not covered by the runtime spec. // Debugging flags. - debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") - logPackets = flag.Bool("log-packets", false, "enable network packet logging") - logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") - debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") - debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s") + debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") + logPackets = flag.Bool("log-packets", false, "enable network packet logging") + logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") + debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") + debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s") + alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr") // Debugging flags: strace related strace = flag.Bool("strace", false, "enable strace") @@ -228,6 +229,10 @@ func main() { e = newEmitter("text", ioutil.Discard) } + if *alsoLogToStderr { + e = log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)} + } + log.SetTarget(e) log.Infof("***************************") -- cgit v1.2.3 From c933f3eede5634bf778dfb757fb68d927a43a7a8 Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Thu, 6 Jun 2019 17:57:42 -0700 Subject: Change visibility of //pkg/sentry/time. PiperOrigin-RevId: 251965598 --- pkg/sentry/time/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index b2f8f6832..b50579a92 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -32,7 +32,7 @@ go_library( "tsc_arm64.s", ], importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/time", - visibility = ["//pkg/sentry:internal"], + visibility = ["//:sandbox"], deps = [ "//pkg/log", "//pkg/metric", -- cgit v1.2.3 From e5fb3aab122c546441c595c2135a273468c5a997 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 6 Jun 2019 22:08:49 -0700 Subject: BUILD: Use runsc to generate version This also ensures BUILD files are correctly formatted. PiperOrigin-RevId: 251990267 --- runsc/BUILD | 10 ++++++---- test/BUILD | 6 +----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/runsc/BUILD b/runsc/BUILD index af8e928c5..3d6c92e4c 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -1,6 +1,4 @@ -package( - licenses = ["notice"], # Apache 2.0 -) +package(licenses = ["notice"]) # Apache 2.0 load("@io_bazel_rules_go//go:def.bzl", "go_binary") load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_deb", "pkg_tar") @@ -84,8 +82,9 @@ pkg_tar( genrule( name = "deb-version", outs = ["version.txt"], - cmd = "cat bazel-out/volatile-status.txt | grep VERSION | sed 's/^[^0-9]*//' >$@", + cmd = "$(location :runsc) -version | head -n 1 | sed 's/^[^0-9]*//' > $@", stamp = 1, + tools = [":runsc"], ) pkg_deb( @@ -98,4 +97,7 @@ pkg_deb( package = "runsc", postinst = "debian/postinst.sh", version_file = ":version.txt", + visibility = [ + "//visibility:public", + ], ) diff --git a/test/BUILD b/test/BUILD index e99b4e501..8e1dc5228 100644 --- a/test/BUILD +++ b/test/BUILD @@ -1,8 +1,4 @@ -# gVisor is a general-purpose sandbox. - -package(licenses = ["notice"]) - -exports_files(["LICENSE"]) +package(licenses = ["notice"]) # Apache 2.0 # We need to define a bazel platform and toolchain to specify dockerPrivileged # and dockerRunAsRoot options, they are required to run tests on the RBE -- cgit v1.2.3 From 48961d27a8bcc76b3783a7cc4a4a5ebcd5532d25 Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Fri, 7 Jun 2019 14:51:18 -0700 Subject: Move //pkg/sentry/memutil to //pkg/memutil. PiperOrigin-RevId: 252124156 --- pkg/memutil/BUILD | 11 +++++++ pkg/memutil/memutil_unsafe.go | 42 +++++++++++++++++++++++++++ pkg/sentry/context/contexttest/BUILD | 2 +- pkg/sentry/context/contexttest/contexttest.go | 2 +- pkg/sentry/memutil/BUILD | 14 --------- pkg/sentry/memutil/memutil.go | 16 ---------- pkg/sentry/memutil/memutil_unsafe.go | 39 ------------------------- pkg/sentry/pgalloc/BUILD | 2 +- pkg/sentry/usage/BUILD | 2 +- pkg/sentry/usage/memory.go | 2 +- runsc/boot/BUILD | 2 +- runsc/boot/loader.go | 2 +- 12 files changed, 60 insertions(+), 76 deletions(-) create mode 100644 pkg/memutil/BUILD create mode 100644 pkg/memutil/memutil_unsafe.go delete mode 100644 pkg/sentry/memutil/BUILD delete mode 100644 pkg/sentry/memutil/memutil.go delete mode 100644 pkg/sentry/memutil/memutil_unsafe.go diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD new file mode 100644 index 000000000..71b48a972 --- /dev/null +++ b/pkg/memutil/BUILD @@ -0,0 +1,11 @@ +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "memutil", + srcs = ["memutil_unsafe.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/memutil", + visibility = ["//visibility:public"], + deps = ["@org_golang_x_sys//unix:go_default_library"], +) diff --git a/pkg/memutil/memutil_unsafe.go b/pkg/memutil/memutil_unsafe.go new file mode 100644 index 000000000..979d942a9 --- /dev/null +++ b/pkg/memutil/memutil_unsafe.go @@ -0,0 +1,42 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +// Package memutil provides a wrapper for the memfd_create() system call. +package memutil + +import ( + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +// CreateMemFD creates a memfd file and returns the fd. +func CreateMemFD(name string, flags int) (int, error) { + p, err := syscall.BytePtrFromString(name) + if err != nil { + return -1, err + } + fd, _, e := syscall.Syscall(unix.SYS_MEMFD_CREATE, uintptr(unsafe.Pointer(p)), uintptr(flags), 0) + if e != 0 { + if e == syscall.ENOSYS { + return -1, fmt.Errorf("memfd_create(2) is not implemented. Check that you have Linux 3.17 or higher") + } + return -1, e + } + return int(fd), nil +} diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD index ce4f1e42c..d17b1bdcf 100644 --- a/pkg/sentry/context/contexttest/BUILD +++ b/pkg/sentry/context/contexttest/BUILD @@ -9,11 +9,11 @@ go_library( importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/context/contexttest", visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/memutil", "//pkg/sentry/context", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", - "//pkg/sentry/memutil", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", "//pkg/sentry/platform/ptrace", diff --git a/pkg/sentry/context/contexttest/contexttest.go b/pkg/sentry/context/contexttest/contexttest.go index 210a235d2..83da40711 100644 --- a/pkg/sentry/context/contexttest/contexttest.go +++ b/pkg/sentry/context/contexttest/contexttest.go @@ -21,11 +21,11 @@ import ( "testing" "time" + "gvisor.googlesource.com/gvisor/pkg/memutil" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/limits" - "gvisor.googlesource.com/gvisor/pkg/sentry/memutil" "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc" "gvisor.googlesource.com/gvisor/pkg/sentry/platform" "gvisor.googlesource.com/gvisor/pkg/sentry/platform/ptrace" diff --git a/pkg/sentry/memutil/BUILD b/pkg/sentry/memutil/BUILD deleted file mode 100644 index 68b03d4cc..000000000 --- a/pkg/sentry/memutil/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "memutil", - srcs = [ - "memutil.go", - "memutil_unsafe.go", - ], - importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/memutil", - visibility = ["//pkg/sentry:internal"], - deps = ["@org_golang_x_sys//unix:go_default_library"], -) diff --git a/pkg/sentry/memutil/memutil.go b/pkg/sentry/memutil/memutil.go deleted file mode 100644 index a4154c42a..000000000 --- a/pkg/sentry/memutil/memutil.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package memutil contains the utility functions for memory operations. -package memutil diff --git a/pkg/sentry/memutil/memutil_unsafe.go b/pkg/sentry/memutil/memutil_unsafe.go deleted file mode 100644 index 92eab8a26..000000000 --- a/pkg/sentry/memutil/memutil_unsafe.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memutil - -import ( - "fmt" - "syscall" - "unsafe" - - "golang.org/x/sys/unix" -) - -// CreateMemFD creates a memfd file and returns the fd. -func CreateMemFD(name string, flags int) (int, error) { - p, err := syscall.BytePtrFromString(name) - if err != nil { - return -1, err - } - fd, _, e := syscall.Syscall(unix.SYS_MEMFD_CREATE, uintptr(unsafe.Pointer(p)), uintptr(flags), 0) - if e != 0 { - if e == syscall.ENOSYS { - return -1, fmt.Errorf("memfd_create(2) is not implemented. Check that you have Linux 3.17 or higher") - } - return -1, e - } - return int(fd), nil -} diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index bbdb1f922..ca2d5ba6f 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -63,10 +63,10 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", + "//pkg/memutil", "//pkg/sentry/arch", "//pkg/sentry/context", "//pkg/sentry/hostmm", - "//pkg/sentry/memutil", "//pkg/sentry/platform", "//pkg/sentry/safemem", "//pkg/sentry/usage", diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index 09198496b..860733061 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -17,6 +17,6 @@ go_library( ], deps = [ "//pkg/bits", - "//pkg/sentry/memutil", + "//pkg/memutil", ], ) diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index c316f1597..9ed974ccb 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -22,7 +22,7 @@ import ( "syscall" "gvisor.googlesource.com/gvisor/pkg/bits" - "gvisor.googlesource.com/gvisor/pkg/sentry/memutil" + "gvisor.googlesource.com/gvisor/pkg/memutil" ) // MemoryKind represents a type of memory used by the application. diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index df9907e52..ac28c4339 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/cpuid", "//pkg/eventchannel", "//pkg/log", + "//pkg/memutil", "//pkg/rand", "//pkg/sentry/arch", "//pkg/sentry/arch:registers_go_proto", @@ -51,7 +52,6 @@ go_library( "//pkg/sentry/kernel/kdefs", "//pkg/sentry/limits", "//pkg/sentry/loader", - "//pkg/sentry/memutil", "//pkg/sentry/pgalloc", "//pkg/sentry/platform", "//pkg/sentry/platform/kvm", diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index ef4ccd0bd..42bddb2e8 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -29,6 +29,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/cpuid" "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/memutil" "gvisor.googlesource.com/gvisor/pkg/rand" "gvisor.googlesource.com/gvisor/pkg/sentry/arch" "gvisor.googlesource.com/gvisor/pkg/sentry/control" @@ -37,7 +38,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/loader" - "gvisor.googlesource.com/gvisor/pkg/sentry/memutil" "gvisor.googlesource.com/gvisor/pkg/sentry/pgalloc" "gvisor.googlesource.com/gvisor/pkg/sentry/platform" "gvisor.googlesource.com/gvisor/pkg/sentry/platform/kvm" -- cgit v1.2.3 From a00157cc0e216a9829f2659ce35c856a22aa5ba2 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Mon, 10 Jun 2019 15:16:42 -0700 Subject: Store more information in the kernel socket table. Store enough information in the kernel socket table to distinguish between different types of sockets. Previously we were only storing the socket family, but this isn't enough to classify sockets. For example, TCPv4 and UDPv4 sockets are both AF_INET, and ICMP sockets are SOCK_DGRAM sockets with a particular protocol. Instead of creating more sub-tables, flatten the socket table and provide a filtering mechanism based on the socket entry. Also generate and store a socket entry index ("sl" in linux) which allows us to output entries in a stable order from procfs. PiperOrigin-RevId: 252495895 --- pkg/sentry/fs/host/socket.go | 4 +-- pkg/sentry/fs/proc/BUILD | 1 + pkg/sentry/fs/proc/net.go | 14 ++++---- pkg/sentry/kernel/BUILD | 13 +++++++ pkg/sentry/kernel/kernel.go | 55 +++++++++++++--------------- pkg/sentry/socket/epsocket/epsocket.go | 13 +++++-- pkg/sentry/socket/epsocket/provider.go | 2 +- pkg/sentry/socket/hostinet/socket.go | 32 +++++++++++------ pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/socket.go | 12 ++++++- pkg/sentry/socket/rpcinet/socket.go | 16 +++++++-- pkg/sentry/socket/socket.go | 9 +++-- pkg/sentry/socket/unix/unix.go | 65 ++++++++++++++++++++-------------- 13 files changed, 152 insertions(+), 86 deletions(-) diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 6423ad938..305eea718 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -164,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, e.stype != linux.SOCK_STREAM, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -196,7 +196,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep, e.stype != linux.SOCK_STREAM), nil + return unixsocket.New(ctx, ep, e.stype), nil } // Send implements transport.ConnectedEndpoint.Send. diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index d19c360e0..1728fe0b5 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -45,6 +45,7 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/mm", + "//pkg/sentry/socket", "//pkg/sentry/socket/rpcinet", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 3daaa962c..034950158 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -27,6 +27,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs" "gvisor.googlesource.com/gvisor/pkg/sentry/inet" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" ) @@ -213,17 +214,18 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s fmt.Fprintf(&buf, "Num RefCount Protocol Flags Type St Inode Path\n") // Entries - for _, sref := range n.k.ListSockets(linux.AF_UNIX) { - s := sref.Get() + for _, se := range n.k.ListSockets() { + s := se.Sock.Get() if s == nil { - log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", sref) + log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock) continue } sfile := s.(*fs.File) - sops, ok := sfile.FileOperations.(*unix.SocketOperations) - if !ok { - panic(fmt.Sprintf("Found non-unix socket file in unix socket table: %+v", sfile)) + if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX { + // Not a unix socket. + continue } + sops := sfile.FileOperations.(*unix.SocketOperations) addr, err := sops.Endpoint().GetLocalAddress() if err != nil { diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 99a2fd964..04e375910 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -64,6 +64,18 @@ go_template_instance( }, ) +go_template_instance( + name = "socket_list", + out = "socket_list.go", + package = "kernel", + prefix = "socket", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*SocketEntry", + "Linker": "*SocketEntry", + }, +) + proto_library( name = "uncaught_signal_proto", srcs = ["uncaught_signal.proto"], @@ -104,6 +116,7 @@ go_library( "sessions.go", "signal.go", "signal_handlers.go", + "socket_list.go", "syscalls.go", "syscalls_state.go", "syslog.go", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 85d73ace2..f253a81d9 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -182,9 +182,13 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` - // socketTable is used to track all sockets on the system. Protected by + // sockets is the list of all network sockets the system. Protected by // extMu. - socketTable map[int]map[*refs.WeakRef]struct{} + sockets socketList + + // nextSocketEntry is the next entry number to use in sockets. Protected + // by extMu. + nextSocketEntry uint64 // deviceRegistry is used to save/restore device.SimpleDevices. deviceRegistry struct{} `state:".(*device.Registry)"` @@ -283,7 +287,6 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() - k.socketTable = make(map[int]map[*refs.WeakRef]struct{}) return nil } @@ -1137,51 +1140,43 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { }) } -// socketEntry represents a socket recorded in Kernel.socketTable. It implements +// SocketEntry represents a socket recorded in Kernel.sockets. It implements // refs.WeakRefUser for sockets stored in the socket table. // // +stateify savable -type socketEntry struct { - k *Kernel - sock *refs.WeakRef - family int +type SocketEntry struct { + socketEntry + k *Kernel + Sock *refs.WeakRef + ID uint64 // Socket table entry number. } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *socketEntry) WeakRefGone() { +func (s *SocketEntry) WeakRefGone() { s.k.extMu.Lock() - // k.socketTable is guaranteed to point to a valid socket table for s.family - // at this point, since we made sure of the fact when we created this - // socketEntry, and we never delete socket tables. - delete(s.k.socketTable[s.family], s.sock) + s.k.sockets.Remove(s) s.k.extMu.Unlock() } // RecordSocket adds a socket to the system-wide socket table for tracking. // // Precondition: Caller must hold a reference to sock. -func (k *Kernel) RecordSocket(sock *fs.File, family int) { +func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Lock() - table, ok := k.socketTable[family] - if !ok { - table = make(map[*refs.WeakRef]struct{}) - k.socketTable[family] = table - } - se := socketEntry{k: k, family: family} - se.sock = refs.NewWeakRef(sock, &se) - table[se.sock] = struct{}{} + id := k.nextSocketEntry + k.nextSocketEntry++ + s := &SocketEntry{k: k, ID: id} + s.Sock = refs.NewWeakRef(sock, s) + k.sockets.PushBack(s) k.extMu.Unlock() } -// ListSockets returns a snapshot of all sockets of a given family. -func (k *Kernel) ListSockets(family int) []*refs.WeakRef { +// ListSockets returns a snapshot of all sockets. +func (k *Kernel) ListSockets() []*SocketEntry { k.extMu.Lock() - socks := []*refs.WeakRef{} - if table, ok := k.socketTable[family]; ok { - socks = make([]*refs.WeakRef, 0, len(table)) - for s := range table { - socks = append(socks, s) - } + var socks []*SocketEntry + for s := k.sockets.Front(); s != nil; s = s.Next() { + socks = append(socks, s) } k.extMu.Unlock() return socks diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index e1e29de35..f67451179 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -228,6 +228,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint skType linux.SockType + protocol int // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` @@ -252,7 +253,7 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { return nil, syserr.TranslateNetstackError(err) @@ -266,6 +267,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, family: family, Endpoint: endpoint, skType: skType, + protocol: protocol, }), nil } @@ -550,7 +552,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns, err := New(t, s.family, s.skType, wq, ep) + ns, err := New(t, s.family, s.skType, s.protocol, wq, ep) if err != nil { return 0, nil, 0, err } @@ -578,7 +580,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(ns, s.family) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, syserr.FromError(e) } @@ -2324,3 +2326,8 @@ func (s *SocketOperations) State() uint32 { // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. return 0 } + +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.skType, s.protocol +} diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index e48a106ea..516582828 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -111,7 +111,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return nil, syserr.TranslateNetstackError(e) } - return New(t, p.family, stype, wq, ep) + return New(t, p.family, stype, protocol, wq, ep) } // Pair just returns nil sockets (not supported). diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 4517951a0..c62c8d8f1 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -56,15 +56,22 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. - fd int // must be O_NONBLOCK - queue waiter.Queue + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd int // must be O_NONBLOCK + queue waiter.Queue } var _ = socket.Socket(&socketOperations{}) -func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) { - s := &socketOperations{family: family, fd: fd} +func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { + s := &socketOperations{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + } if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } @@ -222,7 +229,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) } - f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0) + f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) if err != nil { syscall.Close(fd) return 0, nil, 0, err @@ -233,7 +240,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, } kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(f, s.family) + t.Kernel().RecordSocket(f) return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr) } @@ -542,6 +549,11 @@ func (s *socketOperations) State() uint32 { return uint32(info.State) } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -558,7 +570,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto } // Only accept TCP and UDP. - stype := int(stypeflags) & linux.SOCK_TYPE_MASK + stype := stypeflags & linux.SOCK_TYPE_MASK switch stype { case syscall.SOCK_STREAM: switch protocol { @@ -581,11 +593,11 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto // Conservatively ignore all flags specified by the application and add // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. - fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) + fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { return nil, syserr.FromError(err) } - return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) + return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 863edc241..5dc103877 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -82,7 +82,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int return nil, err } - s, err := NewSocket(t, p) + s, err := NewSocket(t, stype, p) if err != nil { return nil, err } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 16c79aa33..62659784a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -80,6 +80,10 @@ type Socket struct { // protocol is the netlink protocol implementation. protocol Protocol + // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for + // netlink sockets. + skType linux.SockType + // ep is a datagram unix endpoint used to buffer messages sent from the // kernel to userspace. RecvMsg reads messages from this endpoint. ep transport.Endpoint @@ -105,7 +109,7 @@ type Socket struct { var _ socket.Socket = (*Socket)(nil) // NewSocket creates a new Socket. -func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { +func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { // Datagram endpoint used to buffer kernel -> user messages. ep := transport.NewConnectionless() @@ -126,6 +130,7 @@ func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { return &Socket{ ports: t.Kernel().NetlinkPorts(), protocol: protocol, + skType: skType, ep: ep, connection: connection, sendBufferSize: defaultSendBufferSize, @@ -621,3 +626,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, func (s *Socket) State() uint32 { return s.ep.State() } + +// Type implements socket.Socket.Type. +func (s *Socket) Type() (family int, skType linux.SockType, protocol int) { + return linux.AF_NETLINK, s.skType, s.protocol.Protocol() +} diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 2d5b5b58f..c22ff1ff0 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -53,7 +53,10 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd uint32 // must be O_NONBLOCK wq *waiter.Queue rpcConn *conn.RPCConnection @@ -86,6 +89,8 @@ func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.S defer dirent.DecRef() return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{ family: family, + stype: skType, + protocol: protocol, wq: &wq, fd: fd, rpcConn: stack.rpcConn, @@ -332,7 +337,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, syserr.FromError(err) } - t.Kernel().RecordSocket(file, s.family) + t.Kernel().RecordSocket(file) if peerRequested { return fd, payload.Address.Address, payload.Address.Length, nil @@ -835,6 +840,11 @@ func (s *socketOperations) State() uint32 { return 0 } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -876,7 +886,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto return nil, nil } - return newSocketFile(t, s, p.family, stype, 0) + return newSocketFile(t, s, p.family, stype, protocol) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index f1021ec67..d60944b6b 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -120,6 +120,9 @@ type Socket interface { // State returns the current state of the socket, as represented by Linux in // procfs. The returned state value is protocol-specific. State() uint32 + + // Type returns the family, socket type and protocol of the socket. + Type() (family int, skType linux.SockType, protocol int) } // Provider is the interface implemented by providers of sockets for specific @@ -156,7 +159,7 @@ func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.Fi return nil, err } if s != nil { - t.Kernel().RecordSocket(s, family) + t.Kernel().RecordSocket(s) return s, nil } } @@ -179,8 +182,8 @@ func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.F } if s1 != nil && s2 != nil { k := t.Kernel() - k.RecordSocket(s1, family) - k.RecordSocket(s2, family) + k.RecordSocket(s1) + k.RecordSocket(s2) return s1, s2, nil } } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 56ed63e21..b07e8d67b 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -17,6 +17,7 @@ package unix import ( + "fmt" "strings" "syscall" @@ -55,22 +56,22 @@ type SocketOperations struct { refs.AtomicRefCount socket.SendReceiveTimeout - ep transport.Endpoint - isPacket bool + ep transport.Endpoint + stype linux.SockType } // New creates a new unix socket. -func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File { +func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) defer dirent.DecRef() - return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true}) + return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true}) } // NewWithDirent creates a new unix socket using an existing dirent. -func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File { +func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File { return fs.NewFile(ctx, d, flags, &SocketOperations{ - ep: ep, - isPacket: isPacket, + ep: ep, + stype: stype, }) } @@ -88,6 +89,18 @@ func (s *SocketOperations) Release() { s.DecRef() } +func (s *SocketOperations) isPacket() bool { + switch s.stype { + case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + return true + case linux.SOCK_STREAM: + return false + default: + // We shouldn't have allowed any other socket types during creation. + panic(fmt.Sprintf("Invalid socket type %d", s.stype)) + } +} + // Endpoint extracts the transport.Endpoint. func (s *SocketOperations) Endpoint() transport.Endpoint { return s.ep @@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns := New(t, ep, s.isPacket) + ns := New(t, ep, s.stype) defer ns.DecRef() if flags&linux.SOCK_NONBLOCK != 0 { @@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, nil, 0, syserr.FromError(e) } - t.Kernel().RecordSocket(ns, linux.AF_UNIX) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, nil } @@ -487,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 waitAll := flags&linux.MSG_WAITALL != 0 + isPacket := s.isPacket() // Calculate the number of FDs for which we have space and if we are // requesting credentials. @@ -528,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msgFlags |= linux.MSG_CTRUNC } - if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { - if s.isPacket && n < int64(r.MsgSize) { + if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } @@ -570,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags total += n } - if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { + if err != nil || !waitAll || isPacket || n >= dst.NumBytes() { if total > 0 { err = nil } - if s.isPacket && n < int64(r.MsgSize) { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) @@ -601,6 +615,12 @@ func (s *SocketOperations) State() uint32 { return s.ep.State() } +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + // Unix domain sockets always have a protocol of 0. + return linux.AF_UNIX, s.stype, 0 +} + // provider is a unix domain socket provider. type provider struct{} @@ -613,21 +633,16 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs // Create the endpoint and socket. var ep transport.Endpoint - var isPacket bool switch stype { case linux.SOCK_DGRAM: - isPacket = true ep = transport.NewConnectionless() - case linux.SOCK_SEQPACKET: - isPacket = true - fallthrough - case linux.SOCK_STREAM: + case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: ep = transport.NewConnectioned(stype, t.Kernel()) default: return nil, syserr.ErrInvalidArgument } - return New(t, ep, isPacket), nil + return New(t, ep, stype), nil } // Pair creates a new pair of AF_UNIX connected sockets. @@ -637,19 +652,17 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F return nil, nil, syserr.ErrProtocolNotSupported } - var isPacket bool switch stype { - case linux.SOCK_STREAM: - case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: - isPacket = true + case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + // Ok default: return nil, nil, syserr.ErrInvalidArgument } // Create the endpoints and sockets. ep1, ep2 := transport.NewPair(stype, t.Kernel()) - s1 := New(t, ep1, isPacket) - s2 := New(t, ep2, isPacket) + s1 := New(t, ep1, stype) + s2 := New(t, ep2, stype) return s1, s2, nil } -- cgit v1.2.3