diff options
329 files changed, 9221 insertions, 2327 deletions
@@ -13,7 +13,7 @@ # limitations under the License. # Display the current git revision in the info block. -build --workspace_status_command tools/workspace_status.sh +build --stamp --workspace_status_command tools/workspace_status.sh # Enable remote execution so actions are performed on the remote systems. build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 638942a42..5d46168bc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -83,6 +83,8 @@ Rules: ### Code reviews +Before sending code reviews, run `bazel test ...` to ensure tests are passing. + Code changes are accepted via [pull request][github]. When approved, the change will be submitted by a team member and automatically @@ -100,6 +102,36 @@ form `b/1234`. These correspond to bugs in our internal bug tracker. Eventually these bugs will be moved to the GitHub Issues, but until then they can simply be ignored. +### Build and test with Docker + +`scripts/dev.sh` is a convenient script that builds and installs `runsc` as a +new Docker runtime for you. The scripts tries to extract the runtime name from +your local environment and will print it at the end. You can also customize it. +The script creates one regular runtime and another with debug flags enabled. +Here are a few examples: + +```bash +# Default case (inside branch my-branch) +$ scripts/dev.sh +... +Runtimes my-branch and my-branch-d (debug enabled) setup. +Use --runtime=my-branch with your Docker command. + docker run --rm --runtime=my-branch --rm hello-world + +If you rebuild, use scripts/dev.sh --refresh. +Logs are in: /tmp/my-branch/logs + +# --refresh just updates the runtime binary and doesn't restart docker. +$ git/my_branch> scripts/dev.sh --refresh + +# Using a custom runtime name +$ git/my_branch> scripts/dev.sh my-runtime +... +Runtimes my-runtime and my-runtime-d (debug enabled) setup. +Use --runtime=my-runtime with your Docker command. + docker run --rm --runtime=my-runtime --rm hello-world +``` + ### The small print Contributions made by corporations are covered by a different agreement than the @@ -48,7 +48,7 @@ Make sure the following dependencies are installed: * Linux 4.14.77+ ([older linux][old-linux]) * [git][git] -* [Bazel][bazel] 0.23.0+ +* [Bazel][bazel] 0.28.0+ * [Python][python] * [Docker version 17.09.0 or greater][docker] * Gold linker (e.g. `binutils-gold` package on Ubuntu) @@ -3,19 +3,19 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") http_archive( name = "io_bazel_rules_go", - sha256 = "313f2c7a23fecc33023563f082f381a32b9b7254f727a7dd2d6380ccc6dfe09b", + sha256 = "ae8c36ff6e565f674c7a3692d6a9ea1096e4c1ade497272c2108a810fb39acd2", urls = [ - "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/0.19.3/rules_go-0.19.3.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/0.19.3/rules_go-0.19.3.tar.gz", + "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz", ], ) http_archive( name = "bazel_gazelle", - sha256 = "be9296bfd64882e3c08e3283c58fcb461fa6dd3c171764fcc4cf322f60615a9b", + sha256 = "7fc87f4170011201b1690326e8c16c5d802836e3a0d617d8f75c3af2b23180c4", urls = [ - "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/0.18.1/bazel-gazelle-0.18.1.tar.gz", - "https://github.com/bazelbuild/bazel-gazelle/releases/download/0.18.1/bazel-gazelle-0.18.1.tar.gz", + "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/0.18.2/bazel-gazelle-0.18.2.tar.gz", + "https://github.com/bazelbuild/bazel-gazelle/releases/download/0.18.2/bazel-gazelle-0.18.2.tar.gz", ], ) @@ -24,7 +24,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to go_rules_dependencies() go_register_toolchains( - go_version = "1.12.9", + go_version = "1.13", nogo = "@//:nogo", ) @@ -44,13 +44,14 @@ http_archive( ) # Load protobuf dependencies. -load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") - -git_repository( +http_archive( name = "com_google_protobuf", - commit = "09745575a923640154bcf307fba8aedff47f240a", - remote = "https://github.com/protocolbuffers/protobuf", - shallow_since = "1558721209 -0700", + sha256 = "532d2575d8c0992065bb19ec5fba13aa3683499726f6055c11b474f91a00bb0c", + strip_prefix = "protobuf-7f520092d9050d96fb4b707ad11a51701af4ce49", + urls = [ + "https://mirror.bazel.build/github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip", + "https://github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip", + ], ) load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") @@ -61,11 +62,11 @@ protobuf_deps() # See releases at https://releases.bazel.build/bazel-toolchains.html http_archive( name = "bazel_toolchains", - sha256 = "e71eadcfcbdb47b4b740eb48b32ca4226e36aabc425d035a18dd40c2dda808c1", - strip_prefix = "bazel-toolchains-0.28.4", + sha256 = "a019fbd579ce5aed0239de865b2d8281dbb809efd537bf42e0d366783e8dec65", + strip_prefix = "bazel-toolchains-0.29.2", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/0.28.4.tar.gz", - "https://github.com/bazelbuild/bazel-toolchains/archive/0.28.4.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz", ], ) @@ -195,7 +196,7 @@ go_repository( go_repository( name = "org_golang_x_time", - commit = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef", + commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e", importpath = "golang.org/x/time", ) @@ -221,16 +222,6 @@ go_repository( # System Call test dependencies. http_archive( - name = "com_github_gflags_gflags", - sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf", - strip_prefix = "gflags-2.2.2", - urls = [ - "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.2.tar.gz", - "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz", - ], -) - -http_archive( name = "com_google_absl", sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93", strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e", @@ -242,10 +233,10 @@ http_archive( http_archive( name = "com_google_googletest", - sha256 = "db657310d3c5ca2d3f674e3a4b79718d1d39da70604568ee0568ba8e39065ef4", - strip_prefix = "googletest-31200def0dec8a624c861f919e86e4444e6e6ee7", + sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912", + strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34", urls = [ - "https://mirror.bazel.build/github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz", - "https://github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz", + "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", + "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", ], ) diff --git a/cloudbuild/go.Dockerfile b/cloudbuild/go.Dockerfile deleted file mode 100644 index 226442fd2..000000000 --- a/cloudbuild/go.Dockerfile +++ /dev/null @@ -1,2 +0,0 @@ -FROM ubuntu -RUN apt-get -q update && apt-get install -qqy git rsync diff --git a/cloudbuild/go.yaml b/cloudbuild/go.yaml deleted file mode 100644 index a38ef71fc..000000000 --- a/cloudbuild/go.yaml +++ /dev/null @@ -1,22 +0,0 @@ -steps: -- name: 'gcr.io/cloud-builders/git' - args: ['fetch', '--all', '--unshallow'] -- name: 'gcr.io/cloud-builders/bazel' - args: ['build', ':gopath'] -- name: 'gcr.io/cloud-builders/docker' - args: ['build', '-t', 'gcr.io/$PROJECT_ID/go-branch', '-f', 'cloudbuild/go.Dockerfile', '.'] -- name: 'gcr.io/$PROJECT_ID/go-branch' - args: ['tools/go_branch.sh'] -- name: 'gcr.io/cloud-builders/git' - args: ['checkout', 'go'] -- name: 'gcr.io/cloud-builders/git' - args: ['clean', '-f'] -- name: 'golang' - args: ['go', 'build', './...'] -- name: 'gcr.io/cloud-builders/git' - entrypoint: 'bash' - args: - - '-c' - - 'if [[ "$BRANCH_NAME" == "master" ]]; then git push "${_ORIGIN}" go:go; fi' -substitutions: - _ORIGIN: origin diff --git a/kokoro/build.cfg b/kokoro/build.cfg index d67af4694..084347dde 100644 --- a/kokoro/build.cfg +++ b/kokoro/build.cfg @@ -11,13 +11,12 @@ before_action { env_vars { key: "KOKORO_REPO_KEY" - value: "$KOKORO_ROOT/src/keystore/73898_kokoro-repo-key" + value: "73898_kokoro-repo-key" } action { define_artifacts { - regex: "**/runsc" - regex: "**/runsc.sha256" - regex: "**/repo/**" + regex: "**/runsc.*" + regex: "**/dists/**" } } diff --git a/kokoro/continuous.cfg b/kokoro/continuous.cfg deleted file mode 100644 index 88694220a..000000000 --- a/kokoro/continuous.cfg +++ /dev/null @@ -1,11 +0,0 @@ -# This is a temporary file. It will be removed when new Kokoro jobs exist for -# all the other presubmits. -build_file: "repo/scripts/build.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - } -} diff --git a/kokoro/go.cfg b/kokoro/go.cfg index d1577252a..b9c1fcb12 100644 --- a/kokoro/go.cfg +++ b/kokoro/go.cfg @@ -1,5 +1,19 @@ build_file: "repo/scripts/go.sh" +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73898 + keyname: "kokoro-github-access-token" + } + } +} + +env_vars { + key: "KOKORO_GITHUB_ACCESS_TOKEN" + value: "73898_kokoro-github-access-token" +} + env_vars { key: "KOKORO_GO_PUSH" value: "true" diff --git a/kokoro/go_test.cfg b/kokoro/go_tests.cfg index 5eb51041a..5eb51041a 100644 --- a/kokoro/go_test.cfg +++ b/kokoro/go_tests.cfg diff --git a/kokoro/presubmit.cfg b/kokoro/presubmit.cfg deleted file mode 100644 index eb0c78ea4..000000000 --- a/kokoro/presubmit.cfg +++ /dev/null @@ -1,11 +0,0 @@ -# This is a temporary file. It will be removed when new Kokoro jobs exist for -# all the other presubmits. -build_file: "repo/kokoro/run_tests.sh" - -action { - define_artifacts { - regex: "**/sponge_log.xml" - regex: "**/sponge_log.log" - regex: "**/outputs.zip" - } -} diff --git a/kokoro/release-nightly.cfg b/kokoro/release-nightly.cfg deleted file mode 100644 index ae134258c..000000000 --- a/kokoro/release-nightly.cfg +++ /dev/null @@ -1,10 +0,0 @@ -# This file is a temporary bridge. It will be removed shortly, when Kokoro jobs -# are configured to point at the new build and release configurations. -build_file: "repo/kokoro/run_build.sh" - -action { - define_artifacts { - regex: "**/runsc" - regex: "**/runsc.sha512" - } -} diff --git a/kokoro/run_build.sh b/kokoro/run_build.sh deleted file mode 100755 index da6a0c85e..000000000 --- a/kokoro/run_build.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This file is a temporary bridge. We will create multiple independent Kokoro -# workflows that call each of the build scripts independently. -KOKORO_BUILD_NIGHTLY=true $(dirname $0)/../scripts/build.sh diff --git a/kokoro/ubuntu1604/20_bazel.sh b/kokoro/ubuntu1604/20_bazel.sh index 74b4b8be2..b9a894024 100755 --- a/kokoro/ubuntu1604/20_bazel.sh +++ b/kokoro/ubuntu1604/20_bazel.sh @@ -16,9 +16,7 @@ set -xeo pipefail -# We need to install a specific version of bazel due to a bug with the RBE -# environment not respecting the dockerPrivileged configuration. -declare -r BAZEL_VERSION=0.28.1 +declare -r BAZEL_VERSION=0.29.1 # Install bazel dependencies. apt-get update && apt-get install -y openjdk-8-jdk-headless unzip diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index ba233b93f..f45934466 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -2,9 +2,11 @@ # Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix # when the host OS may not be Linux. +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "linux", @@ -44,6 +46,7 @@ go_library( "sem.go", "shm.go", "signal.go", + "signalfd.go", "socket.go", "splice.go", "tcp.go", diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go new file mode 100644 index 000000000..85fad9956 --- /dev/null +++ b/pkg/abi/linux/signalfd.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +const ( + // SFD_NONBLOCK is a signalfd(2) flag. + SFD_NONBLOCK = 00004000 + + // SFD_CLOEXEC is a signalfd(2) flag. + SFD_CLOEXEC = 02000000 +) + +// SignalfdSiginfo is the siginfo encoding for signalfds. +type SignalfdSiginfo struct { + Signo uint32 + Errno int32 + Code int32 + PID uint32 + UID uint32 + FD int32 + TID uint32 + Band uint32 + Overrun uint32 + TrapNo uint32 + Status int32 + Int int32 + Ptr uint64 + UTime uint64 + STime uint64 + Addr uint64 + AddrLSB uint16 + _ [48]uint8 +} diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 39d253b98..6bc486b62 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 47ab65346..5f59866fa 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD index 09d6c2c1f..543fb54bf 100644 --- a/pkg/binary/BUILD +++ b/pkg/binary/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 0c2dde4f8..51967b811 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index b692aa3b1..8d31e068c 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "bpf", diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index cdec96df1..a0b21d4bd 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD index 830e19e07..32422f9e2 100644 --- a/pkg/cpuid/BUILD +++ b/pkg/cpuid/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "cpuid", diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 9961baaa9..71f2abc83 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index 785c685a0..afa8f7659 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index e54e7371c..56495cbd9 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index bd1d614b6..5643d5f26 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -18,6 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", + "//third_party/gvsync", ], ) diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index d59159912..8390915a2 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -82,6 +82,7 @@ func (ep *Endpoint) ctrlWaitFirst() error { *ep.dataLen() = w.Len() // Return control to the client. + raceBecomeInactive() if err := ep.futexSwitchToPeer(); err != nil { return err } diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 991018684..386cee42c 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -180,7 +180,11 @@ const ( // Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(), // ep.SendRecv(), and ep.SendLast() have never been called. func (ep *Endpoint) Connect() error { - return ep.ctrlConnect() + err := ep.ctrlConnect() + if err == nil { + raceBecomeActive() + } + return err } // RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then @@ -192,6 +196,7 @@ func (ep *Endpoint) RecvFirst() (uint32, error) { if err := ep.ctrlWaitFirst(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -218,9 +223,11 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { // after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then // they can only shoot themselves in the foot. *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlRoundTrip(); err != nil { return 0, err } + raceBecomeActive() recvDataLen := atomic.LoadUint32(ep.dataLen()) if recvDataLen > ep.dataCap { return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap) @@ -240,6 +247,7 @@ func (ep *Endpoint) SendLast(dataLen uint32) error { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) } *ep.dataLen() = dataLen + raceBecomeInactive() if err := ep.ctrlWakeLast(); err != nil { return err } diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index 435e4eeae..168a487ec 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -62,6 +62,9 @@ func (c *testConnection) destroy() { } func testSendRecv(t *testing.T, c *testConnection) { + // This shared variable is used to confirm that synchronization between + // flipcall endpoints is visible to the Go race detector. + state := 0 var serverRun sync.WaitGroup serverRun.Add(1) go func() { @@ -71,11 +74,19 @@ func testSendRecv(t *testing.T, c *testConnection) { t.Errorf("server Endpoint.RecvFirst() failed: %v", err) return } + state++ + if state != 2 { + t.Errorf("shared state counter: got %d, wanted 2", state) + } t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3") if _, err := c.serverEP.SendRecv(0); err != nil { t.Errorf("server Endpoint.SendRecv() failed: %v", err) return } + state++ + if state != 4 { + t.Errorf("shared state counter: got %d, wanted 4", state) + } t.Logf("server Endpoint got packet 3") }() defer func() { @@ -89,10 +100,18 @@ func testSendRecv(t *testing.T, c *testConnection) { if err := c.clientEP.Connect(); err != nil { t.Fatalf("client Endpoint.Connect() failed: %v", err) } + state++ + if state != 1 { + t.Errorf("shared state counter: got %d, wanted 1", state) + } t.Logf("client Endpoint sending packet 1 and waiting for packet 2") if _, err := c.clientEP.SendRecv(0); err != nil { t.Fatalf("client Endpoint.SendRecv() failed: %v", err) } + state++ + if state != 3 { + t.Errorf("shared state counter: got %d, wanted 3", state) + } t.Logf("client Endpoint got packet 2, sending packet 3") if err := c.clientEP.SendLast(0); err != nil { t.Fatalf("client Endpoint.SendLast() failed: %v", err) diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 73e6eef29..a37952637 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -17,6 +17,8 @@ package flipcall import ( "reflect" "unsafe" + + "gvisor.dev/gvisor/third_party/gvsync" ) // Packets consist of a 16-byte header followed by an arbitrarily-sized @@ -67,3 +69,19 @@ func (ep *Endpoint) Data() []byte { bsReflect.Cap = int(ep.dataCap) return bs } + +// ioSync is a dummy variable used to indicate synchronization to the Go race +// detector. Compare syscall.ioSync. +var ioSync int64 + +func raceBecomeActive() { + if gvsync.RaceEnabled { + gvsync.RaceAcquire((unsafe.Pointer)(&ioSync)) + } +} + +func raceBecomeInactive() { + if gvsync.RaceEnabled { + gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + } +} diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index 11716af81..0c5f50397 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index e6a8dbd02..4b9321711 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD index 8f3defa25..34d2673ef 100644 --- a/pkg/ilist/BUILD +++ b/pkg/ilist/BUILD @@ -1,5 +1,6 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index c8e923a74..a5d980d14 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index 12615240c..fc5f5779b 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 3b8a691f4..dd6ca6d39 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -1,5 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -21,6 +23,12 @@ proto_library( visibility = ["//:sandbox"], ) +cc_proto_library( + name = "metric_cc_proto", + visibility = ["//:sandbox"], + deps = [":metric_proto"], +) + go_proto_library( name = "metric_go_proto", importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto", diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index c6737bf97..f32244c69 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:public"], @@ -19,11 +20,14 @@ go_library( "pool.go", "server.go", "transport.go", + "transport_flipcall.go", "version.go", ], importpath = "gvisor.dev/gvisor/pkg/p9", deps = [ "//pkg/fd", + "//pkg/fdchannel", + "//pkg/flipcall", "//pkg/log", "//pkg/unet", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 7dc20aeef..2412aa5e1 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -20,6 +20,8 @@ import ( "sync" "syscall" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) @@ -77,6 +79,47 @@ type Client struct { // fidPool is the collection of available fids. fidPool pool + // messageSize is the maximum total size of a message. + messageSize uint32 + + // payloadSize is the maximum payload size of a read or write. + // + // For large reads and writes this means that the read or write is + // broken up into buffer-size/payloadSize requests. + payloadSize uint32 + + // version is the agreed upon version X of 9P2000.L.Google.X. + // version 0 implies 9P2000.L. + version uint32 + + // closedWg is marked as done when the Client.watch() goroutine, which is + // responsible for closing channels and the socket fd, returns. + closedWg sync.WaitGroup + + // sendRecv is the transport function. + // + // This is determined dynamically based on whether or not the server + // supports flipcall channels (preferred as it is faster and more + // efficient, and does not require tags). + sendRecv func(message, message) error + + // -- below corresponds to sendRecvChannel -- + + // channelsMu protects channels. + channelsMu sync.Mutex + + // channelsWg counts the number of channels for which channel.active == + // true. + channelsWg sync.WaitGroup + + // channels is the set of all initialized channels. + channels []*channel + + // availableChannels is a FIFO of inactive channels. + availableChannels []*channel + + // -- below corresponds to sendRecvLegacy -- + // pending is the set of pending messages. pending map[Tag]*response pendingMu sync.Mutex @@ -89,25 +132,12 @@ type Client struct { // Whoever writes to this channel is permitted to call recv. When // finished calling recv, this channel should be emptied. recvr chan bool - - // messageSize is the maximum total size of a message. - messageSize uint32 - - // payloadSize is the maximum payload size of a read or write - // request. For large reads and writes this means that the - // read or write is broken up into buffer-size/payloadSize - // requests. - payloadSize uint32 - - // version is the agreed upon version X of 9P2000.L.Google.X. - // version 0 implies 9P2000.L. - version uint32 } // NewClient creates a new client. It performs a Tversion exchange with // the server to assert that messageSize is ok to use. // -// You should not use the same socket for multiple clients. +// If NewClient succeeds, ownership of socket is transferred to the new Client. func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { // Need at least one byte of payload. if messageSize <= msgRegistry.largestFixedSize { @@ -138,8 +168,15 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client return nil, ErrBadVersionString } for { + // Always exchange the version using the legacy version of the + // protocol. If the protocol supports flipcall, then we switch + // our sendRecv function to use that functionality. Otherwise, + // we stick to sendRecvLegacy. rversion := Rversion{} - err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion) + err := c.sendRecvLegacy(&Tversion{ + Version: versionString(requested), + MSize: messageSize, + }, &rversion) // The server told us to try again with a lower version. if err == syscall.EAGAIN { @@ -165,9 +202,155 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client c.version = version break } + + // Can we switch to use the more advanced channels and create + // independent channels for communication? Prefer it if possible. + if versionSupportsFlipcall(c.version) { + // Attempt to initialize IPC-based communication. + for i := 0; i < channelsPerClient; i++ { + if err := c.openChannel(i); err != nil { + log.Warningf("error opening flipcall channel: %v", err) + break // Stop. + } + } + if len(c.channels) >= 1 { + // At least one channel created. + c.sendRecv = c.sendRecvChannel + } else { + // Channel setup failed; fallback. + c.sendRecv = c.sendRecvLegacy + } + } else { + // No channels available: use the legacy mechanism. + c.sendRecv = c.sendRecvLegacy + } + + // Ensure that the socket and channels are closed when the socket is shut + // down. + c.closedWg.Add(1) + go c.watch(socket) // S/R-SAFE: not relevant. + return c, nil } +// watch watches the given socket and releases resources on hangup events. +// +// This is intended to be called as a goroutine. +func (c *Client) watch(socket *unet.Socket) { + defer c.closedWg.Done() + + events := []unix.PollFd{ + unix.PollFd{ + Fd: int32(socket.FD()), + Events: unix.POLLHUP | unix.POLLRDHUP, + }, + } + + // Wait for a shutdown event. + for { + n, err := unix.Ppoll(events, nil, nil) + if err == syscall.EINTR || err == syscall.EAGAIN { + continue + } + if err != nil { + log.Warningf("p9.Client.watch(): %v", err) + break + } + if n != 1 { + log.Warningf("p9.Client.watch(): got %d events, wanted 1", n) + } + break + } + + // Set availableChannels to nil so that future calls to c.sendRecvChannel() + // don't attempt to activate a channel, and concurrent calls to + // c.sendRecvChannel() don't mark released channels as available. + c.channelsMu.Lock() + c.availableChannels = nil + + // Shut down all active channels. + for _, ch := range c.channels { + if ch.active { + log.Debugf("shutting down active channel@%p...", ch) + ch.Shutdown() + } + } + c.channelsMu.Unlock() + + // Wait for active channels to become inactive. + c.channelsWg.Wait() + + // Close all channels. + c.channelsMu.Lock() + for _, ch := range c.channels { + ch.Close() + } + c.channelsMu.Unlock() + + // Close the main socket. + c.socket.Close() +} + +// openChannel attempts to open a client channel. +// +// Note that this function returns naked errors which should not be propagated +// directly to a caller. It is expected that the errors will be logged and a +// fallback path will be used instead. +func (c *Client) openChannel(id int) error { + var ( + rchannel0 Rchannel + rchannel1 Rchannel + res = new(channel) + ) + + // Open the data channel. + if err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 0, + }, &rchannel0); err != nil { + return fmt.Errorf("error handling Tchannel message: %v", err) + } + if rchannel0.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on primary channel") + } + + // We don't need to hold this. + defer rchannel0.FilePayload().Close() + + // Open the channel for file descriptors. + if err := c.sendRecvLegacy(&Tchannel{ + ID: uint32(id), + Control: 1, + }, &rchannel1); err != nil { + return err + } + if rchannel1.FilePayload() == nil { + return fmt.Errorf("missing file descriptor on file descriptor channel") + } + + // Construct the endpoints. + res.desc = flipcall.PacketWindowDescriptor{ + FD: rchannel0.FilePayload().FD(), + Offset: int64(rchannel0.Offset), + Length: int(rchannel0.Length), + } + if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil { + rchannel1.FilePayload().Close() + return err + } + + // The fds channel owns the control payload, and it will be closed when + // the channel object is closed. + res.fds.Init(rchannel1.FilePayload().Release()) + + // Save the channel. + c.channelsMu.Lock() + defer c.channelsMu.Unlock() + c.channels = append(c.channels, res) + c.availableChannels = append(c.availableChannels, res) + return nil +} + // handleOne handles a single incoming message. // // This should only be called with the token from recvr. Note that the received @@ -247,10 +430,10 @@ func (c *Client) waitAndRecv(done chan error) error { } } -// sendRecv performs a roundtrip message exchange. +// sendRecvLegacy performs a roundtrip message exchange. // // This is called by internal functions. -func (c *Client) sendRecv(t message, r message) error { +func (c *Client) sendRecvLegacy(t message, r message) error { tag, ok := c.tagPool.Get() if !ok { return ErrOutOfTags @@ -296,12 +479,62 @@ func (c *Client) sendRecv(t message, r message) error { return nil } +// sendRecvChannel uses channels to send a message. +func (c *Client) sendRecvChannel(t message, r message) error { + // Acquire an available channel. + c.channelsMu.Lock() + if len(c.availableChannels) == 0 { + c.channelsMu.Unlock() + return c.sendRecvLegacy(t, r) + } + idx := len(c.availableChannels) - 1 + ch := c.availableChannels[idx] + c.availableChannels = c.availableChannels[:idx] + ch.active = true + c.channelsWg.Add(1) + c.channelsMu.Unlock() + + // Ensure that it's connected. + if !ch.connected { + ch.connected = true + if err := ch.data.Connect(); err != nil { + // The channel is unusable, so don't return it to + // c.availableChannels. However, we still have to mark it as + // inactive so c.watch() doesn't wait for it. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + return err + } + } + + // Send the message. + err := ch.sendRecv(c, t, r) + + // Release the channel. + c.channelsMu.Lock() + ch.active = false + // If c.availableChannels is nil, c.watch() has fired and we should not + // mark this channel as available. + if c.availableChannels != nil { + c.availableChannels = append(c.availableChannels, ch) + } + c.channelsMu.Unlock() + c.channelsWg.Done() + + return err +} + // Version returns the negotiated 9P2000.L.Google version number. func (c *Client) Version() uint32 { return c.version } -// Close closes the underlying socket. -func (c *Client) Close() error { - return c.socket.Close() +// Close closes the underlying socket and channels. +func (c *Client) Close() { + // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already + // been called (by c.watch()). + c.socket.Shutdown() + c.closedWg.Wait() } diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go index 87b2dd61e..29a0afadf 100644 --- a/pkg/p9/client_test.go +++ b/pkg/p9/client_test.go @@ -35,23 +35,23 @@ func TestVersion(t *testing.T) { go s.Handle(serverSocket) // NewClient does a Tversion exchange, so this is our test for success. - c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString()) + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) if err != nil { t.Fatalf("got %v, expected nil", err) } // Check a bogus version string. - if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { + if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { t.Errorf("got %v expected %v", err, syscall.EINVAL) } // Check a bogus version number. - if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL { + if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL { t.Errorf("got %v expected %v", err, syscall.EINVAL) } // Check a too high version number. - if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN { + if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN { t.Errorf("got %v expected %v", err, syscall.EAGAIN) } @@ -60,3 +60,45 @@ func TestVersion(t *testing.T) { t.Errorf("got %v expected %v", err, syscall.EINVAL) } } + +func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { + // See above. + serverSocket, clientSocket, err := unet.SocketPair(false) + if err != nil { + b.Fatalf("socketpair got err %v expected nil", err) + } + defer clientSocket.Close() + + // See above. + s := NewServer(nil) + go s.Handle(serverSocket) + + // See above. + c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString()) + if err != nil { + b.Fatalf("got %v, expected nil", err) + } + + // Initialize messages. + sendRecv := fn(c) + tversion := &Tversion{ + Version: versionString(highestSupportedVersion), + MSize: DefaultMessageSize, + } + rversion := new(Rversion) + + // Run in a loop. + for i := 0; i < b.N; i++ { + if err := sendRecv(tversion, rversion); err != nil { + b.Fatalf("got unexpected err: %v", err) + } + } +} + +func BenchmarkSendRecvLegacy(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy }) +} + +func BenchmarkSendRecvChannel(b *testing.B) { + benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel }) +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 999b4f684..ba9a55d6d 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -305,7 +305,9 @@ func (t *Tlopen) handle(cs *connState) message { ref.opened = true ref.openFlags = t.Flags - return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile} + rlopen := &Rlopen{QID: qid, IoUnit: ioUnit} + rlopen.SetFilePayload(osFile) + return rlopen } func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { @@ -364,7 +366,9 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { // Replace the FID reference. cs.InsertFID(t.FID, newRef) - return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil + rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}} + rlcreate.SetFilePayload(osFile) + return rlcreate, nil } // handle implements handler.handle. @@ -1287,5 +1291,48 @@ func (t *Tlconnect) handle(cs *connState) message { return newErr(err) } - return &Rlconnect{File: osFile} + rlconnect := &Rlconnect{} + rlconnect.SetFilePayload(osFile) + return rlconnect +} + +// handle implements handler.handle. +func (t *Tchannel) handle(cs *connState) message { + // Ensure that channels are enabled. + if err := cs.initializeChannels(); err != nil { + return newErr(err) + } + + // Lookup the given channel. + ch := cs.lookupChannel(t.ID) + if ch == nil { + return newErr(syscall.ENOSYS) + } + + // Return the payload. Note that we need to duplicate the file + // descriptor for the channel allocator, because sending is a + // destructive operation between sendRecvLegacy (and now the newer + // channel send operations). Same goes for the client FD. + rchannel := &Rchannel{ + Offset: uint64(ch.desc.Offset), + Length: uint64(ch.desc.Length), + } + switch t.Control { + case 0: + // Open the main data channel. + mfd, err := syscall.Dup(int(cs.channelAlloc.FD())) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(mfd)) + case 1: + cfd, err := syscall.Dup(ch.client.FD()) + if err != nil { + return newErr(err) + } + rchannel.SetFilePayload(fd.New(cfd)) + default: + return newErr(syscall.EINVAL) + } + return rchannel } diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index fd9eb1c5d..ffdd7e8c6 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -64,6 +64,21 @@ type filer interface { SetFilePayload(*fd.FD) } +// filePayload embeds a File object. +type filePayload struct { + File *fd.FD +} + +// FilePayload returns the file payload. +func (f *filePayload) FilePayload() *fd.FD { + return f.File +} + +// SetFilePayload sets the received file. +func (f *filePayload) SetFilePayload(file *fd.FD) { + f.File = file +} + // Tversion is a version request. type Tversion struct { // MSize is the message size to use. @@ -524,10 +539,7 @@ type Rlopen struct { // IoUnit is the recommended I/O unit. IoUnit uint32 - // File may be attached via the socket. - // - // This is an extension specific to this package. - File *fd.FD + filePayload } // Decode implements encoder.Decode. @@ -547,16 +559,6 @@ func (*Rlopen) Type() MsgType { return MsgRlopen } -// FilePayload returns the file payload. -func (r *Rlopen) FilePayload() *fd.FD { - return r.File -} - -// SetFilePayload sets the received file. -func (r *Rlopen) SetFilePayload(file *fd.FD) { - r.File = file -} - // String implements fmt.Stringer. func (r *Rlopen) String() string { return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File) @@ -2171,8 +2173,7 @@ func (t *Tlconnect) String() string { // Rlconnect is a connect response. type Rlconnect struct { - // File is a host socket. - File *fd.FD + filePayload } // Decode implements encoder.Decode. @@ -2186,19 +2187,71 @@ func (*Rlconnect) Type() MsgType { return MsgRlconnect } -// FilePayload returns the file payload. -func (r *Rlconnect) FilePayload() *fd.FD { - return r.File +// String implements fmt.Stringer. +func (r *Rlconnect) String() string { + return fmt.Sprintf("Rlconnect{File: %v}", r.File) } -// SetFilePayload sets the received file. -func (r *Rlconnect) SetFilePayload(file *fd.FD) { - r.File = file +// Tchannel creates a new channel. +type Tchannel struct { + // ID is the channel ID. + ID uint32 + + // Control is 0 if the Rchannel response should provide the flipcall + // component of the channel, and 1 if the Rchannel response should + // provide the fdchannel component of the channel. + Control uint32 +} + +// Decode implements encoder.Decode. +func (t *Tchannel) Decode(b *buffer) { + t.ID = b.Read32() + t.Control = b.Read32() +} + +// Encode implements encoder.Encode. +func (t *Tchannel) Encode(b *buffer) { + b.Write32(t.ID) + b.Write32(t.Control) +} + +// Type implements message.Type. +func (*Tchannel) Type() MsgType { + return MsgTchannel } // String implements fmt.Stringer. -func (r *Rlconnect) String() string { - return fmt.Sprintf("Rlconnect{File: %v}", r.File) +func (t *Tchannel) String() string { + return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control) +} + +// Rchannel is the channel response. +type Rchannel struct { + Offset uint64 + Length uint64 + filePayload +} + +// Decode implements encoder.Decode. +func (r *Rchannel) Decode(b *buffer) { + r.Offset = b.Read64() + r.Length = b.Read64() +} + +// Encode implements encoder.Encode. +func (r *Rchannel) Encode(b *buffer) { + b.Write64(r.Offset) + b.Write64(r.Length) +} + +// Type implements message.Type. +func (*Rchannel) Type() MsgType { + return MsgRchannel +} + +// String implements fmt.Stringer. +func (r *Rchannel) String() string { + return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length) } const maxCacheSize = 3 @@ -2356,4 +2409,6 @@ func init() { msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} }) msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} }) msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) + msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) + msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index e12831dbd..25530adca 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -378,6 +378,8 @@ const ( MsgRlconnect = 137 MsgTallocate = 138 MsgRallocate = 139 + MsgTchannel = 250 + MsgRchannel = 251 ) // QIDType represents the file type for QIDs. diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 6e939a49a..28707c0ca 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") package(licenses = ["notice"]) @@ -77,7 +77,7 @@ go_library( go_test( name = "client_test", - size = "small", + size = "medium", srcs = ["client_test.go"], embed = [":p9test"], deps = [ diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index fe649c2e8..8bbdb2488 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) { } } } + +func TestReadWriteConcurrent(t *testing.T) { + h, c := NewHarness(t) + defer h.Finish() + + _, root := newRoot(h, c) + defer root.Close() + + const ( + instances = 10 + iterations = 10000 + dataSize = 1024 + ) + var ( + dataSets [instances][dataSize]byte + backends [instances]*Mock + files [instances]p9.File + ) + + // Walk to the file normally. + for i := 0; i < instances; i++ { + _, backends[i], files[i] = walkHelper(h, "file", root) + defer files[i].Close() + } + + // Open the files. + for i := 0; i < instances; i++ { + backends[i].EXPECT().Open(p9.ReadWrite) + if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil { + t.Fatalf("open got %v, wanted nil", err) + } + } + + // Initialize random data for each instance. + for i := 0; i < instances; i++ { + if _, err := rand.Read(dataSets[i][:]); err != nil { + t.Fatalf("error initializing dataSet#%d, got %v", i, err) + } + } + + // Define our random read/write mechanism. + randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) { + // Prepare the backend. + backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if n := copy(p, data); n != len(data) { + // Note that we have to assert the result here, as the Return statement + // below cannot be dynamic: it will be bound before this call is made. + h.t.Errorf("wanted length %d, got %d", len(data), n) + } + }).Return(len(data), nil) + + // Execute the read. + if n, err := f.ReadAt(test, 0); n != len(test) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err) + return // No sense doing check below. + } + if !bytes.Equal(test, data) { + t.Errorf("data integrity failed during read") // Not as expected. + } + } + randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) { + // Prepare the backend. + backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) { + if !bytes.Equal(p, data) { + h.t.Errorf("data integrity failed during write") // Not as expected. + } + }).Return(len(data), nil) + + // Execute the write. + if n, err := f.WriteAt(data, 0); n != len(data) || err != nil { + t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err) + } + } + randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) { + test := make([]byte, len(data)) + for i := 0; i < n; i++ { + if rand.Intn(2) == 0 { + randRead(h, backend, f, data, test) + } else { + randWrite(h, backend, f, data) + } + } + } + + // Start reading and writing. + var wg sync.WaitGroup + for i := 0; i < instances; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:]) + }(i) + } + wg.Wait() +} diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 95846e5f7..4d3271b37 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator { // Finish completes all checks and shuts down the server. func (h *Harness) Finish() { - h.clientSocket.Close() + h.clientSocket.Shutdown() h.wg.Wait() h.mockCtrl.Finish() } @@ -315,7 +315,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) { }() // Create the client. - client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString()) + client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString()) if err != nil { serverSocket.Close() clientSocket.Close() diff --git a/pkg/p9/server.go b/pkg/p9/server.go index b294efbb0..69c886a5d 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -21,6 +21,9 @@ import ( "sync/atomic" "syscall" + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) @@ -45,7 +48,6 @@ type Server struct { } // NewServer returns a new server. -// func NewServer(attacher Attacher) *Server { return &Server{ attacher: attacher, @@ -85,6 +87,8 @@ type connState struct { // version 0 implies 9P2000.L. version uint32 + // -- below relates to the legacy handler -- + // recvOkay indicates that a receive may start. recvOkay chan bool @@ -93,6 +97,20 @@ type connState struct { // sendDone is signalled when a send is finished. sendDone chan error + + // -- below relates to the flipcall handler -- + + // channelMu protects below. + channelMu sync.Mutex + + // channelWg represents active workers. + channelWg sync.WaitGroup + + // channelAlloc allocates channel memory. + channelAlloc *flipcall.PacketWindowAllocator + + // channels are the set of initialized channels. + channels []*channel } // fidRef wraps a node and tracks references. @@ -386,6 +404,99 @@ func (cs *connState) WaitTag(t Tag) { <-ch } +// initializeChannels initializes all channels. +// +// This is a no-op if channels are already initialized. +func (cs *connState) initializeChannels() (err error) { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + + // Initialize our channel allocator. + if cs.channelAlloc == nil { + alloc, err := flipcall.NewPacketWindowAllocator() + if err != nil { + return err + } + cs.channelAlloc = alloc + } + + // Create all the channels. + for len(cs.channels) < channelsPerClient { + res := &channel{ + done: make(chan struct{}), + } + + res.desc, err = cs.channelAlloc.Allocate(channelSize) + if err != nil { + return err + } + if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil { + return err + } + + socks, err := fdchannel.NewConnectedSockets() + if err != nil { + res.data.Destroy() // Cleanup. + return err + } + res.fds.Init(socks[0]) + res.client = fd.New(socks[1]) + + cs.channels = append(cs.channels, res) + + // Start servicing the channel. + // + // When we call stop, we will close all the channels and these + // routines should finish. We need the wait group to ensure + // that active handlers are actually finished before cleanup. + cs.channelWg.Add(1) + go func() { // S/R-SAFE: Server side. + defer cs.channelWg.Done() + res.service(cs) + }() + } + + return nil +} + +// lookupChannel looks up the channel with given id. +// +// The function returns nil if no such channel is available. +func (cs *connState) lookupChannel(id uint32) *channel { + cs.channelMu.Lock() + defer cs.channelMu.Unlock() + if id >= uint32(len(cs.channels)) { + return nil + } + return cs.channels[id] +} + +// handle handles a single message. +func (cs *connState) handle(m message) (r message) { + defer func() { + if r == nil { + // Don't allow a panic to propagate. + recover() + + // Include a useful log message. + log.Warningf("panic in handler: %s", debug.Stack()) + + // Wrap in an EFAULT error; we don't really have a + // better way to describe this kind of error. It will + // usually manifest as a result of the test framework. + r = newErr(syscall.EFAULT) + } + }() + if handler, ok := m.(handler); ok { + // Call the message handler. + r = handler.handle(cs) + } else { + // Produce an ENOSYS error. + r = newErr(syscall.ENOSYS) + } + return +} + // handleRequest handles a single request. // // The recvDone channel is signaled when recv is done (with a error if @@ -428,41 +539,20 @@ func (cs *connState) handleRequest() { } // Handle the message. - var r message // r is the response. - defer func() { - if r == nil { - // Don't allow a panic to propagate. - recover() + r := cs.handle(m) - // Include a useful log message. - log.Warningf("panic in handler: %s", debug.Stack()) + // Clear the tag before sending. That's because as soon as this hits + // the wire, the client can legally send the same tag. + cs.ClearTag(tag) - // Wrap in an EFAULT error; we don't really have a - // better way to describe this kind of error. It will - // usually manifest as a result of the test framework. - r = newErr(syscall.EFAULT) - } + // Send back the result. + cs.sendMu.Lock() + err = send(cs.conn, tag, r) + cs.sendMu.Unlock() + cs.sendDone <- err - // Clear the tag before sending. That's because as soon as this - // hits the wire, the client can legally send another message - // with the same tag. - cs.ClearTag(tag) - - // Send back the result. - cs.sendMu.Lock() - err = send(cs.conn, tag, r) - cs.sendMu.Unlock() - cs.sendDone <- err - }() - if handler, ok := m.(handler); ok { - // Call the message handler. - r = handler.handle(cs) - } else { - // Produce an ENOSYS error. - r = newErr(syscall.ENOSYS) - } + // Return the message to the cache. msgRegistry.put(m) - m = nil // 'm' should not be touched after this point. } func (cs *connState) handleRequests() { @@ -477,7 +567,27 @@ func (cs *connState) stop() { close(cs.recvDone) close(cs.sendDone) - for _, fidRef := range cs.fids { + // Free the channels. + cs.channelMu.Lock() + for _, ch := range cs.channels { + ch.Shutdown() + } + cs.channelWg.Wait() + for _, ch := range cs.channels { + ch.Close() + } + cs.channels = nil // Clear. + cs.channelMu.Unlock() + + // Free the channel memory. + if cs.channelAlloc != nil { + cs.channelAlloc.Destroy() + } + + // Close all remaining fids. + for fid, fidRef := range cs.fids { + delete(cs.fids, fid) + // Drop final reference in the FID table. Note this should // always close the file, since we've ensured that there are no // handlers running via the wait for Pending => 0 below. @@ -510,7 +620,7 @@ func (cs *connState) service() error { for i := 0; i < pending; i++ { <-cs.sendDone } - return err + return nil } // This handler is now pending. diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 5648df589..6e8b4bbcd 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -54,7 +54,10 @@ const ( headerLength uint32 = 7 // maximumLength is the largest possible message. - maximumLength uint32 = 4 * 1024 * 1024 + maximumLength uint32 = 1 << 20 + + // DefaultMessageSize is a sensible default. + DefaultMessageSize uint32 = 64 << 10 // initialBufferLength is the initial data buffer we allocate. initialBufferLength uint32 = 64 diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go new file mode 100644 index 000000000..7cdf4ecc3 --- /dev/null +++ b/pkg/p9/transport_flipcall.go @@ -0,0 +1,263 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package p9 + +import ( + "runtime" + "syscall" + + "gvisor.dev/gvisor/pkg/fd" + "gvisor.dev/gvisor/pkg/fdchannel" + "gvisor.dev/gvisor/pkg/flipcall" + "gvisor.dev/gvisor/pkg/log" +) + +// channelsPerClient is the number of channels to create per client. +// +// While the client and server will generally agree on this number, in reality +// it's completely up to the server. We simply define a minimum of 2, and a +// maximum of 4, and select the number of available processes as a tie-breaker. +// Note that we don't want the number of channels to be too large, because each +// will account for channelSize memory used, which can be large. +var channelsPerClient = func() int { + n := runtime.NumCPU() + if n < 2 { + return 2 + } + if n > 4 { + return 4 + } + return n +}() + +// channelSize is the channel size to create. +// +// We simply ensure that this is larger than the largest possible message size, +// plus the flipcall packet header, plus the two bytes we write below. +const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength) + +// channel is a fast IPC channel. +// +// The same object is used by both the server and client implementations. In +// general, the client will use only the send and recv methods. +type channel struct { + desc flipcall.PacketWindowDescriptor + data flipcall.Endpoint + fds fdchannel.Endpoint + buf buffer + + // -- client only -- + connected bool + active bool + + // -- server only -- + client *fd.FD + done chan struct{} +} + +// reset resets the channel buffer. +func (ch *channel) reset(sz uint32) { + ch.buf.data = ch.data.Data()[:sz] +} + +// service services the channel. +func (ch *channel) service(cs *connState) error { + rsz, err := ch.data.RecvFirst() + if err != nil { + return err + } + for rsz > 0 { + m, err := ch.recv(nil, rsz) + if err != nil { + return err + } + r := cs.handle(m) + msgRegistry.put(m) + rsz, err = ch.send(r) + if err != nil { + return err + } + } + return nil // Done. +} + +// Shutdown shuts down the channel. +// +// This must be called before Close. +func (ch *channel) Shutdown() { + ch.data.Shutdown() +} + +// Close closes the channel. +// +// This must only be called once, and cannot return an error. Note that +// synchronization for this method is provided at a high-level, depending on +// whether it is the client or server. This cannot be called while there are +// active callers in either service or sendRecv. +// +// Precondition: the channel should be shutdown. +func (ch *channel) Close() error { + // Close all backing transports. + ch.fds.Destroy() + ch.data.Destroy() + if ch.client != nil { + ch.client.Close() + } + return nil +} + +// send sends the given message. +// +// The return value is the size of the received response. Not that in the +// server case, this is the size of the next request. +func (ch *channel) send(m message) (uint32, error) { + if log.IsLogging(log.Debug) { + log.Debugf("send [channel @%p] %s", ch, m.String()) + } + + // Send any file payload. + sentFD := false + if filer, ok := m.(filer); ok { + if f := filer.FilePayload(); f != nil { + if err := ch.fds.SendFD(f.FD()); err != nil { + return 0, syscall.EIO // Map everything to EIO. + } + f.Close() // Per sendRecvLegacy. + sentFD = true // To mark below. + } + } + + // Encode the message. + // + // Note that IPC itself encodes the length of messages, so we don't + // need to encode a standard 9P header. We write only the message type. + ch.reset(0) + + ch.buf.WriteMsgType(m.Type()) + if sentFD { + ch.buf.Write8(1) // Incoming FD. + } else { + ch.buf.Write8(0) // No incoming FD. + } + m.Encode(&ch.buf) + ssz := uint32(len(ch.buf.data)) // Updated below. + + // Is there a payload? + if payloader, ok := m.(payloader); ok { + p := payloader.Payload() + copy(ch.data.Data()[ssz:], p) + ssz += uint32(len(p)) + } + + // Perform the one-shot communication. + n, err := ch.data.SendRecv(ssz) + if err != nil { + if n > 0 { + return n, nil + } + return 0, syscall.EIO // See above. + } + + return n, nil +} + +// recv decodes a message that exists on the channel. +// +// If the passed r is non-nil, then the type must match or an error will be +// generated. If the passed r is nil, then a new message will be created and +// returned. +func (ch *channel) recv(r message, rsz uint32) (message, error) { + // Decode the response from the inline buffer. + ch.reset(rsz) + t := ch.buf.ReadMsgType() + hasFD := ch.buf.Read8() != 0 + if t == MsgRlerror { + // Change the message type. We check for this special case + // after decoding below, and transform into an error. + r = &Rlerror{} + } else if r == nil { + nr, err := msgRegistry.get(0, t) + if err != nil { + return nil, err + } + r = nr // New message. + } else if t != r.Type() { + // Not an error and not the expected response; propagate. + return nil, &ErrBadResponse{Got: t, Want: r.Type()} + } + + // Is there a payload? Copy from the latter portion. + if payloader, ok := r.(payloader); ok { + fs := payloader.FixedSize() + p := payloader.Payload() + payloadData := ch.buf.data[fs:] + if len(p) < len(payloadData) { + p = make([]byte, len(payloadData)) + copy(p, payloadData) + payloader.SetPayload(p) + } else if n := copy(p, payloadData); n < len(p) { + payloader.SetPayload(p[:n]) + } + ch.buf.data = ch.buf.data[:fs] + } + + r.Decode(&ch.buf) + if ch.buf.isOverrun() { + // Nothing valid was available. + log.Debugf("recv [got %d bytes, needed more]", rsz) + return nil, ErrNoValidMessage + } + + // Read any FD result. + if hasFD { + if rfd, err := ch.fds.RecvFDNonblock(); err == nil { + f := fd.New(rfd) + if filer, ok := r.(filer); ok { + // Set the payload. + filer.SetFilePayload(f) + } else { + // Don't want the FD. + f.Close() + } + } else { + // The header bit was set but nothing came in. + log.Warningf("expected FD, got err: %v", err) + } + } + + // Log a message. + if log.IsLogging(log.Debug) { + log.Debugf("recv [channel @%p] %s", ch, r.String()) + } + + // Convert errors appropriately; see above. + if rlerr, ok := r.(*Rlerror); ok { + return nil, syscall.Errno(rlerr.Error) + } + + return r, nil +} + +// sendRecv sends the given message over the channel. +// +// This is used by the client. +func (ch *channel) sendRecv(c *Client, m, r message) error { + rsz, err := ch.send(m) + if err != nil { + return err + } + _, err = ch.recv(r, rsz) + return err +} diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index cdb3bc841..2f50ff3ea 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -124,7 +124,9 @@ func TestSendRecvWithFile(t *testing.T) { t.Fatalf("unable to create file: %v", err) } - if err := send(client, Tag(1), &Rlopen{File: f}); err != nil { + rlopen := &Rlopen{} + rlopen.SetFilePayload(f) + if err := send(client, Tag(1), rlopen); err != nil { t.Fatalf("send got err %v expected nil", err) } diff --git a/pkg/p9/version.go b/pkg/p9/version.go index c2a2885ae..f1ffdd23a 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 7 + highestSupportedVersion uint32 = 8 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -148,3 +148,10 @@ func VersionSupportsMultiUser(v uint32) bool { func versionSupportsTallocate(v uint32) bool { return v >= 7 } + +// versionSupportsFlipcall returns true if version v supports IPC channels from +// the flipcall package. Note that these must be negotiated, but this version +// string indicates that such a facility exists. +func versionSupportsFlipcall(v uint32) bool { + return v >= 8 +} diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index 697e7a2f4..078f084b2 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 9c08452fc..827385139 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "weak_ref_list", diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index d1024e49d..af94e944d 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -1,5 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test") package(licenses = ["notice"]) diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD index f38fb39f3..22abdc69f 100644 --- a/pkg/secio/BUILD +++ b/pkg/secio/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index 694486296..12d7c77d2 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package( default_visibility = ["//visibility:private"], diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD index 53989301f..2d6379c86 100644 --- a/pkg/sentry/BUILD +++ b/pkg/sentry/BUILD @@ -8,5 +8,7 @@ package_group( packages = [ "//pkg/sentry/...", "//runsc/...", + # Code generated by go_marshal relies on go_marshal libraries. + "//tools/go_marshal/...", ], ) diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 7aace2d7b..c71cff9f3 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -1,4 +1,5 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -42,6 +43,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "registers_cc_proto", + visibility = ["//visibility:public"], + deps = [":registers_proto"], +) + go_proto_library( name = "registers_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto", diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index bf802d1b6..5522cecd0 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 7e8918722..0c86197f7 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "device", diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index d7259b47b..3119a61b6 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "fs", diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index fbca06761..3cb73bd78 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -1126,7 +1126,7 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error { // Remove removes the given file or symlink. The root dirent is used to // resolve name, and must not be nil. -func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error { +func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath bool) error { // Check the root. if root == nil { panic("Dirent.Remove: root must not be nil") @@ -1151,6 +1151,8 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error { // Remove cannot remove directories. if IsDir(child.Inode.StableAttr) { return syscall.EISDIR + } else if dirPath { + return syscall.ENOTDIR } // Remove cannot remove a mount point. diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go index 884e3ff06..47bc72a88 100644 --- a/pkg/sentry/fs/dirent_refs_test.go +++ b/pkg/sentry/fs/dirent_refs_test.go @@ -343,7 +343,7 @@ func TestRemoveExtraRefs(t *testing.T) { } d := f.Dirent - if err := test.root.Remove(contexttest.Context(t), test.root, name); err != nil { + if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil { t.Fatalf("root.Remove(root, %q) failed: %v", name, err) } diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index bf00b9c09..b9bd9ed17 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "fdpipe", diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index bb8117f89..c0a6e884b 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -515,6 +515,11 @@ type lockedReader struct { // File is the file to read from. File *File + + // Offset is the offset to start at. + // + // This applies only to Read, not ReadAt. + Offset int64 } // Read implements io.Reader.Read. @@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) { if r.Ctx.Interrupted() { return 0, syserror.ErrInterrupted } - n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset) + n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset) + r.Offset += n return int(n), err } @@ -544,11 +550,21 @@ type lockedWriter struct { // File is the file to write to. File *File + + // Offset is the offset to start at. + // + // This applies only to Write, not WriteAt. + Offset int64 } // Write implements io.Writer.Write. func (w *lockedWriter) Write(buf []byte) (int, error) { - return w.WriteAt(buf, w.File.offset) + if w.Ctx.Interrupted() { + return 0, syserror.ErrInterrupted + } + n, err := w.WriteAt(buf, w.Offset) + w.Offset += int64(n) + return int(n), err } // WriteAt implements io.Writer.WriteAt. @@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) { // io.Copy, since our own Write interface does not have this same // contract. Enforce that here. for written < len(buf) { + if w.Ctx.Interrupted() { + return written, syserror.ErrInterrupted + } var n int64 n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written)) if n > 0 { diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index d86f5bf45..b88303f17 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -15,6 +15,8 @@ package fs import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -105,8 +107,11 @@ type FileOperations interface { // on the destination, following by a buffered copy with standard Read // and Write operations. // + // If dup is set, the data should be duplicated into the destination + // and retained. + // // The same preconditions as Read apply. - WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error) + WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error) // Write writes src to file at offset and returns the number of bytes // written which must be greater than or equal to 0. Like Read, file @@ -126,7 +131,7 @@ type FileOperations interface { // source. See WriteTo for details regarding how this is called. // // The same preconditions as Write apply; FileFlags.Write must be set. - ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error) + ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error) // Fsync writes buffered modifications of file and/or flushes in-flight // operations to backing storage based on syncType. The range to sync is diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 9820f0b13..225e40186 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "gvisor.dev/gvisor/pkg/refs" @@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme } // WriteTo implements FileOperations.WriteTo. -func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) { err = f.onTop(ctx, file, func(file *File, ops FileOperations) error { - n, err = ops.WriteTo(ctx, file, dst, opts) + n, err = ops.WriteTo(ctx, file, dst, count, dup) return err // Will overwrite itself. }) return @@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm } // ReadFrom implements FileOperations.ReadFrom. -func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) { +func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) { // See above; f.upper must be non-nil. - return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts) + return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count) } // Fsync implements FileOperations.Fsync. diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 6499f87ac..b4ac83dc4 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "dirty_set_impl", diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go index 626b9126a..fc5b3b1a1 100644 --- a/pkg/sentry/fs/fsutil/file.go +++ b/pkg/sentry/fs/fsutil/file.go @@ -15,6 +15,8 @@ package fsutil import ( + "io" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu type FileNoSplice struct{} // WriteTo implements fs.FileOperations.WriteTo. -func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } // ReadFrom implements fs.FileOperations.ReadFrom. -func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) { +func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 6b993928c..2b71ca0e1 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "gofer", diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index b1080fb1a..3e532332e 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "host", diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index c7f4e2d13..ba3e0233d 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -15,6 +15,7 @@ package fs import ( + "io" "sync" "sync/atomic" @@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i } // WriteTo implements FileOperations.WriteTo. -func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) { return 0, syserror.ENOSYS } @@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error { } // ReadFrom implements FileOperations.ReadFrom. -func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) { +func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) { return 0, syserror.ENOSYS } diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 08d7c0c57..5a7a5b8cd 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "lock_range", diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index c7599d1f6..1c93e8886 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "proc", diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 20c3eefc8..76433c7d0 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "seqfile", diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 516efcc4c..d0f351e5a 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "ramfs", diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index eed1c2854..b03b7f836 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -18,7 +18,6 @@ import ( "io" "sync/atomic" - "gvisor.dev/gvisor/pkg/secio" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/syserror" ) @@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } // Check whether or not the objects being sliced are stream-oriented - // (i.e. pipes or sockets). If yes, we elide checks and offset locks. - srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr) - dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr) + // (i.e. pipes or sockets). For all stream-oriented files and files + // where a specific offiset is not request, we acquire the file mutex. + // This has two important side effects. First, it provides the standard + // protection against concurrent writes that would mutate the offset. + // Second, it prevents Splice deadlocks. Only internal anonymous files + // implement the ReadFrom and WriteTo methods directly, and since such + // anonymous files are referred to by a unique fs.File object, we know + // that the file mutex takes strict precedence over internal locks. + // Since we enforce lock ordering here, we can't deadlock by using + // using a file in two different splice operations simultaneously. + srcPipe := !IsRegular(src.Dirent.Inode.StableAttr) + dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr) + dstAppend := !dstPipe && dst.Flags().Append + srcLock := srcPipe || !opts.SrcOffset + dstLock := dstPipe || !opts.DstOffset || dstAppend - if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset { + switch { + case srcLock && dstLock: switch { case dst.UniqueID < src.UniqueID: // Acquire dst first. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() if !src.mu.Lock(ctx) { + dst.mu.Unlock() return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() case dst.UniqueID > src.UniqueID: // Acquire src first. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() if !dst.mu.Lock(ctx) { + src.mu.Unlock() return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() case dst.UniqueID == src.UniqueID: // Acquire only one lock; it's the same file. This is a // bit of a edge case, but presumably it's possible. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() + srcLock = false // Only need one unlock. } // Use both offsets (locked). opts.DstStart = dst.offset opts.SrcStart = src.offset - } else if !dstPipe && !opts.DstOffset { + case dstLock: // Acquire only dst. if !dst.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer dst.mu.Unlock() opts.DstStart = dst.offset // Safe: locked. - } else if !srcPipe && !opts.SrcOffset { + case srcLock: // Acquire only src. if !src.mu.Lock(ctx) { return 0, syserror.ErrInterrupted } - defer src.mu.Unlock() opts.SrcStart = src.offset // Safe: locked. } - // Check append-only mode and the limit. - if !dstPipe { + var err error + if dstAppend { unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append) defer unlock() - if dst.Flags().Append { - if opts.DstOffset { - // We need to acquire the lock. - if !dst.mu.Lock(ctx) { - return 0, syserror.ErrInterrupted - } - defer dst.mu.Unlock() - } - // Figure out the appropriate offset to use. - if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil { - return 0, err - } - } + // Figure out the appropriate offset to use. + err = dst.offsetForAppend(ctx, &opts.DstStart) + } + if err == nil && !dstPipe { // Enforce file limits. limit, ok := dst.checkLimit(ctx, opts.DstStart) switch { case ok && limit == 0: - return 0, syserror.ErrExceedsFileSizeLimit + err = syserror.ErrExceedsFileSizeLimit case ok && limit < opts.Length: opts.Length = limit // Cap the write. } } + if err != nil { + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return 0, err + } - // Attempt to do a WriteTo; this is likely the most efficient. - // - // The underlying implementation may be able to donate buffers. - newOpts := SpliceOpts{ - Length: opts.Length, - SrcStart: opts.SrcStart, - SrcOffset: !srcPipe, - Dup: opts.Dup, - DstStart: opts.DstStart, - DstOffset: !dstPipe, + // Construct readers and writers for the splice. This is used to + // provide a safer locking path for the WriteTo/ReadFrom operations + // (since they will otherwise go through public interface methods which + // conflict with locking done above), and simplifies the fallback path. + w := &lockedWriter{ + Ctx: ctx, + File: dst, + Offset: opts.DstStart, } - n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts) - if n == 0 && err != nil { - // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also - // be more efficient than a copy if buffers are cached or readily - // available. (It's unlikely that they can actually be donate - n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts) + r := &lockedReader{ + Ctx: ctx, + File: src, + Offset: opts.SrcStart, } - if n == 0 && err != nil { - // If we've failed up to here, and at least one of the sources - // is a pipe or socket, then we can't properly support dup. - // Return an error indicating that this operation is not - // supported. - if (srcPipe || dstPipe) && newOpts.Dup { - return 0, syserror.EINVAL - } - // We failed to splice the files. But that's fine; we just fall - // back to a slow path in this case. This copies without doing - // any mode changes, so should still be more efficient. - var ( - r io.Reader - w io.Writer - ) - fw := &lockedWriter{ - Ctx: ctx, - File: dst, - } - if newOpts.DstOffset { - // Use the provided offset. - w = secio.NewOffsetWriter(fw, newOpts.DstStart) - } else { - // Writes will proceed with no offset. - w = fw - } - fr := &lockedReader{ - Ctx: ctx, - File: src, - } - if newOpts.SrcOffset { - // Limit to the given offset and length. - r = io.NewSectionReader(fr, opts.SrcStart, opts.Length) - } else { - // Limit just to the given length. - r = &io.LimitedReader{fr, opts.Length} - } + // Attempt to do a WriteTo; this is likely the most efficient. + n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup) + if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup { + // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be + // more efficient than a copy if buffers are cached or readily + // available. (It's unlikely that they can actually be donated). + n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length) + } - // Copy between the two. - n, err = io.Copy(w, r) + // Support one last fallback option, but only if at least one of + // the source and destination are regular files. This is because + // if we block at some point, we could lose data. If the source is + // not a pipe then reading is not destructive; if the destination + // is a regular file, then it is guaranteed not to block writing. + if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup && (!dstPipe || !srcPipe) { + // Fallback to an in-kernel copy. + n, err = io.Copy(w, &io.LimitedReader{ + R: r, + N: opts.Length, + }) } // Update offsets, if required. @@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, } } + // Drop locks. + if dstLock { + dst.mu.Unlock() + } + if srcLock { + src.mu.Unlock() + } + return n, err } diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 8f7eb5757..11b680929 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tmpfs", diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 5e9327aec..25811f668 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tty", @@ -23,6 +25,7 @@ go_library( "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 1d128532b..2f639c823 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -129,6 +129,9 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode { // Release implements fs.InodeOperations.Release. func (d *dirInodeOperations) Release(ctx context.Context) { + d.mu.Lock() + defer d.mu.Unlock() + d.master.DecRef() if len(d.slaves) != 0 { panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d)) diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index 92ec1ca18..19b7557d5 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -172,6 +172,19 @@ func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io userme return 0, mf.t.ld.windowSize(ctx, io, args) case linux.TIOCSWINSZ: return 0, mf.t.ld.setWindowSize(ctx, io, args) + case linux.TIOCSCTTY: + // Make the given terminal the controlling terminal of the + // calling process. + return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY @@ -185,8 +198,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { linux.TCSETS, linux.TCSETSW, linux.TCSETSF, - linux.TIOCGPGRP, - linux.TIOCSPGRP, linux.TIOCGWINSZ, linux.TIOCSWINSZ, linux.TIOCSETD, @@ -200,8 +211,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) { linux.TIOCEXCL, linux.TIOCNXCL, linux.TIOCGEXCL, - linux.TIOCNOTTY, - linux.TIOCSCTTY, linux.TIOCGSID, linux.TIOCGETD, linux.TIOCVHANGUP, diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go index e30266404..944c4ada1 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/slave.go @@ -152,9 +152,16 @@ func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - // TODO(b/129283598): Implement once we have support for job - // control. - return 0, nil + return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */) + case linux.TIOCNOTTY: + // Release this process's controlling terminal. + return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) + case linux.TIOCGPGRP: + // Get the foreground process group. + return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) + case linux.TIOCSPGRP: + // Set the foreground process group. + return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index b7cecb2ed..ff8138820 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -17,7 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) // Terminal is a pseudoterminal. @@ -26,23 +29,100 @@ import ( type Terminal struct { refs.AtomicRefCount - // n is the terminal index. + // n is the terminal index. It is immutable. n uint32 - // d is the containing directory. + // d is the containing directory. It is immutable. d *dirInodeOperations - // ld is the line discipline of the terminal. + // ld is the line discipline of the terminal. It is immutable. ld *lineDiscipline + + // masterKTTY contains the controlling process of the master end of + // this terminal. This field is immutable. + masterKTTY *kernel.TTY + + // slaveKTTY contains the controlling process of the slave end of this + // terminal. This field is immutable. + slaveKTTY *kernel.TTY } func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal { termios := linux.DefaultSlaveTermios t := Terminal{ - d: d, - n: n, - ld: newLineDiscipline(termios), + d: d, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{}, + slaveKTTY: &kernel.TTY{}, } t.EnableLeakCheck("tty.Terminal") return &t } + +// setControllingTTY makes tm the controlling terminal of the calling thread +// group. +func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("setControllingTTY must be called from a task context") + } + + return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int()) +} + +// releaseControllingTTY removes tm as the controlling terminal of the calling +// thread group. +func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("releaseControllingTTY must be called from a task context") + } + + return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster)) +} + +// foregroundProcessGroup gets the process group ID of tm's foreground process. +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("foregroundProcessGroup must be called from a task context") + } + + ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster)) + if err != nil { + return 0, err + } + + // Write it out to *arg. + _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err +} + +// foregroundProcessGroup sets tm's foreground process. +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + panic("setForegroundProcessGroup must be called from a task context") + } + + // Read in the process group ID. + var pgid int32 + if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + + ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid)) + return uintptr(ret), err +} + +func (tm *Terminal) tty(isMaster bool) *kernel.TTY { + if isMaster { + return tm.masterKTTY + } + return tm.slaveKTTY +} diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 9e8ebb907..b0c286b7a 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") go_template_instance( diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index 9fddb4c4c..bfc46dfa6 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index b51f3e18d..0b471d121 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba } if !cb.Handle(vfs.Dirent{ - Name: child.diskDirent.FileName(), - Type: fs.ToDirentType(childType), - Ino: uint64(child.diskDirent.Inode()), - Off: fd.off, + Name: child.diskDirent.FileName(), + Type: fs.ToDirentType(childType), + Ino: uint64(child.diskDirent.Inode()), + NextOff: fd.off + 1, }) { dir.childList.InsertBefore(child, fd.iter) return nil diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 907d35b7e..2d50e30aa 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "disklayout", diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 63cf7aeaf..1aa2bd6a4 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -584,7 +584,7 @@ func TestIterDirents(t *testing.T) { // Ignore the inode number and offset of dirents because those are likely to // change as the underlying image changes. cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool { - return p.String() == "Ino" || p.String() == "Off" + return p.String() == "Ino" || p.String() == "NextOff" }, cmp.Ignore()) if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" { t.Errorf("dirents mismatch (-want +got):\n%s", diff) diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD index d2450e810..7e364c5fd 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/memfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go index c52dc781c..c620227c9 100644 --- a/pkg/sentry/fsimpl/memfs/directory.go +++ b/pkg/sentry/fsimpl/memfs/directory.go @@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba if fd.off == 0 { if !cb.Handle(vfs.Dirent{ - Name: ".", - Type: linux.DT_DIR, - Ino: vfsd.Impl().(*dentry).inode.ino, - Off: 0, + Name: ".", + Type: linux.DT_DIR, + Ino: vfsd.Impl().(*dentry).inode.ino, + NextOff: 1, }) { return nil } @@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba if fd.off == 1 { parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode if !cb.Handle(vfs.Dirent{ - Name: "..", - Type: parentInode.direntType(), - Ino: parentInode.ino, - Off: 1, + Name: "..", + Type: parentInode.direntType(), + Ino: parentInode.ino, + NextOff: 2, }) { return nil } @@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba // Skip other directoryFD iterators. if child.inode != nil { if !cb.Handle(vfs.Dirent{ - Name: child.vfsd.Name(), - Type: child.inode.direntType(), - Ino: child.inode.ino, - Off: fd.off, + Name: child.vfsd.Name(), + Type: child.inode.direntType(), + Ino: child.inode.ino, + NextOff: fd.off + 1, }) { dir.childList.InsertBefore(child, fd.iter) return nil diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 3d8a4deaf..ade6ac946 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD index f989f2f8b..359468ccc 100644 --- a/pkg/sentry/hostcpu/BUILD +++ b/pkg/sentry/hostcpu/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,6 +7,7 @@ go_library( name = "hostcpu", srcs = [ "getcpu_amd64.s", + "getcpu_arm64.s", "hostcpu.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu", diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s new file mode 100644 index 000000000..caf9abb89 --- /dev/null +++ b/pkg/sentry/hostcpu/getcpu_arm64.s @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for +// the lack of an optimazed way of getting the current CPU number on arm64. + +// func GetCPU() (cpu uint32) +TEXT ·GetCPU(SB), NOSPLIT, $0-4 + MOVW ZR, cpu+0(FP) + MOVD $cpu+0(FP), R0 + MOVD $0x0, R1 // unused + MOVD $0x0, R2 // unused + MOVD $0xA8, R8 // SYS_GETCPU + SVC + RET diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e61d39c82..aba2414d4 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -1,9 +1,11 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "pending_signals_list", @@ -83,6 +85,12 @@ proto_library( deps = ["//pkg/sentry/arch:registers_proto"], ) +cc_proto_library( + name = "uncaught_signal_cc_proto", + visibility = ["//visibility:public"], + deps = [":uncaught_signal_proto"], +) + go_proto_library( name = "uncaught_signal_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto", @@ -144,6 +152,7 @@ go_library( "threads.go", "timekeeper.go", "timekeeper_state.go", + "tty.go", "uts_namespace.go", "vdso.go", "version.go", diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index f46c43128..65427b112 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "epoll_list", diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index 1c5f979d4..983ca67ed 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "eventfd", diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 6a31dc044..41f44999c 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "atomicptr_bucket", diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index ebcfaa619..d7a7d1169 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -24,6 +25,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "memory_events_cc_proto", + visibility = ["//visibility:public"], + deps = [":memory_events_proto"], +) + go_proto_library( name = "memory_events_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto", diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 4d15cca85..2ce8952e2 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "buffer_list", diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 69ef2a720..95bee2d37 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" @@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return n, err } +// WriteFromReader writes to the buffer from an io.Reader. +func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) { + dst := b.data[b.write:] + if count < int64(len(dst)) { + dst = b.data[b.write:][:count] + } + n, err := r.Read(dst) + b.write += n + return int64(n), err +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write])) @@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return n, err } +// ReadToWriter reads from the buffer into an io.Writer. +func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) { + src := b.data[b.read:b.write] + if count < int64(len(src)) { + src = b.data[b.read:][:count] + } + n, err := w.Write(src) + if !dup { + b.read += n + } + return int64(n), err +} + // bufferPool is a pool for buffers. var bufferPool = sync.Pool{ New: func() interface{} { diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 247e2928e..93b50669f 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -173,13 +172,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F } } +type readOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit limits subsequence reads. + limit func(int64) + + // read performs the actual read operation. + read func(*buffer) (int64, error) +} + // read reads data from the pipe into dst and returns the number of bytes // read, or returns ErrWouldBlock if the pipe is empty. // // Precondition: this pipe must have readers. -func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) { +func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { // Don't block for a zero-length read even if the pipe is empty. - if dst.NumBytes() == 0 { + if ops.left() == 0 { return 0, nil } @@ -196,12 +206,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Limit how much we consume. - if dst.NumBytes() > p.size { - dst = dst.TakeFirst64(p.size) + if ops.left() > p.size { + ops.limit(p.size) } done := int64(0) - for dst.NumBytes() > 0 { + for ops.left() > 0 { // Pop the first buffer. first := p.data.Front() if first == nil { @@ -209,10 +219,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) } // Copy user data. - n, err := dst.CopyOutFrom(ctx, first) + n, err := ops.read(first) done += int64(n) p.size -= n - dst = dst.DropFirst64(n) // Empty buffer? if first.Empty() { @@ -230,12 +239,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) return done, nil } +// dup duplicates all data from this pipe into the given writer. +// +// There is no blocking behavior implemented here. The writer may propagate +// some blocking error. All the writes must be complete writes. +func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Is the pipe empty? + if p.size == 0 { + if !p.HasWriters() { + // See above. + return 0, nil + } + return 0, syserror.ErrWouldBlock + } + + // Limit how much we consume. + if ops.left() > p.size { + ops.limit(p.size) + } + + done := int64(0) + for buf := p.data.Front(); buf != nil; buf = buf.Next() { + n, err := ops.read(buf) + done += n + if err != nil { + return done, err + } + } + + return done, nil +} + +type writeOps struct { + // left returns the bytes remaining. + left func() int64 + + // limit should limit subsequent writes. + limit func(int64) + + // write should write to the provided buffer. + write func(*buffer) (int64, error) +} + // write writes data from sv into the pipe and returns the number of bytes // written. If no bytes are written because the pipe is full (or has less than // atomicIOBytes free capacity), write returns ErrWouldBlock. // // Precondition: this pipe must have writers. -func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) { +func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { p.mu.Lock() defer p.mu.Unlock() @@ -246,17 +300,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. - wanted := src.NumBytes() + wanted := ops.left() if avail := p.max - p.size; wanted > avail { if wanted <= p.atomicIOBytes { return 0, syserror.ErrWouldBlock } - // Limit to the available capacity. - src = src.TakeFirst64(avail) + ops.limit(avail) } done := int64(0) - for src.NumBytes() > 0 { + for ops.left() > 0 { // Need a new buffer? last := p.data.Back() if last == nil || last.Full() { @@ -266,10 +319,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) } // Copy user data. - n, err := src.CopyInTo(ctx, last) + n, err := ops.write(last) done += int64(n) p.size += n - src = src.DropFirst64(n) // Handle errors. if err != nil { diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index f69dbf27b..7c307f013 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -15,6 +15,7 @@ package pipe import ( + "io" "math" "syscall" @@ -55,7 +56,45 @@ func (rw *ReaderWriter) Release() { // Read implements fs.FileOperations.Read. func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.read(ctx, dst) + n, err := rw.Pipe.read(ctx, readOps{ + left: func() int64 { + return dst.NumBytes() + }, + limit: func(l int64) { + dst = dst.TakeFirst64(l) + }, + read: func(buf *buffer) (int64, error) { + n, err := dst.CopyOutFrom(ctx, buf) + dst = dst.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventOut) + } + return n, err +} + +// WriteTo implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) { + ops := readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(buf *buffer) (int64, error) { + n, err := buf.ReadToWriter(w, count, dup) + count -= n + return n, err + }, + } + if dup { + // There is no notification for dup operations. + return rw.Pipe.dup(ctx, ops) + } + n, err := rw.Pipe.read(ctx, ops) if n > 0 { rw.Pipe.Notify(waiter.EventOut) } @@ -64,7 +103,40 @@ func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ // Write implements fs.FileOperations.Write. func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.write(ctx, src) + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return src.NumBytes() + }, + limit: func(l int64) { + src = src.TakeFirst64(l) + }, + write: func(buf *buffer) (int64, error) { + n, err := src.CopyInTo(ctx, buf) + src = src.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + rw.Pipe.Notify(waiter.EventIn) + } + return n, err +} + +// ReadFrom implements fs.FileOperations.WriteTo. +func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + n, err := rw.Pipe.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(buf *buffer) (int64, error) { + n, err := buf.WriteFromReader(r, count) + count -= n + return n, err + }, + }) if n > 0 { rw.Pipe.Notify(waiter.EventIn) } diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD index 1725b8562..98ea7a0d8 100644 --- a/pkg/sentry/kernel/sched/BUILD +++ b/pkg/sentry/kernel/sched/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 36edf10f3..80e5e5da3 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "waiter_list", diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 81fcd8258..047b5214d 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -47,6 +47,11 @@ type Session struct { // The id is immutable. id SessionID + // foreground is the foreground process group. + // + // This is protected by TaskSet.mu. + foreground *ProcessGroup + // ProcessGroups is a list of process groups in this Session. This is // protected by TaskSet.mu. processGroups processGroupList @@ -260,12 +265,14 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error { func (tg *ThreadGroup) CreateSession() error { tg.pidns.owner.mu.Lock() defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() return tg.createSession() } // createSession creates a new session for a threadgroup. // -// Precondition: callers must hold TaskSet.mu for writing. +// Precondition: callers must hold TaskSet.mu and the signal mutex for writing. func (tg *ThreadGroup) createSession() error { // Get the ID for this thread in the current namespace. id := tg.pidns.tgids[tg] @@ -321,8 +328,14 @@ func (tg *ThreadGroup) createSession() error { childTG.processGroup.incRefWithParent(pg) childTG.processGroup.decRefWithParent(oldParentPG) }) - tg.processGroup.decRefWithParent(oldParentPG) + // If tg.processGroup is an orphan, decRefWithParent will lock + // the signal mutex of each thread group in tg.processGroup. + // However, tg's signal mutex may already be locked at this + // point. We change tg's process group before calling + // decRefWithParent to avoid locking tg's signal mutex twice. + oldPG := tg.processGroup tg.processGroup = pg + oldPG.decRefWithParent(oldParentPG) } else { // The current process group may be nil only in the case of an // unparented thread group (i.e. the init process). This would @@ -346,6 +359,9 @@ func (tg *ThreadGroup) createSession() error { ns.processGroups[ProcessGroupID(local)] = pg } + // Disconnect from the controlling terminal. + tg.tty = nil + return nil } diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD new file mode 100644 index 000000000..50b69d154 --- /dev/null +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -0,0 +1,22 @@ +package(licenses = ["notice"]) + +load("//tools/go_stateify:defs.bzl", "go_library") + +go_library( + name = "signalfd", + srcs = ["signalfd.go"], + importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/sentry/context", + "//pkg/sentry/fs", + "//pkg/sentry/fs/anon", + "//pkg/sentry/fs/fsutil", + "//pkg/sentry/kernel", + "//pkg/sentry/usermem", + "//pkg/syserror", + "//pkg/waiter", + ], +) diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go new file mode 100644 index 000000000..06fd5ec88 --- /dev/null +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -0,0 +1,137 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package signalfd provides an implementation of signal file descriptors. +package signalfd + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fs/anon" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// SignalOperations represent a file with signalfd semantics. +// +// +stateify savable +type SignalOperations struct { + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FilePipeSeek `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoFsync `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoWrite `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + // target is the original task target. + // + // The semantics here are a bit broken. Linux will always use current + // for all reads, regardless of where the signalfd originated. We can't + // do exactly that because we need to plumb the context through + // EventRegister in order to support proper blocking behavior. This + // will undoubtedly become very complicated quickly. + target *kernel.Task + + // mu protects below. + mu sync.Mutex `state:"nosave"` + + // mask is the signal mask. Protected by mu. + mask linux.SignalSet +} + +// New creates a new signalfd object with the supplied mask. +func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // No task context? Not valid. + return nil, syserror.EINVAL + } + // name matches fs/signalfd.c:signalfd4. + dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]") + return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &SignalOperations{ + target: t, + mask: mask, + }), nil +} + +// Release implements fs.FileOperations.Release. +func (s *SignalOperations) Release() {} + +// Mask returns the signal mask. +func (s *SignalOperations) Mask() linux.SignalSet { + s.mu.Lock() + mask := s.mask + s.mu.Unlock() + return mask +} + +// SetMask sets the signal mask. +func (s *SignalOperations) SetMask(mask linux.SignalSet) { + s.mu.Lock() + s.mask = mask + s.mu.Unlock() +} + +// Read implements fs.FileOperations.Read. +func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { + // Attempt to dequeue relevant signals. + info, err := s.target.Sigtimedwait(s.Mask(), 0) + if err != nil { + // There must be no signal available. + return 0, syserror.ErrWouldBlock + } + + // Copy out the signal info using the specified format. + var buf [128]byte + binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + Signo: uint32(info.Signo), + Errno: info.Errno, + Code: info.Code, + PID: uint32(info.Pid()), + UID: uint32(info.Uid()), + Status: info.Status(), + Overrun: uint32(info.Overrun()), + Addr: info.Addr(), + }) + n, err := dst.CopyOut(ctx, buf[:]) + return int64(n), err +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return mask & waiter.EventIn +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SignalOperations) EventRegister(entry *waiter.Entry, _ waiter.EventMask) { + // Register for the signal set; ignore the passed events. + s.target.SignalRegister(entry, waiter.EventMask(s.Mask())) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SignalOperations) EventUnregister(entry *waiter.Entry) { + // Unregister the original entry. + s.target.SignalUnregister(entry) +} diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index e91f82bb3..c82ef5486 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/third_party/gvsync" ) @@ -133,6 +134,13 @@ type Task struct { // signalStack is exclusive to the task goroutine. signalStack arch.SignalStack + // signalQueue is a set of registered waiters for signal-related events. + // + // signalQueue is protected by the signalMutex. Note that the task does + // not implement all queue methods, specifically the readiness checks. + // The task only broadcast a notification on signal delivery. + signalQueue waiter.Queue `state:"zerovalue"` + // If groupStopPending is true, the task should participate in a group // stop in the interrupt path. // diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 266959a07..39cd1340d 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -28,6 +28,7 @@ import ( ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" ) // SignalAction is an internal signal action. @@ -497,6 +498,9 @@ func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) { // // Preconditions: The signal mutex must be locked. func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool { + // Notify that the signal is queued. + t.signalQueue.Notify(waiter.EventMask(linux.MakeSignalSet(sig))) + // - Do not choose tasks that are blocking the signal. if linux.SignalSetOf(sig)&t.signalMask != 0 { return false @@ -1108,3 +1112,17 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { t.tg.signalHandlers.mu.Unlock() return t.deliverSignal(info, act) } + +// SignalRegister registers a waiter for pending signals. +func (t *Task) SignalRegister(e *waiter.Entry, mask waiter.EventMask) { + t.tg.signalHandlers.mu.Lock() + t.signalQueue.EventRegister(e, mask) + t.tg.signalHandlers.mu.Unlock() +} + +// SignalUnregister unregisters a waiter for pending signals. +func (t *Task) SignalUnregister(e *waiter.Entry) { + t.tg.signalHandlers.mu.Lock() + t.signalQueue.EventUnregister(e) + t.tg.signalHandlers.mu.Unlock() +} diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index d60cd62c7..ae6fc4025 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -172,9 +172,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { if parentPG := tg.parentPG(); parentPG == nil { tg.createSession() } else { - // Inherit the process group. + // Inherit the process group and terminal. parentPG.incRefWithParent(parentPG) tg.processGroup = parentPG + tg.tty = t.parent.tg.tty } } tg.tasks.PushBack(t) diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 2a97e3e8e..0eef24bfb 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -19,10 +19,13 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/syserror" ) // A ThreadGroup is a logical grouping of tasks that has widespread @@ -245,6 +248,12 @@ type ThreadGroup struct { // // mounts is immutable. mounts *fs.MountNamespace + + // tty is the thread group's controlling terminal. If nil, there is no + // controlling terminal. + // + // tty is protected by the signal mutex. + tty *TTY } // newThreadGroup returns a new, empty thread group in PID namespace ns. The @@ -324,6 +333,176 @@ func (tg *ThreadGroup) forEachChildThreadGroupLocked(fn func(*ThreadGroup)) { } } +// SetControllingTTY sets tty as the controlling terminal of tg. +func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error { + tty.mu.Lock() + defer tty.mu.Unlock() + + // We might be asked to set the controlling terminal of multiple + // processes, so we lock both the TaskSet and SignalHandlers. + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // "The calling process must be a session leader and not have a + // controlling terminal already." - tty_ioctl(4) + if tg.processGroup.session.leader != tg || tg.tty != nil { + return syserror.EINVAL + } + + // "If this terminal is already the controlling terminal of a different + // session group, then the ioctl fails with EPERM, unless the caller + // has the CAP_SYS_ADMIN capability and arg equals 1, in which case the + // terminal is stolen, and all processes that had it as controlling + // terminal lose it." - tty_ioctl(4) + if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session { + if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 { + return syserror.EPERM + } + // Steal the TTY away. Unlike TIOCNOTTY, don't send signals. + for othertg := range tg.pidns.owner.Root.tgids { + // This won't deadlock by locking tg.signalHandlers + // because at this point: + // - We only lock signalHandlers if it's in the same + // session as the tty's controlling thread group. + // - We know that the calling thread group is not in + // the same session as the tty's controlling thread + // group. + if othertg.processGroup.session == tty.tg.processGroup.session { + othertg.signalHandlers.mu.Lock() + othertg.tty = nil + othertg.signalHandlers.mu.Unlock() + } + } + } + + // Set the controlling terminal and foreground process group. + tg.tty = tty + tg.processGroup.session.foreground = tg.processGroup + // Set this as the controlling process of the terminal. + tty.tg = tg + + return nil +} + +// ReleaseControllingTTY gives up tty as the controlling tty of tg. +func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error { + tty.mu.Lock() + defer tty.mu.Unlock() + + // We might be asked to set the controlling terminal of multiple + // processes, so we lock both the TaskSet and SignalHandlers. + tg.pidns.owner.mu.RLock() + defer tg.pidns.owner.mu.RUnlock() + + // Just below, we may re-lock signalHandlers in order to send signals. + // Thus we can't defer Unlock here. + tg.signalHandlers.mu.Lock() + + if tg.tty == nil || tg.tty != tty { + tg.signalHandlers.mu.Unlock() + return syserror.ENOTTY + } + + // "If the process was session leader, then send SIGHUP and SIGCONT to + // the foreground process group and all processes in the current + // session lose their controlling terminal." - tty_ioctl(4) + // Remove tty as the controlling tty for each process in the session, + // then send them SIGHUP and SIGCONT. + + // If we're not the session leader, we don't have to do much. + if tty.tg != tg { + tg.tty = nil + tg.signalHandlers.mu.Unlock() + return nil + } + + tg.signalHandlers.mu.Unlock() + + // We're the session leader. SIGHUP and SIGCONT the foreground process + // group and remove all controlling terminals in the session. + var lastErr error + for othertg := range tg.pidns.owner.Root.tgids { + if othertg.processGroup.session == tg.processGroup.session { + othertg.signalHandlers.mu.Lock() + othertg.tty = nil + if othertg.processGroup == tg.processGroup.session.foreground { + if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil { + lastErr = err + } + if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil { + lastErr = err + } + } + othertg.signalHandlers.mu.Unlock() + } + } + + return lastErr +} + +// ForegroundProcessGroup returns the process group ID of the foreground +// process group. +func (tg *ThreadGroup) ForegroundProcessGroup(tty *TTY) (int32, error) { + tty.mu.Lock() + defer tty.mu.Unlock() + + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // "When fd does not refer to the controlling terminal of the calling + // process, -1 is returned" - tcgetpgrp(3) + if tg.tty != tty { + return -1, syserror.ENOTTY + } + + return int32(tg.processGroup.session.foreground.id), nil +} + +// SetForegroundProcessGroup sets the foreground process group of tty to pgid. +func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) (int32, error) { + tty.mu.Lock() + defer tty.mu.Unlock() + + tg.pidns.owner.mu.Lock() + defer tg.pidns.owner.mu.Unlock() + tg.signalHandlers.mu.Lock() + defer tg.signalHandlers.mu.Unlock() + + // TODO(b/129283598): "If tcsetpgrp() is called by a member of a + // background process group in its session, and the calling process is + // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all + // members of this background process group." + + // tty must be the controlling terminal. + if tg.tty != tty { + return -1, syserror.ENOTTY + } + + // pgid must be positive. + if pgid < 0 { + return -1, syserror.EINVAL + } + + // pg must not be empty. Empty process groups are removed from their + // pid namespaces. + pg, ok := tg.pidns.processGroups[pgid] + if !ok { + return -1, syserror.ESRCH + } + + // pg must be part of this process's session. + if tg.processGroup.session != pg.session { + return -1, syserror.EPERM + } + + tg.processGroup.session.foreground.id = pgid + return 0, nil +} + // itimerRealListener implements ktime.Listener for ITIMER_REAL expirations. // // +stateify savable diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go new file mode 100644 index 000000000..34f84487a --- /dev/null +++ b/pkg/sentry/kernel/tty.go @@ -0,0 +1,28 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import "sync" + +// TTY defines the relationship between a thread group and its controlling +// terminal. +// +// +stateify savable +type TTY struct { + mu sync.Mutex `state:"nosave"` + + // tg is protected by mu. + tg *ThreadGroup +} diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 40025d62d..59649c770 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "limits", diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 29c14ec56..9687e7e76 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "mappable_range", diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 072745a08..b35c8c673 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "file_refcount_set", diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 858f895f2..3fd904c67 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "evictable_range", diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index eeb634644..b6d008dbe 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index fe979dccf..31fa48ec5 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go index 47957bb3b..72c7ec564 100644 --- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go +++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go @@ -154,3 +154,19 @@ func (t *thread) clone() (*thread, error) { cpu: ^uint32(0), }, nil } + +// getEventMessage retrieves a message about the ptrace event that just happened. +func (t *thread) getEventMessage() (uintptr, error) { + var msg uintptr + _, _, errno := syscall.RawSyscall6( + syscall.SYS_PTRACE, + syscall.PTRACE_GETEVENTMSG, + uintptr(t.tid), + 0, + uintptr(unsafe.Pointer(&msg)), + 0, 0) + if errno != 0 { + return msg, errno + } + return msg, nil +} diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 6bf7cd097..4f8f9c5d9 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -355,7 +355,8 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal { } if stopSig == syscall.SIGTRAP { if status.TrapCause() == syscall.PTRACE_EVENT_EXIT { - t.dumpAndPanic("wait failed: the process exited") + msg, err := t.getEventMessage() + t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) } // Re-encode the trap cause the way it's expected. return stopSig | syscall.Signal(status.TrapCause()<<8) @@ -426,6 +427,9 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) { break } else { // Some other signal caused a thread stop; ignore. + if sig != syscall.SIGSTOP && sig != syscall.SIGCHLD { + log.Warningf("The thread %d:%d has been interrupted by %d", t.tgid, t.tid, sig) + } continue } } diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 3b95af617..ea090b686 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD index 924d8a6d6..6769cd0a5 100644 --- a/pkg/sentry/platform/safecopy/BUILD +++ b/pkg/sentry/platform/safecopy/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD index fd6dc8e6e..884020f7b 100644 --- a/pkg/sentry/safemem/BUILD +++ b/pkg/sentry/safemem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 0e37ce61b..3e66f9cbb 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -26,6 +26,7 @@ package epsocket import ( "bytes" + "io" "math" "reflect" "sync" @@ -208,6 +209,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and + // transport.Endpoint.SetSockOptInt. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error @@ -227,7 +232,6 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout *waiter.Queue @@ -412,17 +416,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return int64(n), nil } -// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand -// based on the requested size. +// WriteTo implements fs.FileOperations.WriteTo. +func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { + s.readMu.Lock() + + // Copy as much data as possible. + done := int64(0) + for count > 0 { + // This may return a blocking error. + if err := s.fetchReadView(); err != nil { + s.readMu.Unlock() + return done, err.ToError() + } + + // Write to the underlying file. + n, err := dst.Write(s.readView) + done += int64(n) + count -= int64(n) + if dup { + // That's all we support for dup. This is generally + // supported by any Linux system calls, but the + // expectation is that now a caller will call read to + // actually remove these bytes from the socket. + break + } + + // Drop that part of the view. + s.readView.TrimFront(n) + if err != nil { + s.readMu.Unlock() + return done, err + } + } + + s.readMu.Unlock() + return done, nil +} + +// ioSequencePayload implements tcpip.Payload. +// +// t copies user memory bytes on demand based on the requested size. type ioSequencePayload struct { ctx context.Context src usermem.IOSequence } -// Get implements tcpip.Payload. -func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { - if size > i.Size() { - size = i.Size() +// FullPayload implements tcpip.Payloader.FullPayload +func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { + return i.Payload(int(i.src.NumBytes())) +} + +// Payload implements tcpip.Payloader.Payload. +func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { + if max := int(i.src.NumBytes()); size > max { + size = max } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { @@ -431,11 +478,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { return v, nil } -// Size implements tcpip.Payload. -func (i *ioSequencePayload) Size() int { - return int(i.src.NumBytes()) -} - // DropFirst drops the first n bytes from underlying src. func (i *ioSequencePayload) DropFirst(n int) { i.src = i.src.DropFirst(int(n)) @@ -469,6 +511,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), nil } +// readerPayload implements tcpip.Payloader. +// +// It allocates a view and reads from a reader on-demand, based on available +// capacity in the endpoint. +type readerPayload struct { + ctx context.Context + r io.Reader + count int64 + err error +} + +// FullPayload implements tcpip.Payloader.FullPayload. +func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { + return r.Payload(int(r.count)) +} + +// Payload implements tcpip.Payloader.Payload. +func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { + if size > int(r.count) { + size = int(r.count) + } + v := buffer.NewView(size) + n, err := r.r.Read(v) + if n > 0 { + // We ignore the error here. It may re-occur on subsequent + // reads, but for now we can enqueue some amount of data. + r.count -= int64(n) + return v[:n], nil + } + if err == syserror.ErrWouldBlock { + return nil, tcpip.ErrWouldBlock + } else if err != nil { + r.err = err // Save for propation. + return nil, tcpip.ErrBadAddress + } + + // There is no data and no error. Return an error, which will propagate + // r.err, which will be nil. This is the desired result: (0, nil). + return nil, tcpip.ErrBadAddress +} + +// ReadFrom implements fs.FileOperations.ReadFrom. +func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + f := &readerPayload{ctx: ctx, r: r, count: count} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + // Reads may be destructive but should be very fast, + // so we can't release the lock while copying data. + Atomic: true, + }) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + t := ctx.(*kernel.Task) + if err := t.Block(resCh); err != nil { + return 0, syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ + Atomic: true, // See above. + }) + } + if err == tcpip.ErrWouldBlock { + return n, syserror.ErrWouldBlock + } else if err != nil { + return int64(n), f.err // Propagate error. + } + + return int64(n), nil +} + // Readiness returns a mask of ready events for socket s. func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) @@ -777,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -793,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1165,7 +1279,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1173,7 +1287,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -2060,7 +2174,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= int64(v.Size()) || dontWait) { + if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. return int(n), nil } @@ -2085,7 +2199,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.TranslateNetstackError(err) } - if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } @@ -2207,9 +2321,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCOUTQ: - var v tcpip.SendQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 9e2e12799..445080aa4 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "port", diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 5061dcbde..3a6baa308 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -49,6 +50,14 @@ proto_library( ], ) +cc_proto_library( + name = "syscall_rpc_cc_proto", + visibility = [ + "//visibility:public", + ], + deps = [":syscall_rpc_proto"], +) + go_proto_library( name = "syscall_rpc_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2b0ad6395..1867b3a5c 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,6 +175,10 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has + // the int type. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error @@ -838,6 +842,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -853,65 +861,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrQueueSizeNotSupported } return v, nil - default: - return -1, tcpip.ErrUnknownProtocolOption - } -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - case *tcpip.SendQueueSizeOption: + case tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + v := e.connected.SendQueuedSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported - } - *o = qs - return nil - - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - return nil + return int(v), nil - case *tcpip.SendBufferSizeOption: + case tcpip.SendBufferSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize()) + v := e.connected.SendMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - *o = qs - return nil + return int(v), nil - case *tcpip.ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize()) + v := e.receiver.RecvMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported + } + return int(v), nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.PasscredOption: + if e.Passcred() { + *o = tcpip.PasscredOption(1) + } else { + *o = tcpip.PasscredOption(0) } - *o = qs return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 445d25010..7d7b42eba 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -44,6 +45,12 @@ proto_library( visibility = ["//visibility:public"], ) +cc_proto_library( + name = "strace_cc_proto", + visibility = ["//visibility:public"], + deps = [":strace_proto"], +) + go_proto_library( name = "strace_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto", diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go index 3650fd6e1..5d57b75af 100644 --- a/pkg/sentry/strace/linux64.go +++ b/pkg/sentry/strace/linux64.go @@ -335,4 +335,5 @@ var linuxAMD64 = SyscallMap{ 315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex), 316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex), 317: makeSyscallInfo("seccomp", Hex, Hex, Hex), + 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex), } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 33a40b9c6..e76ee27d2 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -74,6 +74,7 @@ go_library( "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/sched", "//pkg/sentry/kernel/shm", + "//pkg/sentry/kernel/signalfd", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/memmap", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ed996ba51..18d24ab61 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -320,21 +320,21 @@ var AMD64 = &kernel.SyscallTable{ 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 275: syscalls.Supported("splice", Splice), + 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) 280: syscalls.Supported("utimensat", Utimensat), 281: syscalls.Supported("epoll_pwait", EpollPwait), - 282: syscalls.ErrorWithEvent("signalfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) + 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), 283: syscalls.Supported("timerfd_create", TimerfdCreate), 284: syscalls.Supported("eventfd", Eventfd), 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), 286: syscalls.Supported("timerfd_settime", TimerfdSettime), 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), 288: syscalls.Supported("accept4", Accept4), - 289: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426) + 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), 290: syscalls.Supported("eventfd2", Eventfd2), 291: syscalls.Supported("epoll_create1", EpollCreate1), 292: syscalls.Supported("dup3", Dup3), diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 2e00a91ce..b9a8e3e21 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -1423,9 +1423,6 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { if err != nil { return err } - if dirPath { - return syserror.ENOENT - } return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error { if !fs.IsDir(d.Inode.StableAttr) { @@ -1436,7 +1433,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { return err } - return d.Remove(t, root, name) + return d.Remove(t, root, name, dirPath) }) } diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 0104a94c0..fb6efd5d8 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -20,7 +20,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd" + "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -506,3 +509,77 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?") return 0, nil, syserror.EINTR } + +// sharedSignalfd is shared between the two calls. +func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) { + // Copy in the signal mask. + mask, err := copyInSigSet(t, sigset, sigsetsize) + if err != nil { + return 0, nil, err + } + + // Always check for valid flags, even if not creating. + if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 { + return 0, nil, syserror.EINVAL + } + + // Is this a change to an existing signalfd? + // + // The spec indicates that this should adjust the mask. + if fd != -1 { + file := t.GetFile(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Is this a signalfd? + if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok { + s.SetMask(mask) + return 0, nil, nil + } + + // Not a signalfd. + return 0, nil, syserror.EINVAL + } + + // Create a new file. + file, err := signalfd.New(t, mask) + if err != nil { + return 0, nil, err + } + defer file.DecRef() + + // Set appropriate flags. + file.SetFlags(fs.SettableFileFlags{ + NonBlocking: flags&linux.SFD_NONBLOCK != 0, + }) + + // Create a new descriptor. + fd, err = t.NewFDFrom(0, file, kernel.FDFlags{ + CloseOnExec: flags&linux.SFD_CLOEXEC != 0, + }) + if err != nil { + return 0, nil, err + } + + // Done. + return uintptr(fd), nil, nil +} + +// Signalfd implements the linux syscall signalfd(2). +func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + sigset := args[1].Pointer() + sigsetsize := args[2].SizeT() + return sharedSignalfd(t, fd, sigset, sigsetsize, 0) +} + +// Signalfd4 implements the linux syscall signalfd4(2). +func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + sigset := args[1].Pointer() + sigsetsize := args[2].SizeT() + flags := args[3].Int() + return sharedSignalfd(t, fd, sigset, sigsetsize, flags) +} diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 8a98fedcb..f0a292f2f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB total int64 n int64 err error - ch chan struct{} - inW bool - outW bool + inCh chan struct{} + outCh chan struct{} ) for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) @@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB break } - // Are we a registered waiter? - if ch == nil { - ch = make(chan struct{}, 1) - } - if !inW && !inFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - inFile.EventRegister(&w, EventMaskRead) - defer inFile.EventUnregister(&w) - inW = true // Registered. - } else if !outW && !outFile.Flags().NonBlocking { - w, _ := waiter.NewChannelEntry(ch) - outFile.EventRegister(&w, EventMaskWrite) - defer outFile.EventUnregister(&w) - outW = true // Registered. - } - - // Was anything registered? If no, everything is non-blocking. - if !inW && !outW { - break - } - - if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) { - // Something became ready, try again without blocking. - continue + // Note that the blocking behavior here is a bit different than the + // normal pattern. Because we need to have both data to read and data + // to write simultaneously, we actually explicitly block on both of + // these cases in turn before returning to the splice operation. + if inFile.Readiness(EventMaskRead) == 0 { + if inCh == nil { + inCh = make(chan struct{}, 1) + inW, _ := waiter.NewChannelEntry(inCh) + inFile.EventRegister(&inW, EventMaskRead) + defer inFile.EventUnregister(&inW) + continue // Need to refresh readiness. + } + if err = t.Block(inCh); err != nil { + break + } } - - // Block until there's data. - if err = t.Block(ch); err != nil { - break + if outFile.Readiness(EventMaskWrite) == 0 { + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + continue // Need to refresh readiness. + } + if err = t.Block(outCh); err != nil { + break + } } } @@ -149,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Length: count, SrcOffset: true, SrcStart: offset, - }, false) + }, outFile.Flags().NonBlocking) // Copy out the new offset. if _, err := t.CopyOut(offsetAddr, n+offset); err != nil { @@ -159,7 +156,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Send data using splice. n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, - }, false) + }, outFile.Flags().NonBlocking) } // We can only pass a single file to handleIOError, so pick inFile @@ -181,12 +178,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. Note that unlike in Linux, this - // flag is applied consistently. We will have either fully blocking or - // non-blocking behavior below, regardless of the underlying files - // being spliced to. It's unclear if this is a bug or not yet. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -200,6 +191,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } defer inFile.DecRef() + // The operation is non-blocking if anything is non-blocking. + // + // N.B. This is a rather simplistic heuristic that avoids some + // poor edge case behavior since the exact semantics here are + // underspecified and vary between versions of Linux itself. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Construct our options. // // Note that exactly one of the underlying buffers must be a pipe. We @@ -257,7 +255,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Splice data. - n, err := doSplice(t, outFile, inFile, opts, nonBlocking) + n, err := doSplice(t, outFile, inFile, opts, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile) @@ -275,9 +273,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } - // Only non-blocking is meaningful. - nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0 - // Get files. outFile := t.GetFile(outFD) if outFile == nil { @@ -301,11 +296,14 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo return 0, nil, syserror.EINVAL } + // The operation is non-blocking if anything is non-blocking. + nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0) + // Splice data. n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, Dup: true, - }, nonBlocking) + }, nonBlock) // See above; inFile is chosen arbitrarily here. return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile) diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 8aa6a3017..beb43ba13 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD index b69603da3..fc7614fff 100644 --- a/pkg/sentry/unimpl/BUILD +++ b/pkg/sentry/unimpl/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -10,6 +11,12 @@ proto_library( deps = ["//pkg/sentry/arch:registers_proto"], ) +cc_proto_library( + name = "unimplemented_syscall_cc_proto", + visibility = ["//visibility:public"], + deps = [":unimplemented_syscall_proto"], +) + go_proto_library( name = "unimplemented_syscall_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto", diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD index a5b4206bb..cc5d25762 100644 --- a/pkg/sentry/usermem/BUILD +++ b/pkg/sentry/usermem/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "addr_range", diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 0f247bf77..eff4b44f6 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 86bde7fb3..7eb2b2821 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -199,8 +199,11 @@ type Dirent struct { // Ino is the inode number. Ino uint64 - // Off is this Dirent's offset. - Off int64 + // NextOff is the offset of the *next* Dirent in the directory; that is, + // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will + // cause the next call to FileDescription.IterDirents() to yield the next + // Dirent. (The offset of the first Dirent in a directory is always 0.) + NextOff int64 } // IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents. diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index 00665c939..bdca80d37 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/state/BUILD b/pkg/state/BUILD index c0f3c658d..329904457 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,5 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index e70f4a79f..8a865d229 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index b149f9e02..bd3f9fd28 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index df37c7d5a..3fd9e3134 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "tcpip", diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 0d2637ee4..78df5a0b1 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 3301967fb..b4e8d6810 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "buffer", diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index 29b30be9c..0c5c20cea 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 76ef02f13..b558350c3 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "header", diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 093850e25..9d3abc0e4 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -76,6 +76,13 @@ const ( // IPv6Version is the version of the ipv6 protocol. IPv6Version = 6 + // IPv6AllNodesMulticastAddress is a link-local multicast group that + // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all nodes on a link. + // + // The address is ff02::1. + IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 @@ -221,6 +228,24 @@ func IsV6MulticastAddress(addr tcpip.Address) bool { return addr[0] == 0xff } +// IsV6UnicastAddress determines if the provided address is a valid IPv6 +// unicast (and specified) address. That is, IsV6UnicastAddress returns +// true if addr contains IPv6AddressSize bytes, is not the unspecified +// address and is not a multicast address. +func IsV6UnicastAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + + // Must not be unspecified + if addr == IPv6Any { + return false + } + + // Return if not a multicast. + return addr[0] != 0xff +} + // SolicitedNodeAddr computes the solicited-node multicast address. This is // used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 // address. diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index c1f454805..74412c894 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -27,6 +27,11 @@ const ( udpChecksum = 6 ) +const ( + // UDPMaximumPacketSize is the largest possible UDP packet. + UDPMaximumPacketSize = 0xffff +) + // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c40744b8e..18adb2085 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -44,14 +44,12 @@ type Endpoint struct { } // New creates a new channel endpoint. -func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ +func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { + return &Endpoint{ C: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } - - return stack.RegisterLinkEndpoint(e), e } // Drain removes all outbound packets from the channel and counts them. @@ -135,3 +133,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen return nil } + +// Wait implements stack.LinkEndpoint.Wait. +func (*Endpoint) Wait() {} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 74fbbb896..8fa9e3984 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 77f988b9f..584db710e 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,6 +41,7 @@ package fdbased import ( "fmt" + "sync" "syscall" "golang.org/x/sys/unix" @@ -81,6 +82,7 @@ const ( PacketMMap ) +// An endpoint implements the link-layer using a message-oriented file descriptor. type endpoint struct { // fds is the set of file descriptors each identifying one inbound/outbound // channel. The endpoint will dispatch from all inbound channels as well as @@ -114,6 +116,9 @@ type endpoint struct { // gsoMaxSize is the maximum GSO packet size. It is zero if GSO is // disabled. gsoMaxSize uint32 + + // wg keeps track of running goroutines. + wg sync.WaitGroup } // Options specify the details about the fd-based endpoint to be created. @@ -164,8 +169,9 @@ type Options struct { // New creates a new fd-based endpoint. // // Makes fd non-blocking, but does not take ownership of fd, which must remain -// open for the lifetime of the returned endpoint. -func New(opts *Options) (tcpip.LinkEndpointID, error) { +// open for the lifetime of the returned endpoint (until after the endpoint has +// stopped being using and Wait returns). +func New(opts *Options) (stack.LinkEndpoint, error) { caps := stack.LinkEndpointCapabilities(0) if opts.RXChecksumOffload { caps |= stack.CapabilityRXChecksumOffload @@ -190,7 +196,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } if len(opts.FDs) == 0 { - return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") } e := &endpoint{ @@ -207,12 +213,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { for i := 0; i < len(e.fds); i++ { fd := e.fds[i] if err := syscall.SetNonblock(fd, true); err != nil { - return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) + return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err) } isSocket, err := isSocketFD(fd) if err != nil { - return 0, err + return nil, err } if isSocket { if opts.GSOMaxSize != 0 { @@ -222,12 +228,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) { } inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket) if err != nil { - return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err) + return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err) } e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) { @@ -290,7 +296,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { // saved, they stop sending outgoing packets and all incoming packets // are rejected. for i := range e.inboundDispatchers { - go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above. + e.wg.Add(1) + go func(i int) { // S/R-SAFE: See above. + e.dispatchLoop(e.inboundDispatchers[i]) + e.wg.Done() + }(i) } } @@ -320,6 +330,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } +// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop +// reading from its FD. +func (e *endpoint) Wait() { + e.wg.Wait() +} + // virtioNetHdr is declared in linux/virtio_net.h. type virtioNetHdr struct { flags uint8 @@ -435,14 +451,12 @@ func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf } // NewInjectable creates a new fd-based InjectableEndpoint. -func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) { +func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint { syscall.SetNonblock(fd, true) - e := &InjectableEndpoint{endpoint: endpoint{ + return &InjectableEndpoint{endpoint: endpoint{ fds: []int{fd}, mtu: mtu, caps: capabilities, }} - - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index e305252d6..04406bc9a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context { } opt.FDs = []int{fds[1]} - epID, err := New(opt) + ep, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) } - ep := stack.FindLinkEndpoint(epID).(*endpoint) c := &context{ t: t, diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ab6a53988..b36629d2c 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -32,8 +32,8 @@ type endpoint struct { // New creates a new loopback endpoint. This link-layer endpoint just turns // outbound packets into inbound packets. -func New() tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{}) +func New() stack.LinkEndpoint { + return &endpoint{} } // Attach implements stack.LinkEndpoint.Attach. It just saves the stack network- @@ -85,3 +85,6 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa return nil } + +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index ea12ef1ac..1bab380b0 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index a577a3d52..7c946101d 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -104,10 +104,16 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) * return endpoint.WriteRawPacket(dest, packet) } +// Wait implements stack.LinkEndpoint.Wait. +func (m *InjectableEndpoint) Wait() { + for _, ep := range m.routes { + ep.Wait() + } +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. -func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) { - e := &InjectableEndpoint{ +func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { + return &InjectableEndpoint{ routes: routes, } - return stack.RegisterLinkEndpoint(e), e } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 174b9330f..3086fec00 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -87,8 +87,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc if err != nil { t.Fatal("Failed to create socket pair:", err) } - _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) + underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone) routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint} - _, endpoint := NewInjectableEndpoint(routes) + endpoint := NewInjectableEndpoint(routes) return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP } diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index f2998aa98..0a5ea3dc4 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 94725cb11..330ed5e94 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 160a8f864..de1ce043d 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 834ea5c40..9e71d4edf 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -94,7 +94,7 @@ type endpoint struct { // New creates a new shared-memory-based endpoint. Buffers will be broken up // into buffers of "bufferSize" bytes. -func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) { +func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) { e := &endpoint{ mtu: mtu, bufferSize: bufferSize, @@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc } if err := e.tx.init(bufferSize, &tx); err != nil { - return 0, err + return nil, err } if err := e.rx.init(bufferSize, &rx); err != nil { e.tx.cleanup() - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(e), nil + return e, nil } // Close frees all resources associated with the endpoint. @@ -132,7 +132,8 @@ func (e *endpoint) Close() { } } -// Wait waits until all workers have stopped after a Close() call. +// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have +// stopped after a Close() call. func (e *endpoint) Wait() { e.completed.Wait() } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 98036f367..0e9ba0846 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -119,12 +119,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress initQueue(t, &c.txq, &c.txCfg) initQueue(t, &c.rxq, &c.rxCfg) - id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) if err != nil { t.Fatalf("New failed: %v", err) } - c.ep = stack.FindLinkEndpoint(id).(*endpoint) + c.ep = ep.(*endpoint) c.ep.Attach(c) return c diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 36c8c46fc..e401dce44 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -58,10 +58,10 @@ type endpoint struct { // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. -func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID { - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), - }) +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + return &endpoint{ + lower: lower, + } } func zoneOffset() (int32, error) { @@ -102,15 +102,15 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error { // snapLen is the maximum amount of a packet to be saved. Packets with a length // less than or equal too snapLen will be saved in their entirety. Longer // packets will be truncated to snapLen. -func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) { +func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) { if err := writePCAPHeader(file, snapLen); err != nil { - return 0, err + return nil, err } - return stack.RegisterLinkEndpoint(&endpoint{ - lower: stack.FindLinkEndpoint(lower), + return &endpoint{ + lower: lower, file: file, maxPCAPLen: snapLen, - }), nil + }, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is @@ -240,6 +240,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return e.lower.WritePacket(r, gso, hdr, payload, protocol) } +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} + func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 2597d4b3e..0746dc8ec 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 3b6ac2ff7..5a1791cb5 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -40,11 +40,10 @@ type Endpoint struct { // New creates a new waitable link-layer endpoint. It wraps around another // endpoint and allows the caller to block new write/dispatch calls and wait for // the inflight ones to finish before returning. -func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) { - e := &Endpoint{ - lower: stack.FindLinkEndpoint(lower), +func New(lower stack.LinkEndpoint) *Endpoint { + return &Endpoint{ + lower: lower, } - return stack.RegisterLinkEndpoint(e), e } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. @@ -121,3 +120,6 @@ func (e *Endpoint) WaitWrite() { func (e *Endpoint) WaitDispatch() { e.dispatchGate.Close() } + +// Wait implements stack.LinkEndpoint.Wait. +func (e *Endpoint) Wait() {} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 56e18ecb0..ae23c96b7 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -70,9 +70,12 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P return nil } +// Wait implements stack.LinkEndpoint.Wait. +func (*countedEndpoint) Wait() {} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Write and check that it goes through. wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) @@ -97,7 +100,7 @@ func TestWaitWrite(t *testing.T) { func TestWaitDispatch(t *testing.T) { ep := &countedEndpoint{} - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) // Check that attach happens. wep.Attach(ep) @@ -139,7 +142,7 @@ func TestOtherMethods(t *testing.T) { hdrLen: hdrLen, linkAddr: linkAddr, } - _, wep := New(stack.RegisterLinkEndpoint(ep)) + wep := New(ep) if v := wep.MTU(); v != mtu { t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index f36f49453..9d16ff8c9 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index d95d44f56..df0d3a8c0 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 4c4b54469..387fca96e 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -47,11 +47,13 @@ func newTestContext(t *testing.T) *testContext { s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) const defaultMTU = 65536 - id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(256, defaultMTU, stackLinkAddr) + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -73,7 +75,7 @@ func newTestContext(t *testing.T) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 118bfc763..c5c7aad86 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "reassembler_list", diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4b3bd74fa..6a40e7ee3 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -144,6 +144,9 @@ func (*testObject) LinkAddress() tcpip.LinkAddress { return "" } +// Wait implements stack.LinkEndpoint.Wait. +func (*testObject) Wait() {} + // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index be84fa63d..58e537aad 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 1b5a55bea..ae827ca27 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -36,11 +36,11 @@ func TestExcludeBroadcast(t *testing.T) { s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) const defaultMTU = 65536 - id, _ := channel.New(256, defaultMTU, "") + ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { - id = sniffer.New(id) + ep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -184,15 +184,12 @@ type errorChannel struct { // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket // will return successive errors from packetCollectorErrors until the list is // empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { - _, e := channel.New(size, mtu, linkAddr) - ec := errorChannel{ - Endpoint: e, +func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { + return &errorChannel{ + Endpoint: channel.New(size, mtu, linkAddr), Ch: make(chan packetInfo, size), packetCollectorErrors: packetCollectorErrors, } - - return stack.RegisterLinkEndpoint(e), &ec } // packetInfo holds all the information about an outbound packet. @@ -242,9 +239,8 @@ type context struct { func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) - _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - linkEPId := stack.RegisterLinkEndpoint(linkEP) - s.CreateNIC(1, linkEPId) + ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) + s.CreateNIC(1, ep) const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" @@ -266,7 +262,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 } return context{ Route: r, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index c71b69123..f06622a8b 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -25,6 +26,7 @@ go_test( size = "small", srcs = [ "icmp_test.go", + "ipv6_test.go", "ndp_test.go", ], embed = [":ipv6"], @@ -36,6 +38,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 227a65cf2..653d984e9 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -83,8 +83,7 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li func TestICMPCounts(t *testing.T) { s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { @@ -211,36 +210,27 @@ func newTestContext(t *testing.T) *testContext { } const defaultMTU = 65536 - _, linkEP0 := channel.New(256, defaultMTU, linkAddr0) - c.linkEP0 = linkEP0 - wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0} - id0 := stack.RegisterLinkEndpoint(wrappedEP0) + c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + + wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { - id0 = sniffer.New(id0) + wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, id0); err != nil { + if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { - t.Fatalf("AddAddress sn lladdr0: %v", err) - } - _, linkEP1 := channel.New(256, defaultMTU, linkAddr1) - c.linkEP1 = linkEP1 - wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1} - id1 := stack.RegisterLinkEndpoint(wrappedEP1) - if err := c.s1.CreateNIC(1, id1); err != nil { + c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) + if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { - t.Fatalf("AddAddress sn lladdr1: %v", err) - } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) if err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go new file mode 100644 index 000000000..57bcd5455 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -0,0 +1,258 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // The least significant 3 bytes are the same as addr2 so both addr2 and + // addr3 will have the same solicited-node address. + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" +) + +// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the +// expected Neighbor Advertisement received count after receiving the packet. +func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + // Receive ICMP packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stats := s.Stats().ICMP.V6PacketsReceived + + if got := stats.NeighborAdvert.Value(); got != want { + t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) + } +} + +// testReceiveICMP tests receiving a UDP packet from src to dst. want is the +// expected UDP received count after receiving the packet. +func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + + // Receive UDP Packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + + // UDP pseudo-header checksum. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) + + // UDP checksum + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + + stat := s.Stats().UDP.PacketsReceived + + if got := stat.Value(); got != want { + t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) + } +} + +// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and +// UDP packets destined to the IPv6 link-local all-nodes multicast address. +func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { + tests := []struct { + name string + protocolName string + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.ProtocolName6, testReceiveICMP}, + {"UDP", udp.ProtocolName, testReceiveUDP}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should receive a packet destined to the all-nodes + // multicast address. + test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) + }) + } +} + +// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP +// packets destined to the IPv6 solicited-node address of an assigned IPv6 +// address. +func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + tests := []struct { + name string + protocolName string + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.ProtocolName6, testReceiveICMP}, + {"UDP", udp.ProtocolName, testReceiveUDP}, + } + + snmc := header.SolicitedNodeAddr(addr2) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{}) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as we haven't added + // those addresses. + test.rxf(t, s, e, addr1, snmc, 0) + + if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + } + + // Should receive a packet destined to the solicited + // node address of addr2/addr3 now that we have added + // added addr2. + test.rxf(t, s, e, addr1, snmc, 1) + + if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have added addr3. + test.rxf(t, s, e, addr1, snmc, 2) + + if err := s.RemoveAddress(1, addr2); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have removed addr2. + test.rxf(t, s, e, addr1, snmc, 3) + + if err := s.RemoveAddress(1, addr3); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as both of them got + // removed. + test.rxf(t, s, e, addr1, snmc, 3) + }) + } +} + +// TestAddIpv6Address tests adding IPv6 addresses. +func TestAddIpv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + }{ + // This test is in response to b/140943433. + { + "Nil", + tcpip.Address([]byte(nil)), + }, + { + "ValidUnicast", + addr1, + }, + { + "ValidLinkLocalUnicast", + lladdr0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New([]string{ProtocolName}, nil, stack.Options{}) + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { + t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 8e4cf0e74..571915d3f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -32,15 +32,14 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Helper() s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) - { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } + + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) } + { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 989058413..11efb4e44 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index e2021cd15..f12189580 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -138,11 +138,11 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) + linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu}) if err != nil { log.Fatal(err) } - if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil { + if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 1716be285..329941775 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -128,7 +128,7 @@ func main() { log.Fatal(err) } - linkID, err := fdbased.New(&fdbased.Options{ + linkEP, err := fdbased.New(&fdbased.Options{ FDs: []int{fd}, MTU: mtu, EthernetHeader: *tap, @@ -137,7 +137,7 @@ func main() { if err != nil { log.Fatal(err) } - if err := s.CreateNIC(1, linkID); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { log.Fatal(err) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 788de3dfe..28c49e8ff 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "linkaddrentry_list", diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go index f8156be47..3a20839da 100644 --- a/pkg/tcpip/stack/icmp_rate_limit.go +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "golang.org/x/time/rate" ) @@ -33,54 +31,11 @@ const ( // ICMPRateLimiter is a global rate limiter that controls the generation of // ICMP messages generated by the stack. type ICMPRateLimiter struct { - mu sync.RWMutex - l *rate.Limiter + *rate.Limiter } // NewICMPRateLimiter returns a global rate limiter for controlling the rate // at which ICMP messages are generated by the stack. func NewICMPRateLimiter() *ICMPRateLimiter { - return &ICMPRateLimiter{l: rate.NewLimiter(icmpLimit, icmpBurst)} -} - -// Allow returns true if we are allowed to send at least 1 message at the -// moment. -func (i *ICMPRateLimiter) Allow() bool { - i.mu.RLock() - allow := i.l.Allow() - i.mu.RUnlock() - return allow -} - -// Limit returns the maximum number of ICMP messages that can be sent in one -// second. -func (i *ICMPRateLimiter) Limit() rate.Limit { - i.mu.RLock() - defer i.mu.RUnlock() - return i.l.Limit() -} - -// SetLimit sets the maximum number of ICMP messages that can be sent in one -// second. -func (i *ICMPRateLimiter) SetLimit(newLimit rate.Limit) { - i.mu.RLock() - defer i.mu.RUnlock() - i.l.SetLimit(newLimit) -} - -// Burst returns how many ICMP messages can be sent at any single instant. -func (i *ICMPRateLimiter) Burst() int { - i.mu.RLock() - defer i.mu.RUnlock() - return i.l.Burst() -} - -// SetBurst sets the maximum number of ICMP messages allowed at any single -// instant. -// -// NOTE: Changing Burst causes the underlying rate limiter to be recreated. -func (i *ICMPRateLimiter) SetBurst(burst int) { - i.mu.Lock() - i.l = rate.NewLimiter(i.l.Limit(), burst) - i.mu.Unlock() + return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 43719085e..a719058b4 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -102,6 +102,25 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback } } +// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// join the IPv6 All-Nodes Multicast address (ff02::1). +func (n *NIC) enable() *tcpip.Error { + n.attachLinkEndpoint() + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + } + + return nil +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -307,6 +326,8 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP address before adding them. + // Sanity check. id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if _, ok := n.endpoints[id]; ok { @@ -339,6 +360,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } + // If we are adding an IPv6 unicast address, join the solicited-node + // multicast address. + if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) { + snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) + if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { + return nil, err + } + } + n.endpoints[id] = ref l, ok := n.primary[protocolAddress.Protocol] @@ -467,13 +497,27 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || r.getKind() != permanent { + r, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } r.setKind(permanentExpired) - r.decRefLocked() + if !r.decRefLocked() { + // The endpoint still has references to it. + return nil + } + + // At this point the endpoint is deleted. + + // If we are removing an IPv6 unicast address, leave the solicited-node + // multicast address. + if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) { + snmc := header.SolicitedNodeAddr(addr) + if err := n.leaveGroupLocked(snmc); err != nil { + return err + } + } return nil } @@ -491,6 +535,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() + return n.joinGroupLocked(protocol, addr) +} + +// joinGroupLocked adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. n MUST be locked before +// joinGroupLocked is called. +func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -518,6 +569,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() + return n.leaveGroupLocked(addr) +} + +// leaveGroupLocked decrements the count for the given multicast address, and +// when it reaches zero removes the endpoint for this address. n MUST be locked +// before leaveGroupLocked is called. +func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -802,11 +860,14 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { +// locked. Returns true if the endpoint was removed. +func (r *referencedNetworkEndpoint) decRefLocked() bool { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) + return true } + + return false } // incRef increments the ref count. It must only be called when the caller is diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 67b70b2ee..07e4c770d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -297,6 +295,15 @@ type LinkEndpoint interface { // IsAttached returns whether a NetworkDispatcher is attached to the // endpoint. IsAttached() bool + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // For now, requesting that an endpoint's worker goroutine(s) stop is + // implementation specific. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -379,10 +386,6 @@ var ( networkProtocols = make(map[string]NetworkProtocolFactory) unassociatedFactory UnassociatedEndpointFactory - - linkEPMu sync.RWMutex - nextLinkEndpointID tcpip.LinkEndpointID = 1 - linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) ) // RegisterTransportProtocolFactory registers a new transport protocol factory @@ -406,28 +409,6 @@ func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { unassociatedFactory = f } -// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an -// ID that can be used to refer to it. -func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { - linkEPMu.Lock() - defer linkEPMu.Unlock() - - v := nextLinkEndpointID - nextLinkEndpointID++ - - linkEndpoints[v] = linkEP - - return v -} - -// FindLinkEndpoint finds the link endpoint associated with the given ID. -func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { - linkEPMu.RLock() - defer linkEPMu.RUnlock() - - return linkEndpoints[id] -} - // GSOType is the type of GSO segments. // // +stateify savable diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6beca6ae8..1fe21b68e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -620,12 +620,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error { - ep := FindLinkEndpoint(linkEP) - if ep == nil { - return tcpip.ErrBadLinkEndpoint - } - +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -638,40 +633,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint s.nics[id] = n if enabled { - n.attachLinkEndpoint() + return n.enable() } return nil } // CreateNIC creates a NIC with the provided id and link-layer endpoint. -func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, true, false) +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, true, false) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, false) +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, false) } // CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer // endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, true) +func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, false, false) +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, false, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, false, false) +func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, false, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -685,9 +680,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - nic.attachLinkEndpoint() - - return nil + return nic.enable() } // CheckNIC checks if a NIC is usable. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c6a8160af..0c26c9911 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct { prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - linkEP stack.LinkEndpoint + ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { @@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen + return f.ep.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -116,7 +116,7 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto } func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.linkEP.Capabilities() + return f.ep.Capabilities() } func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { @@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return nil } - return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) } func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { @@ -189,14 +189,14 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, - linkEP: linkEP, + ep: ep, }, nil } @@ -225,9 +225,9 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. - id, linkEP := channel.New(10, defaultMTU, "") + ep := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -245,7 +245,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -255,7 +255,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -265,7 +265,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -274,7 +274,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -284,7 +284,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -307,59 +307,59 @@ func send(r stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) } -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := sendTo(s, addr, payload); err != nil { t.Error("sendTo failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("sendTo packet count: got = %d, want %d", got, want) } } -func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := send(r, payload); err != nil { t.Error("send failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("send packet count: got = %d, want %d", got, want) } } -func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) } } -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we expect to receive it. want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we do NOT expect to receive it. want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -369,9 +369,9 @@ func TestNetworkSend(t *testing.T) { // Create a stack with the fake network protocol, one nic, and one // address: 1. The route table sends all packets through the only // existing nic. - id, linkEP := channel.New(10, defaultMTU, "") + ep := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -388,7 +388,7 @@ func TestNetworkSend(t *testing.T) { } // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", linkEP, nil) + testSendTo(t, s, "\x03", ep, nil) } func TestNetworkSendMultiRoute(t *testing.T) { @@ -397,8 +397,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { // even addresses. s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -410,8 +410,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -442,10 +442,10 @@ func TestNetworkSendMultiRoute(t *testing.T) { } // Send a packet to an odd destination. - testSendTo(t, s, "\x05", linkEP1, nil) + testSendTo(t, s, "\x05", ep1, nil) // Send a packet to an even destination. - testSendTo(t, s, "\x06", linkEP2, nil) + testSendTo(t, s, "\x06", ep2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -478,8 +478,8 @@ func TestRoutes(t *testing.T) { // even addresses. s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -491,8 +491,8 @@ func TestRoutes(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -556,8 +556,8 @@ func TestAddressRemoval(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -578,15 +578,15 @@ func TestAddressRemoval(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -601,9 +601,9 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatal("CreateNIC failed:", err) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) @@ -626,17 +626,17 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -690,8 +690,8 @@ func TestEndpointExpiration(t *testing.T) { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -724,15 +724,15 @@ func TestEndpointExpiration(t *testing.T) { //----------------------- verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 2. Add Address, everything should work. @@ -741,8 +741,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. @@ -752,15 +752,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 4. Add Address back, everything should work again. @@ -769,8 +769,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 5. Take a reference to the endpoint by getting a route. Verify that // we can still send/receive, including sending using the route. @@ -779,9 +779,9 @@ func TestEndpointExpiration(t *testing.T) { if err != nil { t.Fatal("FindRoute failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -791,16 +791,16 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 7. Add Address back, everything should work again. @@ -809,16 +809,16 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -828,15 +828,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } }) } @@ -846,8 +846,8 @@ func TestEndpointExpiration(t *testing.T) { func TestPromiscuousMode(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -867,13 +867,13 @@ func TestPromiscuousMode(t *testing.T) { // have a matching endpoint. const localAddrByte byte = 0x01 buf[0] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -886,7 +886,7 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestSpoofingWithAddress(t *testing.T) { @@ -896,8 +896,8 @@ func TestSpoofingWithAddress(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -936,8 +936,8 @@ func TestSpoofingWithAddress(t *testing.T) { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } // Sending a packet works. - testSendTo(t, s, dstAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testSendTo(t, s, dstAddr, ep, nil) + testSend(t, r, ep, nil) // FindRoute should also work with a local address that exists on the NIC. r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) @@ -951,7 +951,7 @@ func TestSpoofingWithAddress(t *testing.T) { t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. - testSend(t, r, linkEP, nil) + testSend(t, r, ep, nil) } func TestSpoofingNoAddress(t *testing.T) { @@ -960,8 +960,8 @@ func TestSpoofingNoAddress(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -980,7 +980,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -999,14 +999,14 @@ func TestSpoofingNoAddress(t *testing.T) { } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) @@ -1076,8 +1076,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1132,8 +1132,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1159,7 +1159,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { t.Fatal("AddAddressRange failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { @@ -1198,8 +1198,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1236,8 +1236,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1262,7 +1262,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { t.Fatal("AddAddressRange failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestNetworkOptions(t *testing.T) { @@ -1320,8 +1320,8 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S func TestAddresRangeAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1361,8 +1361,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } // Insert <canBe> primary and <never> never-primary addresses. @@ -1426,8 +1426,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1501,8 +1501,8 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1526,8 +1526,8 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1558,8 +1558,8 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1587,8 +1587,8 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1621,8 +1621,8 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { func TestNICStats(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { @@ -1639,7 +1639,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1653,9 +1653,9 @@ func TestNICStats(t *testing.T) { if err := sendTo(s, "\x01", payload); err != nil { t.Fatal("sendTo failed: ", err) } - want := uint64(linkEP1.Drain()) + want := uint64(ep1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) + t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { @@ -1669,16 +1669,16 @@ func TestNICForwarding(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) s.SetForwarding(true) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -1697,10 +1697,10 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) select { - case <-linkEP2.C: + case <-ep2.C: default: t.Fatal("Packet not forwarded") } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ca185279e..0e69ac7c8 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } @@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptInt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -278,9 +283,9 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } func TestTransportReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -340,9 +345,9 @@ func TestTransportReceive(t *testing.T) { } func TestTransportControlReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -408,9 +413,9 @@ func TestTransportControlReceive(t *testing.T) { } func TestTransportSend(t *testing.T) { - id, _ := channel.New(10, defaultMTU, "") + linkEP := channel.New(10, defaultMTU, "") s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -497,16 +502,16 @@ func TestTransportForwarding(t *testing.T) { s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. - id1 := loopback.New() - if err := s.CreateNIC(1, id1); err != nil { + ep1 := loopback.New() + if err := s.CreateNIC(1, ep1); err != nil { t.Fatalf("CreateNIC #1 failed: %v", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress #1 failed: %v", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatalf("CreateNIC #2 failed: %v", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -545,7 +550,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.Inject(fakeNetNumber, req.ToVectorisedView()) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -559,7 +564,7 @@ func TestTransportForwarding(t *testing.T) { var p channel.PacketInfo select { - case p = <-linkEP2.C: + case p = <-ep2.C: default: t.Fatal("Response packet not forwarded") } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 418e771d2..c021c67ac 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -261,31 +261,34 @@ type FullAddress struct { Port uint16 } -// Payload provides an interface around data that is being sent to an endpoint. -// This allows the endpoint to request the amount of data it needs based on -// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data. -type Payload interface { - // Get returns a slice containing exactly 'min(size, p.Size())' bytes. - Get(size int) ([]byte, *Error) - - // Size returns the payload size. - Size() int +// Payloader is an interface that provides data. +// +// This interface allows the endpoint to request the amount of data it needs +// based on internal buffers without exposing them. +type Payloader interface { + // FullPayload returns all available bytes. + FullPayload() ([]byte, *Error) + + // Payload returns a slice containing at most size bytes. + Payload(size int) ([]byte, *Error) } -// SlicePayload implements Payload on top of slices for convenience. +// SlicePayload implements Payloader for slices. +// +// This is typically used for tests. type SlicePayload []byte -// Get implements Payload. -func (s SlicePayload) Get(size int) ([]byte, *Error) { - if size > s.Size() { - size = s.Size() - } - return s[:size], nil +// FullPayload implements Payloader.FullPayload. +func (s SlicePayload) FullPayload() ([]byte, *Error) { + return s, nil } -// Size implements Payload. -func (s SlicePayload) Size() int { - return len(s) +// Payload implements Payloader.Payload. +func (s SlicePayload) Payload(size int) ([]byte, *Error) { + if size > len(s) { + size = len(s) + } + return s[:size], nil } // A ControlMessages contains socket control messages for IP sockets. @@ -338,7 +341,7 @@ type Endpoint interface { // ErrNoLinkAddress and a notification channel is returned for the caller to // block. Channel is closed once address resolution is complete (success or // not). The channel is only non-nil in this case. - Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error) + Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) // Peek reads data without consuming it from the endpoint. // @@ -398,6 +401,10 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptInt sets a socket option, for simple cases where a value + // has the int type. + SetSockOptInt(opt SockOpt, v int) *Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error @@ -432,16 +439,33 @@ type WriteOptions struct { // EndOfRecord has the same semantics as Linux's MSG_EOR. EndOfRecord bool + + // Atomic means that all data fetched from Payloader must be written to the + // endpoint. If Atomic is false, then data fetched from the Payloader may be + // discarded if available endpoint buffer space is unsufficient. + Atomic bool } // SockOpt represents socket options which values have the int type. type SockOpt int const ( - // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of - // unread bytes in the input buffer should be returned. + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption SockOpt = iota + // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the send buffer size option. + SendBufferSizeOption + + // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to + // specify the receive buffer size option. + ReceiveBufferSizeOption + + // SendQueueSizeOption is used in GetSockOptInt to specify that the + // number of unread bytes in the output buffer should be returned. + SendQueueSizeOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -450,18 +474,6 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send -// buffer size option. -type SendBufferSizeOption int - -// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the -// receive buffer size option. -type ReceiveBufferSizeOption int - -// SendQueueSizeOption is used in GetSockOpt to specify that the number of -// unread bytes in the output buffer should be returned. -type SendQueueSizeOption int - // V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int @@ -600,9 +612,6 @@ func (r Route) String() string { return out.String() } -// LinkEndpointID represents a data link layer endpoint. -type LinkEndpointID uint64 - // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e1f622af6..a111fdb2a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -204,7 +204,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -289,7 +289,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } @@ -319,6 +319,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// SetSockOptInt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { @@ -331,6 +336,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -341,18 +358,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 13e17e2a6..a02731a5d 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -207,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -220,9 +220,8 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 return 0, nil, tcpip.ErrInvalidEndpointState } - payloadBytes, err := payload.Get(payload.Size()) + payloadBytes, err := p.FullPayload() if err != nil { - ep.mu.RUnlock() return 0, nil, err } @@ -230,7 +229,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64 // destination address, route using that address. if !ep.associated { ip := header.IPv4(payloadBytes) - if !ip.IsValid(payload.Size()) { + if !ip.IsValid(len(payloadBytes)) { ep.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } @@ -493,6 +492,11 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { @@ -505,6 +509,19 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } ep.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSize + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption @@ -516,18 +533,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - ep.mu.Lock() - *o = tcpip.SendBufferSizeOption(ep.sndBufSize) - ep.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax) - ep.rcvMu.Unlock() - return nil - case *tcpip.KeepaliveEnabledOption: *o = 0 return nil diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 1ee1a53f8..39a839ab7 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "tcp_segment_list", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ac927569a..35b489c68 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -806,7 +806,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -821,47 +821,52 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha return 0, nil, err } - e.sndBufMu.Unlock() - e.mu.RUnlock() - - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + e.mu.RUnlock() } - // Copy in memory without holding sndBufMu so that worker goroutine can - // make progress independent of this operation. - v, perr := p.Get(avail) - if perr != nil { + // Fetch data. + v, perr := p.Payload(avail) + if perr != nil || len(v) == 0 { + if opts.Atomic { // See above. + e.sndBufMu.Unlock() + e.mu.RUnlock() + } + // Note that perr may be nil if len(v) == 0. return 0, nil, perr } - e.mu.RLock() - e.sndBufMu.Lock() + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a - // write. - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.mu.RUnlock() - return 0, nil, err - } + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } - // Discard any excess data copied in due to avail being reduced due to a - // simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } } // Add data to the send queue. - l := len(v) s := newSegmentFromView(&e.route, e.id, v) - e.sndBufUsed += l - e.sndBufInQueue += seqnum.Size(l) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() // Release the endpoint lock to prevent deadlocks due to lock // order inversion when acquiring workMu. @@ -875,7 +880,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return int64(l), nil, nil + + return int64(len(v)), nil, nil } // Peek reads data without consuming it from the endpoint. @@ -946,62 +952,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 } -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil - - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.reuseAddr = v != 0 - e.mu.Unlock() - return nil - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.mu.Lock() - e.userMSS = int(userMSS) - e.mu.Unlock() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - +// SetSockOptInt sets a socket option. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1065,6 +1018,67 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sndBufMu.Unlock() return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.DelayOption: + if v == 0 { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.delay, 1) + } + return nil + + case tcpip.CorkOption: + if v == 0 { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + return nil + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.reuseAddr = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil + + case tcpip.QuickAckOption: + if v == 0 { + atomic.StoreUint32(&e.slowAck, 1) + } else { + atomic.StoreUint32(&e.slowAck, 0) + } + return nil + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = int(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.netProto != header.IPv6ProtocolNumber { @@ -1176,6 +1190,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + v := e.sndBufSize + e.sndBufMu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + v := e.rcvBufSize + e.rcvListMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -1198,18 +1224,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.SendBufferSizeOption: - e.sndBufMu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.sndBufMu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize) - e.rcvListMu.Unlock() - return nil - case *tcpip.DelayOption: *o = 0 if v := atomic.LoadUint32(&e.delay); v != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 272bbcdbd..9fa97528b 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { enableCUBIC(t, c) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) @@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 7 data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 32bb45224..7fa5cfb6e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.ActiveConnectionOpenings.Value() + 1 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -97,7 +97,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { stats := c.Stack().Stats() want := stats.TCP.FailedConnectionAttempts.Value() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) } @@ -131,7 +131,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { stats := c.Stack().Stats() // SYN and ACK want := stats.TCP.SegmentsSent.Value() + 2 - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) @@ -299,7 +299,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -323,7 +323,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) { want := stats.TCP.ResetsReceived.Value() + 1 iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(iss, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -344,14 +344,14 @@ func TestActiveHandshake(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) } func TestNonBlockingClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -367,7 +367,7 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -417,7 +417,7 @@ func TestSimpleReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -469,7 +469,7 @@ func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -557,8 +557,7 @@ func TestOutOfOrderFlood(t *testing.T) { defer c.Cleanup() // Create a new connection with initial window size of 10. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -631,7 +630,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -700,7 +699,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -785,7 +784,7 @@ func TestShutdownRead(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) @@ -804,8 +803,7 @@ func TestFullWindowReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -872,11 +870,9 @@ func TestNoWindowShrinking(t *testing.T) { defer c.Cleanup() // Start off with a window size of 10, then shrink it to 5. - opt := tcpip.ReceiveBufferSizeOption(10) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 10) - opt = 5 - if err := c.EP.SetSockOpt(opt); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { t.Fatalf("SetSockOpt failed: %v", err) } @@ -976,7 +972,7 @@ func TestSimpleSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1017,7 +1013,7 @@ func TestZeroWindowSend(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 0, nil) + c.CreateConnected(789, 0, -1 /* epRcvBuf */) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1075,8 +1071,7 @@ func TestScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1110,8 +1105,7 @@ func TestNonScaledWindowConnect(t *testing.T) { defer c.Cleanup() // Set the window size greater than the maximum non-scaled window. - opt := tcpip.ReceiveBufferSizeOption(65535 * 3) - c.CreateConnected(789, 30000, &opt) + c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} view := buffer.NewView(len(data)) @@ -1151,7 +1145,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1224,7 +1218,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1293,8 +1287,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { // Set the window size such that a window scale of 4 will be used. const wnd = 65535 * 10 const ws = uint32(4) - opt := tcpip.ReceiveBufferSizeOption(wnd) - c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ header.TCPOptionWS, 3, 0, header.TCPOptionNOP, }) @@ -1399,7 +1392,7 @@ func TestSegmentMerging(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Prevent the endpoint from processing packets. test.stop(c.EP) @@ -1449,7 +1442,7 @@ func TestDelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1497,7 +1490,7 @@ func TestUndelay(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.DelayOption(1)) @@ -1579,7 +1572,7 @@ func TestMSSNotDelayed(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -1695,7 +1688,7 @@ func TestSendGreaterThanMTU(t *testing.T) { c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) testBrokenUpWrite(t, c, maxPayload) } @@ -1704,7 +1697,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) testBrokenUpWrite(t, c, maxPayload) @@ -1727,7 +1720,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1871,7 +1864,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 2 - if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { t.Fatalf("SetSockOpt failed failed: %v", err) } @@ -1973,7 +1966,7 @@ func TestReceiveOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2010,7 +2003,7 @@ func TestSendOnResetConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send RST segment. c.SendPacket(nil, &context.Headers{ @@ -2035,7 +2028,7 @@ func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2078,7 +2071,7 @@ func TestFinRetransmit(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -2132,7 +2125,7 @@ func TestFinWithNoPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. view := buffer.NewView(10) @@ -2203,7 +2196,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write enough segments to fill the congestion window before ACK'ing // any of them. @@ -2291,7 +2284,7 @@ func TestFinWithPendingData(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) @@ -2377,7 +2370,7 @@ func TestFinWithPartialAck(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. @@ -2509,7 +2502,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { defer c.Cleanup() maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize - c.CreateConnectedWithRawOptions(789, 0, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), header.TCPOptionWS, 3, scale, header.TCPOptionNOP, }) @@ -2559,7 +2552,7 @@ func TestScaledSendWindow(t *testing.T) { func TestReceivedValidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ValidSegmentsReceived.Value() + 1 @@ -2580,7 +2573,7 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.InvalidSegmentsReceived.Value() + 1 vv := c.BuildSegment(nil, &context.Headers{ @@ -2604,7 +2597,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { func TestReceivedIncorrectChecksumIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) stats := c.Stack().Stats() want := stats.TCP.ChecksumErrors.Value() + 1 vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ @@ -2635,7 +2628,7 @@ func TestReceivedSegmentQueuing(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Send 200 segments. data := []byte{1, 2, 3} @@ -2681,7 +2674,7 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -2856,8 +2849,8 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2869,8 +2862,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - var s tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&s); err != nil { + s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -2945,26 +2938,26 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, 200) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil { + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) - if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil { + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { t.Fatalf("GetSockOpt failed: %v", err) } @@ -3231,7 +3224,7 @@ func TestPathMTUDiscovery(t *testing.T) { // Create new connection with MSS of 1460. const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize - c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), }) @@ -3308,7 +3301,7 @@ func TestTCPEndpointProbe(t *testing.T) { invoked <- struct{}{} }) - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -3482,7 +3475,7 @@ func TestKeepalive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - c.CreateConnected(789, 30000, nil) + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond)) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 18c707a57..78eff5c3a 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -150,11 +150,12 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - id, linkEP := channel.New(1000, mtu, "") + ep := channel.New(1000, mtu, "") + wep := stack.LinkEndpoint(ep) if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -180,7 +181,7 @@ func New(t *testing.T, mtu uint32) *Context { return &Context{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), } } @@ -511,7 +512,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { } // CreateConnected creates a connected TCP endpoint. -func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } @@ -589,7 +590,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // // It also sets the receive buffer for the endpoint to the specified // value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -597,8 +598,8 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("NewEndpoint failed: %v", err) } - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + if epRcvBuf != -1 { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 4bec48c0f..43fcc27f0 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index ac2666f69..c1ca22b35 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "udp_packet_list", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 66455ef46..0bec7e62d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "math" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -277,17 +276,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - if p.Size() > math.MaxUint16 { - // Payload can't possibly fit in a packet. - return 0, nil, tcpip.ErrMessageTooLong - } - to := opts.To e.mu.RLock() @@ -370,10 +364,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } + if len(v) > header.UDPMaximumPacketSize { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } ttl := route.DefaultTTL() if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { @@ -391,7 +389,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: @@ -570,7 +573,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil } + return -1, tcpip.ErrUnknownProtocolOption } @@ -580,18 +596,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.netProto != header.IPv6ProtocolNumber { @@ -747,6 +751,10 @@ func (e *endpoint) Disconnect() *tcpip.Error { } e.state = StateBound } else { + if e.id.LocalPort != 0 { + // Release the ephemeral port. + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + } e.state = StateInitial } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 995d6e8a1..c6deab892 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -275,12 +275,13 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + ep := channel.New(256, mtu, "") + wep := stack.LinkEndpoint(ep) - id, linkEP := channel.New(256, mtu, "") if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -306,7 +307,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 98d51cc69..6afdb29b7 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index cbd92fc05..8f6f180e5 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b7f505a84..b6bbb0ea2 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 9173dfd0f..8dc88becb 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -1,7 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_template_instance( name = "waiter_list", diff --git a/runsc/BUILD b/runsc/BUILD index a2a465e1e..5e7dacb87 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -13,7 +13,7 @@ go_binary( visibility = [ "//visibility:public", ], - x_defs = {"main.version": "{VERSION}"}, + x_defs = {"main.version": "{STABLE_VERSION}"}, deps = [ "//pkg/log", "//pkg/refs", @@ -46,7 +46,7 @@ go_binary( visibility = [ "//visibility:public", ], - x_defs = {"main.version": "{VERSION}"}, + x_defs = {"main.version": "{STABLE_VERSION}"}, deps = [ "//pkg/log", "//pkg/refs", @@ -101,3 +101,10 @@ pkg_deb( "//visibility:public", ], ) + +sh_test( + name = "version_test", + size = "small", + srcs = ["version_test.sh"], + data = [":runsc"], +) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 588bb8851..54d1ab129 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -109,6 +109,7 @@ go_test( "//pkg/sentry/arch:registers_go_proto", "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", + "//pkg/sentry/kernel/auth", "//pkg/unet", "//runsc/fsgofer", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", diff --git a/runsc/boot/config.go b/runsc/boot/config.go index 05b8f8761..31103367d 100644 --- a/runsc/boot/config.go +++ b/runsc/boot/config.go @@ -211,12 +211,6 @@ type Config struct { // RestoreFile is the path to the saved container image RestoreFile string - // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in - // tests. It allows runsc to start the sandbox process as the current - // user, and without chrooting the sandbox process. This can be - // necessary in test environments that have limited capabilities. - TestOnlyAllowRunAsCurrentUserWithoutChroot bool - // NumNetworkChannels controls the number of AF_PACKET sockets that map // to the same underlying network device. This allows netstack to better // scale for high throughput use cases. @@ -233,6 +227,19 @@ type Config struct { // ReferenceLeakMode sets reference leak check mode ReferenceLeakMode refs.LeakMode + + // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in + // tests. It allows runsc to start the sandbox process as the current + // user, and without chrooting the sandbox process. This can be + // necessary in test environments that have limited capabilities. + TestOnlyAllowRunAsCurrentUserWithoutChroot bool + + // TestOnlyTestNameEnv should only be used in tests. It looks up for the + // test name in the container environment variables and adds it to the debug + // log file name. This is done to help identify the log with the test when + // multiple tests are run in parallel, since there is no way to pass + // parameters to the runtime from docker. + TestOnlyTestNameEnv string } // ToFlags returns a slice of flags that correspond to the given Config. @@ -261,9 +268,12 @@ func (c *Config) ToFlags() []string { "--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr), "--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode), } + // Only include these if set since it is never to be used by users. if c.TestOnlyAllowRunAsCurrentUserWithoutChroot { - // Only include if set since it is never to be used by users. - f = append(f, "-TESTONLY-unsafe-nonroot=true") + f = append(f, "--TESTONLY-unsafe-nonroot=true") + } + if len(c.TestOnlyTestNameEnv) != 0 { + f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv) } return f } diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 7ca776b3a..a2ecc6bcb 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -88,14 +88,24 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), seccomp.AllowAny{}, seccomp.AllowAny{}, - seccomp.AllowValue(0), }, { seccomp.AllowAny{}, seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), seccomp.AllowAny{}, + }, + // Non-private variants are included for flipcall support. They are otherwise + // unncessary, as the sentry will use only private futexes internally. + { + seccomp.AllowAny{}, + seccomp.AllowValue(linux.FUTEX_WAIT), + seccomp.AllowAny{}, + seccomp.AllowAny{}, + }, + { + seccomp.AllowAny{}, + seccomp.AllowValue(linux.FUTEX_WAKE), seccomp.AllowAny{}, - seccomp.AllowValue(0), }, }, syscall.SYS_GETPID: {}, diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 823a34619..d824d7dc5 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -20,7 +20,6 @@ import ( mrand "math/rand" "os" "runtime" - "strings" "sync" "sync/atomic" "syscall" @@ -535,23 +534,12 @@ func (l *Loader) run() error { return err } - // Read /etc/passwd for the user's HOME directory and set the HOME - // environment variable as required by POSIX if it is not overridden by - // the user. - hasHomeEnvv := false - for _, envv := range l.rootProcArgs.Envv { - if strings.HasPrefix(envv, "HOME=") { - hasHomeEnvv = true - } - } - if !hasHomeEnvv { - homeDir, err := getExecUserHome(ctx, l.rootProcArgs.MountNamespace, uint32(l.rootProcArgs.Credentials.RealKUID)) - if err != nil { - return fmt.Errorf("error reading exec user: %v", err) - } - - l.rootProcArgs.Envv = append(l.rootProcArgs.Envv, "HOME="+homeDir) + // Add the HOME enviroment variable if it is not already set. + envv, err := maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv) + if err != nil { + return err } + l.rootProcArgs.Envv = envv // Create the root container init task. It will begin running // when the kernel is started. @@ -815,6 +803,16 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) { }) defer args.MountNamespace.DecRef() + // Add the HOME enviroment varible if it is not already set. + root := args.MountNamespace.Root() + defer root.DecRef() + ctx := fs.WithRoot(l.k.SupervisorContext(), root) + envv, err := maybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv) + if err != nil { + return 0, err + } + args.Envv = envv + // Start the process. proc := control.Proc{Kernel: l.k} args.PIDNamespace = tg.PIDNamespace() diff --git a/runsc/boot/network.go b/runsc/boot/network.go index ea0d9f790..32cba5ac1 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -121,10 +121,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct nicID++ nicids[link.Name] = nicID - linkEP := loopback.New() + ep := loopback.New() log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses) - if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, true /* loopback */); err != nil { return err } @@ -156,7 +156,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } mac := tcpip.LinkAddress(link.LinkAddress) - linkEP, err := fdbased.New(&fdbased.Options{ + ep, err := fdbased.New(&fdbased.Options{ FDs: FDs, MTU: uint32(link.MTU), EthernetHeader: true, @@ -170,7 +170,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) - if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, false /* loopback */); err != nil { return err } @@ -203,14 +203,14 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct // createNICWithAddrs creates a NIC in the network stack and adds the given // addresses. -func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, addrs []net.IP, loopback bool) error { +func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error { if loopback { - if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(linkEP)); err != nil { - return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, linkEP, err) + if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil { + return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v) failed: %v", id, name, err) } } else { - if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(linkEP)); err != nil { - return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, linkEP, err) + if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil { + return fmt.Errorf("CreateNamedNIC(%v, %v) failed: %v", id, name, err) } } diff --git a/runsc/boot/user.go b/runsc/boot/user.go index d1d423a5c..56cc12ee0 100644 --- a/runsc/boot/user.go +++ b/runsc/boot/user.go @@ -16,6 +16,7 @@ package boot import ( "bufio" + "fmt" "io" "strconv" "strings" @@ -23,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/usermem" ) @@ -42,7 +44,7 @@ func (r *fileReader) Read(buf []byte) (int, error) { // getExecUserHome returns the home directory of the executing user read from // /etc/passwd as read from the container filesystem. -func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32) (string, error) { +func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.KUID) (string, error) { // The default user home directory to return if no user matching the user // if found in the /etc/passwd found in the image. const defaultHome = "/" @@ -82,7 +84,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32 File: f, } - homeDir, err := findHomeInPasswd(uid, r, defaultHome) + homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome) if err != nil { return "", err } @@ -90,6 +92,28 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32 return homeDir, nil } +// maybeAddExecUserHome returns a new slice with the HOME enviroment variable +// set if the slice does not already contain it, otherwise it returns the +// original slice unmodified. +func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) { + // Check if the envv already contains HOME. + for _, env := range envv { + if strings.HasPrefix(env, "HOME=") { + // We have it. Return the original slice unmodified. + return envv, nil + } + } + + // Read /etc/passwd for the user's HOME directory and set the HOME + // environment variable as required by POSIX if it is not overridden by + // the user. + homeDir, err := getExecUserHome(ctx, mns, uid) + if err != nil { + return nil, fmt.Errorf("error reading exec user: %v", err) + } + return append(envv, "HOME="+homeDir), nil +} + // findHomeInPasswd parses a passwd file and returns the given user's home // directory. This function does it's best to replicate the runc's behavior. func findHomeInPasswd(uid uint32, passwd io.Reader, defaultHome string) (string, error) { diff --git a/runsc/boot/user_test.go b/runsc/boot/user_test.go index 906baf3e5..9aee2ad07 100644 --- a/runsc/boot/user_test.go +++ b/runsc/boot/user_test.go @@ -25,6 +25,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) func setupTempDir() (string, error) { @@ -68,7 +69,7 @@ func setupPasswd(contents string, perms os.FileMode) func() (string, error) { // TestGetExecUserHome tests the getExecUserHome function. func TestGetExecUserHome(t *testing.T) { tests := map[string]struct { - uid uint32 + uid auth.KUID createRoot func() (string, error) expected string }{ diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go index e817eff77..bf1225e1c 100644 --- a/runsc/cmd/exec.go +++ b/runsc/cmd/exec.go @@ -127,6 +127,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) Fatalf("getting environment variables: %v", err) } } + if e.Capabilities == nil { // enableRaw is set to true to prevent the filtering out of // CAP_NET_RAW. This is the opposite of Create() because exec diff --git a/runsc/container/container.go b/runsc/container/container.go index bbb364214..a721c1c31 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -513,9 +513,16 @@ func (c *Container) Start(conf *boot.Config) error { return err } - // Adjust the oom_score_adj for sandbox and gofers. This must be done after + // Adjust the oom_score_adj for sandbox. This must be done after // save(). - return c.adjustOOMScoreAdj(conf) + err = adjustSandboxOOMScoreAdj(c.Sandbox, c.RootContainerDir, false) + if err != nil { + return err + } + + // Set container's oom_score_adj to the gofer since it is dedicated to + // the container, in case the gofer uses up too much memory. + return c.adjustGoferOOMScoreAdj() } // Restore takes a container and replaces its kernel and file system @@ -782,6 +789,9 @@ func (c *Container) Destroy() error { } defer unlock() + // Stored for later use as stop() sets c.Sandbox to nil. + sb := c.Sandbox + if err := c.stop(); err != nil { err = fmt.Errorf("stopping container: %v", err) log.Warningf("%v", err) @@ -796,6 +806,16 @@ func (c *Container) Destroy() error { c.changeStatus(Stopped) + // Adjust oom_score_adj for the sandbox. This must be done after the + // container is stopped and the directory at c.Root is removed. + // We must test if the sandbox is nil because Destroy should be + // idempotent. + if sb != nil { + if err := adjustSandboxOOMScoreAdj(sb, c.RootContainerDir, true); err != nil { + errs = append(errs, err.Error()) + } + } + // "If any poststop hook fails, the runtime MUST log a warning, but the // remaining hooks and lifecycle continue as if the hook had succeeded" -OCI spec. // Based on the OCI, "The post-stop hooks MUST be called after the container is @@ -926,7 +946,14 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bund } if conf.DebugLog != "" { - debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "gofer") + test := "" + if len(conf.TestOnlyTestNameEnv) != 0 { + // Fetch test name if one is provided and the test only flag was set. + if t, ok := specutils.EnvVar(spec.Process.Env, conf.TestOnlyTestNameEnv); ok { + test = t + } + } + debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "gofer", test) if err != nil { return nil, nil, fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err) } @@ -1139,35 +1166,82 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error { return fn() } -// adjustOOMScoreAdj sets the oom_score_adj for the sandbox and all gofers. +// adjustGoferOOMScoreAdj sets the oom_store_adj for the container's gofer. +func (c *Container) adjustGoferOOMScoreAdj() error { + if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil { + if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil { + return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err) + } + } + + return nil +} + +// adjustSandboxOOMScoreAdj sets the oom_score_adj for the sandbox. // oom_score_adj is set to the lowest oom_score_adj among the containers // running in the sandbox. // // TODO(gvisor.dev/issue/512): This call could race with other containers being // created at the same time and end up setting the wrong oom_score_adj to the // sandbox. -func (c *Container) adjustOOMScoreAdj(conf *boot.Config) error { - // If this container's OOMScoreAdj is nil then we can exit early as no - // change should be made to oom_score_adj for the sandbox. - if c.Spec.Process.OOMScoreAdj == nil { - return nil - } - - containers, err := loadSandbox(conf.RootDir, c.Sandbox.ID) +func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error { + containers, err := loadSandbox(rootDir, s.ID) if err != nil { return fmt.Errorf("loading sandbox containers: %v", err) } + // Do nothing if the sandbox has been terminated. + if len(containers) == 0 { + return nil + } + // Get the lowest score for all containers. var lowScore int scoreFound := false - for _, container := range containers { - if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) { + if len(containers) == 1 && len(containers[0].Spec.Annotations[specutils.ContainerdContainerTypeAnnotation]) == 0 { + // This is a single-container sandbox. Set the oom_score_adj to + // the value specified in the OCI bundle. + if containers[0].Spec.Process.OOMScoreAdj != nil { scoreFound = true - lowScore = *container.Spec.Process.OOMScoreAdj + lowScore = *containers[0].Spec.Process.OOMScoreAdj + } + } else { + for _, container := range containers { + // Special multi-container support for CRI. Ignore the root + // container when calculating oom_score_adj for the sandbox because + // it is the infrastructure (pause) container and always has a very + // low oom_score_adj. + // + // We will use OOMScoreAdj in the single-container case where the + // containerd container-type annotation is not present. + if container.Spec.Annotations[specutils.ContainerdContainerTypeAnnotation] == specutils.ContainerdContainerTypeSandbox { + continue + } + + if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) { + scoreFound = true + lowScore = *container.Spec.Process.OOMScoreAdj + } } } + // If the container is destroyed and remaining containers have no + // oomScoreAdj specified then we must revert to the oom_score_adj of the + // parent process. + if !scoreFound && destroy { + ppid, err := specutils.GetParentPid(s.Pid) + if err != nil { + return fmt.Errorf("getting parent pid of sandbox pid %d: %v", s.Pid, err) + } + pScore, err := specutils.GetOOMScoreAdj(ppid) + if err != nil { + return fmt.Errorf("getting oom_score_adj of parent %d: %v", ppid, err) + } + + scoreFound = true + lowScore = pScore + } + // Only set oom_score_adj if one of the containers has oom_score_adj set // in the OCI bundle. If not, we need to inherit the parent process's // oom_score_adj. @@ -1177,15 +1251,10 @@ func (c *Container) adjustOOMScoreAdj(conf *boot.Config) error { } // Set the lowest of all containers oom_score_adj to the sandbox. - if err := setOOMScoreAdj(c.Sandbox.Pid, lowScore); err != nil { - return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", c.Sandbox.ID, err) + if err := setOOMScoreAdj(s.Pid, lowScore); err != nil { + return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err) } - // Set container's oom_score_adj to the gofer since it is dedicated to the - // container, in case the gofer uses up too much memory. - if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil { - return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err) - } return nil } diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go index 41f5fe1e8..e37ec0ffd 100644 --- a/runsc/dockerutil/dockerutil.go +++ b/runsc/dockerutil/dockerutil.go @@ -240,7 +240,7 @@ func (d *Docker) Stop() error { // Run calls 'docker run' with the arguments provided. The container starts // running in the background and the call returns immediately. func (d *Docker) Run(args ...string) error { - a := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-d"} + a := d.runArgs("-d") a = append(a, args...) _, err := do(a...) if err == nil { @@ -251,7 +251,7 @@ func (d *Docker) Run(args ...string) error { // RunWithPty is like Run but with an attached pty. func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) { - a := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-it"} + a := d.runArgs("-it") a = append(a, args...) return doWithPty(a...) } @@ -259,8 +259,7 @@ func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) { // RunFg calls 'docker run' with the arguments provided in the foreground. It // blocks until the container exits and returns the output. func (d *Docker) RunFg(args ...string) (string, error) { - a := []string{"run", "--runtime", d.Runtime, "--name", d.Name} - a = append(a, args...) + a := d.runArgs(args...) out, err := do(a...) if err == nil { d.logDockerID() @@ -268,6 +267,14 @@ func (d *Docker) RunFg(args ...string) (string, error) { return string(out), err } +func (d *Docker) runArgs(args ...string) []string { + // Environment variable RUNSC_TEST_NAME is picked up by the runtime and added + // to the log name, so one can easily identify the corresponding logs for + // this test. + rv := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-e", "RUNSC_TEST_NAME=" + d.Name} + return append(rv, args...) +} + // Logs calls 'docker logs'. func (d *Docker) Logs() (string, error) { return do("logs", d.Name) @@ -280,6 +287,14 @@ func (d *Docker) Exec(args ...string) (string, error) { return do(a...) } +// ExecAsUser calls 'docker exec' as the given user with the arguments +// provided. +func (d *Docker) ExecAsUser(user string, args ...string) (string, error) { + a := []string{"exec", "--user", user, d.Name} + a = append(a, args...) + return do(a...) +} + // ExecWithTerminal calls 'docker exec -it' with the arguments provided and // attaches a pty to stdio. func (d *Docker) ExecWithTerminal(args ...string) (*exec.Cmd, *os.File, error) { diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD index e2318a978..02168ad1b 100644 --- a/runsc/fsgofer/filter/BUILD +++ b/runsc/fsgofer/filter/BUILD @@ -17,6 +17,7 @@ go_library( ], deps = [ "//pkg/abi/linux", + "//pkg/flipcall", "//pkg/log", "//pkg/seccomp", "@org_golang_x_sys//unix:go_default_library", diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index 8ddfa77d6..2f3f2039a 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -83,6 +83,11 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.AllowAny{}, seccomp.AllowValue(syscall.F_GETFD), }, + // Used by flipcall.PacketWindowAllocator.Init(). + { + seccomp.AllowAny{}, + seccomp.AllowValue(unix.F_ADD_SEALS), + }, }, syscall.SYS_FSTAT: {}, syscall.SYS_FSTATFS: {}, @@ -103,6 +108,19 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.AllowAny{}, seccomp.AllowValue(0), }, + // Non-private futex used for flipcall. + seccomp.Rule{ + seccomp.AllowAny{}, + seccomp.AllowValue(linux.FUTEX_WAIT), + seccomp.AllowAny{}, + seccomp.AllowAny{}, + }, + seccomp.Rule{ + seccomp.AllowAny{}, + seccomp.AllowValue(linux.FUTEX_WAKE), + seccomp.AllowAny{}, + seccomp.AllowAny{}, + }, }, syscall.SYS_GETDENTS64: {}, syscall.SYS_GETPID: {}, @@ -112,6 +130,7 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_LINKAT: {}, syscall.SYS_LSEEK: {}, syscall.SYS_MADVISE: {}, + unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init(). syscall.SYS_MKDIRAT: {}, syscall.SYS_MMAP: []seccomp.Rule{ { @@ -160,6 +179,13 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_RT_SIGPROCMASK: {}, syscall.SYS_SCHED_YIELD: {}, syscall.SYS_SENDMSG: []seccomp.Rule{ + // Used by fdchannel.Endpoint.SendFD(). + { + seccomp.AllowAny{}, + seccomp.AllowAny{}, + seccomp.AllowValue(0), + }, + // Used by unet.SocketWriter.WriteVec(). { seccomp.AllowAny{}, seccomp.AllowAny{}, @@ -170,7 +196,15 @@ var allowedSyscalls = seccomp.SyscallRules{ {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)}, }, syscall.SYS_SIGALTSTACK: {}, - syscall.SYS_SYMLINKAT: {}, + // Used by fdchannel.NewConnectedSockets(). + syscall.SYS_SOCKETPAIR: { + { + seccomp.AllowValue(syscall.AF_UNIX), + seccomp.AllowValue(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC), + seccomp.AllowValue(0), + }, + }, + syscall.SYS_SYMLINKAT: {}, syscall.SYS_TGKILL: []seccomp.Rule{ { seccomp.AllowValue(uint64(os.Getpid())), diff --git a/runsc/main.go b/runsc/main.go index b6546717c..304d771c2 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -79,6 +79,7 @@ var ( // Test flags, not to be used outside tests, ever. testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") + testOnlyTestNameEnv = flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.") ) func main() { @@ -211,6 +212,7 @@ func main() { ReferenceLeakMode: refsLeakMode, TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, + TestOnlyTestNameEnv: *testOnlyTestNameEnv, } if len(*straceSyscalls) != 0 { conf.StraceSyscalls = strings.Split(*straceSyscalls, ",") @@ -244,7 +246,7 @@ func main() { e = newEmitter(*debugLogFormat, f) } else if *debugLog != "" { - f, err := specutils.DebugLogFile(*debugLog, subcommand) + f, err := specutils.DebugLogFile(*debugLog, subcommand, "" /* name */) if err != nil { cmd.Fatalf("error opening debug log file in %q: %v", *debugLog, err) } diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index df3c0c5ef..4c6c83fbd 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -351,7 +351,15 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF nextFD++ } if conf.DebugLog != "" { - debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot") + test := "" + if len(conf.TestOnlyTestNameEnv) == 0 { + // Fetch test name if one is provided and the test only flag was set. + if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok { + test = t + } + } + + debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot", test) if err != nil { return fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err) } diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index 2eec92349..cb9e58dfb 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -23,6 +23,7 @@ import ( "os" "path" "path/filepath" + "strconv" "strings" "syscall" "time" @@ -398,13 +399,15 @@ func WaitForReady(pid int, timeout time.Duration, ready func() (bool, error)) er // - %TIMESTAMP%: is replaced with a timestamp using the following format: // <yyyymmdd-hhmmss.uuuuuu> // - %COMMAND%: is replaced with 'command' -func DebugLogFile(logPattern, command string) (*os.File, error) { +// - %TEST%: is replaced with 'test' (omitted by default) +func DebugLogFile(logPattern, command, test string) (*os.File, error) { if strings.HasSuffix(logPattern, "/") { // Default format: <debug-log>/runsc.log.<yyyymmdd-hhmmss.uuuuuu>.<command> logPattern += "runsc.log.%TIMESTAMP%.%COMMAND%" } logPattern = strings.Replace(logPattern, "%TIMESTAMP%", time.Now().Format("20060102-150405.000000"), -1) logPattern = strings.Replace(logPattern, "%COMMAND%", command, -1) + logPattern = strings.Replace(logPattern, "%TEST%", test, -1) dir := filepath.Dir(logPattern) if err := os.MkdirAll(dir, 0775); err != nil { @@ -503,3 +506,53 @@ func RetryEintr(f func() (uintptr, uintptr, error)) (uintptr, uintptr, error) { } } } + +// GetOOMScoreAdj reads the given process' oom_score_adj +func GetOOMScoreAdj(pid int) (int, error) { + data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid)) + if err != nil { + return 0, err + } + return strconv.Atoi(strings.TrimSpace(string(data))) +} + +// GetParentPid gets the parent process ID of the specified PID. +func GetParentPid(pid int) (int, error) { + data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) + if err != nil { + return 0, err + } + + var cpid string + var name string + var state string + var ppid int + // Parse after the binary name. + _, err = fmt.Sscanf(string(data), + "%v %v %v %d", + // cpid is ignored. + &cpid, + // name is ignored. + &name, + // state is ignored. + &state, + &ppid) + + if err != nil { + return 0, err + } + + return ppid, nil +} + +// EnvVar looks for a varible value in the env slice assuming the following +// format: "NAME=VALUE". +func EnvVar(env []string, name string) (string, bool) { + prefix := name + "=" + for _, e := range env { + if strings.HasPrefix(e, prefix) { + return strings.TrimPrefix(e, prefix), true + } + } + return "", false +} diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index 57ab73d97..edf8b126c 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -26,12 +26,14 @@ import ( "io" "io/ioutil" "log" + "math" "math/rand" "net/http" "os" "os/exec" "os/signal" "path/filepath" + "strconv" "strings" "sync" "sync/atomic" @@ -438,3 +440,44 @@ func IsStatic(filename string) (bool, error) { } return true, nil } + +// TestBoundsForShard calculates the beginning and end indices for the test +// based on the TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. The +// returned ints are the beginning (inclusive) and end (exclusive) of the +// subslice corresponding to the shard. If either of the env vars are not +// present, then the function will return bounds that include all tests. If +// there are more shards than there are tests, then the returned list may be +// empty. +func TestBoundsForShard(numTests int) (int, int, error) { + var ( + begin = 0 + end = numTests + ) + indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS") + if indexStr == "" || totalStr == "" { + return begin, end, nil + } + + // Parse index and total to ints. + shardIndex, err := strconv.Atoi(indexStr) + if err != nil { + return 0, 0, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err) + } + shardTotal, err := strconv.Atoi(totalStr) + if err != nil { + return 0, 0, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err) + } + + // Calculate! + shardSize := int(math.Ceil(float64(numTests) / float64(shardTotal))) + begin = shardIndex * shardSize + end = ((shardIndex + 1) * shardSize) + if begin > numTests { + // Nothing to run. + return 0, 0, nil + } + if end > numTests { + end = numTests + } + return begin, end, nil +} diff --git a/runsc/version.go b/runsc/version.go index ce0573a9b..ab9194b9d 100644 --- a/runsc/version.go +++ b/runsc/version.go @@ -15,4 +15,4 @@ package main // version is set during linking. -var version = "" +var version = "VERSION_MISSING" diff --git a/runsc/version_test.sh b/runsc/version_test.sh new file mode 100755 index 000000000..cc0ca3f05 --- /dev/null +++ b/runsc/version_test.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -euf -x -o pipefail + +readonly runsc="${TEST_SRCDIR}/__main__/runsc/linux_amd64_pure_stripped/runsc" +readonly version=$($runsc --version) + +# Version should should not match VERSION, which is the default and which will +# also appear if something is wrong with workspace_status.sh script. +if [[ $version =~ "VERSION" ]]; then + echo "FAIL: Got bad version $version" + exit 1 +fi + +# Version should contain at least one number. +if [[ ! $version =~ [0-9] ]]; then + echo "FAIL: Got bad version $version" + exit 1 +fi + +echo "PASS: Got OK version $version" +exit 0 diff --git a/scripts/build.sh b/scripts/build.sh index 293d87093..b3a6e4e7a 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -16,6 +16,9 @@ source $(dirname $0)/common.sh +# Install required packages for make_repository.sh et al. +sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils + # Build runsc. runsc=$(build -c opt //runsc) @@ -24,16 +27,19 @@ pkg=$(build -c opt --host_force_python=py2 //runsc:runsc-debian) # Build a repository, if the key is available. if [[ -v KOKORO_REPO_KEY ]]; then - repo=$(tools/make_repository.sh "${KOKORO_REPO_KEY}" gvisor-bot@google.com) + repo=$(tools/make_repository.sh "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" gvisor-bot@google.com main ${pkg}) fi # Install installs artifacts. install() { - mkdir -p $1 - cp "${runsc}" "$1"/runsc - sha512sum "$1"/runsc | awk '{print $1 " runsc"}' > "$1"/runsc.sha512 + local -r binaries_dir="$1" + local -r repo_dir="$2" + mkdir -p "${binaries_dir}" + cp -f "${runsc}" "${binaries_dir}"/runsc + sha512sum "${binaries_dir}"/runsc | awk '{print $1 " runsc"}' > "${binaries_dir}"/runsc.sha512 if [[ -v repo ]]; then - cp -a "${repo}" "${latest_dir}"/repo + rm -rf "${repo_dir}" && mkdir -p "$(dirname "${repo_dir}")" + cp -a "${repo}" "${repo_dir}" fi } @@ -41,22 +47,33 @@ install() { # current date. If the current commit happens to correpond to a tag, then we # will also move everything into a directory named after the given tag. if [[ -v KOKORO_ARTIFACTS_DIR ]]; then - if [[ "${KOKORO_BUILD_NIGHTLY}" == "true" ]]; then + if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then # The "latest" directory and current date. - install "${KOKORO_ARTIFACTS_DIR}/nightly/latest" - install "${KOKORO_ARTIFACTS_DIR}/nightly/$(date -Idate)" + stamp="$(date -Idate)" + install "${KOKORO_ARTIFACTS_DIR}/nightly/latest" \ + "${KOKORO_ARTIFACTS_DIR}/dists/nightly/latest" + install "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" \ + "${KOKORO_ARTIFACTS_DIR}/dists/nightly/${stamp}" else # Is it a tagged release? Build that instead. In that case, we also try to # update the base release directory, in case this is an update. Finally, we # update the "release" directory, which has the last released version. - tag="$(git describe --exact-match --tags HEAD)" - if ! [[ -z "${tag}" ]]; then - install "${KOKORO_ARTIFACTS_DIR}/${tag}" - base=$(echo "${tag}" | cut -d'.' -f1) - if [[ "${base}" != "${tag}" ]]; then - install "${KOKORO_ARTIFACTS_DIR}/${base}" - fi - install "${KOKORO_ARTIFACTS_DIR}/release" + tags="$(git tag --points-at HEAD)" + if ! [[ -z "${tags}" ]]; then + # Note that a given commit can match any number of tags. We have to + # iterate through all possible tags and produce associated artifacts. + for tag in ${tags}; do + name=$(echo "${tag}" | cut -d'-' -f2) + base=$(echo "${name}" | cut -d'.' -f1) + install "${KOKORO_ARTIFACTS_DIR}/release/${name}" \ + "${KOKORO_ARTIFACTS_DIR}/dists/${name}" + if [[ "${base}" != "${tag}" ]]; then + install "${KOKORO_ARTIFACTS_DIR}/release/${base}" \ + "${KOKORO_ARTIFACTS_DIR}/dists/${base}" + fi + install "${KOKORO_ARTIFACTS_DIR}/release/latest" \ + "${KOKORO_ARTIFACTS_DIR}/dists/latest" + done fi fi fi diff --git a/scripts/common.sh b/scripts/common.sh index f2b9e24d8..6dabad141 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -14,10 +14,67 @@ # See the License for the specific language governing permissions and # limitations under the License. -set -xeo pipefail +set -xeou pipefail if [[ -f $(dirname $0)/common_google.sh ]]; then source $(dirname $0)/common_google.sh else source $(dirname $0)/common_bazel.sh fi + +# Ensure it attempts to collect logs in all cases. +trap collect_logs EXIT + +function set_runtime() { + RUNTIME=${1:-runsc} + RUNSC_BIN=/tmp/"${RUNTIME}"/runsc + RUNSC_LOGS_DIR="$(dirname ${RUNSC_BIN})"/logs + RUNSC_LOGS="${RUNSC_LOGS_DIR}"/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND% +} + +function test_runsc() { + test --test_arg=--runtime=${RUNTIME} "$@" +} + +function install_runsc_for_test() { + local -r test_name=$1 + shift + if [[ -z "${test_name}" ]]; then + echo "Missing mandatory test name" + exit 1 + fi + + # Add test to the name, so it doesn't conflict with other runtimes. + set_runtime $(find_branch_name)_"${test_name}" + + # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name + # down to the runtime. + install_runsc "${RUNTIME}" \ + --TESTONLY-test-name-env=RUNSC_TEST_NAME \ + --debug \ + --strace \ + --log-packets \ + "$@" +} + +# Installs the runsc with given runtime name. set_runtime must have been called +# to set runtime and logs location. +function install_runsc() { + local -r runtime=$1 + shift + + # Prepare the runtime binary. + local -r output=$(build //runsc) + mkdir -p "$(dirname ${RUNSC_BIN})" + cp -f "${output}" "${RUNSC_BIN}" + chmod 0755 "${RUNSC_BIN}" + + # Install the runtime. + sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@" + + # Clear old logs files that may exist. + sudo rm -f "${RUNSC_LOGS_DIR}"/* + + # Restart docker to pick up the new runtime configuration. + sudo systemctl restart docker +} diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh index 42248cb25..dde0b51ed 100755 --- a/scripts/common_bazel.sh +++ b/scripts/common_bazel.sh @@ -48,20 +48,12 @@ fi # Wrap bazel. function build() { - bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" + bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 | + tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }' } function test() { - (bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" && rc=0) || rc=$? - - # Zip out everything into a convenient form. - if [[ -v KOKORO_ARTIFACTS_DIR ]]; then - find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | - tar --create --files-from - --transform 's/test\./sponge_log./' | - tar --extract --directory ${KOKORO_ARTIFACTS_DIR} - fi - - return $rc + bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" } function run() { @@ -75,3 +67,26 @@ function run_as_root() { shift bazel run --run_under="sudo" "${binary}" -- "$@" } + +function collect_logs() { + # Zip out everything into a convenient form. + if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then + # Move test logs to Kokoro directory. tar is used to conveniently perform + # renames while moving files. + find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | + tar --create --files-from - --transform 's/test\./sponge_log./' | + tar --extract --directory ${KOKORO_ARTIFACTS_DIR} + + # Collect sentry logs, if any. + if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then + local -r logs=$(ls "${RUNSC_LOGS_DIR}") + if [[ -z "${logs}" ]]; then + tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${RUNTIME}.tar.gz" -C "${RUNSC_LOGS_DIR}" . + fi + fi + fi +} + +function find_branch_name() { + git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename +} diff --git a/scripts/dev.sh b/scripts/dev.sh new file mode 100755 index 000000000..ee74dcb72 --- /dev/null +++ b/scripts/dev.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source $(dirname $0)/common.sh + +# common.sh sets '-x', but it's annoying to see so much output. +set +x + +# Defaults +declare -i REFRESH=0 +declare NAME=$(find_branch_name) + +while [[ $# -gt 0 ]]; do + case "$1" in + --refresh) + REFRESH=1 + ;; + --help) + echo "Use this script to build and install runsc with Docker." + echo + echo "usage: $0 [--refresh] [runtime_name]" + exit 1 + ;; + *) + NAME=$1 + ;; + esac + shift +done + +set_runtime "${NAME}" +echo +echo "Using runtime=${RUNTIME}" +echo + +echo Building runsc... +# Build first and fail on error. $() prevents "set -e" from reporting errors. +build //runsc +declare OUTPUT="$(build //runsc)" + +if [[ ${REFRESH} -eq 0 ]]; then + install_runsc "${RUNTIME}" --net-raw + install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets + + echo + echo "Runtimes ${RUNTIME} and ${RUNTIME}-d (debug enabled) setup." + echo "Use --runtime="${RUNTIME}" with your Docker command." + echo " docker run --rm --runtime="${RUNTIME}" --rm hello-world" + echo + echo "If you rebuild, use $0 --refresh." + +else + mkdir -p "$(dirname ${RUNSC_BIN})" + cp -f ${OUTPUT} "${RUNSC_BIN}" + + echo + echo "Runtime ${RUNTIME} refreshed." +fi + +echo "Logs are in: ${RUNSC_LOGS_DIR}" diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh index d6b18a35b..72ba05260 100755 --- a/scripts/docker_tests.sh +++ b/scripts/docker_tests.sh @@ -16,7 +16,5 @@ source $(dirname $0)/common.sh -# Install the runtime and perform basic tests. -run_as_root //runsc install --experimental=true -- --debug --strace --log-packets -sudo systemctl restart docker -test //test/image:image_test //test/e2e:integration_test +install_runsc_for_test docker +test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/go.sh b/scripts/go.sh index e49d76c6d..0dbfb7747 100755 --- a/scripts/go.sh +++ b/scripts/go.sh @@ -29,6 +29,15 @@ git checkout go && git clean -f go build ./... # Push, if required. -if [[ "${KOKORO_GO_PUSH}" == "true" ]]; then +if [[ -v KOKORO_GO_PUSH ]] && [[ "${KOKORO_GO_PUSH}" == "true" ]]; then + if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then + git config --global credential.helper cache + git credential approve <<EOF +protocol=https +host=github.com +username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}") +password=x-oauth-basic +EOF + fi git push origin go:go fi diff --git a/scripts/hostnet_tests.sh b/scripts/hostnet_tests.sh index 0631c5510..41298293d 100755 --- a/scripts/hostnet_tests.sh +++ b/scripts/hostnet_tests.sh @@ -17,6 +17,5 @@ source $(dirname $0)/common.sh # Install the runtime and perform basic tests. -run_as_root //runsc install --experimental=true -- --debug --strace --log-packets --network=host -sudo systemctl restart docker -test --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test +install_runsc_for_test hostnet --network=host +test_runsc --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test diff --git a/scripts/kvm_tests.sh b/scripts/kvm_tests.sh index 5cb7aa007..5662401df 100755 --- a/scripts/kvm_tests.sh +++ b/scripts/kvm_tests.sh @@ -20,11 +20,9 @@ source $(dirname $0)/common.sh (lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm sudo chmod a+rw /dev/kvm -# Run all KVM-tagged tests (locally). -test --test_strategy=standalone --test_tag_filters=requires-kvm //... -test --test_strategy=standalone //pkg/sentry/platform/kvm:kvm_test +# Run all KVM platform tests (locally). +run_as_root //pkg/sentry/platform/kvm:kvm_test # Install the KVM runtime and run all integration tests. -run_as_root //runsc install --experimental=true -- --debug --strace --log-packets --platform=kvm -sudo systemctl restart docker -test --test_strategy=standalone //test/image:image_test //test/e2e:integration_test +install_runsc_for_test kvm --platform=kvm +test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/overlay_tests.sh b/scripts/overlay_tests.sh index 651a51f70..2a1f12c0b 100755 --- a/scripts/overlay_tests.sh +++ b/scripts/overlay_tests.sh @@ -17,6 +17,5 @@ source $(dirname $0)/common.sh # Install the runtime and perform basic tests. -run_as_root //runsc install --experimental=true -- --debug --strace --log-packets --overlay -sudo systemctl restart docker -test //test/image:image_test //test/e2e:integration_test +install_runsc_for_test overlay --overlay +test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/release.sh b/scripts/release.sh index 422319500..b936bcc77 100755 --- a/scripts/release.sh +++ b/scripts/release.sh @@ -26,9 +26,13 @@ if ! [[ -v KOKORO_RELEASE_TAG ]]; then exit 1 fi +# Unless an explicit releaser is provided, use the bot e-mail. +declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot} +declare -r EMAIL=${EMAIL:-${KOKORO_RELEASE_AUTHOR}@google.com} + # Ensure we have an appropriate configuration for the tag. git config --get user.name || git config user.name "gVisor-bot" -git config --get user.email || git config user.email "gvisor-bot@google.com" +git config --get user.email || git config user.email "${EMAIL}" # Run the release tool, which pushes to the origin repository. tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}" diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh index e42c0e3ec..4e4fcc76b 100755 --- a/scripts/root_tests.sh +++ b/scripts/root_tests.sh @@ -26,6 +26,6 @@ chmod +x ${shim_path} sudo mv ${shim_path} /usr/local/bin/gvisor-containerd-shim # Run the tests that require root. -run_as_root //runsc install --experimental=true -- --debug --strace --log-packets -sudo systemctl restart docker -run_as_root //test/root:root_test +install_runsc_for_test root +run_as_root //test/root:root_test --runtime=${RUNTIME} + diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go index ce2c4f689..7238c2afe 100644 --- a/test/e2e/exec_test.go +++ b/test/e2e/exec_test.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package image provides end-to-end integration tests for runsc. These tests -// require docker and runsc to be installed on the machine. +// Package integration provides end-to-end integration tests for runsc. These +// tests require docker and runsc to be installed on the machine. // // Each test calls docker commands to start up a container, and tests that it // is behaving properly, with various runsc commands. The container is killed @@ -154,3 +154,68 @@ func TestExecError(t *testing.T) { t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want) } } + +// Test that exec inherits environment from run. +func TestExecEnv(t *testing.T) { + if err := dockerutil.Pull("alpine"); err != nil { + t.Fatalf("docker pull failed: %v", err) + } + d := dockerutil.MakeDocker("exec-env-test") + + // Start the container with env FOO=BAR. + if err := d.Run("-e", "FOO=BAR", "alpine", "sleep", "1000"); err != nil { + t.Fatalf("docker run failed: %v", err) + } + defer d.CleanUp() + + // Exec "echo $FOO". + got, err := d.Exec("/bin/sh", "-c", "echo $FOO") + if err != nil { + t.Fatalf("docker exec failed: %v", err) + } + if want := "BAR"; !strings.Contains(got, want) { + t.Errorf("wanted exec output to contain %q, got %q", want, got) + } +} + +// Test that exec always has HOME environment set, even when not set in run. +func TestExecEnvHasHome(t *testing.T) { + // Base alpine image does not have any environment variables set. + if err := dockerutil.Pull("alpine"); err != nil { + t.Fatalf("docker pull failed: %v", err) + } + d := dockerutil.MakeDocker("exec-env-test") + + // We will check that HOME is set for root user, and also for a new + // non-root user we will create. + newUID := 1234 + newHome := "/foo/bar" + + // Create a new user with a home directory, and then sleep. + script := fmt.Sprintf(` + mkdir -p -m 777 %s && \ + adduser foo -D -u %d -h %s && \ + sleep 1000`, newHome, newUID, newHome) + if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil { + t.Fatalf("docker run failed: %v", err) + } + defer d.CleanUp() + + // Exec "echo $HOME", and expect to see "/root". + got, err := d.Exec("/bin/sh", "-c", "echo $HOME") + if err != nil { + t.Fatalf("docker exec failed: %v", err) + } + if want := "/root"; !strings.Contains(got, want) { + t.Errorf("wanted exec output to contain %q, got %q", want, got) + } + + // Execute the same as uid 123 and expect newHome. + got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME") + if err != nil { + t.Fatalf("docker exec failed: %v", err) + } + if want := newHome; !strings.Contains(got, want) { + t.Errorf("wanted exec output to contain %q, got %q", want, got) + } +} diff --git a/test/root/BUILD b/test/root/BUILD index f130df2c7..d5dd9bca2 100644 --- a/test/root/BUILD +++ b/test/root/BUILD @@ -15,6 +15,11 @@ go_test( "cgroup_test.go", "chroot_test.go", "crictl_test.go", + "main_test.go", + "oom_score_adj_test.go", + ], + data = [ + "//runsc", ], embed = [":root"], tags = [ @@ -25,12 +30,15 @@ go_test( ], visibility = ["//:sandbox"], deps = [ + "//runsc/boot", "//runsc/cgroup", + "//runsc/container", "//runsc/criutil", "//runsc/dockerutil", "//runsc/specutils", "//runsc/testutil", "//test/root/testdata", + "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", ], ) diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index cc7e8583e..76f1e4f2a 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -62,6 +62,12 @@ func TestCgroup(t *testing.T) { } d := dockerutil.MakeDocker("cgroup-test") + // This is not a comprehensive list of attributes. + // + // Note that we are specifically missing cpusets, which fail if specified. + // In any case, it's unclear if cpusets can be reliably tested here: these + // are often run on a single core virtual machine, and there is only a single + // CPU available in our current set, and every container's set. attrs := []struct { arg string ctrl string @@ -88,18 +94,6 @@ func TestCgroup(t *testing.T) { want: "3000", }, { - arg: "--cpuset-cpus=0", - ctrl: "cpuset", - file: "cpuset.cpus", - want: "0", - }, - { - arg: "--cpuset-mems=0", - ctrl: "cpuset", - file: "cpuset.mems", - want: "0", - }, - { arg: "--kernel-memory=100MB", ctrl: "memory", file: "memory.kmem.limit_in_bytes", diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go index f47f8e2c2..be0f63d18 100644 --- a/test/root/chroot_test.go +++ b/test/root/chroot_test.go @@ -16,19 +16,15 @@ package root import ( - "flag" "fmt" "io/ioutil" - "os" "os/exec" "path/filepath" "strconv" "strings" "testing" - "github.com/syndtr/gocapability/capability" "gvisor.dev/gvisor/runsc/dockerutil" - "gvisor.dev/gvisor/runsc/specutils" ) // TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned @@ -144,15 +140,3 @@ func TestChrootGofer(t *testing.T) { } } } - -func TestMain(m *testing.M) { - dockerutil.EnsureSupportedDockerVersion() - - if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) { - fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.") - os.Exit(1) - } - - flag.Parse() - os.Exit(m.Run()) -} diff --git a/test/root/main_test.go b/test/root/main_test.go new file mode 100644 index 000000000..d74dec85f --- /dev/null +++ b/test/root/main_test.go @@ -0,0 +1,49 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package root + +import ( + "flag" + "fmt" + "os" + "testing" + + "github.com/syndtr/gocapability/capability" + "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/runsc/specutils" +) + +// TestMain is the main function for root tests. This function checks the +// supported docker version, required capabilities, and configures the executable +// path for runsc. +func TestMain(m *testing.M) { + flag.Parse() + + if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) { + fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.") + os.Exit(1) + } + + dockerutil.EnsureSupportedDockerVersion() + + // Configure exe for tests. + path, err := dockerutil.RuntimePath() + if err != nil { + panic(err.Error()) + } + specutils.ExePath = path + + os.Exit(m.Run()) +} diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go new file mode 100644 index 000000000..6cd378a1b --- /dev/null +++ b/test/root/oom_score_adj_test.go @@ -0,0 +1,376 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package root + +import ( + "fmt" + "os" + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/container" + "gvisor.dev/gvisor/runsc/specutils" + "gvisor.dev/gvisor/runsc/testutil" +) + +var ( + maxOOMScoreAdj = 1000 + highOOMScoreAdj = 500 + lowOOMScoreAdj = -500 + minOOMScoreAdj = -1000 +) + +// Tests for oom_score_adj have to be run as root (rather than in a user +// namespace) because we need to adjust oom_score_adj for PIDs other than our +// own and test values below 0. + +// TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a +// single container sandbox. +func TestOOMScoreAdjSingle(t *testing.T) { + ppid, err := specutils.GetParentPid(os.Getpid()) + if err != nil { + t.Fatalf("getting parent pid: %v", err) + } + parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid) + if err != nil { + t.Fatalf("getting parent oom_score_adj: %v", err) + } + + testCases := []struct { + Name string + + // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then + // no value is set. + OOMScoreAdj *int + }{ + { + Name: "max", + OOMScoreAdj: &maxOOMScoreAdj, + }, + { + Name: "high", + OOMScoreAdj: &highOOMScoreAdj, + }, + { + Name: "low", + OOMScoreAdj: &lowOOMScoreAdj, + }, + { + Name: "min", + OOMScoreAdj: &minOOMScoreAdj, + }, + { + Name: "nil", + OOMScoreAdj: &parentOOMScoreAdj, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + id := testutil.UniqueContainerID() + s := testutil.NewSpecWithArgs("sleep", "1000") + s.Process.OOMScoreAdj = testCase.OOMScoreAdj + + conf := testutil.TestConfig() + containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id}) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() + + c := containers[0] + + // Verify the gofer's oom_score_adj + if testCase.OOMScoreAdj != nil { + goferScore, err := specutils.GetOOMScoreAdj(c.GoferPid) + if err != nil { + t.Fatalf("error reading gofer oom_score_adj: %v", err) + } + if goferScore != *testCase.OOMScoreAdj { + t.Errorf("gofer oom_score_adj got: %d, want: %d", goferScore, *testCase.OOMScoreAdj) + } + + // Verify the sandbox's oom_score_adj. + // + // The sandbox should be the same for all containers so just use + // the first one. + sandboxPid := c.Sandbox.Pid + sandboxScore, err := specutils.GetOOMScoreAdj(sandboxPid) + if err != nil { + t.Fatalf("error reading sandbox oom_score_adj: %v", err) + } + if sandboxScore != *testCase.OOMScoreAdj { + t.Errorf("sandbox oom_score_adj got: %d, want: %d", sandboxScore, *testCase.OOMScoreAdj) + } + } + }) + } +} + +// TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a +// multi-container sandbox. +func TestOOMScoreAdjMulti(t *testing.T) { + ppid, err := specutils.GetParentPid(os.Getpid()) + if err != nil { + t.Fatalf("getting parent pid: %v", err) + } + parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid) + if err != nil { + t.Fatalf("getting parent oom_score_adj: %v", err) + } + + testCases := []struct { + Name string + + // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then + // no value is set. One value for each container. The first value is the + // root container. + OOMScoreAdj []*int + + // Expected is the expected oom_score_adj of the sandbox. If nil, then + // this value is ignored. + Expected *int + + // Remove is a set of container indexes to remove from the sandbox. + Remove []int + + // ExpectedAfterRemove is the expected oom_score_adj of the sandbox + // after containers are removed. Ignored if nil. + ExpectedAfterRemove *int + }{ + // A single container CRI test case. This should not happen in + // practice as there should be at least one container besides the pause + // container. However, we include a test case to ensure sane behavior. + { + Name: "single", + OOMScoreAdj: []*int{&highOOMScoreAdj}, + Expected: &parentOOMScoreAdj, + }, + { + Name: "multi_no_value", + OOMScoreAdj: []*int{nil, nil, nil}, + Expected: &parentOOMScoreAdj, + }, + { + Name: "multi_non_nil_root", + OOMScoreAdj: []*int{&minOOMScoreAdj, nil, nil}, + Expected: &parentOOMScoreAdj, + }, + { + Name: "multi_value", + OOMScoreAdj: []*int{&minOOMScoreAdj, &highOOMScoreAdj, &lowOOMScoreAdj}, + // The lowest value excluding the root container is expected. + Expected: &lowOOMScoreAdj, + }, + { + Name: "multi_min_value", + OOMScoreAdj: []*int{&minOOMScoreAdj, &lowOOMScoreAdj}, + // The lowest value excluding the root container is expected. + Expected: &lowOOMScoreAdj, + }, + { + Name: "multi_max_value", + OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj}, + // The lowest value excluding the root container is expected. + Expected: &highOOMScoreAdj, + }, + { + Name: "remove_adjusted", + OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj}, + // The lowest value excluding the root container is expected. + Expected: &highOOMScoreAdj, + // Remove highOOMScoreAdj container. + Remove: []int{2}, + ExpectedAfterRemove: &maxOOMScoreAdj, + }, + { + // This test removes all non-root sandboxes with a specified oomScoreAdj. + Name: "remove_to_nil", + OOMScoreAdj: []*int{&minOOMScoreAdj, nil, &lowOOMScoreAdj}, + Expected: &lowOOMScoreAdj, + // Remove lowOOMScoreAdj container. + Remove: []int{2}, + // The oom_score_adj expected after remove is that of the parent process. + ExpectedAfterRemove: &parentOOMScoreAdj, + }, + { + Name: "remove_no_effect", + OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj}, + // The lowest value excluding the root container is expected. + Expected: &highOOMScoreAdj, + // Remove the maxOOMScoreAdj container. + Remove: []int{1}, + ExpectedAfterRemove: &highOOMScoreAdj, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + var cmds [][]string + var oomScoreAdj []*int + var toRemove []string + + for _, oomScore := range testCase.OOMScoreAdj { + oomScoreAdj = append(oomScoreAdj, oomScore) + cmds = append(cmds, []string{"sleep", "100"}) + } + + specs, ids := createSpecs(cmds...) + for i, spec := range specs { + // Ensure the correct value is set, including no value. + spec.Process.OOMScoreAdj = oomScoreAdj[i] + + for _, j := range testCase.Remove { + if i == j { + toRemove = append(toRemove, ids[i]) + } + } + } + + conf := testutil.TestConfig() + containers, cleanup, err := startContainers(conf, specs, ids) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() + + for i, c := range containers { + if oomScoreAdj[i] != nil { + // Verify the gofer's oom_score_adj + score, err := specutils.GetOOMScoreAdj(c.GoferPid) + if err != nil { + t.Fatalf("error reading gofer oom_score_adj: %v", err) + } + if score != *oomScoreAdj[i] { + t.Errorf("gofer oom_score_adj got: %d, want: %d", score, *oomScoreAdj[i]) + } + } + } + + // Verify the sandbox's oom_score_adj. + // + // The sandbox should be the same for all containers so just use + // the first one. + sandboxPid := containers[0].Sandbox.Pid + if testCase.Expected != nil { + score, err := specutils.GetOOMScoreAdj(sandboxPid) + if err != nil { + t.Fatalf("error reading sandbox oom_score_adj: %v", err) + } + if score != *testCase.Expected { + t.Errorf("sandbox oom_score_adj got: %d, want: %d", score, *testCase.Expected) + } + } + + if len(toRemove) == 0 { + return + } + + // Remove containers. + for _, removeID := range toRemove { + for _, c := range containers { + if c.ID == removeID { + c.Destroy() + } + } + } + + // Check the new adjusted oom_score_adj. + if testCase.ExpectedAfterRemove != nil { + scoreAfterRemove, err := specutils.GetOOMScoreAdj(sandboxPid) + if err != nil { + t.Fatalf("error reading sandbox oom_score_adj: %v", err) + } + if scoreAfterRemove != *testCase.ExpectedAfterRemove { + t.Errorf("sandbox oom_score_adj got: %d, want: %d", scoreAfterRemove, *testCase.ExpectedAfterRemove) + } + } + }) + } +} + +func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { + var specs []*specs.Spec + var ids []string + rootID := testutil.UniqueContainerID() + + for i, cmd := range cmds { + spec := testutil.NewSpecWithArgs(cmd...) + if i == 0 { + spec.Annotations = map[string]string{ + specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeSandbox, + } + ids = append(ids, rootID) + } else { + spec.Annotations = map[string]string{ + specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer, + specutils.ContainerdSandboxIDAnnotation: rootID, + } + ids = append(ids, testutil.UniqueContainerID()) + } + specs = append(specs, spec) + } + return specs, ids +} + +func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) { + // Setup root dir if one hasn't been provided. + if len(conf.RootDir) == 0 { + rootDir, err := testutil.SetupRootDir() + if err != nil { + return nil, nil, fmt.Errorf("error creating root dir: %v", err) + } + conf.RootDir = rootDir + } + + var containers []*container.Container + var bundles []string + cleanup := func() { + for _, c := range containers { + c.Destroy() + } + for _, b := range bundles { + os.RemoveAll(b) + } + os.RemoveAll(conf.RootDir) + } + for i, spec := range specs { + bundleDir, err := testutil.SetupBundleDir(spec) + if err != nil { + cleanup() + return nil, nil, fmt.Errorf("error setting up container: %v", err) + } + bundles = append(bundles, bundleDir) + + args := container.Args{ + ID: ids[i], + Spec: spec, + BundleDir: bundleDir, + } + cont, err := container.New(conf, args) + if err != nil { + cleanup() + return nil, nil, fmt.Errorf("error creating container: %v", err) + } + containers = append(containers, cont) + + if err := cont.Start(conf); err != nil { + cleanup() + return nil, nil, fmt.Errorf("error starting container: %v", err) + } + } + return containers, cleanup, nil +} diff --git a/test/root/root.go b/test/root/root.go index 349c752cc..0f1d29faf 100644 --- a/test/root/root.go +++ b/test/root/root.go @@ -12,5 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package root is empty. See chroot_test.go for description. +// Package root is used for tests that requires sysadmin privileges run. First, +// follow the setup instruction in runsc/test/README.md. You should also have +// docker, containerd, and crictl installed. To run these tests from the +// project root directory: +// +// ./scripts/root_tests.sh package root diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 5616a8b7b..dfb4e2a97 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -1,25 +1,41 @@ # These packages are used to run language runtime tests inside gVisor sandboxes. -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary") load("//test/runtimes:build_defs.bzl", "runtime_test") package(licenses = ["notice"]) -go_library( - name = "runtimes", - srcs = ["runtimes.go"], - importpath = "gvisor.dev/gvisor/test/runtimes", +go_binary( + name = "runner", + testonly = 1, + srcs = ["runner.go"], + deps = [ + "//runsc/dockerutil", + "//runsc/testutil", + ], ) runtime_test( - name = "runtimes_test", - size = "small", - srcs = ["runtimes_test.go"], - embed = [":runtimes"], - tags = [ - # Requires docker and runsc to be configured before the test runs. - "manual", - "local", - ], - deps = ["//runsc/testutil"], + image = "gcr.io/gvisor-presubmit/go1.12", + lang = "go", +) + +runtime_test( + image = "gcr.io/gvisor-presubmit/java11", + lang = "java", +) + +runtime_test( + image = "gcr.io/gvisor-presubmit/nodejs12.4.0", + lang = "nodejs", +) + +runtime_test( + image = "gcr.io/gvisor-presubmit/php7.3.6", + lang = "php", +) + +runtime_test( + image = "gcr.io/gvisor-presubmit/python3.7.3", + lang = "python", ) diff --git a/test/runtimes/README.md b/test/runtimes/README.md index 34d3507be..e41e78f77 100644 --- a/test/runtimes/README.md +++ b/test/runtimes/README.md @@ -16,10 +16,11 @@ The following runtimes are currently supported: 1) [Install and configure Docker](https://docs.docker.com/install/) -2) Build each Docker container from the runtimes directory: +2) Build each Docker container from the runtimes/images directory: ```bash -$ docker build -f $LANG/Dockerfile [-t $NAME] . +$ cd images +$ docker build -f Dockerfile_$LANG [-t $NAME] . ``` ### Testing: diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl index ac28cc037..5e3065342 100644 --- a/test/runtimes/build_defs.bzl +++ b/test/runtimes/build_defs.bzl @@ -1,19 +1,35 @@ """Defines a rule for runsc test targets.""" -load("@io_bazel_rules_go//go:def.bzl", _go_test = "go_test") - # runtime_test is a macro that will create targets to run the given test target # with different runtime options. -def runtime_test(**kwargs): - """Runs the given test target with different runtime options.""" - name = kwargs["name"] - _go_test(**kwargs) - kwargs["name"] = name + "_hostnet" - kwargs["args"] = ["--runtime-type=hostnet"] - _go_test(**kwargs) - kwargs["name"] = name + "_kvm" - kwargs["args"] = ["--runtime-type=kvm"] - _go_test(**kwargs) - kwargs["name"] = name + "_overlay" - kwargs["args"] = ["--runtime-type=overlay"] - _go_test(**kwargs) +def runtime_test( + lang, + image, + shard_count = 50, + size = "enormous"): + sh_test( + name = lang + "_test", + srcs = ["runner.sh"], + args = [ + "--lang", + lang, + "--image", + image, + ], + data = [ + ":runner", + ], + size = size, + shard_count = shard_count, + tags = [ + # Requires docker and runsc to be configured before the test runs. + "manual", + "local", + ], + ) + +def sh_test(**kwargs): + """Wraps the standard sh_test.""" + native.sh_test( + **kwargs + ) diff --git a/test/runtimes/common/BUILD b/test/runtimes/common/BUILD deleted file mode 100644 index b4740bb97..000000000 --- a/test/runtimes/common/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "common", - srcs = ["common.go"], - importpath = "gvisor.dev/gvisor/test/runtimes/common", - visibility = ["//:sandbox"], -) - -go_test( - name = "common_test", - size = "small", - srcs = ["common_test.go"], - deps = [ - ":common", - "//runsc/testutil", - ], -) diff --git a/test/runtimes/common/common.go b/test/runtimes/common/common.go deleted file mode 100644 index 0ff87fa8b..000000000 --- a/test/runtimes/common/common.go +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package common executes functions for proctor binaries. -package common - -import ( - "flag" - "fmt" - "os" - "path/filepath" - "regexp" -) - -var ( - list = flag.Bool("list", false, "list all available tests") - test = flag.String("test", "", "run a single test from the list of available tests") - version = flag.Bool("v", false, "print out the version of node that is installed") -) - -// TestRunner is an interface to be implemented in each proctor binary. -type TestRunner interface { - // ListTests returns a string slice of tests available to run. - ListTests() ([]string, error) - - // RunTest runs a single test. - RunTest(test string) error -} - -// LaunchFunc parses flags passed by a proctor binary and calls the requested behavior. -func LaunchFunc(tr TestRunner) error { - flag.Parse() - - if *list && *test != "" { - flag.PrintDefaults() - return fmt.Errorf("cannot specify 'list' and 'test' flags simultaneously") - } - if *list { - tests, err := tr.ListTests() - if err != nil { - return fmt.Errorf("failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return nil - } - if *version { - fmt.Println(os.Getenv("LANG_NAME"), "version:", os.Getenv("LANG_VER"), "is installed.") - return nil - } - if *test != "" { - if err := tr.RunTest(*test); err != nil { - return fmt.Errorf("test %q failed to run: %v", *test, err) - } - return nil - } - - if err := runAllTests(tr); err != nil { - return fmt.Errorf("error running all tests: %v", err) - } - return nil -} - -// Search uses filepath.Walk to perform a search of the disk for test files -// and returns a string slice of tests. -func Search(root string, testFilter *regexp.Regexp) ([]string, error) { - var testSlice []string - - err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { - name := filepath.Base(path) - - if info.IsDir() || !testFilter.MatchString(name) { - return nil - } - - relPath, err := filepath.Rel(root, path) - if err != nil { - return err - } - testSlice = append(testSlice, relPath) - return nil - }) - - if err != nil { - return nil, fmt.Errorf("walking %q: %v", root, err) - } - - return testSlice, nil -} - -func runAllTests(tr TestRunner) error { - tests, err := tr.ListTests() - if err != nil { - return fmt.Errorf("failed to list tests: %v", err) - } - for _, test := range tests { - if err := tr.RunTest(test); err != nil { - return fmt.Errorf("test %q failed to run: %v", test, err) - } - } - return nil -} diff --git a/test/runtimes/go/BUILD b/test/runtimes/go/BUILD deleted file mode 100644 index ce971ee9d..000000000 --- a/test/runtimes/go/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "proctor-go", - srcs = ["proctor-go.go"], - deps = ["//test/runtimes/common"], -) diff --git a/test/runtimes/go/Dockerfile b/test/runtimes/go/Dockerfile deleted file mode 100644 index 2d3477392..000000000 --- a/test/runtimes/go/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -FROM ubuntu:bionic -ENV LANG_VER=1.12.5 -ENV LANG_NAME=Go - -RUN apt-get update && apt-get install -y \ - curl \ - gcc \ - git - -WORKDIR /root - -# Download Go 1.4 to use as a bootstrap for building Go from the source. -RUN curl -o go1.4.linux-amd64.tar.gz https://dl.google.com/go/go1.4.linux-amd64.tar.gz -RUN curl -LJO https://github.com/golang/go/archive/go${LANG_VER}.tar.gz -RUN mkdir bootstr -RUN tar -C bootstr -xzf go1.4.linux-amd64.tar.gz -RUN tar -xzf go-go${LANG_VER}.tar.gz -RUN mv go-go${LANG_VER} go - -ENV GOROOT=/root/go -ENV GOROOT_BOOTSTRAP=/root/bootstr/go -ENV LANG_DIR=${GOROOT} - -WORKDIR ${LANG_DIR}/src -RUN ./make.bash -# Pre-compile the tests for faster execution -RUN ["/root/go/bin/go", "tool", "dist", "test", "-compile-only"] - -WORKDIR ${LANG_DIR} - -COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common -COPY go/proctor-go.go ${LANG_DIR} -RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-go.go"] - -ENTRYPOINT ["/root/go/bin/proctor"] diff --git a/test/runtimes/images/Dockerfile_go1.12 b/test/runtimes/images/Dockerfile_go1.12 new file mode 100644 index 000000000..ab9d6abf3 --- /dev/null +++ b/test/runtimes/images/Dockerfile_go1.12 @@ -0,0 +1,10 @@ +# Go is easy, since we already have everything we need to compile the proctor +# binary and run the tests in the golang Docker image. +FROM golang:1.12 +ADD ["proctor/", "/go/src/proctor/"] +RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] + +# Pre-compile the tests so we don't need to do so in each test run. +RUN ["go", "tool", "dist", "test", "-compile-only"] + +ENTRYPOINT ["/proctor", "--runtime=go"] diff --git a/test/runtimes/images/Dockerfile_java11 b/test/runtimes/images/Dockerfile_java11 new file mode 100644 index 000000000..9b7c3d5a3 --- /dev/null +++ b/test/runtimes/images/Dockerfile_java11 @@ -0,0 +1,30 @@ +# Compile the proctor binary. +FROM golang:1.12 AS golang +ADD ["proctor/", "/go/src/proctor/"] +RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] + +FROM ubuntu:bionic +RUN apt-get update && apt-get install -y \ + autoconf \ + build-essential \ + curl \ + make \ + openjdk-11-jdk \ + unzip \ + zip + +# Download the JDK test library. +WORKDIR /root +RUN set -ex \ + && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk11/archive/76072a077ee1.tar.gz/test \ + && tar -xzf /tmp/jdktests.tar.gz \ + && mv jdk11-76072a077ee1/test test \ + && rm -f /tmp/jdktests.tar.gz + +# Install jtreg and add to PATH. +RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz +RUN tar -xzf jtreg.tar.gz +ENV PATH="/root/jtreg/bin:$PATH" + +COPY --from=golang /proctor /proctor +ENTRYPOINT ["/proctor", "--runtime=java"] diff --git a/test/runtimes/images/Dockerfile_nodejs12.4.0 b/test/runtimes/images/Dockerfile_nodejs12.4.0 new file mode 100644 index 000000000..26f68b487 --- /dev/null +++ b/test/runtimes/images/Dockerfile_nodejs12.4.0 @@ -0,0 +1,28 @@ +# Compile the proctor binary. +FROM golang:1.12 AS golang +ADD ["proctor/", "/go/src/proctor/"] +RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] + +FROM ubuntu:bionic +RUN apt-get update && apt-get install -y \ + curl \ + dumb-init \ + g++ \ + make \ + python + +WORKDIR /root +ARG VERSION=v12.4.0 +RUN curl -o node-${VERSION}.tar.gz https://nodejs.org/dist/${VERSION}/node-${VERSION}.tar.gz +RUN tar -zxf node-${VERSION}.tar.gz + +WORKDIR /root/node-${VERSION} +RUN ./configure +RUN make +RUN make test-build + +COPY --from=golang /proctor /proctor + +# Including dumb-init emulates the Linux "init" process, preventing the failure +# of tests involving worker processes. +ENTRYPOINT ["/usr/bin/dumb-init", "/proctor", "--runtime=nodejs"] diff --git a/test/runtimes/images/Dockerfile_php7.3.6 b/test/runtimes/images/Dockerfile_php7.3.6 new file mode 100644 index 000000000..e6b4c6329 --- /dev/null +++ b/test/runtimes/images/Dockerfile_php7.3.6 @@ -0,0 +1,27 @@ +# Compile the proctor binary. +FROM golang:1.12 AS golang +ADD ["proctor/", "/go/src/proctor/"] +RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] + +FROM ubuntu:bionic +RUN apt-get update && apt-get install -y \ + autoconf \ + automake \ + bison \ + build-essential \ + curl \ + libtool \ + libxml2-dev \ + re2c + +WORKDIR /root +ARG VERSION=7.3.6 +RUN curl -o php-${VERSION}.tar.gz https://www.php.net/distributions/php-${VERSION}.tar.gz +RUN tar -zxf php-${VERSION}.tar.gz + +WORKDIR /root/php-${VERSION} +RUN ./configure +RUN make + +COPY --from=golang /proctor /proctor +ENTRYPOINT ["/proctor", "--runtime=php"] diff --git a/test/runtimes/images/Dockerfile_python3.7.3 b/test/runtimes/images/Dockerfile_python3.7.3 new file mode 100644 index 000000000..905cd22d7 --- /dev/null +++ b/test/runtimes/images/Dockerfile_python3.7.3 @@ -0,0 +1,30 @@ +# Compile the proctor binary. +FROM golang:1.12 AS golang +ADD ["proctor/", "/go/src/proctor/"] +RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"] + +FROM ubuntu:bionic + +RUN apt-get update && apt-get install -y \ + curl \ + gcc \ + libbz2-dev \ + libffi-dev \ + liblzma-dev \ + libreadline-dev \ + libssl-dev \ + make \ + zlib1g-dev + +# Use flags -LJO to follow the html redirect and download .tar.gz. +WORKDIR /root +ARG VERSION=3.7.3 +RUN curl -LJO https://github.com/python/cpython/archive/v${VERSION}.tar.gz +RUN tar -zxf cpython-${VERSION}.tar.gz + +WORKDIR /root/cpython-${VERSION} +RUN ./configure --with-pydebug +RUN make -s -j2 + +COPY --from=golang /proctor /proctor +ENTRYPOINT ["/proctor", "--runtime=python"] diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/images/proctor/BUILD new file mode 100644 index 000000000..09dc6c42f --- /dev/null +++ b/test/runtimes/images/proctor/BUILD @@ -0,0 +1,26 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") + +package(licenses = ["notice"]) + +go_binary( + name = "proctor", + srcs = [ + "go.go", + "java.go", + "nodejs.go", + "php.go", + "proctor.go", + "python.go", + ], + visibility = ["//test/runtimes/images:__subpackages__"], +) + +go_test( + name = "proctor_test", + size = "small", + srcs = ["proctor_test.go"], + embed = [":proctor"], + deps = [ + "//runsc/testutil", + ], +) diff --git a/test/runtimes/go/proctor-go.go b/test/runtimes/images/proctor/go.go index 3eb24576e..3e2d5d8db 100644 --- a/test/runtimes/go/proctor-go.go +++ b/test/runtimes/images/proctor/go.go @@ -12,50 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary proctor-go is a utility that facilitates language testing for Go. - -// There are two types of Go tests: "Go tool tests" and "Go tests on disk". -// "Go tool tests" are found and executed using `go tool dist test`. -// "Go tests on disk" are found in the /test directory and are -// executed using `go run run.go`. package main import ( "fmt" - "log" "os" "os/exec" - "path/filepath" "regexp" "strings" - - "gvisor.dev/gvisor/test/runtimes/common" ) var ( - dir = os.Getenv("LANG_DIR") - goBin = filepath.Join(dir, "bin/go") - testDir = filepath.Join(dir, "test") - testRegEx = regexp.MustCompile(`^.+\.go$`) + goTestRegEx = regexp.MustCompile(`^.+\.go$`) // Directories with .dir contain helper files for tests. // Exclude benchmarks and stress tests. - dirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`) + goDirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`) ) -type goRunner struct { -} +// Location of Go tests on disk. +const goTestDir = "/usr/local/go/test" -func main() { - if err := common.LaunchFunc(goRunner{}); err != nil { - log.Fatalf("Failed to start: %v", err) - } -} +// goRunner implements TestRunner for Go. +// +// There are two types of Go tests: "Go tool tests" and "Go tests on disk". +// "Go tool tests" are found and executed using `go tool dist test`. "Go tests +// on disk" are found in the /usr/local/go/test directory and are executed +// using `go run run.go`. +type goRunner struct{} + +var _ TestRunner = goRunner{} -func (g goRunner) ListTests() ([]string, error) { +// ListTests implements TestRunner.ListTests. +func (goRunner) ListTests() ([]string, error) { // Go tool dist test tests. args := []string{"tool", "dist", "test", "-list"} - cmd := exec.Command(filepath.Join(dir, "bin/go"), args...) + cmd := exec.Command("go", args...) cmd.Stderr = os.Stderr out, err := cmd.Output() if err != nil { @@ -67,14 +59,14 @@ func (g goRunner) ListTests() ([]string, error) { } // Go tests on disk. - diskSlice, err := common.Search(testDir, testRegEx) + diskSlice, err := search(goTestDir, goTestRegEx) if err != nil { return nil, err } // Remove items from /bench/, /stress/ and .dir files diskFiltered := diskSlice[:0] for _, file := range diskSlice { - if !dirFilter.MatchString(file) { + if !goDirFilter.MatchString(file) { diskFiltered = append(diskFiltered, file) } } @@ -82,24 +74,17 @@ func (g goRunner) ListTests() ([]string, error) { return append(toolSlice, diskFiltered...), nil } -func (g goRunner) RunTest(test string) error { +// TestCmd implements TestRunner.TestCmd. +func (goRunner) TestCmd(test string) *exec.Cmd { // Check if test exists on disk by searching for file of the same name. // This will determine whether or not it is a Go test on disk. if strings.HasSuffix(test, ".go") { // Test has suffix ".go" which indicates a disk test, run it as such. - cmd := exec.Command(goBin, "run", "run.go", "-v", "--", test) - cmd.Dir = testDir - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to run test: %v", err) - } - } else { - // No ".go" suffix, run as a tool test. - cmd := exec.Command(goBin, "tool", "dist", "test", "-run", test) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to run test: %v", err) - } + cmd := exec.Command("go", "run", "run.go", "-v", "--", test) + cmd.Dir = goTestDir + return cmd } - return nil + + // No ".go" suffix, run as a tool test. + return exec.Command("go", "tool", "dist", "test", "-run", test) } diff --git a/test/runtimes/java/proctor-java.go b/test/runtimes/images/proctor/java.go index 7f6a66f4f..8b362029d 100644 --- a/test/runtimes/java/proctor-java.go +++ b/test/runtimes/images/proctor/java.go @@ -12,40 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary proctor-java is a utility that facilitates language testing for Java. package main import ( "fmt" - "log" "os" "os/exec" - "path/filepath" "regexp" "strings" - - "gvisor.dev/gvisor/test/runtimes/common" ) -var ( - dir = os.Getenv("LANG_DIR") - hash = os.Getenv("LANG_HASH") - jtreg = filepath.Join(dir, "jtreg/bin/jtreg") - exclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`) -) +// Directories to exclude from tests. +var javaExclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`) -type javaRunner struct { -} +// Location of java tests. +const javaTestDir = "/root/test/jdk" -func main() { - if err := common.LaunchFunc(javaRunner{}); err != nil { - log.Fatalf("Failed to start: %v", err) - } -} +// javaRunner implements TestRunner for Java. +type javaRunner struct{} + +var _ TestRunner = javaRunner{} -func (j javaRunner) ListTests() ([]string, error) { +// ListTests implements TestRunner.ListTests. +func (javaRunner) ListTests() ([]string, error) { args := []string{ - "-dir:/root/jdk11-" + hash + "/test/jdk", + "-dir:" + javaTestDir, "-ignore:quiet", "-a", "-listtests", @@ -54,7 +45,7 @@ func (j javaRunner) ListTests() ([]string, error) { ":jdk_sound", ":jdk_imageio", } - cmd := exec.Command(jtreg, args...) + cmd := exec.Command("jtreg", args...) cmd.Stderr = os.Stderr out, err := cmd.Output() if err != nil { @@ -62,19 +53,19 @@ func (j javaRunner) ListTests() ([]string, error) { } var testSlice []string for _, test := range strings.Split(string(out), "\n") { - if !exclDirs.MatchString(test) { + if !javaExclDirs.MatchString(test) { testSlice = append(testSlice, test) } } return testSlice, nil } -func (j javaRunner) RunTest(test string) error { - args := []string{"-noreport", "-dir:/root/jdk11-" + hash + "/test/jdk", test} - cmd := exec.Command(jtreg, args...) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to run: %v", err) +// TestCmd implements TestRunner.TestCmd. +func (javaRunner) TestCmd(test string) *exec.Cmd { + args := []string{ + "-noreport", + "-dir:" + javaTestDir, + test, } - return nil + return exec.Command("jtreg", args...) } diff --git a/test/runtimes/images/proctor/nodejs.go b/test/runtimes/images/proctor/nodejs.go new file mode 100644 index 000000000..bd57db444 --- /dev/null +++ b/test/runtimes/images/proctor/nodejs.go @@ -0,0 +1,46 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "os/exec" + "path/filepath" + "regexp" +) + +var nodejsTestRegEx = regexp.MustCompile(`^test-[^-].+\.js$`) + +// Location of nodejs tests relative to working dir. +const nodejsTestDir = "test" + +// nodejsRunner implements TestRunner for NodeJS. +type nodejsRunner struct{} + +var _ TestRunner = nodejsRunner{} + +// ListTests implements TestRunner.ListTests. +func (nodejsRunner) ListTests() ([]string, error) { + testSlice, err := search(nodejsTestDir, nodejsTestRegEx) + if err != nil { + return nil, err + } + return testSlice, nil +} + +// TestCmd implements TestRunner.TestCmd. +func (nodejsRunner) TestCmd(test string) *exec.Cmd { + args := []string{filepath.Join("tools", "test.py"), test} + return exec.Command("/usr/bin/python", args...) +} diff --git a/test/runtimes/php/proctor-php.go b/test/runtimes/images/proctor/php.go index e6c5fabdf..9115040e1 100644 --- a/test/runtimes/php/proctor-php.go +++ b/test/runtimes/images/proctor/php.go @@ -12,47 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary proctor-php is a utility that facilitates language testing for PHP. package main import ( - "fmt" - "log" - "os" "os/exec" "regexp" - - "gvisor.dev/gvisor/test/runtimes/common" ) -var ( - dir = os.Getenv("LANG_DIR") - testRegEx = regexp.MustCompile(`^.+\.phpt$`) -) +var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`) -type phpRunner struct { -} +// phpRunner implements TestRunner for PHP. +type phpRunner struct{} -func main() { - if err := common.LaunchFunc(phpRunner{}); err != nil { - log.Fatalf("Failed to start: %v", err) - } -} +var _ TestRunner = phpRunner{} -func (p phpRunner) ListTests() ([]string, error) { - testSlice, err := common.Search(dir, testRegEx) +// ListTests implements TestRunner.ListTests. +func (phpRunner) ListTests() ([]string, error) { + testSlice, err := search(".", phpTestRegEx) if err != nil { return nil, err } return testSlice, nil } -func (p phpRunner) RunTest(test string) error { +// TestCmd implements TestRunner.TestCmd. +func (phpRunner) TestCmd(test string) *exec.Cmd { args := []string{"test", "TESTS=" + test} - cmd := exec.Command("make", args...) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to run: %v", err) - } - return nil + return exec.Command("make", args...) } diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/images/proctor/proctor.go new file mode 100644 index 000000000..e6178e82b --- /dev/null +++ b/test/runtimes/images/proctor/proctor.go @@ -0,0 +1,154 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary proctor runs the test for a particular runtime. It is meant to be +// included in Docker images for all runtime tests. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "os/exec" + "os/signal" + "path/filepath" + "regexp" + "syscall" +) + +// TestRunner is an interface that must be implemented for each runtime +// integrated with proctor. +type TestRunner interface { + // ListTests returns a string slice of tests available to run. + ListTests() ([]string, error) + + // TestCmd returns an *exec.Cmd that will run the given test. + TestCmd(test string) *exec.Cmd +} + +var ( + runtime = flag.String("runtime", "", "name of runtime") + list = flag.Bool("list", false, "list all available tests") + test = flag.String("test", "", "run a single test from the list of available tests") + pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") +) + +func main() { + flag.Parse() + + if *pause { + pauseAndReap() + panic("pauseAndReap should never return") + } + + if *runtime == "" { + log.Fatalf("runtime flag must be provided") + } + + tr, err := testRunnerForRuntime(*runtime) + if err != nil { + log.Fatalf("%v", err) + } + + // List tests. + if *list { + tests, err := tr.ListTests() + if err != nil { + log.Fatalf("failed to list tests: %v", err) + } + for _, test := range tests { + fmt.Println(test) + } + return + } + + // Run a single test. + if *test == "" { + log.Fatalf("test flag must be provided") + } + cmd := tr.TestCmd(*test) + cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr + if err := cmd.Run(); err != nil { + log.Fatalf("FAIL: %v", err) + } +} + +// testRunnerForRuntime returns a new TestRunner for the given runtime. +func testRunnerForRuntime(runtime string) (TestRunner, error) { + switch runtime { + case "go": + return goRunner{}, nil + case "java": + return javaRunner{}, nil + case "nodejs": + return nodejsRunner{}, nil + case "php": + return phpRunner{}, nil + case "python": + return pythonRunner{}, nil + } + return nil, fmt.Errorf("invalid runtime %q", runtime) +} + +// pauseAndReap is like init. It runs forever and reaps any children. +func pauseAndReap() { + // Get notified of any new children. + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGCHLD) + + for { + if _, ok := <-ch; !ok { + // Channel closed. This should not happen. + panic("signal channel closed") + } + + // Reap the child. + for { + if cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil); cpid < 1 { + break + } + } + } +} + +// search is a helper function to find tests in the given directory that match +// the regex. +func search(root string, testFilter *regexp.Regexp) ([]string, error) { + var testSlice []string + + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + name := filepath.Base(path) + + if info.IsDir() || !testFilter.MatchString(name) { + return nil + } + + relPath, err := filepath.Rel(root, path) + if err != nil { + return err + } + testSlice = append(testSlice, relPath) + return nil + }) + if err != nil { + return nil, fmt.Errorf("walking %q: %v", root, err) + } + + return testSlice, nil +} diff --git a/test/runtimes/common/common_test.go b/test/runtimes/images/proctor/proctor_test.go index 65875b41b..6bb61d142 100644 --- a/test/runtimes/common/common_test.go +++ b/test/runtimes/images/proctor/proctor_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package common_test +package main import ( "io/ioutil" @@ -24,7 +24,6 @@ import ( "testing" "gvisor.dev/gvisor/runsc/testutil" - "gvisor.dev/gvisor/test/runtimes/common" ) func touch(t *testing.T, name string) { @@ -48,9 +47,9 @@ func TestSearchEmptyDir(t *testing.T) { var want []string testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := common.Search(td, testFilter) + got, err := search(td, testFilter) if err != nil { - t.Errorf("Search error: %v", err) + t.Errorf("search error: %v", err) } if !reflect.DeepEqual(got, want) { @@ -117,9 +116,9 @@ func TestSearch(t *testing.T) { } testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := common.Search(td, testFilter) + got, err := search(td, testFilter) if err != nil { - t.Errorf("Search error: %v", err) + t.Errorf("search error: %v", err) } if !reflect.DeepEqual(got, want) { diff --git a/test/runtimes/python/proctor-python.go b/test/runtimes/images/proctor/python.go index 35e28a7df..b9e0fbe6f 100644 --- a/test/runtimes/python/proctor-python.go +++ b/test/runtimes/images/proctor/python.go @@ -12,36 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary proctor-python is a utility that facilitates language testing for Pyhton. package main import ( "fmt" - "log" "os" "os/exec" - "path/filepath" "strings" - - "gvisor.dev/gvisor/test/runtimes/common" ) -var ( - dir = os.Getenv("LANG_DIR") -) +// pythonRunner implements TestRunner for Python. +type pythonRunner struct{} -type pythonRunner struct { -} +var _ TestRunner = pythonRunner{} -func main() { - if err := common.LaunchFunc(pythonRunner{}); err != nil { - log.Fatalf("Failed to start: %v", err) - } -} - -func (p pythonRunner) ListTests() ([]string, error) { +// ListTests implements TestRunner.ListTests. +func (pythonRunner) ListTests() ([]string, error) { args := []string{"-m", "test", "--list-tests"} - cmd := exec.Command(filepath.Join(dir, "python"), args...) + cmd := exec.Command("./python", args...) cmd.Stderr = os.Stderr out, err := cmd.Output() if err != nil { @@ -54,12 +42,8 @@ func (p pythonRunner) ListTests() ([]string, error) { return toolSlice, nil } -func (p pythonRunner) RunTest(test string) error { +// TestCmd implements TestRunner.TestCmd. +func (pythonRunner) TestCmd(test string) *exec.Cmd { args := []string{"-m", "test", test} - cmd := exec.Command(filepath.Join(dir, "python"), args...) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to run: %v", err) - } - return nil + return exec.Command("./python", args...) } diff --git a/test/runtimes/java/BUILD b/test/runtimes/java/BUILD deleted file mode 100644 index 8c39d39ec..000000000 --- a/test/runtimes/java/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "proctor-java", - srcs = ["proctor-java.go"], - deps = ["//test/runtimes/common"], -) diff --git a/test/runtimes/java/Dockerfile b/test/runtimes/java/Dockerfile deleted file mode 100644 index 1a61d9d8f..000000000 --- a/test/runtimes/java/Dockerfile +++ /dev/null @@ -1,36 +0,0 @@ -FROM ubuntu:bionic -# This hash is associated with a specific JDK release and needed for ensuring -# the same version is downloaded every time. -ENV LANG_HASH=76072a077ee1 -ENV LANG_VER=11 -ENV LANG_NAME=Java - -RUN apt-get update && apt-get install -y \ - autoconf \ - build-essential \ - curl \ - make \ - openjdk-${LANG_VER}-jdk \ - unzip \ - zip - -WORKDIR /root -RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz -RUN tar -zxf go.tar.gz - -# Download the JDK test library. -RUN set -ex \ - && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk${LANG_VER}/archive/${LANG_HASH}.tar.gz/test \ - && tar -xzf /tmp/jdktests.tar.gz -C /root \ - && rm -f /tmp/jdktests.tar.gz - -RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz -RUN tar -xzf jtreg.tar.gz - -ENV LANG_DIR=/root - -COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common -COPY java/proctor-java.go ${LANG_DIR} -RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-java.go"] - -ENTRYPOINT ["/root/go/bin/proctor"] diff --git a/test/runtimes/nodejs/BUILD b/test/runtimes/nodejs/BUILD deleted file mode 100644 index 0594c250b..000000000 --- a/test/runtimes/nodejs/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "proctor-nodejs", - srcs = ["proctor-nodejs.go"], - deps = ["//test/runtimes/common"], -) diff --git a/test/runtimes/nodejs/Dockerfile b/test/runtimes/nodejs/Dockerfile deleted file mode 100644 index ce2943af8..000000000 --- a/test/runtimes/nodejs/Dockerfile +++ /dev/null @@ -1,31 +0,0 @@ -FROM ubuntu:bionic -ENV LANG_VER=12.4.0 -ENV LANG_NAME=Node - -RUN apt-get update && apt-get install -y \ - curl \ - dumb-init \ - g++ \ - make \ - python - -WORKDIR /root -RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz -RUN tar -zxf go.tar.gz - -RUN curl -o node-v${LANG_VER}.tar.gz https://nodejs.org/dist/v${LANG_VER}/node-v${LANG_VER}.tar.gz -RUN tar -zxf node-v${LANG_VER}.tar.gz -ENV LANG_DIR=/root/node-v${LANG_VER} - -WORKDIR ${LANG_DIR} -RUN ./configure -RUN make -RUN make test-build - -COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common -COPY nodejs/proctor-nodejs.go ${LANG_DIR} -RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-nodejs.go"] - -# Including dumb-init emulates the Linux "init" process, preventing the failure -# of tests involving worker processes. -ENTRYPOINT ["/usr/bin/dumb-init", "/root/go/bin/proctor"] diff --git a/test/runtimes/nodejs/proctor-nodejs.go b/test/runtimes/nodejs/proctor-nodejs.go deleted file mode 100644 index 0624f6a0d..000000000 --- a/test/runtimes/nodejs/proctor-nodejs.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary proctor-nodejs is a utility that facilitates language testing for NodeJS. -package main - -import ( - "fmt" - "log" - "os" - "os/exec" - "path/filepath" - "regexp" - - "gvisor.dev/gvisor/test/runtimes/common" -) - -var ( - dir = os.Getenv("LANG_DIR") - testDir = filepath.Join(dir, "test") - testRegEx = regexp.MustCompile(`^test-[^-].+\.js$`) -) - -type nodejsRunner struct { -} - -func main() { - if err := common.LaunchFunc(nodejsRunner{}); err != nil { - log.Fatalf("Failed to start: %v", err) - } -} - -func (n nodejsRunner) ListTests() ([]string, error) { - testSlice, err := common.Search(testDir, testRegEx) - if err != nil { - return nil, err - } - return testSlice, nil -} - -func (n nodejsRunner) RunTest(test string) error { - args := []string{filepath.Join(dir, "tools", "test.py"), test} - cmd := exec.Command("/usr/bin/python", args...) - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to run: %v", err) - } - return nil -} diff --git a/test/runtimes/php/BUILD b/test/runtimes/php/BUILD deleted file mode 100644 index 31799b77a..000000000 --- a/test/runtimes/php/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "proctor-php", - srcs = ["proctor-php.go"], - deps = ["//test/runtimes/common"], -) diff --git a/test/runtimes/php/Dockerfile b/test/runtimes/php/Dockerfile deleted file mode 100644 index d79babe58..000000000 --- a/test/runtimes/php/Dockerfile +++ /dev/null @@ -1,31 +0,0 @@ -FROM ubuntu:bionic -ENV LANG_VER=7.3.6 -ENV LANG_NAME=PHP - -RUN apt-get update && apt-get install -y \ - autoconf \ - automake \ - bison \ - build-essential \ - curl \ - libtool \ - libxml2-dev \ - re2c - -WORKDIR /root -RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz -RUN tar -zxf go.tar.gz - -RUN curl -o php-${LANG_VER}.tar.gz https://www.php.net/distributions/php-${LANG_VER}.tar.gz -RUN tar -zxf php-${LANG_VER}.tar.gz -ENV LANG_DIR=/root/php-${LANG_VER} - -WORKDIR ${LANG_DIR} -RUN ./configure -RUN make - -COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common -COPY php/proctor-php.go ${LANG_DIR} -RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-php.go"] - -ENTRYPOINT ["/root/go/bin/proctor"] diff --git a/test/runtimes/python/BUILD b/test/runtimes/python/BUILD deleted file mode 100644 index 37fd6a0f2..000000000 --- a/test/runtimes/python/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "proctor-python", - srcs = ["proctor-python.go"], - deps = ["//test/runtimes/common"], -) diff --git a/test/runtimes/python/Dockerfile b/test/runtimes/python/Dockerfile deleted file mode 100644 index 5ae328890..000000000 --- a/test/runtimes/python/Dockerfile +++ /dev/null @@ -1,33 +0,0 @@ -FROM ubuntu:bionic -ENV LANG_VER=3.7.3 -ENV LANG_NAME=Python - -RUN apt-get update && apt-get install -y \ - curl \ - gcc \ - libbz2-dev \ - libffi-dev \ - liblzma-dev \ - libreadline-dev \ - libssl-dev \ - make \ - zlib1g-dev - -WORKDIR /root -RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz -RUN tar -zxf go.tar.gz - -# Use flags -LJO to follow the html redirect and download .tar.gz. -RUN curl -LJO https://github.com/python/cpython/archive/v${LANG_VER}.tar.gz -RUN tar -zxf cpython-${LANG_VER}.tar.gz -ENV LANG_DIR=/root/cpython-${LANG_VER} - -WORKDIR ${LANG_DIR} -RUN ./configure --with-pydebug -RUN make -s -j2 - -COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common -COPY python/proctor-python.go ${LANG_DIR} -RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-python.go"] - -ENTRYPOINT ["/root/go/bin/proctor"] diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go new file mode 100644 index 000000000..3a15f59a7 --- /dev/null +++ b/test/runtimes/runner.go @@ -0,0 +1,147 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary runner runs the runtime tests in a Docker container. +package main + +import ( + "flag" + "fmt" + "io" + "log" + "os" + "sort" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/runsc/testutil" +) + +var ( + lang = flag.String("lang", "", "language runtime to test") + image = flag.String("image", "", "docker image with runtime tests") +) + +// Wait time for each test to run. +const timeout = 5 * time.Minute + +func main() { + flag.Parse() + if *lang == "" || *image == "" { + fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n") + os.Exit(1) + } + + os.Exit(runTests()) +} + +// runTests is a helper that is called by main. It exists so that we can run +// defered functions before exiting. It returns an exit code that should be +// passed to os.Exit. +func runTests() int { + // Create a single docker container that will be used for all tests. + d := dockerutil.MakeDocker("gvisor-" + *lang) + defer d.CleanUp() + + // Get a slice of tests to run. This will also start a single Docker + // container that will be used to run each test. The final test will + // stop the Docker container. + tests, err := getTests(d) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + return 1 + } + + m := testing.MainStart(testDeps{}, tests, nil, nil) + return m.Run() +} + +// getTests returns a slice of tests to run, subject to the shard size and +// index. +func getTests(d dockerutil.Docker) ([]testing.InternalTest, error) { + // Pull the image. + if err := dockerutil.Pull(*image); err != nil { + return nil, fmt.Errorf("docker pull %q failed: %v", *image, err) + } + + // Run proctor with --pause flag to keep container alive forever. + if err := d.Run(*image, "--pause"); err != nil { + return nil, fmt.Errorf("docker run failed: %v", err) + } + + // Get a list of all tests in the image. + list, err := d.Exec("/proctor", "--runtime", *lang, "--list") + if err != nil { + return nil, fmt.Errorf("docker exec failed: %v", err) + } + + // Calculate a subset of tests to run corresponding to the current + // shard. + tests := strings.Fields(list) + sort.Strings(tests) + begin, end, err := testutil.TestBoundsForShard(len(tests)) + if err != nil { + return nil, fmt.Errorf("TestsForShard() failed: %v", err) + } + log.Printf("Got bounds [%d:%d) for shard out of %d total tests", begin, end, len(tests)) + tests = tests[begin:end] + + var itests []testing.InternalTest + for _, tc := range tests { + // Capture tc in this scope. + tc := tc + itests = append(itests, testing.InternalTest{ + Name: tc, + F: func(t *testing.T) { + var ( + now = time.Now() + done = make(chan struct{}) + output string + err error + ) + go func() { + fmt.Printf("RUNNING %s...\n", tc) + output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc) + close(done) + }() + + select { + case <-done: + if err == nil { + fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now)) + return + } + t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output) + case <-time.After(timeout): + t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output) + } + }, + }) + } + return itests, nil +} + +// testDeps implements testing.testDeps (an unexported interface), and is +// required to use testing.MainStart. +type testDeps struct{} + +func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil } +func (f testDeps) StartCPUProfile(io.Writer) error { return nil } +func (f testDeps) StopCPUProfile() {} +func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil } +func (f testDeps) ImportPath() string { return "" } +func (f testDeps) StartTestLog(io.Writer) {} +func (f testDeps) StopTestLog() error { return nil } diff --git a/kokoro/run_tests.sh b/test/runtimes/runner.sh index 5552da11c..a8d9a3460 100644..100755 --- a/kokoro/run_tests.sh +++ b/test/runtimes/runner.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright 2019 The gVisor Authors. +# Copyright 2018 The gVisor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,16 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -set -xeo pipefail +set -euf -x -o pipefail -# This file is a temporary bridge. We will create multiple independent Kokoro -# workflows that call each of the test scripts independently. +echo -- "$@" + +# Create outputs dir if it does not exist. +if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then + mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}" + chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}" +fi + +# Update the timestamp on the shard status file. Bazel looks for this. +touch "${TEST_SHARD_STATUS_FILE}" + +# Get location of runner binary. +readonly runner=$(find "${TEST_SRCDIR}" -name runner) + +# Pass the arguments of this script directly to the runner. +exec "${runner}" "$@" -# Run all the tests in sequence. -$(dirname $0)/../scripts/do_tests.sh -$(dirname $0)/../scripts/make_tests.sh -$(dirname $0)/../scripts/root_tests.sh -$(dirname $0)/../scripts/docker_tests.sh -$(dirname $0)/../scripts/overlay_tests.sh -$(dirname $0)/../scripts/hostnet_tests.sh -$(dirname $0)/../scripts/simple_tests.sh diff --git a/test/runtimes/runtimes_test.go b/test/runtimes/runtimes_test.go deleted file mode 100644 index 0ff5dda02..000000000 --- a/test/runtimes/runtimes_test.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package runtimes - -import ( - "strings" - "testing" - "time" - - "gvisor.dev/gvisor/runsc/testutil" -) - -// Wait time for each test to run. -const timeout = 180 * time.Second - -// Helper function to execute the docker container associated with the -// language passed. Captures the output of the list function and executes -// each test individually, supplying any errors recieved. -func testLang(t *testing.T, lang string) { - t.Helper() - - img := "gcr.io/gvisor-presubmit/" + lang - if err := testutil.Pull(img); err != nil { - t.Fatalf("docker pull failed: %v", err) - } - - c := testutil.MakeDocker("gvisor-list") - - list, err := c.RunFg(img, "--list") - if err != nil { - t.Fatalf("docker run failed: %v", err) - } - c.CleanUp() - - tests := strings.Fields(list) - - for _, tc := range tests { - tc := tc - t.Run(tc, func(t *testing.T) { - d := testutil.MakeDocker("gvisor-test") - if err := d.Run(img, "--test", tc); err != nil { - t.Fatalf("docker test %q failed to run: %v", tc, err) - } - defer d.CleanUp() - - status, err := d.Wait(timeout) - if err != nil { - t.Fatalf("docker test %q failed to wait: %v", tc, err) - } - if status == 0 { - t.Logf("test %q passed", tc) - return - } - logs, err := d.Logs() - if err != nil { - t.Fatalf("docker test %q failed to supply logs: %v", tc, err) - } - t.Errorf("test %q failed: %v", tc, logs) - }) - } -} - -func TestGo(t *testing.T) { - testLang(t, "go") -} - -func TestJava(t *testing.T) { - testLang(t, "java") -} - -func TestNodejs(t *testing.T) { - testLang(t, "nodejs") -} - -func TestPhp(t *testing.T) { - testLang(t, "php") -} - -func TestPython(t *testing.T) { - testLang(t, "python") -} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 0135435ea..63e4c63dd 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -321,6 +321,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:pty_root_test", +) + +syscall_test( add_overlay = True, test = "//test/syscalls/linux:pwritev2_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index a964bef24..a4cebf46f 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1,3 +1,4 @@ +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") load("//test/syscalls:build_defs.bzl", "select_for_linux") package( @@ -265,12 +266,15 @@ cc_binary( ], linkstatic = 1, deps = [ - # The heap check doesn't handle mremap properly. + # The heapchecker doesn't recognize that io_destroy munmaps. "@com_google_googletest//:gtest", "@com_google_absl//absl/strings", "//test/util:cleanup", "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:memory_util", "//test/util:posix_error", + "//test/util:proc_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -390,6 +394,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest", ], @@ -408,6 +413,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/flags:flag", "@com_google_googletest//:gtest", ], ) @@ -724,6 +730,7 @@ cc_binary( "//test/util:test_util", "//test/util:timer_util", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -972,6 +979,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", @@ -992,6 +1000,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", ], @@ -1278,8 +1287,10 @@ cc_binary( srcs = ["pty.cc"], linkstatic = 1, deps = [ + "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:posix_error", + "//test/util:pty_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", @@ -1292,6 +1303,23 @@ cc_binary( ) cc_binary( + name = "pty_root_test", + testonly = 1, + srcs = ["pty_root.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:posix_error", + "//test/util:pty_util", + "//test/util:test_main", + "//test/util:thread_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "partial_bad_buffer_test", testonly = 1, srcs = ["partial_bad_buffer.cc"], @@ -1402,6 +1430,7 @@ cc_binary( "//test/util:posix_error", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/flags:flag", "@com_google_googletest//:gtest", ], ) @@ -1418,6 +1447,7 @@ cc_binary( "//test/util:posix_error", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/flags:flag", "@com_google_googletest//:gtest", ], ) @@ -1601,6 +1631,7 @@ cc_binary( "//test/util:test_util", "//test/util:thread_util", "//test/util:time_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], @@ -1863,7 +1894,9 @@ cc_binary( "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "//test/util:thread_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) @@ -1897,6 +1930,7 @@ cc_binary( "//test/util:test_util", "//test/util:thread_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) @@ -1949,6 +1983,24 @@ cc_binary( ) cc_binary( + name = "signalfd_test", + testonly = 1, + srcs = ["signalfd.cc"], + linkstatic = 1, + deps = [ + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:posix_error", + "//test/util:signal_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "sigprocmask_test", testonly = 1, srcs = ["sigprocmask.cc"], @@ -1971,6 +2023,7 @@ cc_binary( "//test/util:posix_error", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], @@ -3106,6 +3159,7 @@ cc_binary( "//test/util:signal_util", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], @@ -3185,6 +3239,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", ], @@ -3276,6 +3331,7 @@ cc_binary( "//test/util:multiprocess_util", "//test/util:test_util", "//test/util:time_util", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc index 68dc05417..b27d4e10a 100644 --- a/test/syscalls/linux/aio.cc +++ b/test/syscalls/linux/aio.cc @@ -14,31 +14,57 @@ #include <fcntl.h> #include <linux/aio_abi.h> -#include <string.h> #include <sys/mman.h> #include <sys/syscall.h> #include <sys/types.h> #include <unistd.h> +#include <algorithm> +#include <string> + #include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/memory_util.h" +#include "test/util/posix_error.h" +#include "test/util/proc_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" +using ::testing::_; + namespace gvisor { namespace testing { namespace { +// Returns the size of the VMA containing the given address. +PosixErrorOr<size_t> VmaSizeAt(uintptr_t addr) { + ASSIGN_OR_RETURN_ERRNO(std::string proc_self_maps, + GetContents("/proc/self/maps")); + ASSIGN_OR_RETURN_ERRNO(auto entries, ParseProcMaps(proc_self_maps)); + // Use binary search to find the first VMA that might contain addr. + ProcMapsEntry target = {}; + target.end = addr; + auto it = + std::upper_bound(entries.begin(), entries.end(), target, + [](const ProcMapsEntry& x, const ProcMapsEntry& y) { + return x.end < y.end; + }); + // Check that it actually contains addr. + if (it == entries.end() || addr < it->start) { + return PosixError(ENOENT, absl::StrCat("no VMA contains address ", addr)); + } + return it->end - it->start; +} + constexpr char kData[] = "hello world!"; int SubmitCtx(aio_context_t ctx, long nr, struct iocb** iocbpp) { return syscall(__NR_io_submit, ctx, nr, iocbpp); } -} // namespace - class AIOTest : public FileTest { public: AIOTest() : ctx_(0) {} @@ -124,10 +150,10 @@ TEST_F(AIOTest, BasicWrite) { EXPECT_EQ(events[0].res, strlen(kData)); // Verify that the file contains the contents. - char verify_buf[32] = {}; - ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)), - SyscallSucceeds()); - EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0); + char verify_buf[sizeof(kData)] = {}; + ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)), + SyscallSucceedsWithValue(strlen(kData))); + EXPECT_STREQ(verify_buf, kData); } TEST_F(AIOTest, BadWrite) { @@ -220,38 +246,25 @@ TEST_F(AIOTest, CloneVm) { TEST_F(AIOTest, Mremap) { // Setup a context that is 128 entries deep. ASSERT_THAT(SetupContext(128), SyscallSucceeds()); + const size_t ctx_size = + ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_))); struct iocb cb = CreateCallback(); struct iocb* cbs[1] = {&cb}; // Reserve address space for the mremap target so we have something safe to // map over. - // - // N.B. We reserve 2 pages because we'll attempt to remap to 2 pages below. - // That should fail with EFAULT, but will fail with EINVAL if this mmap - // returns the page immediately below ctx_, as - // [new_address, new_address+2*kPageSize) overlaps [ctx_, ctx_+kPageSize). - void* new_address = mmap(nullptr, 2 * kPageSize, PROT_READ, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - ASSERT_THAT(reinterpret_cast<intptr_t>(new_address), SyscallSucceeds()); - auto mmap_cleanup = Cleanup([new_address] { - EXPECT_THAT(munmap(new_address, 2 * kPageSize), SyscallSucceeds()); - }); - - // Test that remapping to a larger address fails. - void* res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, 2 * kPageSize, - MREMAP_FIXED | MREMAP_MAYMOVE, new_address); - ASSERT_THAT(reinterpret_cast<intptr_t>(res), SyscallFailsWithErrno(EFAULT)); + Mapping dst = + ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE)); // Remap context 'handle' to a different address. - res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, kPageSize, - MREMAP_FIXED | MREMAP_MAYMOVE, new_address); - ASSERT_THAT( - reinterpret_cast<intptr_t>(res), - SyscallSucceedsWithValue(reinterpret_cast<intptr_t>(new_address))); - mmap_cleanup.Release(); + ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(), + MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()), + IsPosixErrorOkAndHolds(dst.ptr())); aio_context_t old_ctx = ctx_; - ctx_ = reinterpret_cast<aio_context_t>(new_address); + ctx_ = reinterpret_cast<aio_context_t>(dst.addr()); + // io_destroy() will unmap dst now. + dst.release(); // Check that submitting the request with the old 'ctx_' fails. ASSERT_THAT(SubmitCtx(old_ctx, 1, cbs), SyscallFailsWithErrno(EINVAL)); @@ -260,18 +273,12 @@ TEST_F(AIOTest, Mremap) { ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1)); // Remap again. - new_address = - mmap(nullptr, kPageSize, PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - ASSERT_THAT(reinterpret_cast<int64_t>(new_address), SyscallSucceeds()); - auto mmap_cleanup2 = Cleanup([new_address] { - EXPECT_THAT(munmap(new_address, kPageSize), SyscallSucceeds()); - }); - res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, kPageSize, - MREMAP_FIXED | MREMAP_MAYMOVE, new_address); - ASSERT_THAT(reinterpret_cast<int64_t>(res), - SyscallSucceedsWithValue(reinterpret_cast<int64_t>(new_address))); - mmap_cleanup2.Release(); - ctx_ = reinterpret_cast<aio_context_t>(new_address); + dst = ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE)); + ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(), + MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()), + IsPosixErrorOkAndHolds(dst.ptr())); + ctx_ = reinterpret_cast<aio_context_t>(dst.addr()); + dst.release(); // Get the reply with yet another 'ctx_' and verify it. struct io_event events[1]; @@ -281,51 +288,33 @@ TEST_F(AIOTest, Mremap) { EXPECT_EQ(events[0].res, strlen(kData)); // Verify that the file contains the contents. - char verify_buf[32] = {}; - ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)), - SyscallSucceeds()); - EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0); + char verify_buf[sizeof(kData)] = {}; + ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)), + SyscallSucceedsWithValue(strlen(kData))); + EXPECT_STREQ(verify_buf, kData); } -// Tests that AIO context can be replaced with a different mapping at the same -// address and continue working. Don't ask why, but Linux allows it. -TEST_F(AIOTest, MremapOver) { +// Tests that AIO context cannot be expanded with mremap. +TEST_F(AIOTest, MremapExpansion) { // Setup a context that is 128 entries deep. ASSERT_THAT(SetupContext(128), SyscallSucceeds()); + const size_t ctx_size = + ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_))); - struct iocb cb = CreateCallback(); - struct iocb* cbs[1] = {&cb}; - - ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1)); - - // Allocate a new VMA, copy 'ctx_' content over, and remap it on top - // of 'ctx_'. - void* new_address = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - ASSERT_THAT(reinterpret_cast<int64_t>(new_address), SyscallSucceeds()); - auto mmap_cleanup = Cleanup([new_address] { - EXPECT_THAT(munmap(new_address, kPageSize), SyscallSucceeds()); - }); - - memcpy(new_address, reinterpret_cast<void*>(ctx_), kPageSize); - void* res = - mremap(new_address, kPageSize, kPageSize, MREMAP_FIXED | MREMAP_MAYMOVE, - reinterpret_cast<void*>(ctx_)); - ASSERT_THAT(reinterpret_cast<int64_t>(res), SyscallSucceedsWithValue(ctx_)); - mmap_cleanup.Release(); - - // Everything continues to work just fine. - struct io_event events[1]; - ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1)); - EXPECT_EQ(events[0].data, 0x123); - EXPECT_EQ(events[0].obj, reinterpret_cast<long>(&cb)); - EXPECT_EQ(events[0].res, strlen(kData)); - - // Verify that the file contains the contents. - char verify_buf[32] = {}; - ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)), - SyscallSucceeds()); - EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0); + // Reserve address space for the mremap target so we have something safe to + // map over. + Mapping dst = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(ctx_size + kPageSize, PROT_NONE, MAP_PRIVATE)); + + // Test that remapping to a larger address range fails. + ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(), + MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()), + PosixErrorIs(EFAULT, _)); + + // mm/mremap.c:sys_mremap() => mremap_to() does do_munmap() of the destination + // before it hits the VM_DONTEXPAND check in vma_to_resize(), so we should no + // longer munmap it (another thread may have created a mapping there). + dst.release(); } // Tests that AIO calls fail if context's address is inaccessible. @@ -429,5 +418,7 @@ TEST_P(AIOVectorizedParamTest, BadIOVecs) { INSTANTIATE_TEST_SUITE_P(BadIOVecs, AIOVectorizedParamTest, ::testing::Values(IOCB_CMD_PREADV, IOCB_CMD_PWRITEV)); +} // namespace + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc index 2e82f0b3a..7a28b674d 100644 --- a/test/syscalls/linux/chown.cc +++ b/test/syscalls/linux/chown.cc @@ -16,10 +16,12 @@ #include <grp.h> #include <sys/types.h> #include <unistd.h> + #include <vector> #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "absl/synchronization/notification.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" @@ -29,9 +31,9 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -DEFINE_int32(scratch_uid1, 65534, "first scratch UID"); -DEFINE_int32(scratch_uid2, 65533, "second scratch UID"); -DEFINE_int32(scratch_gid, 65534, "first scratch GID"); +ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID"); +ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); +ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID"); namespace gvisor { namespace testing { @@ -100,10 +102,12 @@ TEST_P(ChownParamTest, ChownFilePermissionDenied) { // Change EUID and EGID. // // See note about POSIX below. - EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1), - SyscallSucceeds()); - EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid1, -1), - SyscallSucceeds()); + EXPECT_THAT( + syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1), + SyscallSucceeds()); + EXPECT_THAT( + syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid1), -1), + SyscallSucceeds()); EXPECT_THAT(GetParam()(file.path(), geteuid(), getegid()), PosixErrorIs(EPERM, ::testing::ContainsRegex("chown"))); @@ -125,8 +129,9 @@ TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) { // setresuid syscall. However, we want this thread to have its own set of // credentials different from the parent process, so we use the raw // syscall. - EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid2, -1), - SyscallSucceeds()); + EXPECT_THAT( + syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid2), -1), + SyscallSucceeds()); // Create file and immediately close it. FileDescriptor fd = @@ -143,12 +148,13 @@ TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) { fileCreated.WaitForNotification(); // Set file's owners to someone different. - EXPECT_NO_ERRNO(GetParam()(filename, FLAGS_scratch_uid1, FLAGS_scratch_gid)); + EXPECT_NO_ERRNO(GetParam()(filename, absl::GetFlag(FLAGS_scratch_uid1), + absl::GetFlag(FLAGS_scratch_gid))); struct stat s; EXPECT_THAT(stat(filename.c_str(), &s), SyscallSucceeds()); - EXPECT_EQ(s.st_uid, FLAGS_scratch_uid1); - EXPECT_EQ(s.st_gid, FLAGS_scratch_gid); + EXPECT_EQ(s.st_uid, absl::GetFlag(FLAGS_scratch_uid1)); + EXPECT_EQ(s.st_gid, absl::GetFlag(FLAGS_scratch_gid)); fileChowned.Notify(); } diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 2f8e7c9dd..8a45be12a 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -17,9 +17,12 @@ #include <syscall.h> #include <unistd.h> +#include <string> + #include "gtest/gtest.h" #include "absl/base/macros.h" #include "absl/base/port.h" +#include "absl/flags/flag.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/time/clock.h" @@ -33,18 +36,19 @@ #include "test/util/test_util.h" #include "test/util/timer_util.h" -DEFINE_string(child_setlock_on, "", - "Contains the path to try to set a file lock on."); -DEFINE_bool(child_setlock_write, false, - "Whether to set a writable lock (otherwise readable)"); -DEFINE_bool(blocking, false, - "Whether to set a blocking lock (otherwise non-blocking)."); -DEFINE_bool(retry_eintr, false, "Whether to retry in the subprocess on EINTR."); -DEFINE_uint64(child_setlock_start, 0, "The value of struct flock start"); -DEFINE_uint64(child_setlock_len, 0, "The value of struct flock len"); -DEFINE_int32(socket_fd, -1, - "A socket to use for communicating more state back " - "to the parent."); +ABSL_FLAG(std::string, child_setlock_on, "", + "Contains the path to try to set a file lock on."); +ABSL_FLAG(bool, child_setlock_write, false, + "Whether to set a writable lock (otherwise readable)"); +ABSL_FLAG(bool, blocking, false, + "Whether to set a blocking lock (otherwise non-blocking)."); +ABSL_FLAG(bool, retry_eintr, false, + "Whether to retry in the subprocess on EINTR."); +ABSL_FLAG(uint64_t, child_setlock_start, 0, "The value of struct flock start"); +ABSL_FLAG(uint64_t, child_setlock_len, 0, "The value of struct flock len"); +ABSL_FLAG(int32_t, socket_fd, -1, + "A socket to use for communicating more state back " + "to the parent."); namespace gvisor { namespace testing { @@ -918,25 +922,26 @@ TEST(FcntlTest, GetOwn) { int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - if (!FLAGS_child_setlock_on.empty()) { - int socket_fd = FLAGS_socket_fd; - int fd = open(FLAGS_child_setlock_on.c_str(), O_RDWR, 0666); + const std::string setlock_on = absl::GetFlag(FLAGS_child_setlock_on); + if (!setlock_on.empty()) { + int socket_fd = absl::GetFlag(FLAGS_socket_fd); + int fd = open(setlock_on.c_str(), O_RDWR, 0666); if (fd == -1 && errno != 0) { int err = errno; - std::cerr << "CHILD open " << FLAGS_child_setlock_on << " failed " << err + std::cerr << "CHILD open " << setlock_on << " failed " << err << std::endl; exit(err); } struct flock fl; - if (FLAGS_child_setlock_write) { + if (absl::GetFlag(FLAGS_child_setlock_write)) { fl.l_type = F_WRLCK; } else { fl.l_type = F_RDLCK; } fl.l_whence = SEEK_SET; - fl.l_start = FLAGS_child_setlock_start; - fl.l_len = FLAGS_child_setlock_len; + fl.l_start = absl::GetFlag(FLAGS_child_setlock_start); + fl.l_len = absl::GetFlag(FLAGS_child_setlock_len); // Test the fcntl, no need to log, the error is unambiguously // from fcntl at this point. @@ -946,8 +951,8 @@ int main(int argc, char** argv) { gvisor::testing::MonotonicTimer timer; timer.Start(); do { - ret = fcntl(fd, FLAGS_blocking ? F_SETLKW : F_SETLK, &fl); - } while (FLAGS_retry_eintr && ret == -1 && errno == EINTR); + ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl); + } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR); auto usec = absl::ToInt64Microseconds(timer.Duration()); if (ret == -1 && errno != 0) { diff --git a/test/syscalls/linux/kill.cc b/test/syscalls/linux/kill.cc index 18ad923b8..db29bd59c 100644 --- a/test/syscalls/linux/kill.cc +++ b/test/syscalls/linux/kill.cc @@ -21,6 +21,7 @@ #include <csignal> #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" @@ -31,8 +32,8 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -DEFINE_int32(scratch_uid, 65534, "scratch UID"); -DEFINE_int32(scratch_gid, 65534, "scratch GID"); +ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID"); +ABSL_FLAG(int32_t, scratch_gid, 65534, "scratch GID"); using ::testing::Ge; @@ -255,8 +256,8 @@ TEST(KillTest, ProcessGroups) { TEST(KillTest, ChildDropsPrivsCannotKill) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - int uid = FLAGS_scratch_uid; - int gid = FLAGS_scratch_gid; + const int uid = absl::GetFlag(FLAGS_scratch_uid); + const int gid = absl::GetFlag(FLAGS_scratch_gid); // Create the child that drops privileges and tries to kill the parent. pid_t pid = fork(); @@ -331,8 +332,8 @@ TEST(KillTest, CanSIGCONTSameSession) { EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP) << "status " << status; - int uid = FLAGS_scratch_uid; - int gid = FLAGS_scratch_gid; + const int uid = absl::GetFlag(FLAGS_scratch_uid); + const int gid = absl::GetFlag(FLAGS_scratch_gid); // Drop privileges only in child process, or else this parent process won't be // able to open some log files after the test ends. diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc index a91703070..dd5352954 100644 --- a/test/syscalls/linux/link.cc +++ b/test/syscalls/linux/link.cc @@ -22,6 +22,7 @@ #include <string> #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "absl/strings/str_cat.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" @@ -31,7 +32,7 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -DEFINE_int32(scratch_uid, 65534, "scratch UID"); +ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID"); namespace gvisor { namespace testing { @@ -92,7 +93,8 @@ TEST(LinkTest, PermissionDenied) { // threads have the same UIDs, so using the setuid wrapper sets all threads' // real UID. // Also drops capabilities. - EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)), + SyscallSucceeds()); EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()), SyscallFailsWithErrno(EPERM)); diff --git a/test/syscalls/linux/mremap.cc b/test/syscalls/linux/mremap.cc index 64e435cb7..f0e5f7d82 100644 --- a/test/syscalls/linux/mremap.cc +++ b/test/syscalls/linux/mremap.cc @@ -35,17 +35,6 @@ namespace testing { namespace { -// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of -// void* isn't directly compatible with SyscallSucceeds. -PosixErrorOr<void*> Mremap(void* old_address, size_t old_size, size_t new_size, - int flags, void* new_address) { - void* rv = mremap(old_address, old_size, new_size, flags, new_address); - if (rv == MAP_FAILED) { - return PosixError(errno, "mremap failed"); - } - return rv; -} - // Fixture for mremap tests parameterized by mmap flags. using MremapParamTest = ::testing::TestWithParam<int>; diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 65afb90f3..10e2a6dfc 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -168,6 +168,20 @@ TEST_P(PipeTest, Write) { EXPECT_EQ(wbuf, rbuf); } +TEST_P(PipeTest, WritePage) { + SKIP_IF(!CreateBlocking()); + + std::vector<char> wbuf(kPageSize); + RandomizeBuffer(wbuf.data(), wbuf.size()); + std::vector<char> rbuf(wbuf.size()); + + ASSERT_THAT(write(wfd_.get(), wbuf.data(), wbuf.size()), + SyscallSucceedsWithValue(wbuf.size())); + ASSERT_THAT(read(rfd_.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), wbuf.size()), 0); +} + TEST_P(PipeTest, NonBlocking) { SKIP_IF(!CreateNonBlocking()); diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index bd1779557..d07571a5f 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -21,6 +21,7 @@ #include <string> #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "test/util/capability_util.h" #include "test/util/cleanup.h" #include "test/util/multiprocess_util.h" @@ -28,9 +29,9 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -DEFINE_bool(prctl_no_new_privs_test_child, false, - "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) " - "plus an offset (see test source)."); +ABSL_FLAG(bool, prctl_no_new_privs_test_child, false, + "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) " + "plus an offset (see test source)."); namespace gvisor { namespace testing { @@ -220,7 +221,7 @@ TEST(PrctlTest, RootDumpability) { int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - if (FLAGS_prctl_no_new_privs_test_child) { + if (absl::GetFlag(FLAGS_prctl_no_new_privs_test_child)) { exit(gvisor::testing::kPrctlNoNewPrivsTestChildExitBase + prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)); } diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc index 00dd6523e..30f0d75b3 100644 --- a/test/syscalls/linux/prctl_setuid.cc +++ b/test/syscalls/linux/prctl_setuid.cc @@ -14,9 +14,11 @@ #include <sched.h> #include <sys/prctl.h> + #include <string> #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "test/util/capability_util.h" #include "test/util/logging.h" #include "test/util/multiprocess_util.h" @@ -24,12 +26,12 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -DEFINE_int32(scratch_uid, 65534, "scratch UID"); +ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID"); // This flag is used to verify that after an exec PR_GET_KEEPCAPS // returns 0, the return code will be offset by kPrGetKeepCapsExitBase. -DEFINE_bool(prctl_pr_get_keepcaps, false, - "If true the test will verify that prctl with pr_get_keepcaps" - "returns 0. The test will exit with the result of that check."); +ABSL_FLAG(bool, prctl_pr_get_keepcaps, false, + "If true the test will verify that prctl with pr_get_keepcaps" + "returns 0. The test will exit with the result of that check."); // These tests exist seperately from prctl because we need to start // them as root. Setuid() has the behavior that permissions are fully @@ -113,10 +115,12 @@ TEST_F(PrctlKeepCapsSetuidTest, SetUidNoKeepCaps) { // call to only apply to this task. POSIX threads, however, require that // all threads have the same UIDs, so using the setuid wrapper sets all // threads' real UID. - EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)), + SyscallSucceeds()); // Verify that we changed uid. - EXPECT_THAT(getuid(), SyscallSucceedsWithValue(FLAGS_scratch_uid)); + EXPECT_THAT(getuid(), + SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid))); // Verify we lost the capability in the effective set, this always happens. TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie()); @@ -157,10 +161,12 @@ TEST_F(PrctlKeepCapsSetuidTest, SetUidKeepCaps) { // call to only apply to this task. POSIX threads, however, require that // all threads have the same UIDs, so using the setuid wrapper sets all // threads' real UID. - EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)), + SyscallSucceeds()); // Verify that we changed uid. - EXPECT_THAT(getuid(), SyscallSucceedsWithValue(FLAGS_scratch_uid)); + EXPECT_THAT(getuid(), + SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid))); // Verify we lost the capability in the effective set, this always happens. TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie()); @@ -253,7 +259,7 @@ TEST_F(PrctlKeepCapsSetuidTest, PrGetKeepCaps) { int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - if (FLAGS_prctl_pr_get_keepcaps) { + if (absl::GetFlag(FLAGS_prctl_pr_get_keepcaps)) { return gvisor::testing::kPrGetKeepCapsExitBase + prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0); } diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 2b753b7d1..6f07803d9 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1882,7 +1882,9 @@ void CheckDuplicatesRecursively(std::string path) { errno = 0; DIR* dir = opendir(path.c_str()); if (dir == nullptr) { - ASSERT_THAT(errno, ::testing::AnyOf(EPERM, EACCES)) << path; + // Ignore any directories we can't read or missing directories as the + // directory could have been deleted/mutated from the time the parent + // directory contents were read. return; } auto dir_closer = Cleanup([&dir]() { closedir(dir); }); diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index c097af196..efdaf202b 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -28,7 +28,7 @@ TEST(ProcNetIfInet6, Format) { EXPECT_THAT(ifinet6, ::testing::MatchesRegex( // Ex: "00000000000000000000000000000001 01 80 10 80 lo\n" - "^([a-f\\d]{32}( [a-f\\d]{2}){4} +[a-z][a-z\\d]*\\n)+$")); + "^([a-f0-9]{32}( [a-f0-9]{2}){4} +[a-z][a-z0-9]*\n)+$")); } TEST(ProcSysNetIpv4Sack, Exists) { diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index abf2b1a04..8f3800380 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -27,6 +27,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/logging.h" @@ -36,10 +37,10 @@ #include "test/util/thread_util.h" #include "test/util/time_util.h" -DEFINE_bool(ptrace_test_execve_child, false, - "If true, run the " - "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_" - "TraceExit child workload."); +ABSL_FLAG(bool, ptrace_test_execve_child, false, + "If true, run the " + "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_" + "TraceExit child workload."); namespace gvisor { namespace testing { @@ -1206,7 +1207,7 @@ TEST(PtraceTest, SeizeSetOptions) { int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - if (FLAGS_ptrace_test_execve_child) { + if (absl::GetFlag(FLAGS_ptrace_test_execve_child)) { gvisor::testing::RunExecveChild(); } diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index d1ab4703f..286388316 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -13,13 +13,17 @@ // limitations under the License. #include <fcntl.h> +#include <linux/capability.h> #include <linux/major.h> #include <poll.h> +#include <sched.h> +#include <signal.h> #include <sys/ioctl.h> #include <sys/mman.h> #include <sys/stat.h> #include <sys/sysmacros.h> #include <sys/types.h> +#include <sys/wait.h> #include <termios.h> #include <unistd.h> @@ -31,8 +35,10 @@ #include "absl/synchronization/notification.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" +#include "test/util/pty_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -370,25 +376,6 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, return PosixError(ETIMEDOUT, "Poll timed out"); } -// Opens the slave end of the passed master as R/W and nonblocking. -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { - // Get pty index. - int n; - int ret = ioctl(master.get(), TIOCGPTN, &n); - if (ret < 0) { - return PosixError(errno, "ioctl(TIOCGPTN) failed"); - } - - // Unlock pts. - int unlock = 0; - ret = ioctl(master.get(), TIOCSPTLCK, &unlock); - if (ret < 0) { - return PosixError(errno, "ioctl(TIOSPTLCK) failed"); - } - - return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK); -} - TEST(BasicPtyTest, StatUnopenedMaster) { struct stat s; ASSERT_THAT(stat("/dev/ptmx", &s), SyscallSucceeds()); @@ -1233,6 +1220,374 @@ TEST_F(PtyTest, SetMasterWindowSize) { EXPECT_EQ(retrieved_ws.ws_col, kCols); } +class JobControlTest : public ::testing::Test { + protected: + void SetUp() override { + master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); + slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + + // Make this a session leader, which also drops the controlling terminal. + // In the gVisor test environment, this test will be run as the session + // leader already (as the sentry init process). + if (!IsRunningOnGvisor()) { + ASSERT_THAT(setsid(), SyscallSucceeds()); + } + } + + // Master and slave ends of the PTY. Non-blocking. + FileDescriptor master_; + FileDescriptor slave_; +}; + +TEST_F(JobControlTest, SetTTYMaster) { + ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetTTYNonLeader) { + // Fork a process that won't be the session leader. + pid_t child = fork(); + if (!child) { + // We shouldn't be able to set the terminal. + TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +TEST_F(JobControlTest, SetTTYBadArg) { + // Despite the man page saying arg should be 0 here, Linux doesn't actually + // check. + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetTTYDifferentSession) { + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should fail. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(setsid() >= 0); + // We shouldn't be able to steal the terminal. + TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +TEST_F(JobControlTest, ReleaseTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Make sure we're ignoring SIGHUP, which will be sent to this process once we + // disconnect they TTY. + struct sigaction sa = { + .sa_handler = SIG_IGN, + .sa_flags = 0, + }; + sigemptyset(&sa.sa_mask); + struct sigaction old_sa; + EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds()); + EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); +} + +TEST_F(JobControlTest, ReleaseUnsetTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); +} + +TEST_F(JobControlTest, ReleaseWrongTTY) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); +} + +TEST_F(JobControlTest, ReleaseTTYNonLeader) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + pid_t child = fork(); + if (!child) { + TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +TEST_F(JobControlTest, ReleaseTTYDifferentSession) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + pid_t child = fork(); + if (!child) { + // Join a new session, then try to disconnect. + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +// Used by the child process spawned in ReleaseTTYSignals to track received +// signals. +static int received; + +void sig_handler(int signum) { received |= signum; } + +// When the session leader releases its controlling terminal, the foreground +// process group gets SIGHUP, then SIGCONT. This test: +// - Spawns 2 threads +// - Has thread 1 return 0 if it gets both SIGHUP and SIGCONT +// - Has thread 2 leave the foreground process group, and return non-zero if it +// receives any signals. +// - Has the parent thread release its controlling terminal +// - Checks that thread 1 got both signals +// - Checks that thread 2 didn't get any signals. +TEST_F(JobControlTest, ReleaseTTYSignals) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + received = 0; + struct sigaction sa = { + .sa_handler = sig_handler, + .sa_flags = 0, + }; + sigemptyset(&sa.sa_mask); + sigaddset(&sa.sa_mask, SIGHUP); + sigaddset(&sa.sa_mask, SIGCONT); + sigprocmask(SIG_BLOCK, &sa.sa_mask, NULL); + + pid_t same_pgrp_child = fork(); + if (!same_pgrp_child) { + // The child will wait for SIGHUP and SIGCONT, then return 0. It begins with + // SIGHUP and SIGCONT blocked. We install signal handlers for those signals, + // then use sigsuspend to wait for those specific signals. + TEST_PCHECK(!sigaction(SIGHUP, &sa, NULL)); + TEST_PCHECK(!sigaction(SIGCONT, &sa, NULL)); + sigset_t mask; + sigfillset(&mask); + sigdelset(&mask, SIGHUP); + sigdelset(&mask, SIGCONT); + while (received != (SIGHUP | SIGCONT)) { + sigsuspend(&mask); + } + _exit(0); + } + + // We don't want to block these anymore. + sigprocmask(SIG_UNBLOCK, &sa.sa_mask, NULL); + + // This child will return non-zero if either SIGHUP or SIGCONT are received. + pid_t diff_pgrp_child = fork(); + if (!diff_pgrp_child) { + TEST_PCHECK(!setpgid(0, 0)); + TEST_PCHECK(pause()); + _exit(1); + } + + EXPECT_THAT(setpgid(diff_pgrp_child, diff_pgrp_child), SyscallSucceeds()); + + // Make sure we're ignoring SIGHUP, which will be sent to this process once we + // disconnect they TTY. + struct sigaction sighup_sa = { + .sa_handler = SIG_IGN, + .sa_flags = 0, + }; + sigemptyset(&sighup_sa.sa_mask); + struct sigaction old_sa; + EXPECT_THAT(sigaction(SIGHUP, &sighup_sa, &old_sa), SyscallSucceeds()); + + // Release the controlling terminal, sending SIGHUP and SIGCONT to all other + // processes in this process group. + EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + + EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); + + // The child in the same process group will get signaled. + int wstatus; + EXPECT_THAT(waitpid(same_pgrp_child, &wstatus, 0), + SyscallSucceedsWithValue(same_pgrp_child)); + EXPECT_EQ(wstatus, 0); + + // The other child will not get signaled. + EXPECT_THAT(waitpid(diff_pgrp_child, &wstatus, WNOHANG), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(kill(diff_pgrp_child, SIGKILL), SyscallSucceeds()); +} + +TEST_F(JobControlTest, GetForegroundProcessGroup) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + pid_t foreground_pgid; + pid_t pid; + ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + SyscallSucceeds()); + ASSERT_THAT(pid = getpid(), SyscallSucceeds()); + + ASSERT_EQ(foreground_pgid, pid); +} + +TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { + // At this point there's no controlling terminal, so TIOCGPGRP should fail. + pid_t foreground_pgid; + ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + SyscallFailsWithErrno(ENOTTY)); +} + +// This test: +// - sets itself as the foreground process group +// - creates a child process in a new process group +// - sets that child as the foreground process group +// - kills its child and sets itself as the foreground process group. +TEST_F(JobControlTest, SetForegroundProcessGroup) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. + struct sigaction sa = { + .sa_handler = SIG_IGN, + .sa_flags = 0, + }; + sigemptyset(&sa.sa_mask); + sigaction(SIGTTOU, &sa, NULL); + + // Set ourself as the foreground process group. + ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds()); + + // Create a new process that just waits to be signaled. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(!pause()); + // We should never reach this. + _exit(1); + } + + // Make the child its own process group, then make it the controlling process + // group of the terminal. + ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); + ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds()); + + // Sanity check - we're still the controlling session. + ASSERT_EQ(getsid(0), getsid(child)); + + // Signal the child, wait for it to exit, then retake the terminal. + ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds()); + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_TRUE(WIFSIGNALED(wstatus)); + ASSERT_EQ(WTERMSIG(wstatus), SIGTERM); + + // Set ourself as the foreground process. + pid_t pgid; + ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds()); + ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds()); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { + pid_t pid = getpid(); + ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + SyscallFailsWithErrno(ENOTTY)); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + pid_t pid = -1; + ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Create a new process, put it in a new process group, make that group the + // foreground process group, then have the process wait. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(!setpgid(0, 0)); + _exit(0); + } + + // Wait for the child to exit. + int wstatus; + EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + // The child's process group doesn't exist anymore - this should fail. + ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), + SyscallFailsWithErrno(ESRCH)); +} + +TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { + ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Create a new process and put it in a new session. + pid_t child = fork(); + if (!child) { + TEST_PCHECK(setsid() >= 0); + // Tell the parent we're in a new session. + TEST_PCHECK(!raise(SIGSTOP)); + TEST_PCHECK(!pause()); + _exit(1); + } + + // Wait for the child to tell us it's in a new session. + int wstatus; + EXPECT_THAT(waitpid(child, &wstatus, WUNTRACED), + SyscallSucceedsWithValue(child)); + EXPECT_TRUE(WSTOPSIG(wstatus)); + + // Child is in a new session, so we can't make it the foregroup process group. + EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), + SyscallFailsWithErrno(EPERM)); + + EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds()); +} + +// Verify that we don't hang when creating a new session from an orphaned +// process group (b/139968068). Calling setsid() creates an orphaned process +// group, as process groups that contain the session's leading process are +// orphans. +// +// We create 2 sessions in this test. The init process in gVisor is considered +// not to be an orphan (see sessions.go), so we have to create a session from +// which to create a session. The latter session is being created from an +// orphaned process group. +TEST_F(JobControlTest, OrphanRegression) { + pid_t session_2_leader = fork(); + if (!session_2_leader) { + TEST_PCHECK(setsid() >= 0); + + pid_t session_3_leader = fork(); + if (!session_3_leader) { + TEST_PCHECK(setsid() >= 0); + + _exit(0); + } + + int wstatus; + TEST_PCHECK(waitpid(session_3_leader, &wstatus, 0) == session_3_leader); + TEST_PCHECK(wstatus == 0); + + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(session_2_leader, &wstatus, 0), + SyscallSucceedsWithValue(session_2_leader)); + ASSERT_EQ(wstatus, 0); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc new file mode 100644 index 000000000..14a4af980 --- /dev/null +++ b/test/syscalls/linux/pty_root.cc @@ -0,0 +1,68 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/ioctl.h> +#include <termios.h> + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/pty_util.h" + +namespace gvisor { +namespace testing { + +// These tests should be run as root. +namespace { + +TEST(JobControlRootTest, StealTTY) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + // Make this a session leader, which also drops the controlling terminal. + // In the gVisor test environment, this test will be run as the session + // leader already (as the sentry init process). + if (!IsRunningOnGvisor()) { + ASSERT_THAT(setsid(), SyscallSucceeds()); + } + + FileDescriptor master = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); + FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + + // Make slave the controlling terminal. + ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds()); + + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg + // of 1. + pid_t child = fork(); + if (!child) { + ASSERT_THAT(setsid(), SyscallSucceeds()); + // We shouldn't be able to steal the terminal with the wrong arg value. + TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); + // We should be able to steal it here. + TEST_PCHECK(!ioctl(slave.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int wstatus; + ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + ASSERT_EQ(wstatus, 0); +} + +} // namespace +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 9167ab066..4502e7fb4 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -19,9 +19,12 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -442,6 +445,72 @@ TEST(SendFileTest, SendToNotARegularFile) { EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0), SyscallFailsWithErrno(EINVAL)); } + +TEST(SendFileTest, SendPipeWouldBlock) { + // Create temp file. + constexpr char kData[] = + "The fool doth think he is wise, but the wise man knows himself to be a " + "fool."; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + // Open the input file as read only. + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Setup the output named pipe. + int fds[2]; + ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill up the pipe's buffer. + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds()); + std::vector<char> buf(2 * pipe_size); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(pipe_size)); + + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST(SendFileTest, SendPipeBlocks) { + // Create temp file. + constexpr char kData[] = + "The fault, dear Brutus, is not in our stars, but in ourselves."; + constexpr int kDataSize = sizeof(kData) - 1; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + // Open the input file as read only. + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Setup the output named pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill up the pipe's buffer. + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds()); + std::vector<char> buf(pipe_size); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(pipe_size)); + + ScopedThread t([&]() { + absl::SleepFor(absl::Milliseconds(100)); + ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(pipe_size)); + }); + + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize), + SyscallSucceedsWithValue(kDataSize)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc new file mode 100644 index 000000000..54c598627 --- /dev/null +++ b/test/syscalls/linux/signalfd.cc @@ -0,0 +1,333 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <poll.h> +#include <signal.h> +#include <stdio.h> +#include <string.h> +#include <sys/signalfd.h> +#include <unistd.h> + +#include <functional> +#include <vector> + +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "absl/synchronization/mutex.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/signal_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +using ::testing::KilledBySignal; + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kSigno = SIGUSR1; +constexpr int kSignoAlt = SIGUSR2; + +// Returns a new signalfd. +inline PosixErrorOr<FileDescriptor> NewSignalFD(sigset_t* mask, int flags = 0) { + int fd = signalfd(-1, mask, flags); + MaybeSave(); + if (fd < 0) { + return PosixError(errno, "signalfd"); + } + return FileDescriptor(fd); +} + +TEST(Signalfd, Basic) { + // Create the signalfd. + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, kSigno); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); + + // Deliver the blocked signal. + const auto scoped_sigmask = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); + ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + + // We should now read the signal. + struct signalfd_siginfo rbuf; + ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallSucceedsWithValue(sizeof(rbuf))); + EXPECT_EQ(rbuf.ssi_signo, kSigno); +} + +TEST(Signalfd, MaskWorks) { + // Create two signalfds with different masks. + sigset_t mask1, mask2; + sigemptyset(&mask1); + sigemptyset(&mask2); + sigaddset(&mask1, kSigno); + sigaddset(&mask2, kSignoAlt); + FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask1, 0)); + FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask2, 0)); + + // Deliver the two signals. + const auto scoped_sigmask1 = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); + const auto scoped_sigmask2 = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSignoAlt)); + ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + ASSERT_THAT(tgkill(getpid(), gettid(), kSignoAlt), SyscallSucceeds()); + + // We should see the signals on the appropriate signalfds. + // + // We read in the opposite order as the signals deliver above, to ensure that + // we don't happen to read the correct signal from the correct signalfd. + struct signalfd_siginfo rbuf1, rbuf2; + ASSERT_THAT(read(fd2.get(), &rbuf2, sizeof(rbuf2)), + SyscallSucceedsWithValue(sizeof(rbuf2))); + EXPECT_EQ(rbuf2.ssi_signo, kSignoAlt); + ASSERT_THAT(read(fd1.get(), &rbuf1, sizeof(rbuf1)), + SyscallSucceedsWithValue(sizeof(rbuf1))); + EXPECT_EQ(rbuf1.ssi_signo, kSigno); +} + +TEST(Signalfd, Cloexec) { + // Exec tests confirm that O_CLOEXEC has the intended effect. We just create a + // signalfd with the appropriate flag here and assert that the FD has it set. + sigset_t mask; + sigemptyset(&mask); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC)); + EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); +} + +TEST(Signalfd, Blocking) { + // Create the signalfd in blocking mode. + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, kSigno); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); + + // Shared tid variable. + absl::Mutex mu; + bool has_tid; + pid_t tid; + + // Start a thread reading. + ScopedThread t([&] { + // Copy the tid and notify the caller. + { + absl::MutexLock ml(&mu); + tid = gettid(); + has_tid = true; + } + + // Read the signal from the signalfd. + struct signalfd_siginfo rbuf; + ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallSucceedsWithValue(sizeof(rbuf))); + EXPECT_EQ(rbuf.ssi_signo, kSigno); + }); + + // Wait until blocked. + absl::MutexLock ml(&mu); + mu.Await(absl::Condition(&has_tid)); + + // Deliver the signal to either the waiting thread, or + // to this thread. N.B. this is a bug in the core gVisor + // behavior for signalfd, and needs to be fixed. + // + // See gvisor.dev/issue/139. + if (IsRunningOnGvisor()) { + ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + } else { + ASSERT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds()); + } + + // Ensure that it was received. + t.Join(); +} + +TEST(Signalfd, ThreadGroup) { + // Create the signalfd in blocking mode. + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, kSigno); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); + + // Shared variable. + absl::Mutex mu; + bool first = false; + bool second = false; + + // Start a thread reading. + ScopedThread t([&] { + // Read the signal from the signalfd. + struct signalfd_siginfo rbuf; + ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallSucceedsWithValue(sizeof(rbuf))); + EXPECT_EQ(rbuf.ssi_signo, kSigno); + + // Wait for the other thread. + absl::MutexLock ml(&mu); + first = true; + mu.Await(absl::Condition(&second)); + }); + + // Deliver the signal to the threadgroup. + ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds()); + + // Wait for the first thread to process. + { + absl::MutexLock ml(&mu); + mu.Await(absl::Condition(&first)); + } + + // Deliver to the thread group again (other thread still exists). + ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds()); + + // Ensure that we can also receive it. + struct signalfd_siginfo rbuf; + ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallSucceedsWithValue(sizeof(rbuf))); + EXPECT_EQ(rbuf.ssi_signo, kSigno); + + // Mark the test as done. + { + absl::MutexLock ml(&mu); + second = true; + } + + // The other thread should be joinable. + t.Join(); +} + +TEST(Signalfd, Nonblock) { + // Create the signalfd in non-blocking mode. + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, kSigno); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK)); + + // We should return if we attempt to read. + struct signalfd_siginfo rbuf; + ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Block and deliver the signal. + const auto scoped_sigmask = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); + ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + + // Ensure that a read actually works. + ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallSucceedsWithValue(sizeof(rbuf))); + EXPECT_EQ(rbuf.ssi_signo, kSigno); + + // Should block again. + EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST(Signalfd, SetMask) { + // Create the signalfd matching nothing. + sigset_t mask; + sigemptyset(&mask); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK)); + + // Block and deliver a signal. + const auto scoped_sigmask = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); + ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds()); + + // We should have nothing. + struct signalfd_siginfo rbuf; + ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Change the signal mask. + sigaddset(&mask, kSigno); + ASSERT_THAT(signalfd(fd.get(), &mask, 0), SyscallSucceeds()); + + // We should now have the signal. + ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallSucceedsWithValue(sizeof(rbuf))); + EXPECT_EQ(rbuf.ssi_signo, kSigno); +} + +TEST(Signalfd, Poll) { + // Create the signalfd. + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, kSigno); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0)); + + // Block the signal, and start a thread to deliver it. + const auto scoped_sigmask = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno)); + pid_t orig_tid = gettid(); + ScopedThread t([&] { + absl::SleepFor(absl::Seconds(5)); + ASSERT_THAT(tgkill(getpid(), orig_tid, kSigno), SyscallSucceeds()); + }); + + // Start polling for the signal. We expect that it is not available at the + // outset, but then becomes available when the signal is sent. We give a + // timeout of 10000ms (or the delay above + 5 seconds of additional grace + // time). + struct pollfd poll_fd = {fd.get(), POLLIN, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); + + // Actually read the signal to prevent delivery. + struct signalfd_siginfo rbuf; + EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)), + SyscallSucceedsWithValue(sizeof(rbuf))); +} + +TEST(Signalfd, KillStillKills) { + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, SIGKILL); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC)); + + // Just because there is a signalfd, we shouldn't see any change in behavior + // for unblockable signals. It's easier to test this with SIGKILL. + const auto scoped_sigmask = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL)); + EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), ""); +} + +} // namespace + +} // namespace testing +} // namespace gvisor + +int main(int argc, char** argv) { + // These tests depend on delivering signals. Block them up front so that all + // other threads created by TestInit will also have them blocked, and they + // will not interface with the rest of the test. + sigset_t set; + sigemptyset(&set); + sigaddset(&set, gvisor::testing::kSigno); + sigaddset(&set, gvisor::testing::kSignoAlt); + TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); + + gvisor::testing::TestInit(&argc, &argv); + + return RUN_ALL_TESTS(); +} diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc index 9c7210e17..7db57d968 100644 --- a/test/syscalls/linux/sigstop.cc +++ b/test/syscalls/linux/sigstop.cc @@ -17,6 +17,7 @@ #include <sys/select.h> #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/multiprocess_util.h" @@ -24,8 +25,8 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -DEFINE_bool(sigstop_test_child, false, - "If true, run the SigstopTest child workload."); +ABSL_FLAG(bool, sigstop_test_child, false, + "If true, run the SigstopTest child workload."); namespace gvisor { namespace testing { @@ -141,7 +142,7 @@ void RunChild() { int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - if (FLAGS_sigstop_test_child) { + if (absl::GetFlag(FLAGS_sigstop_test_child)) { gvisor::testing::RunChild(); return 1; } diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index a43cf9bce..bfa7943b1 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -117,7 +117,7 @@ TEST_P(TCPSocketPairTest, RSTCausesPollHUP) { struct pollfd poll_fd3 = {sockets->first_fd(), POLLHUP, 0}; ASSERT_THAT(RetryEINTR(poll)(&poll_fd3, 1, kPollTimeoutMs), SyscallSucceedsWithValue(1)); - ASSERT_NE(poll_fd.revents & (POLLHUP | POLLIN), 0); + ASSERT_NE(poll_fd3.revents & POLLHUP, 0); } // This test validates that even if a RST is sent the other end will not diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index e25f264f6..85232cb1f 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -14,12 +14,16 @@ #include <fcntl.h> #include <sys/eventfd.h> +#include <sys/resource.h> #include <sys/sendfile.h> +#include <sys/time.h> #include <unistd.h> #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -36,23 +40,23 @@ TEST(SpliceTest, TwoRegularFiles) { const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); // Open the input file as read only. - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); // Open the output file as write only. - const FileDescriptor outf = + const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); // Verify that it is rejected as expected; regardless of offsets. loff_t in_offset = 0; loff_t out_offset = 0; - EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), &out_offset, 1, 0), + EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), &out_offset, 1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), &out_offset, 1, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), &out_offset, 1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), nullptr, 1, 0), + EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), nullptr, 1, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), nullptr, 1, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), nullptr, 1, 0), SyscallFailsWithErrno(EINVAL)); } @@ -75,8 +79,6 @@ TEST(SpliceTest, SamePipe) { } TEST(TeeTest, SamePipe) { - SKIP_IF(IsRunningOnGvisor()); - // Create a new pipe. int fds[2]; ASSERT_THAT(pipe(fds), SyscallSucceeds()); @@ -95,11 +97,9 @@ TEST(TeeTest, SamePipe) { } TEST(TeeTest, RegularFile) { - SKIP_IF(IsRunningOnGvisor()); - // Open some file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); // Create a new pipe. @@ -109,9 +109,9 @@ TEST(TeeTest, RegularFile) { const FileDescriptor wfd(fds[1]); // Attempt to tee from the file. - EXPECT_THAT(tee(inf.get(), wfd.get(), kPageSize, 0), + EXPECT_THAT(tee(in_fd.get(), wfd.get(), kPageSize, 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(tee(rfd.get(), inf.get(), kPageSize, 0), + EXPECT_THAT(tee(rfd.get(), in_fd.get(), kPageSize, 0), SyscallFailsWithErrno(EINVAL)); } @@ -142,7 +142,7 @@ TEST(SpliceTest, FromEventFD) { constexpr uint64_t kEventFDValue = 1; int efd; ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds()); - const FileDescriptor inf(efd); + const FileDescriptor in_fd(efd); // Create a new pipe. int fds[2]; @@ -152,7 +152,7 @@ TEST(SpliceTest, FromEventFD) { // Splice 8-byte eventfd value to pipe. constexpr int kEventFDSize = 8; - EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0), SyscallSucceedsWithValue(kEventFDSize)); // Contents should be equal. @@ -166,7 +166,7 @@ TEST(SpliceTest, FromEventFD) { TEST(SpliceTest, FromEventFDOffset) { int efd; ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor inf(efd); + const FileDescriptor in_fd(efd); // Create a new pipe. int fds[2]; @@ -179,7 +179,7 @@ TEST(SpliceTest, FromEventFDOffset) { // This is not allowed because eventfd doesn't support pread. constexpr int kEventFDSize = 8; loff_t in_off = 0; - EXPECT_THAT(splice(inf.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), + EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0), SyscallFailsWithErrno(EINVAL)); } @@ -200,28 +200,29 @@ TEST(SpliceTest, ToEventFDOffset) { int efd; ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds()); - const FileDescriptor outf(efd); + const FileDescriptor out_fd(efd); // Attempt to splice 8-byte eventfd value to pipe with offset. // // This is not allowed because eventfd doesn't support pwrite. loff_t out_off = 0; - EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_off, kEventFDSize, 0), - SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT( + splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0), + SyscallFailsWithErrno(EINVAL)); } TEST(SpliceTest, ToPipe) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); // Fill with some random data. std::vector<char> buf(kPageSize); RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(inf.get(), buf.data(), buf.size()), + ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(lseek(inf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); + ASSERT_THAT(lseek(in_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); // Create a new pipe. int fds[2]; @@ -230,7 +231,7 @@ TEST(SpliceTest, ToPipe) { const FileDescriptor wfd(fds[1]); // Splice to the pipe. - EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kPageSize, 0), + EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0), SyscallSucceedsWithValue(kPageSize)); // Contents should be equal. @@ -243,13 +244,13 @@ TEST(SpliceTest, ToPipe) { TEST(SpliceTest, ToPipeOffset) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor inf = + const FileDescriptor in_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); // Fill with some random data. std::vector<char> buf(kPageSize); RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(write(inf.get(), buf.data(), buf.size()), + ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(kPageSize)); // Create a new pipe. @@ -261,7 +262,7 @@ TEST(SpliceTest, ToPipeOffset) { // Splice to the pipe. loff_t in_offset = kPageSize / 2; EXPECT_THAT( - splice(inf.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0), + splice(in_fd.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0), SyscallSucceedsWithValue(kPageSize / 2)); // Contents should be equal to only the second part. @@ -286,22 +287,22 @@ TEST(SpliceTest, FromPipe) { // Open the input file. const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor outf = + const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); // Splice to the output file. - EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), nullptr, kPageSize, 0), + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, kPageSize, 0), SyscallSucceedsWithValue(kPageSize)); // The offset of the output should be equal to kPageSize. We assert that and // reset to zero so that we can read the contents and ensure they match. - EXPECT_THAT(lseek(outf.get(), 0, SEEK_CUR), + EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(kPageSize)); - ASSERT_THAT(lseek(outf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); + ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); // Contents should be equal. std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()), + ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), SyscallSucceedsWithValue(kPageSize)); EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } @@ -321,18 +322,19 @@ TEST(SpliceTest, FromPipeOffset) { // Open the input file. const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - const FileDescriptor outf = + const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); // Splice to the output file. loff_t out_offset = kPageSize / 2; - EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_offset, kPageSize, 0), - SyscallSucceedsWithValue(kPageSize)); + EXPECT_THAT( + splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kPageSize, 0), + SyscallSucceedsWithValue(kPageSize)); // Content should reflect the splice. We write to a specific offset in the // file, so the internals should now be allocated sparsely. std::vector<char> rbuf(kPageSize); - ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()), + ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), SyscallSucceedsWithValue(kPageSize)); std::vector<char> zbuf(kPageSize / 2); memset(zbuf.data(), 0, zbuf.size()); @@ -404,8 +406,6 @@ TEST(SpliceTest, Blocking) { } TEST(TeeTest, Blocking) { - SKIP_IF(IsRunningOnGvisor()); - // Create two new pipes. int first[2], second[2]; ASSERT_THAT(pipe(first), SyscallSucceeds()); @@ -440,6 +440,49 @@ TEST(TeeTest, Blocking) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0); } +TEST(TeeTest, BlockingWrite) { + // Create two new pipes. + int first[2], second[2]; + ASSERT_THAT(pipe(first), SyscallSucceeds()); + const FileDescriptor rfd1(first[0]); + const FileDescriptor wfd1(first[1]); + ASSERT_THAT(pipe(second), SyscallSucceeds()); + const FileDescriptor rfd2(second[0]); + const FileDescriptor wfd2(second[1]); + + // Make some data available to be read. + std::vector<char> buf1(kPageSize); + RandomizeBuffer(buf1.data(), buf1.size()); + ASSERT_THAT(write(wfd1.get(), buf1.data(), buf1.size()), + SyscallSucceedsWithValue(kPageSize)); + + // Fill up the write pipe's buffer. + int pipe_size = -1; + ASSERT_THAT(pipe_size = fcntl(wfd2.get(), F_GETPIPE_SZ), SyscallSucceeds()); + std::vector<char> buf2(pipe_size); + ASSERT_THAT(write(wfd2.get(), buf2.data(), buf2.size()), + SyscallSucceedsWithValue(pipe_size)); + + ScopedThread t([&]() { + absl::SleepFor(absl::Milliseconds(100)); + ASSERT_THAT(read(rfd2.get(), buf2.data(), buf2.size()), + SyscallSucceedsWithValue(pipe_size)); + }); + + // Attempt a tee immediately; it should block. + EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0), + SyscallSucceedsWithValue(kPageSize)); + + // Thread should be joinable. + t.Join(); + + // Content should reflect the tee. + std::vector<char> rbuf(kPageSize); + ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(kPageSize)); + EXPECT_EQ(memcmp(rbuf.data(), buf1.data(), kPageSize), 0); +} + TEST(SpliceTest, NonBlocking) { // Create two new pipes. int first[2], second[2]; @@ -457,8 +500,6 @@ TEST(SpliceTest, NonBlocking) { } TEST(TeeTest, NonBlocking) { - SKIP_IF(IsRunningOnGvisor()); - // Create two new pipes. int first[2], second[2]; ASSERT_THAT(pipe(first), SyscallSucceeds()); @@ -473,6 +514,79 @@ TEST(TeeTest, NonBlocking) { SyscallFailsWithErrno(EAGAIN)); } +TEST(TeeTest, MultiPage) { + // Create two new pipes. + int first[2], second[2]; + ASSERT_THAT(pipe(first), SyscallSucceeds()); + const FileDescriptor rfd1(first[0]); + const FileDescriptor wfd1(first[1]); + ASSERT_THAT(pipe(second), SyscallSucceeds()); + const FileDescriptor rfd2(second[0]); + const FileDescriptor wfd2(second[1]); + + // Make some data available to be read. + std::vector<char> wbuf(8 * kPageSize); + RandomizeBuffer(wbuf.data(), wbuf.size()); + ASSERT_THAT(write(wfd1.get(), wbuf.data(), wbuf.size()), + SyscallSucceedsWithValue(wbuf.size())); + + // Attempt a tee immediately; it should complete. + EXPECT_THAT(tee(rfd1.get(), wfd2.get(), wbuf.size(), 0), + SyscallSucceedsWithValue(wbuf.size())); + + // Content should reflect the tee. + std::vector<char> rbuf(wbuf.size()); + ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0); + ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0); +} + +TEST(SpliceTest, FromPipeMaxFileSize) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Fill with some random data. + std::vector<char> buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize)); + + // Open the input file. + const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor out_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); + + EXPECT_THAT(ftruncate(out_fd.get(), 13 << 20), SyscallSucceeds()); + EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_END), + SyscallSucceedsWithValue(13 << 20)); + + // Set our file size limit. + sigset_t set; + sigemptyset(&set); + sigaddset(&set, SIGXFSZ); + TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); + rlimit rlim = {}; + rlim.rlim_cur = rlim.rlim_max = (13 << 20); + EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &rlim), SyscallSucceeds()); + + // Splice to the output file. + EXPECT_THAT( + splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3 * kPageSize, 0), + SyscallFailsWithErrno(EFBIG)); + + // Contents should be equal. + std::vector<char> rbuf(kPageSize); + ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(kPageSize)); + EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc index 59fb5dfe6..7e73325bf 100644 --- a/test/syscalls/linux/sticky.cc +++ b/test/syscalls/linux/sticky.cc @@ -17,9 +17,11 @@ #include <sys/prctl.h> #include <sys/types.h> #include <unistd.h> + #include <vector> #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" @@ -27,8 +29,8 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -DEFINE_int32(scratch_uid, 65534, "first scratch UID"); -DEFINE_int32(scratch_gid, 65534, "first scratch GID"); +ABSL_FLAG(int32_t, scratch_uid, 65534, "first scratch UID"); +ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID"); namespace gvisor { namespace testing { @@ -52,10 +54,12 @@ TEST(StickyTest, StickyBitPermDenied) { } // Change EUID and EGID. - EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1), - SyscallSucceeds()); - EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid, -1), - SyscallSucceeds()); + EXPECT_THAT( + syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1), + SyscallSucceeds()); + EXPECT_THAT( + syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1), + SyscallSucceeds()); EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM)); }); @@ -78,8 +82,9 @@ TEST(StickyTest, StickyBitSameUID) { } // Change EGID. - EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1), - SyscallSucceeds()); + EXPECT_THAT( + syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1), + SyscallSucceeds()); // We still have the same EUID. EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds()); @@ -101,10 +106,12 @@ TEST(StickyTest, StickyBitCapFOWNER) { EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds()); // Change EUID and EGID. - EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1), - SyscallSucceeds()); - EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid, -1), - SyscallSucceeds()); + EXPECT_THAT( + syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1), + SyscallSucceeds()); + EXPECT_THAT( + syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1), + SyscallSucceeds()); EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true)); EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds()); diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index fd42e81e1..3db18d7ac 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -23,6 +23,7 @@ #include <atomic> #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/cleanup.h" @@ -33,8 +34,8 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -DEFINE_bool(timers_test_sleep, false, - "If true, sleep forever instead of running tests."); +ABSL_FLAG(bool, timers_test_sleep, false, + "If true, sleep forever instead of running tests."); using ::testing::_; using ::testing::AnyOf; @@ -635,7 +636,7 @@ TEST(IntervalTimerTest, IgnoredSignalCountsAsOverrun) { int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - if (FLAGS_timers_test_sleep) { + if (absl::GetFlag(FLAGS_timers_test_sleep)) { while (true) { absl::SleepFor(absl::Seconds(10)); } diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index bf1ca8679..d48453a93 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -18,6 +18,7 @@ #include <unistd.h> #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "test/util/capability_util.h" @@ -25,10 +26,10 @@ #include "test/util/test_util.h" #include "test/util/thread_util.h" -DEFINE_int32(scratch_uid1, 65534, "first scratch UID"); -DEFINE_int32(scratch_uid2, 65533, "second scratch UID"); -DEFINE_int32(scratch_gid1, 65534, "first scratch GID"); -DEFINE_int32(scratch_gid2, 65533, "second scratch GID"); +ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID"); +ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); +ABSL_FLAG(int32_t, scratch_gid1, 65534, "first scratch GID"); +ABSL_FLAG(int32_t, scratch_gid2, 65533, "second scratch GID"); using ::testing::UnorderedElementsAreArray; @@ -146,7 +147,7 @@ TEST(UidGidRootTest, Setuid) { // real UID. EXPECT_THAT(syscall(SYS_setuid, -1), SyscallFailsWithErrno(EINVAL)); - const uid_t uid = FLAGS_scratch_uid1; + const uid_t uid = absl::GetFlag(FLAGS_scratch_uid1); EXPECT_THAT(syscall(SYS_setuid, uid), SyscallSucceeds()); // "If the effective UID of the caller is root (more precisely: if the // caller has the CAP_SETUID capability), the real UID and saved set-user-ID @@ -160,7 +161,7 @@ TEST(UidGidRootTest, Setgid) { EXPECT_THAT(setgid(-1), SyscallFailsWithErrno(EINVAL)); - const gid_t gid = FLAGS_scratch_gid1; + const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); ASSERT_THAT(setgid(gid), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid)); } @@ -168,7 +169,7 @@ TEST(UidGidRootTest, Setgid) { TEST(UidGidRootTest, SetgidNotFromThreadGroupLeader) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - const gid_t gid = FLAGS_scratch_gid1; + const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); // NOTE(b/64676707): Do setgid in a separate thread so that we can test if // info.si_pid is set correctly. ScopedThread([gid] { ASSERT_THAT(setgid(gid), SyscallSucceeds()); }); @@ -189,8 +190,8 @@ TEST(UidGidRootTest, Setreuid) { // cannot be opened by the `uid` set below after the test. After calling // setuid(non-zero-UID), there is no way to get root privileges back. ScopedThread([&] { - const uid_t ruid = FLAGS_scratch_uid1; - const uid_t euid = FLAGS_scratch_uid2; + const uid_t ruid = absl::GetFlag(FLAGS_scratch_uid1); + const uid_t euid = absl::GetFlag(FLAGS_scratch_uid2); // Use syscall instead of glibc setuid wrapper because we want this setuid // call to only apply to this task. posix threads, however, require that all @@ -211,8 +212,8 @@ TEST(UidGidRootTest, Setregid) { EXPECT_THAT(setregid(-1, -1), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0)); - const gid_t rgid = FLAGS_scratch_gid1; - const gid_t egid = FLAGS_scratch_gid2; + const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1); + const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2); ASSERT_THAT(setregid(rgid, egid), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid)); } diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc index b6f65e027..2040375c9 100644 --- a/test/syscalls/linux/unlink.cc +++ b/test/syscalls/linux/unlink.cc @@ -123,6 +123,8 @@ TEST(UnlinkTest, AtBad) { SyscallSucceeds()); EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", AT_REMOVEDIR), SyscallFailsWithErrno(ENOTDIR)); + EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile/", 0), + SyscallFailsWithErrno(ENOTDIR)); ASSERT_THAT(close(fd), SyscallSucceeds()); EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds()); diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc index f67b06f37..0aaba482d 100644 --- a/test/syscalls/linux/vfork.cc +++ b/test/syscalls/linux/vfork.cc @@ -22,14 +22,15 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/flags/flag.h" #include "absl/time/time.h" #include "test/util/logging.h" #include "test/util/multiprocess_util.h" #include "test/util/test_util.h" #include "test/util/time_util.h" -DEFINE_bool(vfork_test_child, false, - "If true, run the VforkTest child workload."); +ABSL_FLAG(bool, vfork_test_child, false, + "If true, run the VforkTest child workload."); namespace gvisor { namespace testing { @@ -186,7 +187,7 @@ int RunChild() { int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - if (FLAGS_vfork_test_child) { + if (absl::GetFlag(FLAGS_vfork_test_child)) { return gvisor::testing::RunChild(); } diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go index e900f8abc..c1e9ce22c 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/syscalls/syscall_test_runner.go @@ -20,12 +20,10 @@ import ( "flag" "fmt" "io/ioutil" - "math" "os" "os/exec" "os/signal" "path/filepath" - "strconv" "strings" "syscall" "testing" @@ -358,32 +356,14 @@ func main() { fatalf("ParseTestCases(%q) failed: %v", testBin, err) } - // If sharding, then get the subset of tests to run based on the shard index. - if indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS"); indexStr != "" && totalStr != "" { - // Parse index and total to ints. - index, err := strconv.Atoi(indexStr) - if err != nil { - fatalf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err) - } - total, err := strconv.Atoi(totalStr) - if err != nil { - fatalf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err) - } - // Calculate subslice of tests to run. - shardSize := int(math.Ceil(float64(len(testCases)) / float64(total))) - begin := index * shardSize - // Set end as begin of next subslice. - end := ((index + 1) * shardSize) - if begin > len(testCases) { - // Nothing to run. - return - } - if end > len(testCases) { - end = len(testCases) - } - testCases = testCases[begin:end] + // Get subset of tests corresponding to shard. + begin, end, err := testutil.TestBoundsForShard(len(testCases)) + if err != nil { + fatalf("TestsForShard() failed: %v", err) } + testCases = testCases[begin:end] + // Run the tests. var tests []testing.InternalTest for _, tc := range testCases { // Capture tc. diff --git a/test/util/BUILD b/test/util/BUILD index 8afd89d8d..25ed9c944 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,3 +1,4 @@ +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("//test/syscalls:build_defs.bzl", "select_for_linux") package( @@ -190,6 +191,17 @@ cc_test( ) cc_library( + name = "pty_util", + testonly = 1, + srcs = ["pty_util.cc"], + hdrs = ["pty_util.h"], + deps = [ + ":file_descriptor", + ":posix_error", + ], +) + +cc_library( name = "signal_util", testonly = 1, srcs = ["signal_util.cc"], @@ -227,8 +239,9 @@ cc_library( ":logging", ":posix_error", ":save_util", - "@com_github_gflags_gflags//:gflags", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", diff --git a/test/util/memory_util.h b/test/util/memory_util.h index 190c469b5..e189b73e8 100644 --- a/test/util/memory_util.h +++ b/test/util/memory_util.h @@ -118,6 +118,18 @@ inline PosixErrorOr<Mapping> MmapAnon(size_t length, int prot, int flags) { return Mmap(nullptr, length, prot, flags | MAP_ANONYMOUS, -1, 0); } +// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of +// void* isn't directly compatible with SyscallSucceeds. +inline PosixErrorOr<void*> Mremap(void* old_address, size_t old_size, + size_t new_size, int flags, + void* new_address) { + void* rv = mremap(old_address, old_size, new_size, flags, new_address); + if (rv == MAP_FAILED) { + return PosixError(errno, "mremap failed"); + } + return rv; +} + // Returns true if the page containing addr is mapped. inline bool IsMapped(uintptr_t addr) { int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)), diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc new file mode 100644 index 000000000..c0fd9a095 --- /dev/null +++ b/test/util/pty_util.cc @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/pty_util.h" + +#include <sys/ioctl.h> +#include <termios.h> + +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { + // Get pty index. + int n; + int ret = ioctl(master.get(), TIOCGPTN, &n); + if (ret < 0) { + return PosixError(errno, "ioctl(TIOCGPTN) failed"); + } + + // Unlock pts. + int unlock = 0; + ret = ioctl(master.get(), TIOCSPTLCK, &unlock); + if (ret < 0) { + return PosixError(errno, "ioctl(TIOSPTLCK) failed"); + } + + return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/pty_util.h b/test/util/pty_util.h new file mode 100644 index 000000000..367b14f15 --- /dev/null +++ b/test/util/pty_util.h @@ -0,0 +1,30 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_UTIL_PTY_UTIL_H_ +#define GVISOR_TEST_UTIL_PTY_UTIL_H_ + +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +// Opens the slave end of the passed master as R/W and nonblocking. +PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_UTIL_PTY_UTIL_H_ diff --git a/test/util/test_util.cc b/test/util/test_util.cc index e42bba04a..ba0dcf7d0 100644 --- a/test/util/test_util.cc +++ b/test/util/test_util.cc @@ -28,6 +28,8 @@ #include <vector> #include "absl/base/attributes.h" +#include "absl/flags/flag.h" // IWYU pragma: keep +#include "absl/flags/parse.h" // IWYU pragma: keep #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" @@ -224,7 +226,7 @@ bool Equivalent(uint64_t current, uint64_t target, double tolerance) { void TestInit(int* argc, char*** argv) { ::testing::InitGoogleTest(argc, *argv); - ::gflags::ParseCommandLineFlags(argc, argv, true); + ::absl::ParseCommandLine(*argc, *argv); // Always mask SIGPIPE as it's common and tests aren't expected to handle it. struct sigaction sa = {}; diff --git a/test/util/test_util.h b/test/util/test_util.h index cdbe8bfd1..b9d2dc2ba 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -185,7 +185,6 @@ #include <utility> #include <vector> -#include <gflags/gflags.h> #include "gmock/gmock.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" diff --git a/third_party/gvsync/downgradable_rwmutex_unsafe.go b/third_party/gvsync/downgradable_rwmutex_unsafe.go index 069939033..1f6007aa1 100644 --- a/third_party/gvsync/downgradable_rwmutex_unsafe.go +++ b/third_party/gvsync/downgradable_rwmutex_unsafe.go @@ -57,9 +57,6 @@ func (rw *DowngradableRWMutex) RLock() { // RUnlock undoes a single RLock call. func (rw *DowngradableRWMutex) RUnlock() { if RaceEnabled { - // TODO(jamieliu): Why does this need to be ReleaseMerge instead of - // Release? IIUC this establishes Unlock happens-before RUnlock, which - // seems unnecessary. RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) RaceDisable() } diff --git a/tools/go_branch.sh b/tools/go_branch.sh index d9e79401d..ddb9b6e7b 100755 --- a/tools/go_branch.sh +++ b/tools/go_branch.sh @@ -59,7 +59,11 @@ git checkout -b go "${go_branch}" # Start working on a merge commit that combines the previous history with the # current history. Note that we don't actually want any changes yet. -git merge --allow-unrelated-histories --no-commit --strategy ours ${head} +# +# N.B. The git behavior changed at some point and the relevant flag was added +# to allow for override, so try the only behavior first then pass the flag. +git merge --no-commit --strategy ours ${head} || \ + git merge --allow-unrelated-histories --no-commit --strategy ours ${head} # Sync the entire gopath_dir and go.mod. rsync --recursive --verbose --delete --exclude .git --exclude README.md -L "${gopath_dir}/" . diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD new file mode 100644 index 000000000..c862b277c --- /dev/null +++ b/tools/go_marshal/BUILD @@ -0,0 +1,14 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "go_marshal", + srcs = ["main.go"], + visibility = [ + "//:sandbox", + ], + deps = [ + "//tools/go_marshal/gomarshal", + ], +) diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md new file mode 100644 index 000000000..481575bd3 --- /dev/null +++ b/tools/go_marshal/README.md @@ -0,0 +1,164 @@ +This package implements the go_marshal utility. + +# Overview + +`go_marshal` is a code generation utility similar to `go_stateify` for +automatically generating code to marshal go data structures to memory. + +`go_marshal` attempts to improve on `binary.Write` and the sentry's +`binary.Marshal` by moving the go runtime reflection necessary to marshal a +struct to compile-time. + +`go_marshal` automatically generates implementations for `abi.Marshallable` and +`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall +implementations) can directly invoke `safemem.Reader.ReadToBlocks` and +`safemem.Writer.WriteFromBlocks`. Data structures that require custom +serialization will have manual implementations for these interfaces. + +Data structures can be flagged for code generation by adding a struct-level +comment `// +marshal`. + +# Usage + +See `defs.bzl`: two new rules are provided, `go_marshal` and `go_library`. + +The recommended way to generate a go library with marshalling is to use the +`go_library` with mostly identical configuration as the native go_library rule. + +``` +load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library") + +go_library( + name = "foo", + srcs = ["foo.go"], +) +``` + +Under the hood, the `go_marshal` rule is used to generate a file that will +appear in a Go target; the output file should appear explicitly in a srcs list. +For example (note that the above is the preferred method): + +``` +load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_marshal") + +go_marshal( + name = "foo_abi", + srcs = ["foo.go"], + out = "foo_abi.go", + package = "foo", +) + +go_library( + name = "foo", + srcs = [ + "foo.go", + "foo_abi.go", + ], + deps = [ + "<PKGPATH>/gvisor/pkg/abi", + "<PKGPATH>/gvisor/pkg/sentry/safemem/safemem", + "<PKGPATH>/gvisor/pkg/sentry/usermem/usermem", + ], +) +``` + +As part of the interface generation, `go_marshal` also generates some tests for +sanity checking the struct definitions for potential alignment issues, and a +simple round-trip test through Marshal/Unmarshal to verify the implementation. +These tests use reflection to verify properties of the ABI struct, and should be +considered part of the generated interfaces (but are too expensive to execute at +runtime). Ensure these tests run at some point. + +``` +$ cat BUILD +load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library") + +go_library( + name = "foo", + srcs = ["foo.go"], +) +$ blaze build :foo +$ blaze query ... +<path-to-dir>:foo_abi_autogen +<path-to-dir>:foo_abi_autogen_test +$ blaze test :foo_abi_autogen_test +<test-output> +``` + +# Restrictions + +Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is +intended for ABI structs, which have these additional restrictions: + +- At the moment, `go_marshal` only supports struct declarations. + +- Structs are marshalled as packed types. This means no implicit padding is + inserted between fields shorter than the platform register size. For + alignment, manually insert padding fields. + +- Structs used with `go_marshal` must have a compile-time static size. This + means no dynamically sizes fields like slices or strings. Use statically + sized array (byte arrays for strings) instead. + +- No pointers, channel, map or function pointer fields, and no fields that are + arrays of these types. These don't make sense in an ABI data structure. + +- We could support opaque pointers as `uintptr`, but this is currently not + implemented. Implementing this would require handling the architecture + dependent native pointer size. + +- Fields must either be a primitive integer type (`byte`, + `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable. + +- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric + type. + +- `float*` fields are currently not supported, but could be if necessary. + +# Appendix + +## Working with Non-Packed Structs + +ABI structs must generally be packed types, meaning they should have no implicit +padding between short fields. However, if a field is tagged +`marshal:"unaligned"`, `go_marshal` will fall back to a safer but slower +mechanism to deal with potentially unaligned fields. + +Note that the non-packed property is inheritted by any other struct that embeds +this struct, since the `go_marshal` tool currently can't reason about alignments +for embedded structs that are not aligned. + +Because of this, it's generally best to avoid using `marshal:"unaligned"` and +insert explicit padding fields instead. + +## Debugging go_marshal + +To enable debugging output from the go marshal tool, pass the `-debug` flag to +the tool. When using the build rules from above, add a `debug = True` field to +the build rule like this: + +``` +load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library") + +go_library( + name = "foo", + srcs = ["foo.go"], + debug = True, +) +``` + +## Modifying the `go_marshal` Tool + +The following are some guidelines for modifying the `go_marshal` tool: + +- The `go_marshal` tool currently does a single pass over all types requesting + code generation, in arbitrary order. This means the generated code can't + directly obtain information about embedded marshallable types at + compile-time. One way to work around this restriction is to add a new + Marshallable interface method providing this piece of information, and + calling it from the generated code. Use this sparingly, as we want to rely + on compile-time information as much as possible for performance. + +- No runtime reflection in the code generated for the marshallable interface. + The entire point of the tool is to avoid runtime reflection. The generated + tests may use reflection. diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD new file mode 100644 index 000000000..c859ced77 --- /dev/null +++ b/tools/go_marshal/analysis/BUILD @@ -0,0 +1,13 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "analysis", + testonly = 1, + srcs = ["analysis_unsafe.go"], + importpath = "gvisor.dev/gvisor/tools/go_marshal/analysis", + visibility = [ + "//:sandbox", + ], +) diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go new file mode 100644 index 000000000..9a9a4f298 --- /dev/null +++ b/tools/go_marshal/analysis/analysis_unsafe.go @@ -0,0 +1,175 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package analysis implements common functionality used by generated +// go_marshal tests. +package analysis + +// All functions in this package are unsafe and are not intended for general +// consumption. They contain sharp edge cases and the caller is responsible for +// ensuring none of them are hit. Callers must be carefully to pass in only sane +// arguments. Failure to do so may cause panics at best and arbitrary memory +// corruption at worst. +// +// Never use outside of tests. + +import ( + "fmt" + "math/rand" + "reflect" + "testing" + "unsafe" +) + +// RandomizeValue assigns random value(s) to an abitrary type. This is intended +// for used with ABI structs from go_marshal, meaning the typical restrictions +// apply (fixed-size types, no pointers, maps, channels, etc), and should only +// be used on zeroed values to avoid overwriting pointers to active go objects. +// +// Internally, we populate the type with random data by doing an unsafe cast to +// access the underlying memory of the type and filling it as if it were a byte +// slice. This almost gets us what we want, but padding fields named "_" are +// normally not accessible, so we walk the type and recursively zero all "_" +// fields. +// +// Precondition: x must be a pointer. x must not contain any valid +// pointers to active go objects (pointer fields aren't allowed in ABI +// structs anyways), or we'd be violating the go runtime contract and +// the GC may malfunction. +func RandomizeValue(x interface{}) { + v := reflect.Indirect(reflect.ValueOf(x)) + if !v.CanSet() { + panic("RandomizeType() called with an unaddressable value. You probably need to pass a pointer to the argument") + } + + // Cast the underlying memory for the type into a byte slice. + var b []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + // Note: v.UnsafeAddr panics if x is passed by value. x should be a pointer. + hdr.Data = v.UnsafeAddr() + hdr.Len = int(v.Type().Size()) + hdr.Cap = hdr.Len + + // Fill the byte slice with random data, which in effect fills the type with + // random values. + n, err := rand.Read(b) + if err != nil || n != len(b) { + panic("unreachable") + } + + // Normally, padding fields are not accessible, so zero them out. + reflectZeroPaddingFields(v.Type(), b, false) +} + +// reflectZeroPaddingFields assigns zero values to padding fields for the value +// of type r, represented by the memory in data. Padding fields are defined as +// fields with the name "_". If zero is true, the immediate value itself is +// zeroed. In addition, the type is recursively scanned for padding fields in +// inner types. +// +// This is used for zeroing padding fields after calling RandomizeValue. +func reflectZeroPaddingFields(r reflect.Type, data []byte, zero bool) { + if zero { + for i, _ := range data { + data[i] = 0 + } + } + switch r.Kind() { + case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64: + // These types are explicitly allowed in an ABI type, but we don't need + // to recurse further as they're scalar types. + case reflect.Struct: + for i, numFields := 0, r.NumField(); i < numFields; i++ { + f := r.Field(i) + off := f.Offset + len := f.Type.Size() + window := data[off : off+len] + reflectZeroPaddingFields(f.Type, window, f.Name == "_") + } + case reflect.Array: + eLen := int(r.Elem().Size()) + if int(r.Size()) != eLen*r.Len() { + panic("Array has unexpected size?") + } + for i, n := 0, r.Len(); i < n; i++ { + reflectZeroPaddingFields(r.Elem(), data[i*eLen:(i+1)*eLen], false) + } + default: + panic(fmt.Sprintf("Type %v not allowed in ABI struct", r.Kind())) + + } +} + +// AlignmentCheck ensures the definition of the type represented by typ doesn't +// cause the go compiler to emit implicit padding between elements of the type +// (i.e. fields in a struct). +// +// AlignmentCheck doesn't explicitly recurse for embedded structs because any +// struct present in an ABI struct must also be Marshallable, and therefore +// they're aligned by definition (or their alignment check would have failed). +func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) { + switch typ.Kind() { + case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64: + // Primitive types are always considered well aligned. Primitive types + // that are fields in structs are checked independently, this branch + // exists to handle recursive calls to alignmentCheck. + case reflect.Struct: + xOff := 0 + nextXOff := 0 + skipNext := false + for i, numFields := 0, typ.NumField(); i < numFields; i++ { + xOff = nextXOff + f := typ.Field(i) + fmt.Printf("Checking alignment of %s.%s @ %d [+%d]...\n", typ.Name(), f.Name, f.Offset, f.Type.Size()) + nextXOff = int(f.Offset + f.Type.Size()) + + if f.Name == "_" { + // Padding fields need not be aligned. + fmt.Printf("Padding field of type %v\n", f.Type) + continue + } + + if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" { + skipNext = true + continue + } + + if skipNext { + skipNext = false + fmt.Printf("Skipping alignment check for field %s.%s explicitly marked as unaligned.\n", typ.Name(), f.Name) + continue + } + + if xOff != int(f.Offset) { + implicitPad := int(f.Offset) - xOff + t.Fatalf("Suspect offset for field %s.%s, detected an implicit %d byte padding from offset %d to %d; either add %d bytes of explicit padding before this field or tag it as `marshal:\"unaligned\"`.", typ.Name(), f.Name, implicitPad, xOff, f.Offset, implicitPad) + } + } + + // Ensure structs end on a byte explicitly defined by the type. + if typ.NumField() > 0 && nextXOff != int(typ.Size()) { + implicitPad := int(typ.Size()) - nextXOff + f := typ.Field(typ.NumField() - 1) // Final field + t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.", + typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name) + } + case reflect.Array: + // Independent arrays are also always considered well aligned. We only + // need to worry about their alignment when they're embedded in structs, + // which we handle above. + default: + t.Fatalf("Unsupported type in ABI struct while checking for field alignment for type: %v", typ.Kind()) + } + return true, uint64(typ.Size()) +} diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl new file mode 100644 index 000000000..c32eb559f --- /dev/null +++ b/tools/go_marshal/defs.bzl @@ -0,0 +1,152 @@ +"""Marshal is a tool for generating marshalling interfaces for Go types. + +The recommended way is to use the go_library rule defined below with mostly +identical configuration as the native go_library rule. + +load("//tools/go_marshal:defs.bzl", "go_library") + +go_library( + name = "foo", + srcs = ["foo.go"], +) + +Under the hood, the go_marshal rule is used to generate a file that will +appear in a Go target; the output file should appear explicitly in a srcs list. +For example (the above is still the preferred way): + +load("//tools/go_marshal:defs.bzl", "go_marshal") + +go_marshal( + name = "foo_abi", + srcs = ["foo.go"], + out = "foo_abi.go", + package = "foo", +) + +go_library( + name = "foo", + srcs = [ + "foo.go", + "foo_abi.go", + ], + deps = [ + "//tools/go_marshal:marshal", + "//pkg/sentry/platform/safecopy", + "//pkg/sentry/usermem", + ], +) +""" + +load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test") + +def _go_marshal_impl(ctx): + """Execute the go_marshal tool.""" + output = ctx.outputs.lib + output_test = ctx.outputs.test + (build_dir, _, _) = ctx.build_file_path.rpartition("/BUILD") + + decl = "/".join(["gvisor.dev/gvisor", build_dir]) + + # Run the marshal command. + args = ["-output=%s" % output.path] + args += ["-pkg=%s" % ctx.attr.package] + args += ["-output_test=%s" % output_test.path] + args += ["-declarationPkg=%s" % decl] + + if ctx.attr.debug: + args += ["-debug"] + + args += ["--"] + for src in ctx.attr.srcs: + args += [f.path for f in src.files.to_list()] + ctx.actions.run( + inputs = ctx.files.srcs, + outputs = [output, output_test], + mnemonic = "GoMarshal", + progress_message = "go_marshal: %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + +# Generates save and restore logic from a set of Go files. +# +# Args: +# name: the name of the rule. +# srcs: the input source files. These files should include all structs in the +# package that need to be saved. +# imports: an optional list of extra, non-aliased, Go-style absolute import +# paths. +# out: the name of the generated file output. This must not conflict with any +# other files and must be added to the srcs of the relevant go_library. +# package: the package name for the input sources. +go_marshal = rule( + implementation = _go_marshal_impl, + attrs = { + "srcs": attr.label_list(mandatory = True, allow_files = True), + "libname": attr.string(mandatory = True), + "imports": attr.string_list(mandatory = False), + "package": attr.string(mandatory = True), + "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"), + "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_marshal:go_marshal")), + }, + outputs = { + "lib": "%{name}_unsafe.go", + "test": "%{name}_test.go", + }, +) + +def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs): + """wraps the standard go_library and does mashalling interface generation. + + Args: + name: Same as native go_library. + srcs: Same as native go_library. + deps: Same as native go_library. + imports: Extra import paths to pass to the go_marshal tool. + debug: Enables debugging output from the go_marshal tool. + **kwargs: Remaining args to pass to the native go_library rule unmodified. + """ + go_marshal( + name = name + "_abi_autogen", + libname = name, + srcs = [src for src in srcs if src.endswith(".go")], + debug = debug, + imports = imports, + package = name, + ) + + extra_deps = [ + "//tools/go_marshal/marshal", + "//pkg/sentry/platform/safecopy", + "//pkg/sentry/usermem", + ] + + all_srcs = srcs + [name + "_abi_autogen_unsafe.go"] + all_deps = deps + [] # + extra_deps + + for extra in extra_deps: + if extra not in deps: + all_deps.append(extra) + + _go_library( + name = name, + srcs = all_srcs, + deps = all_deps, + **kwargs + ) + + # Don't pass importpath arg to go_test. + kwargs.pop("importpath", "") + + _go_test( + name = name + "_abi_autogen_test", + srcs = [name + "_abi_autogen_test.go"], + # Generated test has a fixed set of dependencies since we generate these + # tests. They should only depend on the library generated above, and the + # Marshallable interface. + deps = [ + ":" + name, + "//tools/go_marshal/analysis", + ], + **kwargs + ) diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD new file mode 100644 index 000000000..a0eae6492 --- /dev/null +++ b/tools/go_marshal/gomarshal/BUILD @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "gomarshal", + srcs = [ + "generator.go", + "generator_interfaces.go", + "generator_tests.go", + "util.go", + ], + importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal", + visibility = [ + "//:sandbox", + ], +) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go new file mode 100644 index 000000000..641ccd938 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator.go @@ -0,0 +1,382 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gomarshal implements the go_marshal code generator. See README.md. +package gomarshal + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "sort" +) + +const ( + marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal" + usermemImport = "gvisor.dev/gvisor/pkg/sentry/usermem" + safecopyImport = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" +) + +// List of identifiers we use in generated code, that may conflict a +// similarly-named source identifier. Avoid problems by refusing the generate +// code when we see these. +// +// This only applies to import aliases at the moment. All other identifiers +// are qualified by a receiver argument, since they're struct fields. +// +// All recievers are single letters, so we don't allow import aliases to be a +// single letter. +var badIdents = []string{ + "src", "srcs", "dst", "dsts", "blk", "buf", "err", + // All single-letter identifiers. +} + +// Generator drives code generation for a single invocation of the go_marshal +// utility. +// +// The Generator holds arguments passed to the tool, and drives parsing, +// processing and code Generator for all types marked with +marshal declared in +// the input files. +// +// See Generator.run() as the entry point. +type Generator struct { + // Paths to input go source files. + inputs []string + // Output file to write generated go source. + output *os.File + // Output file to write generated tests. + outputTest *os.File + // Package name for the generated file. + pkg string + // Go import path for package we're processing. This package should directly + // declare the type we're generating code for. + declaration string + // Set of extra packages to import in the generated file. + imports *importTable +} + +// NewGenerator creates a new code Generator. +func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) { + f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) + } + fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) + } + g := Generator{ + inputs: srcs, + output: f, + outputTest: fTest, + pkg: pkg, + declaration: declaration, + imports: newImportTable(), + } + for _, i := range imports { + // All imports on the extra imports list are unconditionally marked as + // used, so they're always added to the generated code. + g.imports.add(i).markUsed() + } + g.imports.add(marshalImport).markUsed() + // The follow imports may or may not be used by the generated + // code, depending what's required for the target types. Don't + // mark these imports as used by default. + g.imports.add(usermemImport) + g.imports.add(safecopyImport) + g.imports.add("unsafe") + + return &g, nil +} + +// writeHeader writes the header for the generated source file. The header +// includes the package name, package level comments and import statements. +func (g *Generator) writeHeader() error { + var b sourceBuffer + b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") + b.emit("package %s\n\n", g.pkg) + if err := b.write(g.output); err != nil { + return err + } + + return g.imports.write(g.output) +} + +// writeTypeChecks writes a statement to force the compiler to perform a type +// check for all Marshallable types referenced by the generated code. +func (g *Generator) writeTypeChecks(ms map[string]struct{}) error { + if len(ms) == 0 { + return nil + } + + msl := make([]string, 0, len(ms)) + for m, _ := range ms { + msl = append(msl, m) + } + sort.Strings(msl) + + var buf bytes.Buffer + fmt.Fprint(&buf, "// Marshallable types used by this file.\n") + + for _, m := range msl { + fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m) + } + fmt.Fprint(&buf, "\n") + + _, err := fmt.Fprint(g.output, buf.String()) + return err +} + +// parse processes all input files passed this generator and produces a set of +// parsed go ASTs. +func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { + debugf("go_marshal invoked with %d input files:\n", len(g.inputs)) + for _, path := range g.inputs { + debugf(" %s\n", path) + } + + files := make([]*ast.File, 0, len(g.inputs)) + fsets := make([]*token.FileSet, 0, len(g.inputs)) + + for _, path := range g.inputs { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) + if err != nil { + // Not a valid input file? + return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err) + } + + if debugEnabled() { + debugf("AST for %q:\n", path) + ast.Print(fset, f) + } + + files = append(files, f) + fsets = append(fsets, fset) + } + + return files, fsets, nil +} + +// collectMarshallabeTypes walks the parsed AST and collects a list of type +// declarations for which we need to generate the Marshallable interface. +func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { + var types []*ast.TypeSpec + for _, decl := range a.Decls { + gdecl, ok := decl.(*ast.GenDecl) + // Type declaration? + if !ok || gdecl.Tok != token.TYPE { + debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") + continue + } + // Does it have a comment? + if gdecl.Doc == nil { + debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n") + continue + } + // Does the comment contain a "+marshal" line? + marked := false + for _, c := range gdecl.Doc.List { + if c.Text == "// +marshal" { + marked = true + break + } + } + if !marked { + debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n") + continue + } + for _, spec := range gdecl.Specs { + // We already confirmed we're in a type declaration earlier. + t := spec.(*ast.TypeSpec) + if _, ok := t.Type.(*ast.StructType); ok { + debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name) + types = append(types, t) + continue + } + debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl) + } + } + return types +} + +// collectImports collects all imports from all input source files. Some of +// these imports are copied to the generated output, if they're referenced by +// the generated code. +// +// collectImports de-duplicates imports while building the list, and ensures +// identifiers in the generated code don't conflict with any imported package +// names. +func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { + badImportNames := make(map[string]bool) + for _, i := range badIdents { + badImportNames[i] = true + } + + is := make(map[string]importStmt) + for _, decl := range a.Decls { + gdecl, ok := decl.(*ast.GenDecl) + // Import statement? + if !ok || gdecl.Tok != token.IMPORT { + continue + } + for _, spec := range gdecl.Specs { + i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f) + debugf("Collected import '%s' as '%s'\n", i.path, i.name) + + // Make sure we have an import that doesn't use any local names that + // would conflict with identifiers in the generated code. + if len(i.name) == 1 { + abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) + } + if badImportNames[i.name] { + abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) + } + } + } + return is + +} + +func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { + // We're guaranteed to have only struct type specs by now. See + // Generator.collectMarshallabeTypes. + i := newInterfaceGenerator(t, fset) + i.validate() + i.emitMarshallable() + return i +} + +// generateOneTestSuite generates a test suite for the automatically generated +// implementations type t. +func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator { + i := newTestGenerator(t, g.declaration) + i.emitTests() + return i +} + +// Run is the entry point to code generation using g. +// +// Run parses all input source files specified in g and emits generated code. +func (g *Generator) Run() error { + // Parse our input source files into ASTs and token sets. + asts, fsets, err := g.parse() + if err != nil { + return err + } + + if len(asts) != len(fsets) { + panic("ASTs and FileSets don't match") + } + + // Map of imports in source files; key = local package name, value = import + // path. + is := make(map[string]importStmt) + for i, a := range asts { + // Collect all imports from the source files. We may need to copy some + // of these to the generated code if they're referenced. This has to be + // done before the loop below because we need to process all ASTs before + // we start requesting imports to be copied one by one as we encounter + // them in each generated source. + for name, i := range g.collectImports(a, fsets[i]) { + is[name] = i + } + } + + var impls []*interfaceGenerator + var ts []*testGenerator + // Set of Marshallable types referenced by generated code. + ms := make(map[string]struct{}) + for i, a := range asts { + // Collect type declarations marked for code generation and generate + // Marshallable interfaces. + for _, t := range g.collectMarshallabeTypes(a, fsets[i]) { + impl := g.generateOne(t, fsets[i]) + // Collect Marshallable types referenced by the generated code. + for ref, _ := range impl.ms { + ms[ref] = struct{}{} + } + impls = append(impls, impl) + // Collect imports referenced by the generated code and add them to + // the list of imports we need to copy to the generated code. + for name, _ := range impl.is { + if !g.imports.markUsed(name) { + panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name)) + } + } + ts = append(ts, g.generateOneTestSuite(t)) + } + } + + // Tool was invoked with input files with no data structures marked for code + // generation. This is probably not what the user intended. + if len(impls) == 0 { + var buf bytes.Buffer + fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n") + for _, i := range g.inputs { + fmt.Fprintf(&buf, " %s\n", i) + } + abort(buf.String()) + } + + // Write output file header. These include things like package name and + // import statements. + if err := g.writeHeader(); err != nil { + return err + } + + // Write type checks for referenced marshallable types to output file. + if err := g.writeTypeChecks(ms); err != nil { + return err + } + + // Write generated interfaces to output file. + for _, i := range impls { + if err := i.write(g.output); err != nil { + return err + } + } + + // Write generated tests to test file. + return g.writeTests(ts) +} + +// writeTests outputs tests for the generated interface implementations to a go +// source file. +func (g *Generator) writeTests(ts []*testGenerator) error { + var b sourceBuffer + b.emit("package %s_test\n\n", g.pkg) + if err := b.write(g.outputTest); err != nil { + return err + } + + imports := newImportTable() + for _, t := range ts { + imports.merge(t.imports) + } + + if err := imports.write(g.outputTest); err != nil { + return err + } + + for _, t := range ts { + if err := t.write(g.outputTest); err != nil { + return err + } + } + return nil +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go new file mode 100644 index 000000000..a712c14dc --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -0,0 +1,507 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "fmt" + "go/ast" + "go/token" + "strings" +) + +// interfaceGenerator generates marshalling interfaces for a single type. +// +// getState is not thread-safe. +type interfaceGenerator struct { + sourceBuffer + + // The type we're serializing. + t *ast.TypeSpec + + // Receiver argument for generated methods. + r string + + // FileSet containing the tokens for the type we're processing. + f *token.FileSet + + // is records external packages referenced by the generated implementation. + is map[string]struct{} + + // ms records Marshallable types referenced by the generated implementation + // of t's interfaces. + ms map[string]struct{} + + // as records embedded fields in t that are potentially not packed. The key + // is the accessor for the field. + as map[string]struct{} +} + +// typeName returns the name of the type this g represents. +func (g *interfaceGenerator) typeName() string { + return g.t.Name.Name +} + +// newinterfaceGenerator creates a new interface generator. +func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { + if _, ok := t.Type.(*ast.StructType); !ok { + panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) + } + g := &interfaceGenerator{ + t: t, + r: receiverName(t), + f: fset, + is: make(map[string]struct{}), + ms: make(map[string]struct{}), + as: make(map[string]struct{}), + } + g.recordUsedMarshallable(g.typeName()) + return g +} + +func (g *interfaceGenerator) recordUsedMarshallable(m string) { + g.ms[m] = struct{}{} + +} + +func (g *interfaceGenerator) recordUsedImport(i string) { + g.is[i] = struct{}{} + +} + +func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { + g.as[fieldName] = struct{}{} +} + +func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) { + // This is guaranteed to succeed because g.t is always a struct. + st := g.t.Type.(*ast.StructType) + for _, field := range st.Fields.List { + fn(field) + } +} + +func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { + return fmt.Sprintf("%s.%s", g.r, n.Name) +} + +// abortAt aborts the go_marshal tool with the given error message, with a +// reference position to the input source. Same as abortAt, but uses g to +// resolve p to position. +func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { + abortAt(g.f.Position(p), msg) +} + +// validate ensures the type we're working with can be marshalled. These checks +// are done ahead of time and in one place so we can make assumptions later. +func (g *interfaceGenerator) validate() { + g.forEachField(func(f *ast.Field) { + if len(f.Names) == 0 { + g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") + } + }) + + g.forEachField(func(f *ast.Field) { + fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we + // provide suggestions for some disallowed types and reject + // them, then attempt to marshal any remaining types by + // invoking the marshal.Marshallable interface on them. If + // these types don't actually implement + // marshal.Marshallable, compilation of the generated code + // will fail with an appropriate error message. + return + case "int": + g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } + }, + selector: func(_, _, _ *ast.Ident) { + // No validation to perform on selector fields. However this + // callback must still be provided. + }, + array: func(n, _ *ast.Ident, len int) { + a := f.Type.(*ast.ArrayType) + if a.Len == nil { + g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) + } + + if _, ok := a.Len.(*ast.BasicLit); !ok { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions")) + } + + if _, ok := a.Elt.(*ast.Ident); !ok { + g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) + } + + if len <= 0 { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) + } + }, + unhandled: func(_ *ast.Ident) { + g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) + }, + }.dispatch(f) + }) +} + +// scalarSize returns the size of type identified by t. If t isn't a primitive +// type, the size isn't known at code generation time, and must be resolved via +// the marshal.Marshallable interface. +func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) { + switch t.Name { + case "int8", "uint8", "byte": + return 1, false + case "int16", "uint16": + return 2, false + case "int32", "uint32": + return 4, false + case "int64", "uint64": + return 8, false + default: + return 0, true + } +} + +func (g *interfaceGenerator) shift(bufVar string, n int) { + g.emit("%s = %s[%d:]\n", bufVar, bufVar, n) +} + +func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { + g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) +} + +func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(%s)\n", bufVar, accessor) + g.shift(bufVar, 1) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor) + g.shift(bufVar, 2) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor) + g.shift(bufVar, 4) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor) + g.shift(bufVar, 8) + default: + g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) + g.shiftDynamic(bufVar, accessor) + } +} + +func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) { + switch typ { + case "int8": + g.emit("%s = int8(%s[0])\n", accessor, bufVar) + g.shift(bufVar, 1) + case "uint8": + g.emit("%s = uint8(%s[0])\n", accessor, bufVar) + g.shift(bufVar, 1) + case "byte": + g.emit("%s = %s[0]\n", accessor, bufVar) + g.shift(bufVar, 1) + + case "int16": + g.recordUsedImport("usermem") + g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar) + g.shift(bufVar, 2) + case "uint16": + g.recordUsedImport("usermem") + g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar) + g.shift(bufVar, 2) + + case "int32": + g.recordUsedImport("usermem") + g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar) + g.shift(bufVar, 4) + case "uint32": + g.recordUsedImport("usermem") + g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar) + g.shift(bufVar, 4) + + case "int64": + g.recordUsedImport("usermem") + g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar) + g.shift(bufVar, 8) + case "uint64": + g.recordUsedImport("usermem") + g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar) + g.shift(bufVar, 8) + default: + g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) + g.shiftDynamic(bufVar, accessor) + g.recordPotentiallyNonPackedField(accessor) + } +} + +// areFieldsPackedExpression returns a go expression checking whether g.t's fields are +// packed. Returns "", false if g.t has no fields that may be potentially +// packed, otherwise returns <clause>, true, where <clause> is an expression +// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". +func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { + if len(g.as) == 0 { + return "", false + } + + cs := make([]string, 0, len(g.as)) + for accessor, _ := range g.as { + cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) + } + return strings.Join(cs, " && "), true +} + +func (g *interfaceGenerator) emitMarshallable() { + // Is g.t a packed struct without consideing field types? + thisPacked := true + g.forEachField(func(f *ast.Field) { + if f.Tag != nil { + if f.Tag.Value == "`marshal:\"unaligned\"`" { + if thisPacked { + debugfAt(g.f.Position(g.t.Pos()), + fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) + thisPacked = false + } + } + } + }) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + g.forEachField(fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n))) + } + }, + selector: func(n, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(n, t *ast.Ident, len int) { + if len < 1 { + // Zero-length arrays should've been rejected by validate(). + panic("unreachable") + } + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size * len + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.forEachField(fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for i := 0; i < %d; i++ {\n", size) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.forEachField(fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for i := 0; i < %d; i++ {\n", size) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + expr, fieldsMaybePacked := g.areFieldsPackedExpression() + switch { + case !thisPacked: + g.emit("return false\n") + case fieldsMaybePacked: + g.emit("return %s\n", expr) + default: + g.emit("return true\n") + + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.MarshalBytes(dst)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.UnmarshalBytes(src)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } + }) + g.emit("}\n\n") + +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go new file mode 100644 index 000000000..df25cb5b2 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -0,0 +1,154 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "fmt" + "go/ast" + "io" + "strings" +) + +var standardImports = []string{ + "fmt", + "reflect", + "testing", + "gvisor.dev/gvisor/tools/go_marshal/analysis", +} + +type testGenerator struct { + sourceBuffer + + // The type we're serializing. + t *ast.TypeSpec + + // Receiver argument for generated methods. + r string + + // Imports used by generated code. + imports *importTable + + // Import statement for the package declaring the type we generated code + // for. We need this to construct test instances for the type, since the + // tests aren't written in the same package. + decl *importStmt +} + +func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator { + if _, ok := t.Type.(*ast.StructType); !ok { + panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) + } + g := &testGenerator{ + t: t, + r: receiverName(t), + imports: newImportTable(), + } + + for _, i := range standardImports { + g.imports.add(i).markUsed() + } + g.decl = g.imports.add(declaration) + g.decl.markUsed() + + return g +} + +func (g *testGenerator) typeName() string { + return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name) +} + +func (g *testGenerator) forEachField(fn func(f *ast.Field)) { + // This is guaranteed to succeed because g.t is always a struct. + st := g.t.Type.(*ast.StructType) + for _, field := range st.Fields.List { + fn(field) + } +} + +func (g *testGenerator) testFuncName(base string) string { + return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) +} + +func (g *testGenerator) inTestFunction(name string, body func()) { + g.emit("func %s(t *testing.T) {\n", g.testFuncName(name)) + g.inIndent(body) + g.emit("}\n\n") +} + +func (g *testGenerator) emitTestNonZeroSize() { + g.inTestFunction("TestSizeNonZero", func() { + g.emit("x := &%s{}\n", g.typeName()) + g.emit("if x.SizeBytes() == 0 {\n") + g.inIndent(func() { + g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestSuspectAlignment() { + g.inTestFunction("TestSuspectAlignment", func() { + g.emit("x := %s{}\n", g.typeName()) + g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") + }) +} + +func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { + g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() { + g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("buf := make([]byte, x.SizeBytes())\n") + g.emit("x.MarshalBytes(buf)\n") + g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n") + g.emit("x.MarshalUnsafe(bufUnsafe)\n\n") + + g.emit("y.UnmarshalBytes(buf)\n") + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n\n") + + g.emit("z.UnmarshalUnsafe(buf)\n") + g.emit("if !reflect.DeepEqual(x, z) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n") + }) + g.emit("}\n") + g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") + g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTests() { + g.emitTestNonZeroSize() + g.emitTestSuspectAlignment() + g.emitTestMarshalUnmarshalPreservesData() +} + +func (g *testGenerator) write(out io.Writer) error { + return g.sourceBuffer.write(out) +} diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go new file mode 100644 index 000000000..967537abf --- /dev/null +++ b/tools/go_marshal/gomarshal/util.go @@ -0,0 +1,387 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/token" + "io" + "os" + "path" + "reflect" + "sort" + "strconv" + "strings" +) + +var debug = flag.Bool("debug", false, "enables debugging output") + +// receiverName returns an appropriate receiver name given a type spec. +func receiverName(t *ast.TypeSpec) string { + if len(t.Name.Name) < 1 { + // Zero length type name? + panic("unreachable") + } + return strings.ToLower(t.Name.Name[:1]) +} + +// kindString returns a user-friendly representation of an AST expr type. +func kindString(e ast.Expr) string { + switch e.(type) { + case *ast.Ident: + return "scalar" + case *ast.ArrayType: + return "array" + case *ast.StructType: + return "struct" + case *ast.StarExpr: + return "pointer" + case *ast.FuncType: + return "function" + case *ast.InterfaceType: + return "interface" + case *ast.MapType: + return "map" + case *ast.ChanType: + return "channel" + default: + return reflect.TypeOf(e).String() + } +} + +// fieldDispatcher is a collection of callbacks for handling different types of +// fields in a struct declaration. +type fieldDispatcher struct { + primitive func(n, t *ast.Ident) + selector func(n, tX, tSel *ast.Ident) + array func(n, t *ast.Ident, size int) + unhandled func(n *ast.Ident) +} + +// Precondition: All dispatch callbacks that will be invoked must be +// provided. Embedded fields are not allowed, len(f.Names) >= 1. +func (fd fieldDispatcher) dispatch(f *ast.Field) { + // Each field declaration may actually be multiple declarations of the same + // type. For example, consider: + // + // type Point struct { + // x, y, z int + // } + // + // We invoke the call-backs once per such instance. Embedded fields are not + // allowed, and results in a panic. + if len(f.Names) < 1 { + panic("Precondition not met: attempted to dispatch on embedded field") + } + + for _, name := range f.Names { + switch v := f.Type.(type) { + case *ast.Ident: + fd.primitive(name, v) + case *ast.SelectorExpr: + fd.selector(name, v.X.(*ast.Ident), v.Sel) + case *ast.ArrayType: + len := 0 + if v.Len != nil { + // Non-literal array length is handled by generatorInterfaces.validate(). + if lenLit, ok := v.Len.(*ast.BasicLit); ok { + var err error + len, err = strconv.Atoi(lenLit.Value) + if err != nil { + panic(err) + } + } + } + switch t := v.Elt.(type) { + case *ast.Ident: + fd.array(name, t, len) + default: + fd.array(name, nil, len) + } + default: + fd.unhandled(name) + } + } +} + +// debugEnabled indicates whether debugging is enabled for gomarshal. +func debugEnabled() bool { + return *debug +} + +// abort aborts the go_marshal tool with the given error message. +func abort(msg string) { + if !strings.HasSuffix(msg, "\n") { + msg += "\n" + } + fmt.Print(msg) + os.Exit(1) +} + +// abortAt aborts the go_marshal tool with the given error message, with +// a reference position to the input source. +func abortAt(p token.Position, msg string) { + abort(fmt.Sprintf("%v:\n %s\n", p, msg)) +} + +// debugf conditionally prints a debug message. +func debugf(f string, a ...interface{}) { + if debugEnabled() { + fmt.Printf(f, a...) + } +} + +// debugfAt conditionally prints a debug message with a reference to a position +// in the input source. +func debugfAt(p token.Position, f string, a ...interface{}) { + if debugEnabled() { + fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) + } +} + +// emit generates a line of code in the output file. +// +// emit is a wrapper around writing a formatted string to the output +// buffer. emit can be invoked in one of two ways: +// +// (1) emit("some string") +// When emit is called with a single string argument, it is simply copied to +// the output buffer without any further formatting. +// (2) emit(fmtString, args...) +// emit can also be invoked in a similar fashion to *Printf() functions, +// where the first argument is a format string. +// +// Calling emit with a single argument that is not a string will result in a +// panic, as the caller's intent is ambiguous. +func emit(out io.Writer, indent int, a ...interface{}) { + const spacesPerIndentLevel = 4 + + if len(a) < 1 { + panic("emit() called with no arguments") + } + + if indent > 0 { + if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { + // Writing to the emit output should not fail. Typically the output + // is a byte.Buffer; writes to these never fail. + panic(err) + } + } + + first, ok := a[0].(string) + if !ok { + // First argument must be either the string to emit (case 1 from + // function-level comment), or a format string (case 2). + panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) + } + + if len(a) == 1 { + // Single string argument. Assume no formatting requested. + if _, err := fmt.Fprint(out, first); err != nil { + // Writing to out should not fail. + panic(err) + } + return + + } + + // Formatting requested. + if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { + // Writing to out should not fail. + panic(err) + } +} + +// sourceBuffer represents fragments of generated go source code. +// +// sourceBuffer provides a convenient way to build up go souce fragments in +// memory. May be safely zero-value initialized. Not thread-safe. +type sourceBuffer struct { + // Current indentation level. + indent int + + // Memory buffer containing contents while they're being generated. + b bytes.Buffer +} + +func (b *sourceBuffer) incIndent() { + b.indent++ +} + +func (b *sourceBuffer) decIndent() { + if b.indent <= 0 { + panic("decIndent() without matching incIndent()") + } + b.indent-- +} + +func (b *sourceBuffer) emit(a ...interface{}) { + emit(&b.b, b.indent, a...) +} + +func (b *sourceBuffer) emitNoIndent(a ...interface{}) { + emit(&b.b, 0 /*indent*/, a...) +} + +func (b *sourceBuffer) inIndent(body func()) { + b.incIndent() + body() + b.decIndent() +} + +func (b *sourceBuffer) write(out io.Writer) error { + _, err := fmt.Fprint(out, b.b.String()) + return err +} + +// Write implements io.Writer.Write. +func (b *sourceBuffer) Write(buf []byte) (int, error) { + return (b.b.Write(buf)) +} + +// importStmt represents a single import statement. +type importStmt struct { + // Local name of the imported package. + name string + // Import path. + path string + // Indicates whether the local name is an alias, or simply the final + // component of the path. + aliased bool + // Indicates whether this import was referenced by generated code. + used bool +} + +func newImport(p string) *importStmt { + name := path.Base(p) + return &importStmt{ + name: name, + path: p, + aliased: false, + } +} + +func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { + p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. + name := path.Base(p) + if name == "" || name == "/" || name == "." { + panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", + f.Position(spec.Path.Pos()), name)) + } + if spec.Name != nil { + name = spec.Name.Name + } + return &importStmt{ + name: name, + path: p, + aliased: spec.Name != nil, + } +} + +func (i *importStmt) String() string { + if i.aliased { + return fmt.Sprintf("%s \"%s\"", i.name, i.path) + } + return fmt.Sprintf("\"%s\"", i.path) +} + +func (i *importStmt) markUsed() { + i.used = true +} + +func (i *importStmt) equivalent(other *importStmt) bool { + return i == other +} + +// importTable represents a collection of importStmts. +type importTable struct { + // Map of imports and whether they should be copied to the output. + is map[string]*importStmt +} + +func newImportTable() *importTable { + return &importTable{ + is: make(map[string]*importStmt), + } +} + +// Merges import statements from other into i. Collisions in import statements +// result in a panic. +func (i *importTable) merge(other *importTable) { + for name, im := range other.is { + if dup, ok := i.is[name]; ok && dup.equivalent(im) { + panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) + } + + i.is[name] = im + } +} + +func (i *importTable) add(s string) *importStmt { + n := newImport(s) + i.is[n.name] = n + return n +} + +func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { + n := newImportFromSpec(spec, f) + i.is[n.name] = n + return n +} + +// Marks the import named n as used. If no such import is in the table, returns +// false. +func (i *importTable) markUsed(n string) bool { + if n, ok := i.is[n]; ok { + n.markUsed() + return true + } + return false +} + +func (i *importTable) clear() { + for _, i := range i.is { + i.used = false + } +} + +func (i *importTable) write(out io.Writer) error { + if len(i.is) == 0 { + // Nothing to import, we're done. + return nil + } + + imports := make([]string, 0, len(i.is)) + for _, i := range i.is { + if i.used { + imports = append(imports, i.String()) + } + } + sort.Strings(imports) + + var b sourceBuffer + b.emit("import (\n") + b.incIndent() + for _, i := range imports { + b.emit("%s\n", i) + } + b.decIndent() + b.emit(")\n\n") + + return b.write(out) +} diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go new file mode 100644 index 000000000..3d12eb93c --- /dev/null +++ b/tools/go_marshal/main.go @@ -0,0 +1,73 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// go_marshal is a code generation utility for automatically generating code to +// marshal go data structures to memory. +// +// This binary is typically run as part of the build process, and is invoked by +// the go_marshal bazel rule defined in defs.bzl. +// +// See README.md. +package main + +import ( + "flag" + "fmt" + "os" + "strings" + + "gvisor.dev/gvisor/tools/go_marshal/gomarshal" +) + +var ( + pkg = flag.String("pkg", "", "output package") + output = flag.String("output", "", "output file") + outputTest = flag.String("output_test", "", "output file for tests") + imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") + declarationPkg = flag.String("declarationPkg", "", "import path of target declaring the types we're generating on") +) + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s <input go src files>\n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + if len(flag.Args()) == 0 { + flag.Usage() + os.Exit(1) + } + + if *pkg == "" { + flag.Usage() + fmt.Fprint(os.Stderr, "Flag -pkg must be provided.\n") + os.Exit(1) + } + + var extraImports []string + if len(*imports) > 0 { + // Note: strings.Split(s, sep) returns s if sep doesn't exist in s. Thus + // we check for an empty imports list to avoid emitting an empty string + // as an import. + extraImports = strings.Split(*imports, ",") + } + g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, *declarationPkg, extraImports) + if err != nil { + panic(err) + } + + if err := g.Run(); err != nil { + panic(err) + } +} diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD new file mode 100644 index 000000000..47dda97a1 --- /dev/null +++ b/tools/go_marshal/marshal/BUILD @@ -0,0 +1,14 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "marshal", + srcs = [ + "marshal.go", + ], + importpath = "gvisor.dev/gvisor/tools/go_marshal/marshal", + visibility = [ + "//:sandbox", + ], +) diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go new file mode 100644 index 000000000..a313a27ed --- /dev/null +++ b/tools/go_marshal/marshal/marshal.go @@ -0,0 +1,60 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package marshal defines the Marshallable interface for +// serialize/deserializing go data structures to/from memory, according to the +// Linux ABI. +// +// Implementations of this interface are typically automatically generated by +// tools/go_marshal. See the go_marshal README for details. +package marshal + +// Marshallable represents a type that can be marshalled to and from memory. +type Marshallable interface { + // SizeBytes is the size of the memory representation of a type in + // marshalled form. + SizeBytes() int + + // MarshalBytes serializes a copy of a type to dst. dst must be at least + // SizeBytes() long. + MarshalBytes(dst []byte) + + // UnmarshalBytes deserializes a type from src. src must be at least + // SizeBytes() long. + UnmarshalBytes(src []byte) + + // Packed returns true if the marshalled size of the type is the same as the + // size it occupies in memory. This happens when the type has no fields + // starting at unaligned addresses (should always be true by default for ABI + // structs, verified by automatically generated tests when using + // go_marshal), and has no fields marked `marshal:"unaligned"`. + Packed() bool + + // MarshalUnsafe serializes a type by bulk copying its in-memory + // representation to the dst buffer. This is only safe to do when the type + // has no implicit padding, see Marshallable.Packed. When Packed would + // return false, MarshalUnsafe should fall back to the safer but slower + // MarshalBytes. + MarshalUnsafe(dst []byte) + + // UnmarshalUnsafe deserializes a type directly to the underlying memory + // allocated for the object by the runtime. + // + // This allows much faster unmarshalling of types which have no implicit + // padding, see Marshallable.Packed. When Packed would return false, + // UnmarshalUnsafe should fall back to the safer but slower unmarshal + // mechanism implemented in UnmarshalBytes (usually by calling + // UnmarshalBytes directly). + UnmarshalUnsafe(src []byte) +} diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD new file mode 100644 index 000000000..fa82f8e9b --- /dev/null +++ b/tools/go_marshal/test/BUILD @@ -0,0 +1,31 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + +package(licenses = ["notice"]) + +load("//tools/go_marshal:defs.bzl", "go_library") + +package_group( + name = "gomarshal_test", + packages = [ + "//tools/go_marshal/test/...", + ], +) + +go_test( + name = "benchmark_test", + srcs = ["benchmark_test.go"], + deps = [ + ":test", + "//pkg/binary", + "//pkg/sentry/usermem", + "//tools/go_marshal/analysis", + ], +) + +go_library( + name = "test", + testonly = 1, + srcs = ["test.go"], + importpath = "gvisor.dev/gvisor/tools/go_marshal/test", + deps = ["//tools/go_marshal/test/external"], +) diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go new file mode 100644 index 000000000..e70db06d8 --- /dev/null +++ b/tools/go_marshal/test/benchmark_test.go @@ -0,0 +1,178 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package benchmark_test + +import ( + "bytes" + encbin "encoding/binary" + "fmt" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/tools/go_marshal/analysis" + test "gvisor.dev/gvisor/tools/go_marshal/test" +) + +// Marshalling using the standard encoding/binary package. +func BenchmarkEncodingBinary(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + size := encbin.Size(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := bytes.NewBuffer(make([]byte, size)) + buf.Reset() + if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil { + b.Error("Write:", err) + } + if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil { + b.Error("Read:", err) + } + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling using the sentry's binary.Marshal. +func BenchmarkBinary(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + size := binary.Size(s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, 0, size) + buf = binary.Marshal(buf, usermem.ByteOrder, &s1) + binary.Unmarshal(buf, usermem.ByteOrder, &s2) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling field-by-field with manually-written code. +func BenchmarkMarshalManual(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, 0, s1.SizeBytes()) + + // Marshal + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev) + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino) + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink) + buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode) + buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID) + buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID) + buf = binary.AppendUint32(buf, usermem.ByteOrder, 0) + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec)) + + // Unmarshal + s2.Dev = usermem.ByteOrder.Uint64(buf[0:8]) + s2.Ino = usermem.ByteOrder.Uint64(buf[8:16]) + s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24]) + s2.Mode = usermem.ByteOrder.Uint32(buf[24:28]) + s2.UID = usermem.ByteOrder.Uint32(buf[28:32]) + s2.GID = usermem.ByteOrder.Uint32(buf[32:36]) + // Padding: buf[36:40] + s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48]) + s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56])) + s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64])) + s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72])) + s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80])) + s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88])) + s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96])) + s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104])) + s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112])) + s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120])) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling with the go_marshal safe API. +func BenchmarkGoMarshalSafe(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, s1.SizeBytes()) + s1.MarshalBytes(buf) + s2.UnmarshalBytes(buf) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling with the go_marshal unsafe API. +func BenchmarkGoMarshalUnsafe(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, s1.SizeBytes()) + s1.MarshalUnsafe(buf) + s2.UnmarshalUnsafe(buf) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD new file mode 100644 index 000000000..8fb43179b --- /dev/null +++ b/tools/go_marshal/test/external/BUILD @@ -0,0 +1,11 @@ +package(licenses = ["notice"]) + +load("//tools/go_marshal:defs.bzl", "go_library") + +go_library( + name = "external", + testonly = 1, + srcs = ["external.go"], + importpath = "gvisor.dev/gvisor/tools/go_marshal/test/external", + visibility = ["//tools/go_marshal/test:gomarshal_test"], +) diff --git a/test/runtimes/runtimes.go b/tools/go_marshal/test/external/external.go index 2568e07fe..4be3722f3 100644 --- a/test/runtimes/runtimes.go +++ b/tools/go_marshal/test/external/external.go @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package runtimes provides language tests for runsc runtimes. -// Each test calls docker commands to start up a container for each supported runtime, -// and tests that its respective language tests are behaving as expected, like -// connecting to a port or looking at the output. The container is killed and deleted -// at the end. -package runtimes +// Package external defines types we can import for testing. +package external + +// External is a public Marshallable type for use in testing. +// +// +marshal +type External struct { + j int64 +} diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go new file mode 100644 index 000000000..8de02d707 --- /dev/null +++ b/tools/go_marshal/test/test.go @@ -0,0 +1,105 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test contains data structures for testing the go_marshal tool. +package test + +import ( + // We're intentionally using a package name alias here even though it's not + // necessary to test the code generator's ability to handle package aliases. + ex "gvisor.dev/gvisor/tools/go_marshal/test/external" +) + +// Type1 is a test data type. +// +// +marshal +type Type1 struct { + a Type2 + x, y int64 // Multiple field names. + b byte `marshal:"unaligned"` // Short field. + c uint64 + _ uint32 // Unnamed scalar field. + _ [6]byte // Unnamed vector field, typical padding. + _ [2]byte + xs [8]int32 + as [10]Type2 `marshal:"unaligned"` // Array of Marshallable objects. + ss Type3 +} + +// Type2 is a test data type. +// +// +marshal +type Type2 struct { + n int64 + c byte + _ [7]byte + m int64 + a int64 +} + +// Type3 is a test data type. +// +// +marshal +type Type3 struct { + s int64 + x ex.External // Type defined in another package. +} + +// Type4 is a test data type. +// +// +marshal +type Type4 struct { + c byte + x int64 `marshal:"unaligned"` + d byte + _ [7]byte +} + +// Type5 is a test data type. +// +// +marshal +type Type5 struct { + n int64 + t Type4 + m int64 +} + +// Timespec represents struct timespec in <time.h>. +// +// +marshal +type Timespec struct { + Sec int64 + Nsec int64 +} + +// Stat represents struct stat. +// +// +marshal +type Stat struct { + Dev uint64 + Ino uint64 + Nlink uint64 + Mode uint32 + UID uint32 + GID uint32 + _ int32 + Rdev uint64 + Size int64 + Blksize int64 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [3]int64 +} diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index aeba197e2..3ce36c1c8 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -35,7 +35,7 @@ go_library( ) """ -load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test") +load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library") def _go_stateify_impl(ctx): """Implementation for the stateify tool.""" @@ -60,28 +60,57 @@ def _go_stateify_impl(ctx): executable = ctx.executable._tool, ) -# Generates save and restore logic from a set of Go files. -# -# Args: -# name: the name of the rule. -# srcs: the input source files. These files should include all structs in the package that need to be saved. -# imports: an optional list of extra non-aliased, Go-style absolute import paths. -# out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library. -# package: the package name for the input sources. go_stateify = rule( implementation = _go_stateify_impl, + doc = "Generates save and restore logic from a set of Go files.", attrs = { - "srcs": attr.label_list(mandatory = True, allow_files = True), - "imports": attr.string_list(mandatory = False), - "package": attr.string(mandatory = True), - "out": attr.output(mandatory = True), - "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_stateify:stateify")), + "srcs": attr.label_list( + doc = """ +The input source files. These files should include all structs in the package +that need to be saved. +""", + mandatory = True, + allow_files = True, + ), + "imports": attr.string_list( + doc = """ +An optional list of extra non-aliased, Go-style absolute import paths required +for statified types. +""", + mandatory = False, + ), + "package": attr.string( + doc = "The package name for the input sources.", + mandatory = True, + ), + "out": attr.output( + doc = """ +The name of the generated file output. This must not conflict with any other +files and must be added to the srcs of the relevant go_library. +""", + mandatory = True, + ), + "_tool": attr.label( + executable = True, + cfg = "host", + default = Label("//tools/go_stateify:stateify"), + ), "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"), }, ) def go_library(name, srcs, deps = [], imports = [], **kwargs): - """wraps the standard go_library and does stateification.""" + """Standard go_library wrapped which generates state source files. + + Args: + name: the name of the go_library rule. + srcs: sources of the go_library. Each will be processed for stateify + annotations. + deps: dependencies for the go_library. + imports: an optional list of extra non-aliased, Go-style absolute import + paths required for stateified types. + **kwargs: passed to go_library. + """ if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs: # Only do stateification for non-state packages without manual autogen. go_stateify( @@ -105,9 +134,3 @@ def go_library(name, srcs, deps = [], imports = [], **kwargs): deps = all_deps, **kwargs ) - -def go_test(**kwargs): - """Wraps the standard go_test.""" - _go_test( - **kwargs - ) diff --git a/tools/make_repository.sh b/tools/make_repository.sh index bf9c50d74..071f72b74 100755 --- a/tools/make_repository.sh +++ b/tools/make_repository.sh @@ -16,13 +16,14 @@ # Parse arguments. We require more than two arguments, which are the private # keyring, the e-mail associated with the signer, and the list of packages. -if [ "$#" -le 2 ]; then - echo "usage: $0 <private-key> <signer-email> <packages...>" +if [ "$#" -le 3 ]; then + echo "usage: $0 <private-key> <signer-email> <component> <packages...>" exit 1 fi declare -r private_key=$(readlink -e "$1") declare -r signer="$2" -shift; shift +declare -r component="$3" +shift; shift; shift # Verbose from this point. set -xeo pipefail @@ -37,13 +38,20 @@ cleanup() { rm -f "${keyring}" } trap cleanup EXIT -gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" - -# Export the public key from the keyring. -gpg --no-default-keyring --keyring "${keyring}" --armor --export "${signer}" > "${tmpdir}"/keyFile +gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" >&2 # Copy the packages, and ensure permissions are correct. -cp -a "$@" "${tmpdir}" && chmod 0644 "${tmpdir}"/* +for pkg in "$@"; do + name=$(basename "${pkg}" .deb) + name=$(basename "${name}" .changes) + arch=${name##*_} + if [[ "${name}" == "${arch}" ]]; then + continue # Not a regular package. + fi + mkdir -p "${tmpdir}"/"${component}"/binary-"${arch}" + cp -a "${pkg}" "${tmpdir}"/"${component}"/binary-"${arch}" +done +find "${tmpdir}" -type f -exec chmod 0644 {} \; # Ensure there are no symlinks hanging around; these may be remnants of the # build process. They may be useful for other things, but we are going to build @@ -51,19 +59,21 @@ cp -a "$@" "${tmpdir}" && chmod 0644 "${tmpdir}"/* find "${tmpdir}" -type l -exec rm -f {} \; # Sign all packages. -for file in "${tmpdir}"/*.deb; do - dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}" +for file in "${tmpdir}"/"${component}"/binary-*/*.deb; do + dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}" >&2 done # Build the package list. -(cd "${tmpdir}" && apt-ftparchive packages . | gzip > Packages.gz) +for dir in "${tmpdir}"/"${component}"/binary-*; do + (cd "${dir}" && apt-ftparchive packages . | gzip > Packages.gz) +done # Build the release list. (cd "${tmpdir}" && apt-ftparchive release . > Release) # Sign the release. -(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign -o InRelease Release) -(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs -o Release.gpg Release) +(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign -o InRelease Release >&2) +(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs -o Release.gpg Release >&2) # Show the results. echo "${tmpdir}" diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh index 64a905fc9..fb09ff331 100755 --- a/tools/workspace_status.sh +++ b/tools/workspace_status.sh @@ -14,4 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -echo VERSION $(git describe --always --tags --abbrev=12 --dirty) +# The STABLE_ prefix will trigger a re-link if it changes. +echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty) |