diff options
348 files changed, 10718 insertions, 4585 deletions
diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml index 3bd883035..f2d584ac0 100644 --- a/.github/workflows/issue_reviver.yml +++ b/.github/workflows/issue_reviver.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/checkout@v2 if: github.repository == 'google/gvisor' - - run: make run TARGETS="//tools/github" ARGS="revive" + - run: make run TARGETS="//tools/github" ARGS="-path=. revive" if: github.repository == 'google/gvisor' env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index a9e0a4717..ce300869c 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -1,6 +1,6 @@ # The stale workflow closes stale issues and pull requests, unless specific # tags have been applied in order to keep them open. -name: "Close stale issues" +name: "Stale issues" "on": schedule: - cron: "0 0 * * *" @@ -14,7 +14,7 @@ jobs: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-label: 'stale' stale-pr-label: 'stale' - exempt-issue-labels: 'exported, type: bug, type: cleanup, type: enhancement, type: process, type: proposal, type: question' + exempt-issue-labels: 'revived, exported, type: bug, type: cleanup, type: enhancement, type: process, type: proposal, type: question' exempt-pr-labels: 'ready to pull, exported' stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.' stale-pr-message: 'This pull request is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.' @@ -4,6 +4,9 @@ [![gVisor chat](https://badges.gitter.im/gvisor/community.png)](https://gitter.im/gvisor/community) [![code search](https://img.shields.io/badge/code-search-blue)](https://cs.opensource.google/gvisor/gvisor) +[![Issue reviver](https://github.com/google/gvisor/actions/workflows/issue_reviver.yml/badge.svg)](https://github.com/google/gvisor/actions/workflows/issue_reviver.yml) +[![Stale issues](https://github.com/google/gvisor/actions/workflows/stale.yml/badge.svg)](https://github.com/google/gvisor/actions/workflows/stale.yml) + ## What is gVisor? **gVisor** is an application kernel, written in Go, that implements a @@ -42,10 +42,10 @@ http_archive( # binaries of symbols, which we don't want. "//tools:rules_go_symbols.patch", ], - sha256 = "7904dbecbaffd068651916dce77ff3437679f9d20e1a7956bff43826e7645fcc", + sha256 = "69de5c704a05ff37862f7e0f5534d4f479418afc21806c887db544a316f3cb6b", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.25.1/rules_go-v0.25.1.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/v0.25.1/rules_go-v0.25.1.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.27.0/rules_go-v0.27.0.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/v0.27.0/rules_go-v0.27.0.tar.gz", ], ) @@ -58,10 +58,10 @@ http_archive( # slightly future proof this mechanism. "//tools:bazel_gazelle_generate.patch", ], - sha256 = "222e49f034ca7a1d1231422cdb67066b885819885c356673cb1f72f748a3c9d4", + sha256 = "62ca106be173579c0a167deb23358fdfe71ffa1e4cfdddf5582af26520f1c66f", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.3/bazel-gazelle-v0.22.3.tar.gz", - "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.3/bazel-gazelle-v0.22.3.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.23.0/bazel-gazelle-v0.23.0.tar.gz", + "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.23.0/bazel-gazelle-v0.23.0.tar.gz", ], ) @@ -69,7 +69,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe go_rules_dependencies() -go_register_toolchains(go_version = "1.15.7") +go_register_toolchains(go_version = "1.16.2") load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") diff --git a/g3doc/user_guide/containerd/configuration.md b/g3doc/user_guide/containerd/configuration.md index 011af3b10..a214fb0c7 100644 --- a/g3doc/user_guide/containerd/configuration.md +++ b/g3doc/user_guide/containerd/configuration.md @@ -14,6 +14,7 @@ cat <<EOF | sudo tee /etc/containerd/runsc.toml option = "value" [runsc_config] flag = "value" +EOF ``` The set of options that can be configured can be found in @@ -32,10 +33,12 @@ configuration. Here is an example: ```shell cat <<EOF | sudo tee /etc/containerd/config.toml -disabled_plugins = ["restart"] -[plugins.cri.containerd.runtimes.runsc] +version = 2 +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runc] + runtime_type = "io.containerd.runc.v2" +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc] runtime_type = "io.containerd.runsc.v1" -[plugins.cri.containerd.runtimes.runsc.options] +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc.options] TypeUrl = "io.containerd.runsc.v1.options" ConfigPath = "/etc/containerd/runsc.toml" EOF @@ -56,14 +59,16 @@ a containerd configuration file that enables both options: ```shell cat <<EOF | sudo tee /etc/containerd/config.toml -disabled_plugins = ["restart"] +version = 2 [debug] level = "debug" -[plugins.linux] +[plugins."io.containerd.runtime.v1.linux"] shim_debug = true -[plugins.cri.containerd.runtimes.runsc] +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runc] + runtime_type = "io.containerd.runc.v2" +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc] runtime_type = "io.containerd.runsc.v1" -[plugins.cri.containerd.runtimes.runsc.options] +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc.option] TypeUrl = "io.containerd.runsc.v1.options" ConfigPath = "/etc/containerd/runsc.toml" EOF @@ -93,4 +98,5 @@ log_level = "debug" [runsc_config] debug = "true" debug-log = "/var/log/runsc/%ID%/gvisor.%COMMAND%.log" +EOF ``` diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md index 02e82eb32..c742f225c 100644 --- a/g3doc/user_guide/containerd/quick_start.md +++ b/g3doc/user_guide/containerd/quick_start.md @@ -21,10 +21,12 @@ Update `/etc/containerd/config.toml`. Make sure `containerd-shim-runsc-v1` is in ```shell cat <<EOF | sudo tee /etc/containerd/config.toml -disabled_plugins = ["restart"] -[plugins.linux] +version = 2 +[plugins."io.containerd.runtime.v1.linux"] shim_debug = true -[plugins.cri.containerd.runtimes.runsc] +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runc] + runtime_type = "io.containerd.runc.v2" +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc] runtime_type = "io.containerd.runsc.v1" EOF ``` diff --git a/g3doc/user_guide/tutorials/BUILD b/g3doc/user_guide/tutorials/BUILD index f405349b3..a862c76f4 100644 --- a/g3doc/user_guide/tutorials/BUILD +++ b/g3doc/user_guide/tutorials/BUILD @@ -37,10 +37,19 @@ doc( ) doc( + name = "knative", + src = "knative.md", + category = "User Guide", + permalink = "/docs/tutorials/knative/", + subcategory = "Tutorials", + weight = "40", +) + +doc( name = "cni", src = "cni.md", category = "User Guide", permalink = "/docs/tutorials/cni/", subcategory = "Tutorials", - weight = "40", + weight = "50", ) diff --git a/g3doc/user_guide/tutorials/knative.md b/g3doc/user_guide/tutorials/knative.md new file mode 100644 index 000000000..3f5207fcc --- /dev/null +++ b/g3doc/user_guide/tutorials/knative.md @@ -0,0 +1,88 @@ +# Knative Services + +[Knative](https://knative.dev/) is a platform for running serverless workloads +on Kubernetes. This guide will show you how to run basic Knative workloads in +gVisor. + +## Prerequisites + +This guide assumes you have have a cluster that is capable of running gVisor +workloads. This could be a +[GKE Sandbox](https://cloud.google.com/kubernetes-engine/sandbox/) enabled +cluster on Google Cloud Platform or one you have set up yourself using +[containerd Quick Start](https://gvisor.dev/docs/user_guide/containerd/quick_start/). + +This guide will also assume you have Knative installed using +[Istio](https://istio.io/) as the network layer. You can follow the +[Knative installation guide](https://knative.dev/docs/install/install-serving-with-yaml/) +to install Knative. + +## Enable the RuntimeClass feature flag + +Knative allows the use of various parameters on Pods via +[feature flags](https://knative.dev/docs/serving/feature-flags/). We will enable +the +[runtimeClassName](https://knative.dev/docs/serving/feature-flags/#kubernetes-runtime-class) +feature flag to enable the use of the Kubernetes +[Runtime Class](https://kubernetes.io/docs/concepts/containers/runtime-class/). + +Edit the feature flags ConfigMap. + +```bash +kubectl edit configmap config-features -n knative-serving +``` + +Add the `kubernetes.podspec-runtimeclassname: enabled` to the `data` field. Once +you are finished the ConfigMap will look something like this (minus all the +system fields). + +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: + name: config-features + namespace: knative-serving + labels: + serving.knative.dev/release: v0.22.0 +data: + kubernetes.podspec-runtimeclassname: enabled +``` + +## Deploy the Service + +After you have set the Runtime Class feature flag you can now create Knative +services that specify a `runtimeClassName` in the spec. + +```bash +cat <<EOF | kubectl apply -f - +apiVersion: serving.knative.dev/v1 +kind: Service +metadata: + name: helloworld-go +spec: + template: + spec: + runtimeClassName: gvisor + containers: + - image: gcr.io/knative-samples/helloworld-go + env: + - name: TARGET + value: "gVisor User" +EOF +``` + +You can see the pods running and their Runtime Class. + +```bash +kubectl get pods -o=custom-columns='NAME:.metadata.name,RUNTIME CLASS:.spec.runtimeClassName,STATUS:.status.phase' +``` + +Output should look something like the following. Note that your service might +scale to zero. If you access it via it's URL you should get a new Pod. + +``` +NAME RUNTIME CLASS STATUS +helloworld-go-00002-deployment-646c87b7f5-5v68s gvisor Running +``` + +Congrats! Your Knative service is now running in gVisor! @@ -89,6 +89,7 @@ analyzers: - pkg/sentry/fsimpl/gofer/filesystem.go # unsupported usage. - pkg/sentry/fsimpl/gofer/gofer.go # unsupported usage. - pkg/sentry/fsimpl/gofer/regular_file.go # unsupported usage. + - pkg/sentry/fsimpl/gofer/revalidate.go # unsupported usage. - pkg/sentry/fsimpl/gofer/special_file.go # unsupported usage. - pkg/sentry/fsimpl/gofer/symlink.go # unsupported usage. - pkg/sentry/fsimpl/overlay/copy_up.go # unsupported usage. diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index ecaeb11ac..29ead20d0 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -15,6 +15,7 @@ go_library( "bpf.go", "capability.go", "clone.go", + "context.go", "dev.go", "elf.go", "epoll.go", @@ -76,8 +77,8 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/abi", - "//pkg/binary", "//pkg/bits", + "//pkg/context", "//pkg/marshal", "//pkg/marshal/primitive", ], @@ -86,9 +87,8 @@ go_library( go_test( name = "linux_test", size = "small", - srcs = ["netfilter_test.go"], - library = ":linux", - deps = [ - "//pkg/binary", + srcs = [ + "netfilter_test.go", ], + library = ":linux", ) diff --git a/pkg/abi/linux/context.go b/pkg/abi/linux/context.go new file mode 100644 index 000000000..d2dbba183 --- /dev/null +++ b/pkg/abi/linux/context.go @@ -0,0 +1,36 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/context" +) + +// contextID is the linux package's type for context.Context.Value keys. +type contextID int + +const ( + // CtxSignalNoInfoFunc is a Context.Value key for a function to send signals. + CtxSignalNoInfoFunc contextID = iota +) + +// SignalNoInfoFuncFromContext returns a callback function that can be used to send a +// signal to the given context. +func SignalNoInfoFuncFromContext(ctx context.Context) func(Signal) error { + if f := ctx.Value(CtxSignalNoInfoFunc); f != nil { + return f.(func(Signal) error) + } + return nil +} diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go index 7c9a02f20..c5713541f 100644 --- a/pkg/abi/linux/elf.go +++ b/pkg/abi/linux/elf.go @@ -106,3 +106,53 @@ const ( // NT_ARM_TLS is for ARM TLS register. NT_ARM_TLS = 0x401 ) + +// ElfHeader64 is the ELF64 file header. +// +// +marshal +type ElfHeader64 struct { + Ident [16]byte // File identification. + Type uint16 // File type. + Machine uint16 // Machine architecture. + Version uint32 // ELF format version. + Entry uint64 // Entry point. + Phoff uint64 // Program header file offset. + Shoff uint64 // Section header file offset. + Flags uint32 // Architecture-specific flags. + Ehsize uint16 // Size of ELF header in bytes. + Phentsize uint16 // Size of program header entry. + Phnum uint16 // Number of program header entries. + Shentsize uint16 // Size of section header entry. + Shnum uint16 // Number of section header entries. + Shstrndx uint16 // Section name strings section. +} + +// ElfSection64 is the ELF64 Section header. +// +// +marshal +type ElfSection64 struct { + Name uint32 // Section name (index into the section header string table). + Type uint32 // Section type. + Flags uint64 // Section flags. + Addr uint64 // Address in memory image. + Off uint64 // Offset in file. + Size uint64 // Size in bytes. + Link uint32 // Index of a related section. + Info uint32 // Depends on section type. + Addralign uint64 // Alignment in bytes. + Entsize uint64 // Size of each entry in section. +} + +// ElfProg64 is the ELF64 Program header. +// +// +marshal +type ElfProg64 struct { + Type uint32 // Entry type. + Flags uint32 // Access permission flags. + Off uint64 // File offset of contents. + Vaddr uint64 // Virtual address in memory image. + Paddr uint64 // Physical address (not used). + Filesz uint64 // Size of contents in file. + Memsz uint64 // Size of contents in memory. + Align uint64 // Alignment in memory and file. +} diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go index 1121a1a92..67706f5aa 100644 --- a/pkg/abi/linux/epoll.go +++ b/pkg/abi/linux/epoll.go @@ -14,10 +14,6 @@ package linux -import ( - "gvisor.dev/gvisor/pkg/binary" -) - // Event masks. const ( EPOLLIN = 0x1 @@ -59,4 +55,4 @@ const ( ) // SizeOfEpollEvent is the size of EpollEvent struct. -var SizeOfEpollEvent = int(binary.Size(EpollEvent{})) +var SizeOfEpollEvent = (*EpollEvent)(nil).SizeBytes() diff --git a/pkg/abi/linux/errors.go b/pkg/abi/linux/errors.go index 93f85a864..b08b2687e 100644 --- a/pkg/abi/linux/errors.go +++ b/pkg/abi/linux/errors.go @@ -15,158 +15,149 @@ package linux // Errno represents a Linux errno value. -type Errno struct { - number int - name string -} - -// Number returns the errno number. -func (e *Errno) Number() int { - return e.number -} - -// String implements fmt.Stringer.String. -func (e *Errno) String() string { - return e.name -} +type Errno int // Errno values from include/uapi/asm-generic/errno-base.h. -var ( - EPERM = &Errno{1, "operation not permitted"} - ENOENT = &Errno{2, "no such file or directory"} - ESRCH = &Errno{3, "no such process"} - EINTR = &Errno{4, "interrupted system call"} - EIO = &Errno{5, "I/O error"} - ENXIO = &Errno{6, "no such device or address"} - E2BIG = &Errno{7, "argument list too long"} - ENOEXEC = &Errno{8, "exec format error"} - EBADF = &Errno{9, "bad file number"} - ECHILD = &Errno{10, "no child processes"} - EAGAIN = &Errno{11, "try again"} - ENOMEM = &Errno{12, "out of memory"} - EACCES = &Errno{13, "permission denied"} - EFAULT = &Errno{14, "bad address"} - ENOTBLK = &Errno{15, "block device required"} - EBUSY = &Errno{16, "device or resource busy"} - EEXIST = &Errno{17, "file exists"} - EXDEV = &Errno{18, "cross-device link"} - ENODEV = &Errno{19, "no such device"} - ENOTDIR = &Errno{20, "not a directory"} - EISDIR = &Errno{21, "is a directory"} - EINVAL = &Errno{22, "invalid argument"} - ENFILE = &Errno{23, "file table overflow"} - EMFILE = &Errno{24, "too many open files"} - ENOTTY = &Errno{25, "not a typewriter"} - ETXTBSY = &Errno{26, "text file busy"} - EFBIG = &Errno{27, "file too large"} - ENOSPC = &Errno{28, "no space left on device"} - ESPIPE = &Errno{29, "illegal seek"} - EROFS = &Errno{30, "read-only file system"} - EMLINK = &Errno{31, "too many links"} - EPIPE = &Errno{32, "broken pipe"} - EDOM = &Errno{33, "math argument out of domain of func"} - ERANGE = &Errno{34, "math result not representable"} +const ( + NOERRNO = iota + EPERM + ENOENT + ESRCH + EINTR + EIO + ENXIO + E2BIG + ENOEXEC + EBADF + ECHILD // 10 + EAGAIN + ENOMEM + EACCES + EFAULT + ENOTBLK + EBUSY + EEXIST + EXDEV + ENODEV + ENOTDIR // 20 + EISDIR + EINVAL + ENFILE + EMFILE + ENOTTY + ETXTBSY + EFBIG + ENOSPC + ESPIPE + EROFS // 30 + EMLINK + EPIPE + EDOM + ERANGE + // Errno values from include/uapi/asm-generic/errno.h. + EDEADLK + ENAMETOOLONG + ENOLCK + ENOSYS + ENOTEMPTY + ELOOP //40 + _ // Skip for EWOULDBLOCK = EAGAIN + ENOMSG //42 + EIDRM + ECHRNG + EL2NSYNC + EL3HLT + EL3RST + ELNRNG + EUNATCH + ENOCSI + EL2HLT // 50 + EBADE + EBADR + EXFULL + ENOANO + EBADRQC + EBADSLT + _ // Skip for EDEADLOCK = EDEADLK + EBFONT + ENOSTR // 60 + ENODATA + ETIME + ENOSR + ENONET + ENOPKG + EREMOTE + ENOLINK + EADV + ESRMNT + ECOMM // 70 + EPROTO + EMULTIHOP + EDOTDOT + EBADMSG + EOVERFLOW + ENOTUNIQ + EBADFD + EREMCHG + ELIBACC + ELIBBAD // 80 + ELIBSCN + ELIBMAX + ELIBEXEC + EILSEQ + ERESTART + ESTRPIPE + EUSERS + ENOTSOCK + EDESTADDRREQ + EMSGSIZE // 90 + EPROTOTYPE + ENOPROTOOPT + EPROTONOSUPPORT + ESOCKTNOSUPPORT + EOPNOTSUPP + EPFNOSUPPORT + EAFNOSUPPORT + EADDRINUSE + EADDRNOTAVAIL + ENETDOWN // 100 + ENETUNREACH + ENETRESET + ECONNABORTED + ECONNRESET + ENOBUFS + EISCONN + ENOTCONN + ESHUTDOWN + ETOOMANYREFS + ETIMEDOUT // 110 + ECONNREFUSED + EHOSTDOWN + EHOSTUNREACH + EALREADY + EINPROGRESS + ESTALE + EUCLEAN + ENOTNAM + ENAVAIL + EISNAM // 120 + EREMOTEIO + EDQUOT + ENOMEDIUM + EMEDIUMTYPE + ECANCELED + ENOKEY + EKEYEXPIRED + EKEYREVOKED + EKEYREJECTED + EOWNERDEAD // 130 + ENOTRECOVERABLE + ERFKILL + EHWPOISON ) -// Errno values from include/uapi/asm-generic/errno.h. -var ( - EDEADLK = &Errno{35, "resource deadlock would occur"} - ENAMETOOLONG = &Errno{36, "file name too long"} - ENOLCK = &Errno{37, "no record locks available"} - ENOSYS = &Errno{38, "invalid system call number"} - ENOTEMPTY = &Errno{39, "directory not empty"} - ELOOP = &Errno{40, "too many symbolic links encountered"} - EWOULDBLOCK = &Errno{EAGAIN.number, "operation would block"} - ENOMSG = &Errno{42, "no message of desired type"} - EIDRM = &Errno{43, "identifier removed"} - ECHRNG = &Errno{44, "channel number out of range"} - EL2NSYNC = &Errno{45, "level 2 not synchronized"} - EL3HLT = &Errno{46, "level 3 halted"} - EL3RST = &Errno{47, "level 3 reset"} - ELNRNG = &Errno{48, "link number out of range"} - EUNATCH = &Errno{49, "protocol driver not attached"} - ENOCSI = &Errno{50, "no CSI structure available"} - EL2HLT = &Errno{51, "level 2 halted"} - EBADE = &Errno{52, "invalid exchange"} - EBADR = &Errno{53, "invalid request descriptor"} - EXFULL = &Errno{54, "exchange full"} - ENOANO = &Errno{55, "no anode"} - EBADRQC = &Errno{56, "invalid request code"} - EBADSLT = &Errno{57, "invalid slot"} - EDEADLOCK = EDEADLK - EBFONT = &Errno{59, "bad font file format"} - ENOSTR = &Errno{60, "device not a stream"} - ENODATA = &Errno{61, "no data available"} - ETIME = &Errno{62, "timer expired"} - ENOSR = &Errno{63, "out of streams resources"} - ENONET = &Errno{64, "machine is not on the network"} - ENOPKG = &Errno{65, "package not installed"} - EREMOTE = &Errno{66, "object is remote"} - ENOLINK = &Errno{67, "link has been severed"} - EADV = &Errno{68, "advertise error"} - ESRMNT = &Errno{69, "srmount error"} - ECOMM = &Errno{70, "communication error on send"} - EPROTO = &Errno{71, "protocol error"} - EMULTIHOP = &Errno{72, "multihop attempted"} - EDOTDOT = &Errno{73, "RFS specific error"} - EBADMSG = &Errno{74, "not a data message"} - EOVERFLOW = &Errno{75, "value too large for defined data type"} - ENOTUNIQ = &Errno{76, "name not unique on network"} - EBADFD = &Errno{77, "file descriptor in bad state"} - EREMCHG = &Errno{78, "remote address changed"} - ELIBACC = &Errno{79, "can not access a needed shared library"} - ELIBBAD = &Errno{80, "accessing a corrupted shared library"} - ELIBSCN = &Errno{81, ".lib section in a.out corrupted"} - ELIBMAX = &Errno{82, "attempting to link in too many shared libraries"} - ELIBEXEC = &Errno{83, "cannot exec a shared library directly"} - EILSEQ = &Errno{84, "illegal byte sequence"} - ERESTART = &Errno{85, "interrupted system call should be restarted"} - ESTRPIPE = &Errno{86, "streams pipe error"} - EUSERS = &Errno{87, "too many users"} - ENOTSOCK = &Errno{88, "socket operation on non-socket"} - EDESTADDRREQ = &Errno{89, "destination address required"} - EMSGSIZE = &Errno{90, "message too long"} - EPROTOTYPE = &Errno{91, "protocol wrong type for socket"} - ENOPROTOOPT = &Errno{92, "protocol not available"} - EPROTONOSUPPORT = &Errno{93, "protocol not supported"} - ESOCKTNOSUPPORT = &Errno{94, "socket type not supported"} - EOPNOTSUPP = &Errno{95, "operation not supported on transport endpoint"} - EPFNOSUPPORT = &Errno{96, "protocol family not supported"} - EAFNOSUPPORT = &Errno{97, "address family not supported by protocol"} - EADDRINUSE = &Errno{98, "address already in use"} - EADDRNOTAVAIL = &Errno{99, "cannot assign requested address"} - ENETDOWN = &Errno{100, "network is down"} - ENETUNREACH = &Errno{101, "network is unreachable"} - ENETRESET = &Errno{102, "network dropped connection because of reset"} - ECONNABORTED = &Errno{103, "software caused connection abort"} - ECONNRESET = &Errno{104, "connection reset by peer"} - ENOBUFS = &Errno{105, "no buffer space available"} - EISCONN = &Errno{106, "transport endpoint is already connected"} - ENOTCONN = &Errno{107, "transport endpoint is not connected"} - ESHUTDOWN = &Errno{108, "cannot send after transport endpoint shutdown"} - ETOOMANYREFS = &Errno{109, "too many references: cannot splice"} - ETIMEDOUT = &Errno{110, "connection timed out"} - ECONNREFUSED = &Errno{111, "connection refused"} - EHOSTDOWN = &Errno{112, "host is down"} - EHOSTUNREACH = &Errno{113, "no route to host"} - EALREADY = &Errno{114, "operation already in progress"} - EINPROGRESS = &Errno{115, "operation now in progress"} - ESTALE = &Errno{116, "stale file handle"} - EUCLEAN = &Errno{117, "structure needs cleaning"} - ENOTNAM = &Errno{118, "not a XENIX named type file"} - ENAVAIL = &Errno{119, "no XENIX semaphores available"} - EISNAM = &Errno{120, "is a named type file"} - EREMOTEIO = &Errno{121, "remote I/O error"} - EDQUOT = &Errno{122, "quota exceeded"} - ENOMEDIUM = &Errno{123, "no medium found"} - EMEDIUMTYPE = &Errno{124, "wrong medium type"} - ECANCELED = &Errno{125, "operation Canceled"} - ENOKEY = &Errno{126, "required key not available"} - EKEYEXPIRED = &Errno{127, "key has expired"} - EKEYREVOKED = &Errno{128, "key has been revoked"} - EKEYREJECTED = &Errno{129, "key was rejected by service"} - EOWNERDEAD = &Errno{130, "owner died"} - ENOTRECOVERABLE = &Errno{131, "state not recoverable"} - ERFKILL = &Errno{132, "operation not possible due to RF-kill"} - EHWPOISON = &Errno{133, "memory page has hardware error"} +// errnos derived from other errnos +const ( + EWOULDBLOCK = EAGAIN + EDEADLOCK = EDEADLK ) diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index e11ca2d62..1e23850a9 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -19,7 +19,6 @@ import ( "strings" "gvisor.dev/gvisor/pkg/abi" - "gvisor.dev/gvisor/pkg/binary" ) // Constants for open(2). @@ -201,7 +200,7 @@ const ( ) // SizeOfStat is the size of a Stat struct. -var SizeOfStat = binary.Size(Stat{}) +var SizeOfStat = (*Stat)(nil).SizeBytes() // Flags for statx. const ( @@ -268,7 +267,7 @@ type Statx struct { } // SizeOfStatx is the size of a Statx struct. -var SizeOfStatx = binary.Size(Statx{}) +var SizeOfStatx = (*Statx)(nil).SizeBytes() // FileMode represents a mode_t. type FileMode uint16 diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go index 0faf015c7..51a39704b 100644 --- a/pkg/abi/linux/netdevice.go +++ b/pkg/abi/linux/netdevice.go @@ -14,8 +14,6 @@ package linux -import "gvisor.dev/gvisor/pkg/binary" - const ( // IFNAMSIZ is the size of the name field for IFReq. IFNAMSIZ = 16 @@ -66,7 +64,7 @@ func (ifr *IFReq) SetName(name string) { } // SizeOfIFReq is the binary size of an IFReq struct (40 bytes). -var SizeOfIFReq = binary.Size(IFReq{}) +var SizeOfIFReq = (*IFReq)(nil).SizeBytes() // IFMap contains interface hardware parameters. type IFMap struct { diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 35c632168..3fd05483a 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -245,6 +245,8 @@ const SizeOfXTCounters = 16 // include/uapi/linux/netfilter/x_tables.h. That struct contains a union // exposing different data to the user and kernel, but this struct holds only // the user data. +// +// +marshal type XTEntryMatch struct { MatchSize uint16 Name ExtensionName @@ -284,6 +286,8 @@ const SizeOfXTGetRevision = 30 // include/uapi/linux/netfilter/x_tables.h. That struct contains a union // exposing different data to the user and kernel, but this struct holds only // the user data. +// +// +marshal type XTEntryTarget struct { TargetSize uint16 Name ExtensionName @@ -306,6 +310,8 @@ type KernelXTEntryTarget struct { // XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE, // RETURN, or jump. It corresponds to struct xt_standard_target in // include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTStandardTarget struct { Target XTEntryTarget // A positive verdict indicates a jump, and is the offset from the @@ -322,6 +328,8 @@ const SizeOfXTStandardTarget = 40 // beginning of user-defined chains by putting the name of the chain in // ErrorName. It corresponds to struct xt_error_target in // include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTErrorTarget struct { Target XTEntryTarget Name ErrorName @@ -349,6 +357,8 @@ const ( // NfNATIPV4Range corresponds to struct nf_nat_ipv4_range // in include/uapi/linux/netfilter/nf_nat.h. The fields are in // network byte order. +// +// +marshal type NfNATIPV4Range struct { Flags uint32 MinIP [4]byte @@ -359,6 +369,8 @@ type NfNATIPV4Range struct { // NfNATIPV4MultiRangeCompat corresponds to struct // nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h. +// +// +marshal type NfNATIPV4MultiRangeCompat struct { RangeSize uint32 RangeIPV4 NfNATIPV4Range @@ -366,6 +378,8 @@ type NfNATIPV4MultiRangeCompat struct { // XTRedirectTarget triggers a redirect when reached. // Adding 4 bytes of padding to make the struct 8 byte aligned. +// +// +marshal type XTRedirectTarget struct { Target XTEntryTarget NfRange NfNATIPV4MultiRangeCompat @@ -377,6 +391,8 @@ const SizeOfXTRedirectTarget = 56 // XTSNATTarget triggers Source NAT when reached. // Adding 4 bytes of padding to make the struct 8 byte aligned. +// +// +marshal type XTSNATTarget struct { Target XTEntryTarget NfRange NfNATIPV4MultiRangeCompat @@ -463,6 +479,8 @@ var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTReplace struct { Name TableName ValidHooks uint32 @@ -502,6 +520,8 @@ func (tn TableName) String() string { // ErrorName holds the name of a netfilter error. These can also hold // user-defined chains. +// +// +marshal type ErrorName [XT_FUNCTION_MAXNAMELEN]byte // String implements fmt.Stringer. @@ -520,6 +540,8 @@ func goString(cstring []byte) string { // XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp // in include/uapi/linux/netfilter/xt_tcpudp.h. +// +// +marshal type XTTCP struct { // SourcePortStart specifies the inclusive start of the range of source // ports to which the matcher applies. @@ -573,6 +595,8 @@ const ( // XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp // in include/uapi/linux/netfilter/xt_tcpudp.h. +// +// +marshal type XTUDP struct { // SourcePortStart is the inclusive start of the range of source ports // to which the matcher applies. @@ -613,6 +637,8 @@ const ( // IPTOwnerInfo holds data for matching packets with owner. It corresponds // to struct ipt_owner_info in libxt_owner.c of iptables binary. +// +// +marshal type IPTOwnerInfo struct { // UID is user id which created the packet. UID uint32 @@ -634,7 +660,7 @@ type IPTOwnerInfo struct { Match uint8 // Invert flips the meaning of Match field. - Invert uint8 + Invert uint8 `marshal:"unaligned"` } // SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo. diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index f7c70b430..b088b207c 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -264,6 +264,8 @@ const ( // NFNATRange corresponds to struct nf_nat_range in // include/uapi/linux/netfilter/nf_nat.h. +// +// +marshal type NFNATRange struct { Flags uint32 MinAddr Inet6Addr diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go index bf73271c6..600820a0b 100644 --- a/pkg/abi/linux/netfilter_test.go +++ b/pkg/abi/linux/netfilter_test.go @@ -15,9 +15,8 @@ package linux import ( + "encoding/binary" "testing" - - "gvisor.dev/gvisor/pkg/binary" ) func TestSizes(t *testing.T) { @@ -42,7 +41,7 @@ func TestSizes(t *testing.T) { } for _, tc := range testCases { - if calculated := binary.Size(tc.typ); calculated != tc.defined { + if calculated := uintptr(binary.Size(tc.typ)); calculated != tc.defined { t.Errorf("%T has a defined size of %d and calculated size of %d", tc.typ, tc.defined, calculated) } } diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index b41f94a69..232fee67e 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -53,6 +53,8 @@ type SockAddrNetlink struct { const SockAddrNetlinkSize = 12 // NetlinkMessageHeader is struct nlmsghdr, from uapi/linux/netlink.h. +// +// +marshal type NetlinkMessageHeader struct { Length uint32 Type uint16 @@ -99,6 +101,8 @@ const NLMSG_ALIGNTO = 4 // NetlinkAttrHeader is the header of a netlink attribute, followed by payload. // // This is struct nlattr, from uapi/linux/netlink.h. +// +// +marshal type NetlinkAttrHeader struct { Length uint16 Type uint16 @@ -126,6 +130,8 @@ const ( ) // NetlinkErrorMessage is struct nlmsgerr, from uapi/linux/netlink.h. +// +// +marshal type NetlinkErrorMessage struct { Error int32 Header NetlinkMessageHeader diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index ceda0a8d3..581a11b24 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -85,6 +85,8 @@ const ( ) // InterfaceInfoMessage is struct ifinfomsg, from uapi/linux/rtnetlink.h. +// +// +marshal type InterfaceInfoMessage struct { Family uint8 _ uint8 @@ -164,6 +166,8 @@ const ( ) // InterfaceAddrMessage is struct ifaddrmsg, from uapi/linux/if_addr.h. +// +// +marshal type InterfaceAddrMessage struct { Family uint8 PrefixLen uint8 @@ -193,6 +197,8 @@ const ( ) // RouteMessage is struct rtmsg, from uapi/linux/rtnetlink.h. +// +// +marshal type RouteMessage struct { Family uint8 DstLen uint8 diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 185eee0bb..95871b8a5 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -15,7 +15,6 @@ package linux import ( - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/marshal" ) @@ -251,18 +250,24 @@ type SockAddrInet struct { } // Inet6MulticastRequest is struct ipv6_mreq, from uapi/linux/in6.h. +// +// +marshal type Inet6MulticastRequest struct { MulticastAddr Inet6Addr InterfaceIndex int32 } // InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h. +// +// +marshal type InetMulticastRequest struct { MulticastAddr InetAddr InterfaceAddr InetAddr } // InetMulticastRequestWithNIC is struct ip_mreqn, from uapi/linux/in.h. +// +// +marshal type InetMulticastRequestWithNIC struct { InetMulticastRequest InterfaceIndex int32 @@ -491,7 +496,7 @@ type TCPInfo struct { } // SizeOfTCPInfo is the binary size of a TCPInfo struct. -var SizeOfTCPInfo = int(binary.Size(TCPInfo{})) +var SizeOfTCPInfo = (*TCPInfo)(nil).SizeBytes() // Control message types, from linux/socket.h. const ( @@ -502,6 +507,8 @@ const ( // A ControlMessageHeader is the header for a socket control message. // // ControlMessageHeader represents struct cmsghdr from linux/socket.h. +// +// +marshal type ControlMessageHeader struct { Length uint64 Level int32 @@ -510,7 +517,7 @@ type ControlMessageHeader struct { // SizeOfControlMessageHeader is the binary size of a ControlMessageHeader // struct. -var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) +var SizeOfControlMessageHeader = (*ControlMessageHeader)(nil).SizeBytes() // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // @@ -527,6 +534,7 @@ type ControlMessageCredentials struct { // // ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h. // +// +marshal // +stateify savable type ControlMessageIPPacketInfo struct { NIC int32 @@ -536,7 +544,7 @@ type ControlMessageIPPacketInfo struct { // SizeOfControlMessageCredentials is the binary size of a // ControlMessageCredentials struct. -var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{})) +var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes() // A ControlMessageRights is an SCM_RIGHTS socket control message. type ControlMessageRights []int32 diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 1a30f6967..11072d4de 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -5,6 +5,8 @@ package(licenses = ["notice"]) go_library( name = "atomicbitops", srcs = [ + "aligned_32bit_unsafe.go", + "aligned_64bit.go", "atomicbitops.go", "atomicbitops_amd64.s", "atomicbitops_arm64.s", diff --git a/pkg/atomicbitops/aligned_32bit_unsafe.go b/pkg/atomicbitops/aligned_32bit_unsafe.go new file mode 100644 index 000000000..df706b453 --- /dev/null +++ b/pkg/atomicbitops/aligned_32bit_unsafe.go @@ -0,0 +1,96 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm mips 386 + +package atomicbitops + +import ( + "sync/atomic" + "unsafe" +) + +// AlignedAtomicInt64 is an atomic int64 that is guaranteed to be 64-bit +// aligned, even on 32-bit systems. +// +// Per https://golang.org/pkg/sync/atomic/#pkg-note-BUG: +// +// "On ARM, 386, and 32-bit MIPS, it is the caller's responsibility to arrange +// for 64-bit alignment of 64-bit words accessed atomically. The first word in +// a variable or in an allocated struct, array, or slice can be relied upon to +// be 64-bit aligned." +// +// +stateify savable +type AlignedAtomicInt64 struct { + value [15]byte +} + +func (aa *AlignedAtomicInt64) ptr() *int64 { + // In the 15-byte aa.value, there are guaranteed to be 8 contiguous + // bytes with 64-bit alignment. We find an address in this range by + // adding 7, then clear the 3 least significant bits to get its start. + return (*int64)(unsafe.Pointer((uintptr(unsafe.Pointer(&aa.value[0])) + 7) &^ 7)) +} + +// Load is analagous to atomic.LoadInt64. +func (aa *AlignedAtomicInt64) Load() int64 { + return atomic.LoadInt64(aa.ptr()) +} + +// Store is analagous to atomic.StoreInt64. +func (aa *AlignedAtomicInt64) Store(v int64) { + atomic.StoreInt64(aa.ptr(), v) +} + +// Add is analagous to atomic.AddInt64. +func (aa *AlignedAtomicInt64) Add(v int64) int64 { + return atomic.AddInt64(aa.ptr(), v) +} + +// AlignedAtomicUint64 is an atomic uint64 that is guaranteed to be 64-bit +// aligned, even on 32-bit systems. +// +// Per https://golang.org/pkg/sync/atomic/#pkg-note-BUG: +// +// "On ARM, 386, and 32-bit MIPS, it is the caller's responsibility to arrange +// for 64-bit alignment of 64-bit words accessed atomically. The first word in +// a variable or in an allocated struct, array, or slice can be relied upon to +// be 64-bit aligned." +// +// +stateify savable +type AlignedAtomicUint64 struct { + value [15]byte +} + +func (aa *AlignedAtomicUint64) ptr() *uint64 { + // In the 15-byte aa.value, there are guaranteed to be 8 contiguous + // bytes with 64-bit alignment. We find an address in this range by + // adding 7, then clear the 3 least significant bits to get its start. + return (*uint64)(unsafe.Pointer((uintptr(unsafe.Pointer(&aa.value[0])) + 7) &^ 7)) +} + +// Load is analagous to atomic.LoadUint64. +func (aa *AlignedAtomicUint64) Load() uint64 { + return atomic.LoadUint64(aa.ptr()) +} + +// Store is analagous to atomic.StoreUint64. +func (aa *AlignedAtomicUint64) Store(v uint64) { + atomic.StoreUint64(aa.ptr(), v) +} + +// Add is analagous to atomic.AddUint64. +func (aa *AlignedAtomicUint64) Add(v uint64) uint64 { + return atomic.AddUint64(aa.ptr(), v) +} diff --git a/pkg/atomicbitops/aligned_64bit.go b/pkg/atomicbitops/aligned_64bit.go new file mode 100644 index 000000000..1544c7814 --- /dev/null +++ b/pkg/atomicbitops/aligned_64bit.go @@ -0,0 +1,71 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !arm,!mips,!386 + +package atomicbitops + +import "sync/atomic" + +// AlignedAtomicInt64 is an atomic int64 that is guaranteed to be 64-bit +// aligned, even on 32-bit systems. On most architectures, it's just a regular +// int64. +// +// See aligned_unsafe.go in this directory for justification. +// +// +stateify savable +type AlignedAtomicInt64 struct { + value int64 +} + +// Load is analagous to atomic.LoadInt64. +func (aa *AlignedAtomicInt64) Load() int64 { + return atomic.LoadInt64(&aa.value) +} + +// Store is analagous to atomic.StoreInt64. +func (aa *AlignedAtomicInt64) Store(v int64) { + atomic.StoreInt64(&aa.value, v) +} + +// Add is analagous to atomic.AddInt64. +func (aa *AlignedAtomicInt64) Add(v int64) int64 { + return atomic.AddInt64(&aa.value, v) +} + +// AlignedAtomicUint64 is an atomic uint64 that is guaranteed to be 64-bit +// aligned, even on 32-bit systems. On most architectures, it's just a regular +// uint64. +// +// See aligned_unsafe.go in this directory for justification. +// +// +stateify savable +type AlignedAtomicUint64 struct { + value uint64 +} + +// Load is analagous to atomic.LoadUint64. +func (aa *AlignedAtomicUint64) Load() uint64 { + return atomic.LoadUint64(&aa.value) +} + +// Store is analagous to atomic.StoreUint64. +func (aa *AlignedAtomicUint64) Store(v uint64) { + atomic.StoreUint64(&aa.value, v) +} + +// Add is analagous to atomic.AddUint64. +func (aa *AlignedAtomicUint64) Add(v uint64) uint64 { + return atomic.AddUint64(&aa.value, v) +} diff --git a/pkg/bits/bits.go b/pkg/bits/bits.go index a26433ad6..d16448c3d 100644 --- a/pkg/bits/bits.go +++ b/pkg/bits/bits.go @@ -14,3 +14,13 @@ // Package bits includes all bit related types and operations. package bits + +// AlignUp rounds a length up to an alignment. align must be a power of 2. +func AlignUp(length int, align uint) int { + return (length + int(align) - 1) & ^(int(align) - 1) +} + +// AlignDown rounds a length down to an alignment. align must be a power of 2. +func AlignDown(length int, align uint) int { + return length & ^(int(align) - 1) +} diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index 2a6977f85..c17390522 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -26,6 +26,7 @@ go_test( library = ":bpf", deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/hostarch", + "//pkg/marshal", ], ) diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go index c85d786b9..f64a2dc50 100644 --- a/pkg/bpf/interpreter_test.go +++ b/pkg/bpf/interpreter_test.go @@ -15,10 +15,12 @@ package bpf import ( + "encoding/binary" "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" ) func TestCompilationErrors(t *testing.T) { @@ -750,29 +752,29 @@ func TestSimpleFilter(t *testing.T) { // desc is the test's description. desc string - // seccompData is the input data. - seccompData + // SeccompData is the input data. + data linux.SeccompData // expectedRet is the expected return value of the BPF program. expectedRet uint32 }{ { desc: "Invalid arch is rejected", - seccompData: seccompData{nr: 1 /* x86 exit */, arch: 0x40000003 /* AUDIT_ARCH_I386 */}, + data: linux.SeccompData{Nr: 1 /* x86 exit */, Arch: 0x40000003 /* AUDIT_ARCH_I386 */}, expectedRet: 0, }, { desc: "Disallowed syscall is rejected", - seccompData: seccompData{nr: 105 /* __NR_setuid */, arch: 0xc000003e}, + data: linux.SeccompData{Nr: 105 /* __NR_setuid */, Arch: 0xc000003e}, expectedRet: 0, }, { desc: "Allowed syscall is indeed allowed", - seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e}, + data: linux.SeccompData{Nr: 231 /* __NR_exit_group */, Arch: 0xc000003e}, expectedRet: 0x7fff0000, }, } { - ret, err := Exec(p, test.seccompData.asInput()) + ret, err := Exec(p, dataAsInput(&test.data)) if err != nil { t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err) continue @@ -792,6 +794,6 @@ type seccompData struct { } // asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() Input { - return InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +func dataAsInput(data *linux.SeccompData) Input { + return InputBytes{marshal.Marshal(data), hostarch.ByteOrder} } diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD index 1186f788e..19cd28a32 100644 --- a/pkg/buffer/BUILD +++ b/pkg/buffer/BUILD @@ -21,7 +21,6 @@ go_library( "buffer.go", "buffer_list.go", "pool.go", - "safemem.go", "view.go", "view_unsafe.go", ], @@ -29,8 +28,6 @@ go_library( deps = [ "//pkg/context", "//pkg/log", - "//pkg/safemem", - "//pkg/usermem", ], ) @@ -38,13 +35,12 @@ go_test( name = "buffer_test", size = "small", srcs = [ + "buffer_test.go", "pool_test.go", - "safemem_test.go", "view_test.go", ], library = ":buffer", deps = [ - "//pkg/safemem", "//pkg/state", ], ) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index 311808ae9..5b77a6a3f 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -33,12 +33,40 @@ func (b *buffer) init(size int) { b.data = make([]byte, size) } +// initWithData initializes b with data, taking ownership. +func (b *buffer) initWithData(data []byte) { + b.data = data + b.read = 0 + b.write = len(data) +} + // Reset resets read and write locations, effectively emptying the buffer. func (b *buffer) Reset() { b.read = 0 b.write = 0 } +// Remove removes r from the unread portion. It returns false if r does not +// fully reside in b. +func (b *buffer) Remove(r Range) bool { + sz := b.ReadSize() + switch { + case r.Len() != r.Intersect(Range{end: sz}).Len(): + return false + case r.Len() == 0: + // Noop + case r.begin == 0: + b.read += r.end + case r.end == sz: + b.write -= r.Len() + default: + // Remove from the middle of b.data. + copy(b.data[b.read+r.begin:], b.data[b.read+r.end:b.write]) + b.write -= r.Len() + } + return true +} + // Full indicates the buffer is full. // // This indicates there is no capacity left to write. diff --git a/pkg/buffer/buffer_test.go b/pkg/buffer/buffer_test.go new file mode 100644 index 000000000..32db841e4 --- /dev/null +++ b/pkg/buffer/buffer_test.go @@ -0,0 +1,111 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "bytes" + "testing" +) + +func TestBufferRemove(t *testing.T) { + sample := []byte("01234567") + + // Success cases + for _, tc := range []struct { + desc string + data []byte + rng Range + want []byte + }{ + { + desc: "empty slice", + }, + { + desc: "empty range", + data: sample, + want: sample, + }, + { + desc: "empty range with positive begin", + data: sample, + rng: Range{begin: 1, end: 1}, + want: sample, + }, + { + desc: "range at beginning", + data: sample, + rng: Range{begin: 0, end: 1}, + want: sample[1:], + }, + { + desc: "range in middle", + data: sample, + rng: Range{begin: 2, end: 4}, + want: []byte("014567"), + }, + { + desc: "range at end", + data: sample, + rng: Range{begin: 7, end: 8}, + want: sample[:7], + }, + { + desc: "range all", + data: sample, + rng: Range{begin: 0, end: 8}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var buf buffer + buf.initWithData(tc.data) + if ok := buf.Remove(tc.rng); !ok { + t.Errorf("buf.Remove(%#v) = false, want true", tc.rng) + } else if got := buf.ReadSlice(); !bytes.Equal(got, tc.want) { + t.Errorf("buf.ReadSlice() = %q, want %q", got, tc.want) + } + }) + } + + // Failure cases + for _, tc := range []struct { + desc string + data []byte + rng Range + }{ + { + desc: "begin out-of-range", + data: sample, + rng: Range{begin: -1, end: 4}, + }, + { + desc: "end out-of-range", + data: sample, + rng: Range{begin: 4, end: 9}, + }, + { + desc: "both out-of-range", + data: sample, + rng: Range{begin: -100, end: 100}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var buf buffer + buf.initWithData(tc.data) + if ok := buf.Remove(tc.rng); ok { + t.Errorf("buf.Remove(%#v) = true, want false", tc.rng) + } + }) + } +} diff --git a/pkg/buffer/pool.go b/pkg/buffer/pool.go index 7ad6132ab..2ec41dd4f 100644 --- a/pkg/buffer/pool.go +++ b/pkg/buffer/pool.go @@ -42,6 +42,13 @@ type pool struct { // get gets a new buffer from p. func (p *pool) get() *buffer { + buf := p.getNoInit() + buf.init(p.bufferSize) + return buf +} + +// get gets a new buffer from p without initializing it. +func (p *pool) getNoInit() *buffer { if p.avail == nil { p.avail = p.embeddedStorage[:] } @@ -52,7 +59,6 @@ func (p *pool) get() *buffer { p.bufferSize = defaultBufferSize } buf := &p.avail[0] - buf.init(p.bufferSize) p.avail = p.avail[1:] return buf } @@ -62,6 +68,7 @@ func (p *pool) put(buf *buffer) { // Remove reference to the underlying storage, allowing it to be garbage // collected. buf.data = nil + buf.Reset() } // setBufferSize sets the size of underlying storage buffer for future diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go deleted file mode 100644 index 8b42575b4..000000000 --- a/pkg/buffer/safemem.go +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "gvisor.dev/gvisor/pkg/safemem" -) - -// WriteBlock returns this buffer as a write Block. -func (b *buffer) WriteBlock() safemem.Block { - return safemem.BlockFromSafeSlice(b.WriteSlice()) -} - -// ReadBlock returns this buffer as a read Block. -func (b *buffer) ReadBlock() safemem.Block { - return safemem.BlockFromSafeSlice(b.ReadSlice()) -} - -// WriteFromSafememReader writes up to count bytes from r to v and advances the -// write index by the number of bytes written. It calls r.ReadToBlocks() at -// most once. -func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) { - if count == 0 { - return 0, nil - } - - var ( - dst safemem.BlockSeq - blocks []safemem.Block - ) - - // Need at least one buffer. - firstBuf := v.data.Back() - if firstBuf == nil { - firstBuf = v.pool.get() - v.data.PushBack(firstBuf) - } - - // Does the last block have sufficient capacity alone? - if l := uint64(firstBuf.WriteSize()); l >= count { - dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count)) - } else { - // Append blocks until sufficient. - count -= l - blocks = append(blocks, firstBuf.WriteBlock()) - for count > 0 { - emptyBuf := v.pool.get() - v.data.PushBack(emptyBuf) - block := emptyBuf.WriteBlock().TakeFirst64(count) - count -= uint64(block.Len()) - blocks = append(blocks, block) - } - dst = safemem.BlockSeqFromSlice(blocks) - } - - // Perform I/O. - n, err := r.ReadToBlocks(dst) - v.size += int64(n) - - // Update all indices. - for left := n; left > 0; firstBuf = firstBuf.Next() { - if l := firstBuf.WriteSize(); left >= uint64(l) { - firstBuf.WriteMove(l) // Whole block. - left -= uint64(l) - } else { - firstBuf.WriteMove(int(left)) // Partial block. - left = 0 - } - } - - return n, err -} - -// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the -// write index by the number of bytes written. -func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes()) -} - -// ReadToSafememWriter reads up to count bytes from v to w. It does not advance -// the read index. It calls w.WriteFromBlocks() at most once. -func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) { - if count == 0 { - return 0, nil - } - - var ( - src safemem.BlockSeq - blocks []safemem.Block - ) - - firstBuf := v.data.Front() - if firstBuf == nil { - return 0, nil // No EOF. - } - - // Is all the data in a single block? - if l := uint64(firstBuf.ReadSize()); l >= count { - src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count)) - } else { - // Build a list of all the buffers. - count -= l - blocks = append(blocks, firstBuf.ReadBlock()) - for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() { - block := buf.ReadBlock().TakeFirst64(count) - count -= uint64(block.Len()) - blocks = append(blocks, block) - } - src = safemem.BlockSeqFromSlice(blocks) - } - - // Perform I/O. As documented, we don't advance the read index. - return w.WriteFromBlocks(src) -} - -// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the -// read index by the number of bytes read, such that it's only safe to call if -// the caller guarantees that ReadToBlocks will only be called once. -func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes()) -} diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go deleted file mode 100644 index 721cc5934..000000000 --- a/pkg/buffer/safemem_test.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package buffer - -import ( - "bytes" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/safemem" -) - -func TestSafemem(t *testing.T) { - const bufferSize = defaultBufferSize - - testCases := []struct { - name string - input string - output string - readLen int - op func(*View) - }{ - // Basic coverage. - { - name: "short", - input: "010", - output: "010", - }, - { - name: "long", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "0" + strings.Repeat("1", bufferSize) + "0", - }, - { - name: "short-read", - input: "0", - readLen: 100, // > size. - output: "0", - }, - { - name: "zero-read", - input: "0", - output: "", - }, - { - name: "read-empty", - input: "", - readLen: 1, // > size. - output: "", - }, - - // Ensure offsets work. - { - name: "offsets-short", - input: "012", - output: "2", - op: func(v *View) { - v.TrimFront(2) - }, - }, - { - name: "offsets-long0", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: strings.Repeat("1", bufferSize) + "0", - op: func(v *View) { - v.TrimFront(1) - }, - }, - { - name: "offsets-long1", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: strings.Repeat("1", bufferSize-1) + "0", - op: func(v *View) { - v.TrimFront(2) - }, - }, - { - name: "offsets-long2", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "10", - op: func(v *View) { - v.TrimFront(bufferSize) - }, - }, - - // Ensure truncation works. - { - name: "truncate-short", - input: "012", - output: "01", - op: func(v *View) { - v.Truncate(2) - }, - }, - { - name: "truncate-long0", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "0" + strings.Repeat("1", bufferSize), - op: func(v *View) { - v.Truncate(bufferSize + 1) - }, - }, - { - name: "truncate-long1", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "0" + strings.Repeat("1", bufferSize-1), - op: func(v *View) { - v.Truncate(bufferSize) - }, - }, - { - name: "truncate-long2", - input: "0" + strings.Repeat("1", bufferSize) + "0", - output: "01", - op: func(v *View) { - v.Truncate(2) - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Construct the new view. - var view View - bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice([]byte(tc.input))) - n, err := view.WriteFromBlocks(bs) - if err != nil { - t.Errorf("expected err nil, got %v", err) - } - if n != uint64(len(tc.input)) { - t.Errorf("expected %d bytes, got %d", len(tc.input), n) - } - - // Run the operation. - if tc.op != nil { - tc.op(&view) - } - - // Read and validate. - readLen := tc.readLen - if readLen == 0 { - readLen = len(tc.output) // Default. - } - out := make([]byte, readLen) - bs = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(out)) - n, err = view.ReadToBlocks(bs) - if err != nil { - t.Errorf("expected nil, got %v", err) - } - if n != uint64(len(tc.output)) { - t.Errorf("expected %d bytes, got %d", len(tc.output), n) - } - - // Ensure the contents are correct. - if !bytes.Equal(out[:n], []byte(tc.output[:n])) { - t.Errorf("contents are wrong: expected %q, got %q", tc.output, string(out)) - } - }) - } -} diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go index 00652d675..7bcfcd543 100644 --- a/pkg/buffer/view.go +++ b/pkg/buffer/view.go @@ -19,6 +19,9 @@ import ( "io" ) +// Buffer is an alias to View. +type Buffer = View + // View is a non-linear buffer. // // All methods are thread compatible. @@ -39,6 +42,51 @@ func (v *View) TrimFront(count int64) { } } +// Remove deletes data at specified location in v. It returns false if specified +// range does not fully reside in v. +func (v *View) Remove(offset, length int) bool { + if offset < 0 || length < 0 { + return false + } + tgt := Range{begin: offset, end: offset + length} + if tgt.Len() != tgt.Intersect(Range{end: int(v.size)}).Len() { + return false + } + + // Scan through each buffer and remove intersections. + var curr Range + for buf := v.data.Front(); buf != nil; { + origLen := buf.ReadSize() + curr.end = curr.begin + origLen + + if x := curr.Intersect(tgt); x.Len() > 0 { + if !buf.Remove(x.Offset(-curr.begin)) { + panic("buf.Remove() failed") + } + if buf.ReadSize() == 0 { + // buf fully removed, removing it from the list. + oldBuf := buf + buf = buf.Next() + v.data.Remove(oldBuf) + v.pool.put(oldBuf) + } else { + // Only partial data intersects, moving on to next one. + buf = buf.Next() + } + v.size -= int64(x.Len()) + } else { + // This buffer is not in range, moving on to next one. + buf = buf.Next() + } + + curr.begin += origLen + if curr.begin >= tgt.end { + break + } + } + return true +} + // ReadAt implements io.ReaderAt.ReadAt. func (v *View) ReadAt(p []byte, offset int64) (int, error) { var ( @@ -81,7 +129,6 @@ func (v *View) advanceRead(count int64) { oldBuf := buf buf = buf.Next() // Iterate. v.data.Remove(oldBuf) - oldBuf.Reset() v.pool.put(oldBuf) // Update counts. @@ -118,7 +165,6 @@ func (v *View) Truncate(length int64) { // Drop the buffer completely; see above. v.data.Remove(buf) - buf.Reset() v.pool.put(buf) v.size -= sz } @@ -224,6 +270,78 @@ func (v *View) Append(data []byte) { } } +// AppendOwned takes ownership of data and appends it to v. +func (v *View) AppendOwned(data []byte) { + if len(data) > 0 { + buf := v.pool.getNoInit() + buf.initWithData(data) + v.data.PushBack(buf) + v.size += int64(len(data)) + } +} + +// PullUp makes the specified range contiguous and returns the backing memory. +func (v *View) PullUp(offset, length int) ([]byte, bool) { + if length == 0 { + return nil, true + } + tgt := Range{begin: offset, end: offset + length} + if tgt.Intersect(Range{end: int(v.size)}).Len() != length { + return nil, false + } + + curr := Range{} + buf := v.data.Front() + for ; buf != nil; buf = buf.Next() { + origLen := buf.ReadSize() + curr.end = curr.begin + origLen + + if x := curr.Intersect(tgt); x.Len() == tgt.Len() { + // buf covers the whole requested target range. + sub := x.Offset(-curr.begin) + return buf.ReadSlice()[sub.begin:sub.end], true + } else if x.Len() > 0 { + // buf is pointing at the starting buffer we want to merge. + break + } + + curr.begin += origLen + } + + // Calculate the total merged length. + totLen := 0 + for n := buf; n != nil; n = n.Next() { + totLen += n.ReadSize() + if curr.begin+totLen >= tgt.end { + break + } + } + + // Merge the buffers. + data := make([]byte, totLen) + off := 0 + for n := buf; n != nil && off < totLen; { + copy(data[off:], n.ReadSlice()) + off += n.ReadSize() + + // Remove buffers except for the first one, which will be reused. + if n == buf { + n = n.Next() + } else { + old := n + n = n.Next() + v.data.Remove(old) + v.pool.put(old) + } + } + + // Update the first buffer with merged data. + buf.initWithData(data) + + r := tgt.Offset(-curr.begin) + return buf.data[r.begin:r.end], true +} + // Flatten returns a flattened copy of this data. // // This method should not be used in any performance-sensitive paths. It may @@ -267,6 +385,27 @@ func (v *View) Apply(fn func([]byte)) { } } +// SubApply applies fn to a given range of data in v. Any part of the range +// outside of v is ignored. +func (v *View) SubApply(offset, length int, fn func([]byte)) { + for buf := v.data.Front(); length > 0 && buf != nil; buf = buf.Next() { + d := buf.ReadSlice() + if offset >= len(d) { + offset -= len(d) + continue + } + if offset > 0 { + d = d[offset:] + offset = 0 + } + if length < len(d) { + d = d[:length] + } + fn(d) + length -= len(d) + } +} + // Merge merges the provided View with this one. // // The other view will be appended to v, and other will be empty after this @@ -389,3 +528,39 @@ func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) { } return done, err } + +// A Range specifies a range of buffer. +type Range struct { + begin int + end int +} + +// Intersect returns the intersection of x and y. +func (x Range) Intersect(y Range) Range { + if x.begin < y.begin { + x.begin = y.begin + } + if x.end > y.end { + x.end = y.end + } + if x.begin >= x.end { + return Range{} + } + return x +} + +// Offset returns x offset by off. +func (x Range) Offset(off int) Range { + x.begin += off + x.end += off + return x +} + +// Len returns the length of x. +func (x Range) Len() int { + l := x.end - x.begin + if l < 0 { + l = 0 + } + return l +} diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go index 839af0223..796efa240 100644 --- a/pkg/buffer/view_test.go +++ b/pkg/buffer/view_test.go @@ -17,7 +17,9 @@ package buffer import ( "bytes" "context" + "fmt" "io" + "reflect" "strings" "testing" @@ -237,6 +239,18 @@ func TestView(t *testing.T) { }, }, + // AppendOwned. + { + name: "append-owned", + input: "hello", + output: "hello world", + op: func(t *testing.T, v *View) { + b := []byte("Xworld") + v.AppendOwned(b) + b[0] = ' ' + }, + }, + // Truncate. { name: "truncate", @@ -495,6 +509,267 @@ func TestView(t *testing.T) { } } +func TestViewPullUp(t *testing.T) { + for _, tc := range []struct { + desc string + inputs []string + offset int + length int + output string + failed bool + // lengths is the lengths of each buffer node after the pull up. + lengths []int + }{ + { + desc: "whole empty view", + }, + { + desc: "zero pull", + inputs: []string{"hello", " world"}, + lengths: []int{5, 6}, + }, + { + desc: "whole view", + inputs: []string{"hello", " world"}, + offset: 0, + length: 11, + output: "hello world", + lengths: []int{11}, + }, + { + desc: "middle to end aligned", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 4, + length: 10, + output: "456789abcd", + lengths: []int{4, 10}, + }, + { + desc: "middle to end unaligned", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 6, + length: 8, + output: "6789abcd", + lengths: []int{4, 10}, + }, + { + desc: "middle aligned", + inputs: []string{"0123", "45678", "9abcd", "efgh"}, + offset: 6, + length: 5, + output: "6789a", + lengths: []int{4, 10, 4}, + }, + + // Failed cases. + { + desc: "empty view - length too long", + offset: 0, + length: 1, + failed: true, + }, + { + desc: "empty view - offset too large", + offset: 1, + length: 1, + failed: true, + }, + { + desc: "length too long", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 4, + length: 100, + failed: true, + lengths: []int{4, 5, 5}, + }, + { + desc: "offset too large", + inputs: []string{"0123", "45678", "9abcd"}, + offset: 100, + length: 1, + failed: true, + lengths: []int{4, 5, 5}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var v View + for _, s := range tc.inputs { + v.AppendOwned([]byte(s)) + } + + got, gotOk := v.PullUp(tc.offset, tc.length) + want, wantOk := []byte(tc.output), !tc.failed + if gotOk != wantOk || !bytes.Equal(got, want) { + t.Errorf("v.PullUp(%d, %d) = %q, %t; %q, %t", tc.offset, tc.length, got, gotOk, want, wantOk) + } + + var gotLengths []int + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + gotLengths = append(gotLengths, buf.ReadSize()) + } + if !reflect.DeepEqual(gotLengths, tc.lengths) { + t.Errorf("lengths = %v; want %v", gotLengths, tc.lengths) + } + }) + } +} + +func TestViewRemove(t *testing.T) { + // Success cases + for _, tc := range []struct { + desc string + // before is the contents for each buffer node initially. + before []string + // after is the contents for each buffer node after removal. + after []string + offset int + length int + }{ + { + desc: "empty view", + }, + { + desc: "nothing removed", + before: []string{"hello", " world"}, + after: []string{"hello", " world"}, + }, + { + desc: "whole view", + before: []string{"hello", " world"}, + offset: 0, + length: 11, + }, + { + desc: "beginning to middle aligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"9abcd"}, + offset: 0, + length: 9, + }, + { + desc: "beginning to middle unaligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"678", "9abcd"}, + offset: 0, + length: 6, + }, + { + desc: "middle to end aligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123"}, + offset: 4, + length: 10, + }, + { + desc: "middle to end unaligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123", "45"}, + offset: 6, + length: 8, + }, + { + desc: "middle aligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123", "9abcd"}, + offset: 4, + length: 5, + }, + { + desc: "middle unaligned", + before: []string{"0123", "45678", "9abcd"}, + after: []string{"0123", "4578", "9abcd"}, + offset: 6, + length: 1, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var v View + for _, s := range tc.before { + v.AppendOwned([]byte(s)) + } + + if ok := v.Remove(tc.offset, tc.length); !ok { + t.Errorf("v.Remove(%d, %d) = false, want true", tc.offset, tc.length) + } + + var got []string + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + got = append(got, string(buf.ReadSlice())) + } + if !reflect.DeepEqual(got, tc.after) { + t.Errorf("after = %v; want %v", got, tc.after) + } + }) + } + + // Failure cases + for _, tc := range []struct { + desc string + // before is the contents for each buffer node initially. + before []string + offset int + length int + }{ + { + desc: "offset out-of-range", + before: []string{"hello", " world"}, + offset: -1, + length: 3, + }, + { + desc: "length too long", + before: []string{"hello", " world"}, + offset: 0, + length: 12, + }, + { + desc: "length too long with positive offset", + before: []string{"hello", " world"}, + offset: 3, + length: 9, + }, + { + desc: "length negative", + before: []string{"hello", " world"}, + offset: 0, + length: -1, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + var v View + for _, s := range tc.before { + v.AppendOwned([]byte(s)) + } + if ok := v.Remove(tc.offset, tc.length); ok { + t.Errorf("v.Remove(%d, %d) = true, want false", tc.offset, tc.length) + } + }) + } +} + +func TestViewSubApply(t *testing.T) { + var v View + v.AppendOwned([]byte("0123")) + v.AppendOwned([]byte("45678")) + v.AppendOwned([]byte("9abcd")) + + data := []byte("0123456789abcd") + + for i := 0; i <= len(data); i++ { + for j := i; j <= len(data); j++ { + t.Run(fmt.Sprintf("SubApply(%d,%d)", i, j), func(t *testing.T) { + var got []byte + v.SubApply(i, j-i, func(b []byte) { + got = append(got, b...) + }) + if want := data[i:j]; !bytes.Equal(got, want) { + t.Errorf("got = %q; want %q", got, want) + } + }) + } + } +} + func doSaveAndLoad(t *testing.T, toSave, toLoad *View) { t.Helper() var buf bytes.Buffer @@ -542,3 +817,84 @@ func TestSaveRestoreView(t *testing.T) { t.Errorf("v.Flatten() = %x, want %x", got, data) } } + +func TestRangeIntersect(t *testing.T) { + for _, tc := range []struct { + desc string + x, y, want Range + }{ + { + desc: "empty intersects empty", + }, + { + desc: "empty intersection", + x: Range{end: 10}, + y: Range{begin: 10, end: 20}, + }, + { + desc: "some intersection", + x: Range{begin: 5, end: 20}, + y: Range{end: 10}, + want: Range{begin: 5, end: 10}, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + if got := tc.x.Intersect(tc.y); got != tc.want { + t.Errorf("(%#v).Intersect(%#v) = %#v; want %#v", tc.x, tc.y, got, tc.want) + } + if got := tc.y.Intersect(tc.x); got != tc.want { + t.Errorf("(%#v).Intersect(%#v) = %#v; want %#v", tc.y, tc.x, got, tc.want) + } + }) + } +} + +func TestRangeOffset(t *testing.T) { + for _, tc := range []struct { + input Range + offset int + output Range + }{ + { + input: Range{}, + offset: 0, + output: Range{}, + }, + { + input: Range{}, + offset: -1, + output: Range{begin: -1, end: -1}, + }, + { + input: Range{begin: 10, end: 20}, + offset: -1, + output: Range{begin: 9, end: 19}, + }, + { + input: Range{begin: 10, end: 20}, + offset: 2, + output: Range{begin: 12, end: 22}, + }, + } { + if got := tc.input.Offset(tc.offset); got != tc.output { + t.Errorf("(%#v).Offset(%d) = %#v, want %#v", tc.input, tc.offset, got, tc.output) + } + } +} + +func TestRangeLen(t *testing.T) { + for _, tc := range []struct { + r Range + want int + }{ + {r: Range{}, want: 0}, + {r: Range{begin: 1, end: 1}, want: 0}, + {r: Range{begin: -1, end: -1}, want: 0}, + {r: Range{end: 10}, want: 10}, + {r: Range{begin: 5, end: 10}, want: 5}, + } { + if got := tc.r.Len(); got != tc.want { + t.Errorf("(%#v).Len() = %d, want %d", tc.r, got, tc.want) + } + } +} diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index 1f75319a7..70018cf18 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -6,10 +6,7 @@ go_library( name = "compressio", srcs = ["compressio.go"], visibility = ["//:sandbox"], - deps = [ - "//pkg/binary", - "//pkg/sync", - ], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index b094c5662..615d7f134 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -48,12 +48,12 @@ import ( "compress/flate" "crypto/hmac" "crypto/sha256" + "encoding/binary" "errors" "hash" "io" "runtime" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sync" ) @@ -130,6 +130,10 @@ type worker struct { hashPool *hashPool input chan *chunk output chan result + + // scratch is a temporary buffer used for marshalling. This is declared + // unfront here to avoid reallocation. + scratch [4]byte } // work is the main work routine; see worker. @@ -167,7 +171,8 @@ func (w *worker) work(compress bool, level int) { // Write the hash, if enabled. if h != nil { - binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) + h.Write(w.scratch[:4]) c.h = h h = nil } @@ -175,7 +180,8 @@ func (w *worker) work(compress bool, level int) { // Check the hash of the compressed contents. if h != nil { h.Write(c.compressed.Bytes()) - binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) + h.Write(w.scratch[:4]) io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum))) sum := h.Sum(nil) @@ -352,6 +358,10 @@ type Reader struct { // in is the source. in io.Reader + + // scratch is a temporary buffer used for marshalling. This is declared + // unfront here to avoid reallocation. + scratch [4]byte } var _ io.Reader = (*Reader)(nil) @@ -368,14 +378,15 @@ func NewReader(in io.Reader, key []byte) (*Reader, error) { // Use double buffering for read. r.init(key, 2*runtime.GOMAXPROCS(0), false, 0) - var err error - if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil { + if _, err := io.ReadFull(in, r.scratch[:4]); err != nil { return nil, err } + r.chunkSize = binary.BigEndian.Uint32(r.scratch[:4]) if r.hashPool != nil { h := r.hashPool.getHash() - binary.WriteUint32(h, binary.BigEndian, r.chunkSize) + binary.BigEndian.PutUint32(r.scratch[:], r.chunkSize) + h.Write(r.scratch[:4]) r.lastSum = h.Sum(nil) r.hashPool.putHash(h) sum := make([]byte, len(r.lastSum)) @@ -467,8 +478,7 @@ func (r *Reader) Read(p []byte) (int, error) { // reader. The length is used to limit the reader. // // See writer.flush. - l, err := binary.ReadUint32(r.in, binary.BigEndian) - if err != nil { + if _, err := io.ReadFull(r.in, r.scratch[:4]); err != nil { // This is generally okay as long as there // are still buffers outstanding. We actually // just wait for completion of those buffers here @@ -488,6 +498,7 @@ func (r *Reader) Read(p []byte) (int, error) { return done, err } } + l := binary.BigEndian.Uint32(r.scratch[:4]) // Read this chunk and schedule decompression. compressed := bufPool.Get().(*bytes.Buffer) @@ -573,6 +584,10 @@ type Writer struct { // closed indicates whether the file has been closed. closed bool + + // scratch is a temporary buffer used for marshalling. This is declared + // unfront here to avoid reallocation. + scratch [4]byte } var _ io.Writer = (*Writer)(nil) @@ -594,13 +609,15 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, } w.init(key, 1+runtime.GOMAXPROCS(0), true, level) - if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil { + binary.BigEndian.PutUint32(w.scratch[:], chunkSize) + if _, err := w.out.Write(w.scratch[:4]); err != nil { return nil, err } if w.hashPool != nil { h := w.hashPool.getHash() - binary.WriteUint32(h, binary.BigEndian, chunkSize) + binary.BigEndian.PutUint32(w.scratch[:], chunkSize) + h.Write(w.scratch[:4]) w.lastSum = h.Sum(nil) w.hashPool.putHash(h) if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { @@ -616,7 +633,9 @@ func (w *Writer) flush(c *chunk) error { // Prefix each chunk with a length; this allows the reader to safely // limit reads while buffering. l := uint32(c.compressed.Len()) - if err := binary.WriteUint32(w.out, binary.BigEndian, l); err != nil { + + binary.BigEndian.PutUint32(w.scratch[:], l) + if _, err := w.out.Write(w.scratch[:4]); err != nil { return err } diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go index 74a55a123..514592b08 100644 --- a/pkg/crypto/crypto_stdlib.go +++ b/pkg/crypto/crypto_stdlib.go @@ -22,11 +22,11 @@ import ( // EcdsaVerify verifies the signature in r, s of hash using ECDSA and the // public key, pub. Its return value records whether the signature is valid. -func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { - return ecdsa.Verify(pub, hash, r, s) +func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) { + return ecdsa.Verify(pub, hash, r, s), nil } // SumSha384 returns the SHA384 checksum of the data. -func SumSha384(data []byte) (sum384 [sha512.Size384]byte) { - return sha512.Sum384(data) +func SumSha384(data []byte) ([sha512.Size384]byte, error) { + return sha512.Sum384(data), nil } diff --git a/pkg/linuxerr/BUILD b/pkg/linuxerr/BUILD new file mode 100644 index 000000000..c5abbd34f --- /dev/null +++ b/pkg/linuxerr/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "linuxerr", + srcs = ["linuxerr.go"], + visibility = ["//visibility:public"], + deps = ["//pkg/abi/linux"], +) + +go_test( + name = "linuxerr_test", + srcs = ["linuxerr_test.go"], + deps = [ + ":linuxerr", + "//pkg/syserror", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/linuxerr/linuxerr.go b/pkg/linuxerr/linuxerr.go new file mode 100644 index 000000000..f45caaadf --- /dev/null +++ b/pkg/linuxerr/linuxerr.go @@ -0,0 +1,184 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package linuxerr contains syscall error codes exported as an error interface +// pointers. This allows for fast comparison and return operations comperable +// to unix.Errno constants. +package linuxerr + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// Error represents a syscall errno with a descriptive message. +type Error struct { + errno linux.Errno + message string +} + +func new(err linux.Errno, message string) *Error { + return &Error{ + errno: err, + message: message, + } +} + +// Error implements error.Error. +func (e *Error) Error() string { return e.message } + +// Errno returns the underlying linux.Errno value. +func (e *Error) Errno() linux.Errno { return e.errno } + +// The following varables have the same meaning as their errno equivalent. + +// Errno values from include/uapi/asm-generic/errno-base.h. +var ( + EPERM = new(linux.EPERM, "operation not permitted") + ENOENT = new(linux.ENOENT, "no such file or directory") + ESRCH = new(linux.ESRCH, "no such process") + EINTR = new(linux.EINTR, "interrupted system call") + EIO = new(linux.EIO, "I/O error") + ENXIO = new(linux.ENXIO, "no such device or address") + E2BIG = new(linux.E2BIG, "argument list too long") + ENOEXEC = new(linux.ENOEXEC, "exec format error") + EBADF = new(linux.EBADF, "bad file number") + ECHILD = new(linux.ECHILD, "no child processes") + EAGAIN = new(linux.EAGAIN, "try again") + ENOMEM = new(linux.ENOMEM, "out of memory") + EACCES = new(linux.EACCES, "permission denied") + EFAULT = new(linux.EFAULT, "bad address") + ENOTBLK = new(linux.ENOTBLK, "block device required") + EBUSY = new(linux.EBUSY, "device or resource busy") + EEXIST = new(linux.EEXIST, "file exists") + EXDEV = new(linux.EXDEV, "cross-device link") + ENODEV = new(linux.ENODEV, "no such device") + ENOTDIR = new(linux.ENOTDIR, "not a directory") + EISDIR = new(linux.EISDIR, "is a directory") + EINVAL = new(linux.EINVAL, "invalid argument") + ENFILE = new(linux.ENFILE, "file table overflow") + EMFILE = new(linux.EMFILE, "too many open files") + ENOTTY = new(linux.ENOTTY, "not a typewriter") + ETXTBSY = new(linux.ETXTBSY, "text file busy") + EFBIG = new(linux.EFBIG, "file too large") + ENOSPC = new(linux.ENOSPC, "no space left on device") + ESPIPE = new(linux.ESPIPE, "illegal seek") + EROFS = new(linux.EROFS, "read-only file system") + EMLINK = new(linux.EMLINK, "too many links") + EPIPE = new(linux.EPIPE, "broken pipe") + EDOM = new(linux.EDOM, "math argument out of domain of func") + ERANGE = new(linux.ERANGE, "math result not representable") +) + +// Errno values from include/uapi/asm-generic/errno.h. +var ( + EDEADLK = new(linux.EDEADLK, "resource deadlock would occur") + ENAMETOOLONG = new(linux.ENAMETOOLONG, "file name too long") + ENOLCK = new(linux.ENOLCK, "no record locks available") + ENOSYS = new(linux.ENOSYS, "invalid system call number") + ENOTEMPTY = new(linux.ENOTEMPTY, "directory not empty") + ELOOP = new(linux.ELOOP, "too many symbolic links encountered") + EWOULDBLOCK = new(linux.EWOULDBLOCK, "operation would block") + ENOMSG = new(linux.ENOMSG, "no message of desired type") + EIDRM = new(linux.EIDRM, "identifier removed") + ECHRNG = new(linux.ECHRNG, "channel number out of range") + EL2NSYNC = new(linux.EL2NSYNC, "level 2 not synchronized") + EL3HLT = new(linux.EL3HLT, "level 3 halted") + EL3RST = new(linux.EL3RST, "level 3 reset") + ELNRNG = new(linux.ELNRNG, "link number out of range") + EUNATCH = new(linux.EUNATCH, "protocol driver not attached") + ENOCSI = new(linux.ENOCSI, "no CSI structure available") + EL2HLT = new(linux.EL2HLT, "level 2 halted") + EBADE = new(linux.EBADE, "invalid exchange") + EBADR = new(linux.EBADR, "invalid request descriptor") + EXFULL = new(linux.EXFULL, "exchange full") + ENOANO = new(linux.ENOANO, "no anode") + EBADRQC = new(linux.EBADRQC, "invalid request code") + EBADSLT = new(linux.EBADSLT, "invalid slot") + EDEADLOCK = new(linux.EDEADLOCK, EDEADLK.message) + EBFONT = new(linux.EBFONT, "bad font file format") + ENOSTR = new(linux.ENOSTR, "device not a stream") + ENODATA = new(linux.ENODATA, "no data available") + ETIME = new(linux.ETIME, "timer expired") + ENOSR = new(linux.ENOSR, "out of streams resources") + ENONET = new(linux.ENOENT, "machine is not on the network") + ENOPKG = new(linux.ENOPKG, "package not installed") + EREMOTE = new(linux.EREMOTE, "object is remote") + ENOLINK = new(linux.ENOLINK, "link has been severed") + EADV = new(linux.EADV, "advertise error") + ESRMNT = new(linux.ESRMNT, "srmount error") + ECOMM = new(linux.ECOMM, "communication error on send") + EPROTO = new(linux.EPROTO, "protocol error") + EMULTIHOP = new(linux.EMULTIHOP, "multihop attempted") + EDOTDOT = new(linux.EDOTDOT, "RFS specific error") + EBADMSG = new(linux.EBADMSG, "not a data message") + EOVERFLOW = new(linux.EOVERFLOW, "value too large for defined data type") + ENOTUNIQ = new(linux.ENOTUNIQ, "name not unique on network") + EBADFD = new(linux.EBADFD, "file descriptor in bad state") + EREMCHG = new(linux.EREMCHG, "remote address changed") + ELIBACC = new(linux.ELIBACC, "can not access a needed shared library") + ELIBBAD = new(linux.ELIBBAD, "accessing a corrupted shared library") + ELIBSCN = new(linux.ELIBSCN, ".lib section in a.out corrupted") + ELIBMAX = new(linux.ELIBMAX, "attempting to link in too many shared libraries") + ELIBEXEC = new(linux.ELIBEXEC, "cannot exec a shared library directly") + EILSEQ = new(linux.EILSEQ, "illegal byte sequence") + ERESTART = new(linux.ERESTART, "interrupted system call should be restarted") + ESTRPIPE = new(linux.ESTRPIPE, "streams pipe error") + EUSERS = new(linux.EUSERS, "too many users") + ENOTSOCK = new(linux.ENOTSOCK, "socket operation on non-socket") + EDESTADDRREQ = new(linux.EDESTADDRREQ, "destination address required") + EMSGSIZE = new(linux.EMSGSIZE, "message too long") + EPROTOTYPE = new(linux.EPROTOTYPE, "protocol wrong type for socket") + ENOPROTOOPT = new(linux.ENOPROTOOPT, "protocol not available") + EPROTONOSUPPORT = new(linux.EPROTONOSUPPORT, "protocol not supported") + ESOCKTNOSUPPORT = new(linux.ESOCKTNOSUPPORT, "socket type not supported") + EOPNOTSUPP = new(linux.EOPNOTSUPP, "operation not supported on transport endpoint") + EPFNOSUPPORT = new(linux.EPFNOSUPPORT, "protocol family not supported") + EAFNOSUPPORT = new(linux.EAFNOSUPPORT, "address family not supported by protocol") + EADDRINUSE = new(linux.EADDRINUSE, "address already in use") + EADDRNOTAVAIL = new(linux.EADDRNOTAVAIL, "cannot assign requested address") + ENETDOWN = new(linux.ENETDOWN, "network is down") + ENETUNREACH = new(linux.ENETUNREACH, "network is unreachable") + ENETRESET = new(linux.ENETRESET, "network dropped connection because of reset") + ECONNABORTED = new(linux.ECONNABORTED, "software caused connection abort") + ECONNRESET = new(linux.ECONNRESET, "connection reset by peer") + ENOBUFS = new(linux.ENOBUFS, "no buffer space available") + EISCONN = new(linux.EISCONN, "transport endpoint is already connected") + ENOTCONN = new(linux.ENOTCONN, "transport endpoint is not connected") + ESHUTDOWN = new(linux.ESHUTDOWN, "cannot send after transport endpoint shutdown") + ETOOMANYREFS = new(linux.ETOOMANYREFS, "too many references: cannot splice") + ETIMEDOUT = new(linux.ETIMEDOUT, "connection timed out") + ECONNREFUSED = new(linux.ECONNREFUSED, "connection refused") + EHOSTDOWN = new(linux.EHOSTDOWN, "host is down") + EHOSTUNREACH = new(linux.EHOSTUNREACH, "no route to host") + EALREADY = new(linux.EALREADY, "operation already in progress") + EINPROGRESS = new(linux.EINPROGRESS, "operation now in progress") + ESTALE = new(linux.ESTALE, "stale file handle") + EUCLEAN = new(linux.EUCLEAN, "structure needs cleaning") + ENOTNAM = new(linux.ENOTNAM, "not a XENIX named type file") + ENAVAIL = new(linux.ENAVAIL, "no XENIX semaphores available") + EISNAM = new(linux.EISNAM, "is a named type file") + EREMOTEIO = new(linux.EREMOTEIO, "remote I/O error") + EDQUOT = new(linux.EDQUOT, "quota exceeded") + ENOMEDIUM = new(linux.ENOMEDIUM, "no medium found") + EMEDIUMTYPE = new(linux.EMEDIUMTYPE, "wrong medium type") + ECANCELED = new(linux.ECANCELED, "operation Canceled") + ENOKEY = new(linux.ENOKEY, "required key not available") + EKEYEXPIRED = new(linux.EKEYEXPIRED, "key has expired") + EKEYREVOKED = new(linux.EKEYREVOKED, "key has been revoked") + EKEYREJECTED = new(linux.EKEYREJECTED, "key was rejected by service") + EOWNERDEAD = new(linux.EOWNERDEAD, "owner died") + ENOTRECOVERABLE = new(linux.ENOTRECOVERABLE, "state not recoverable") + ERFKILL = new(linux.ERFKILL, "operation not possible due to RF-kill") + EHWPOISON = new(linux.EHWPOISON, "memory page has hardware error") +) diff --git a/pkg/syserror/syserror_test.go b/pkg/linuxerr/linuxerr_test.go index c141e5f6e..d34937e93 100644 --- a/pkg/syserror/syserror_test.go +++ b/pkg/linuxerr/linuxerr_test.go @@ -19,6 +19,7 @@ import ( "testing" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/linuxerr" "gvisor.dev/gvisor/pkg/syserror" ) @@ -30,7 +31,13 @@ func BenchmarkAssignErrno(b *testing.B) { } } -func BenchmarkAssignError(b *testing.B) { +func BenchmarkLinuxerrAssignError(b *testing.B) { + for i := b.N; i > 0; i-- { + globalError = linuxerr.EINVAL + } +} + +func BenchmarkAssignSyserrorError(b *testing.B) { for i := b.N; i > 0; i-- { globalError = syserror.EINVAL } @@ -46,7 +53,17 @@ func BenchmarkCompareErrno(b *testing.B) { } } -func BenchmarkCompareError(b *testing.B) { +func BenchmarkCompareLinuxerrError(b *testing.B) { + globalError = linuxerr.E2BIG + j := 0 + for i := b.N; i > 0; i-- { + if globalError == linuxerr.EINVAL { + j++ + } + } +} + +func BenchmarkCompareSyserrorError(b *testing.B) { globalError = syserror.EAGAIN j := 0 for i := b.N; i > 0; i-- { @@ -62,7 +79,7 @@ func BenchmarkSwitchErrno(b *testing.B) { for i := b.N; i > 0; i-- { switch globalError { case unix.EINVAL: - j += 1 + j++ case unix.EINTR: j += 2 case unix.EAGAIN: @@ -71,13 +88,28 @@ func BenchmarkSwitchErrno(b *testing.B) { } } -func BenchmarkSwitchError(b *testing.B) { +func BenchmarkSwitchLinuxerrError(b *testing.B) { + globalError = linuxerr.EPERM + j := 0 + for i := b.N; i > 0; i-- { + switch globalError { + case linuxerr.EINVAL: + j++ + case linuxerr.EINTR: + j += 2 + case linuxerr.EAGAIN: + j += 3 + } + } +} + +func BenchmarkSwitchSyserrorError(b *testing.B) { globalError = syserror.EPERM j := 0 for i := b.N; i > 0; i-- { switch globalError { case syserror.EINVAL: - j += 1 + j++ case syserror.EINTR: j += 2 case syserror.EAGAIN: diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD index 7cd89e639..7a5002176 100644 --- a/pkg/marshal/BUILD +++ b/pkg/marshal/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "marshal.go", "marshal_impl_util.go", + "util.go", ], visibility = [ "//:sandbox", diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index 32c8ed138..6f38992b7 100644 --- a/pkg/marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -125,6 +125,81 @@ func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) { var _ marshal.Marshallable = (*ByteSlice)(nil) +// The following set of functions are convenient shorthands for wrapping a +// built-in type in a marshallable primitive type. For example: +// +// func useMarshallable(m marshal.Marshallable) { ... } +// +// // Compare: +// +// buf = []byte{...} +// // useMarshallable(&primitive.ByteSlice(buf)) // Not allowed, can't address temp value. +// bufP := primitive.ByteSlice(buf) +// useMarshallable(&bufP) +// +// // Vs: +// +// useMarshallable(AsByteSlice(buf)) +// +// Note that the argument to these function escapes, so avoid using them on very +// hot code paths. But generally if a function accepts an interface as an +// argument, the argument escapes anyways. + +// AllocateInt8 returns x as a marshallable. +func AllocateInt8(x int8) marshal.Marshallable { + p := Int8(x) + return &p +} + +// AllocateUint8 returns x as a marshallable. +func AllocateUint8(x uint8) marshal.Marshallable { + p := Uint8(x) + return &p +} + +// AllocateInt16 returns x as a marshallable. +func AllocateInt16(x int16) marshal.Marshallable { + p := Int16(x) + return &p +} + +// AllocateUint16 returns x as a marshallable. +func AllocateUint16(x uint16) marshal.Marshallable { + p := Uint16(x) + return &p +} + +// AllocateInt32 returns x as a marshallable. +func AllocateInt32(x int32) marshal.Marshallable { + p := Int32(x) + return &p +} + +// AllocateUint32 returns x as a marshallable. +func AllocateUint32(x uint32) marshal.Marshallable { + p := Uint32(x) + return &p +} + +// AllocateInt64 returns x as a marshallable. +func AllocateInt64(x int64) marshal.Marshallable { + p := Int64(x) + return &p +} + +// AllocateUint64 returns x as a marshallable. +func AllocateUint64(x uint64) marshal.Marshallable { + p := Uint64(x) + return &p +} + +// AsByteSlice returns b as a marshallable. Note that this allocates a new slice +// header, but does not copy the slice contents. +func AsByteSlice(b []byte) marshal.Marshallable { + bs := ByteSlice(b) + return &bs +} + // Below, we define some convenience functions for marshalling primitive types // using the newtypes above, without requiring superfluous casts. diff --git a/pkg/tcpip/time.s b/pkg/marshal/util.go index fb37360ac..c1e5475bd 100644 --- a/pkg/tcpip/time.s +++ b/pkg/marshal/util.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,4 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Empty assembly file so empty func definitions work. +package marshal + +// Marshal returns the serialized contents of m in a newly allocated +// byte slice. +func Marshal(m Marshallable) []byte { + buf := make([]byte, m.SizeBytes()) + m.MarshalUnsafe(buf) + return buf +} diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 6450f664c..ac7868ad9 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -36,7 +36,6 @@ const ( ) // DigestSize returns the size (in bytes) of a digest. -// TODO(b/156980949): Allow config SHA384. func DigestSize(hashAlgorithm int) int { switch hashAlgorithm { case linux.FS_VERITY_HASH_ALG_SHA256: @@ -69,7 +68,6 @@ func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) blockSize: hostarch.PageSize, } - // TODO(b/156980949): Allow config SHA384. switch hashAlgorithms { case linux.FS_VERITY_HASH_ALG_SHA256: layout.digestSize = sha256DigestSize @@ -429,8 +427,6 @@ func Verify(params *VerifyParams) (int64, error) { } // If this is the end of file, zero the remaining bytes in buf, // otherwise they are still from the previous block. - // TODO(b/162908070): Investigate possible issues with zero - // padding the data. if bytesRead < len(buf) { for j := bytesRead; j < len(buf); j++ { buf[j] = 0 diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index e822fe77d..fdeee3a5f 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -36,10 +36,17 @@ var ( // new metric after initialization. ErrInitializationDone = errors.New("metric cannot be created after initialization is complete") + // createdSentryMetrics indicates that the sentry metrics are created. + createdSentryMetrics = false + // WeirdnessMetric is a metric with fields created to track the number - // of weird occurrences such as time fallback, partial_result and - // vsyscall count. + // of weird occurrences such as time fallback, partial_result, vsyscall + // count, watchdog startup timeouts and stuck tasks. WeirdnessMetric *Uint64Metric + + // SuspiciousOperationsMetric is a metric with fields created to detect + // operations such as opening an executable file to write from a gofer. + SuspiciousOperationsMetric *Uint64Metric ) // Uint64Metric encapsulates a uint64 that represents some kind of metric to be @@ -388,13 +395,21 @@ func EmitMetricUpdate() { // CreateSentryMetrics creates the sentry metrics during kernel initialization. func CreateSentryMetrics() { - if WeirdnessMetric != nil { + if createdSentryMetrics { return } - WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result and vsyscalls invoked in the sandbox", + createdSentryMetrics = true + + WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.", Field{ name: "weirdness_type", - allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count"}, + allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"}, + }) + + SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.", + Field{ + name: "operation_type", + allowedValues: []string{"opened_write_execute_file"}, }) } diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index 7abc82e1b..28396b0ea 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -121,6 +121,22 @@ func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, At return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil } +func (c *clientFile) MultiGetAttr(names []string) ([]FullStat, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, unix.EBADF + } + + if !versionSupportsTmultiGetAttr(c.client.version) { + return DefaultMultiGetAttr(c, names) + } + + rmultigetattr := Rmultigetattr{} + if err := c.client.sendRecv(&Tmultigetattr{FID: c.fid, Names: names}, &rmultigetattr); err != nil { + return nil, err + } + return rmultigetattr.Stats, nil +} + // StatFS implements File.StatFS. func (c *clientFile) StatFS() (FSStat, error) { if atomic.LoadUint32(&c.closed) != 0 { diff --git a/pkg/p9/file.go b/pkg/p9/file.go index c59c6a65b..97e0231d6 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -15,6 +15,8 @@ package p9 import ( + "errors" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/fd" ) @@ -72,6 +74,15 @@ type File interface { // On the server, WalkGetAttr has a read concurrency guarantee. WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) + // MultiGetAttr batches up multiple calls to GetAttr(). names is a list of + // path components similar to Walk(). If the first component name is empty, + // the current file is stat'd and included in the results. If the walk reaches + // a file that doesn't exist or not a directory, MultiGetAttr returns the + // partial result with no error. + // + // On the server, MultiGetAttr has a read concurrency guarantee. + MultiGetAttr(names []string) ([]FullStat, error) + // StatFS returns information about the file system associated with // this file. // @@ -306,6 +317,53 @@ func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { type DisallowServerCalls struct{} // Renamed implements File.Renamed. -func (*clientFile) Renamed(File, string) { +func (*DisallowServerCalls) Renamed(File, string) { panic("Renamed should not be called on the client") } + +// DefaultMultiGetAttr implements File.MultiGetAttr() on top of File. +func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) { + stats := make([]FullStat, 0, len(names)) + parent := start + mask := AttrMaskAll() + for i, name := range names { + if len(name) == 0 && i == 0 { + qid, valid, attr, err := parent.GetAttr(mask) + if err != nil { + return nil, err + } + stats = append(stats, FullStat{ + QID: qid, + Valid: valid, + Attr: attr, + }) + continue + } + qids, child, valid, attr, err := parent.WalkGetAttr([]string{name}) + if parent != start { + _ = parent.Close() + } + if err != nil { + if errors.Is(err, unix.ENOENT) { + return stats, nil + } + return nil, err + } + stats = append(stats, FullStat{ + QID: qids[0], + Valid: valid, + Attr: attr, + }) + if attr.Mode.FileType() != ModeDirectory { + // Doesn't need to continue if entry is not a dir. Including symlinks + // that cannot be followed. + _ = child.Close() + break + } + parent = child + } + if parent != start { + _ = parent.Close() + } + return stats, nil +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 58312d0cc..758e11b13 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -1421,3 +1421,31 @@ func (t *Tchannel) handle(cs *connState) message { } return rchannel } + +// handle implements handler.handle. +func (t *Tmultigetattr) handle(cs *connState) message { + for i, name := range t.Names { + if len(name) == 0 && i == 0 { + // Empty name is allowed on the first entry to indicate that the current + // FID needs to be included in the result. + continue + } + if err := checkSafeName(name); err != nil { + return newErr(err) + } + } + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(unix.EBADF) + } + defer ref.DecRef() + + var stats []FullStat + if err := ref.safelyRead(func() (err error) { + stats, err = ref.file.MultiGetAttr(t.Names) + return err + }); err != nil { + return newErr(err) + } + return &Rmultigetattr{Stats: stats} +} diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index cf13cbb69..2ff4694c0 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -254,8 +254,8 @@ func (r *Rwalk) decode(b *buffer) { // encode implements encoder.encode. func (r *Rwalk) encode(b *buffer) { b.Write16(uint16(len(r.QIDs))) - for _, q := range r.QIDs { - q.encode(b) + for i := range r.QIDs { + r.QIDs[i].encode(b) } } @@ -2243,8 +2243,8 @@ func (r *Rwalkgetattr) encode(b *buffer) { r.Valid.encode(b) r.Attr.encode(b) b.Write16(uint16(len(r.QIDs))) - for _, q := range r.QIDs { - q.encode(b) + for i := range r.QIDs { + r.QIDs[i].encode(b) } } @@ -2552,6 +2552,80 @@ func (r *Rchannel) String() string { return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length) } +// Tmultigetattr is a multi-getattr request. +type Tmultigetattr struct { + // FID is the FID to be walked. + FID FID + + // Names are the set of names to be walked. + Names []string +} + +// decode implements encoder.decode. +func (t *Tmultigetattr) decode(b *buffer) { + t.FID = b.ReadFID() + n := b.Read16() + t.Names = t.Names[:0] + for i := 0; i < int(n); i++ { + t.Names = append(t.Names, b.ReadString()) + } +} + +// encode implements encoder.encode. +func (t *Tmultigetattr) encode(b *buffer) { + b.WriteFID(t.FID) + b.Write16(uint16(len(t.Names))) + for _, name := range t.Names { + b.WriteString(name) + } +} + +// Type implements message.Type. +func (*Tmultigetattr) Type() MsgType { + return MsgTmultigetattr +} + +// String implements fmt.Stringer. +func (t *Tmultigetattr) String() string { + return fmt.Sprintf("Tmultigetattr{FID: %d, Names: %v}", t.FID, t.Names) +} + +// Rmultigetattr is a multi-getattr response. +type Rmultigetattr struct { + // Stats are the set of FullStat returned for each of the names in the + // request. + Stats []FullStat +} + +// decode implements encoder.decode. +func (r *Rmultigetattr) decode(b *buffer) { + n := b.Read16() + r.Stats = r.Stats[:0] + for i := 0; i < int(n); i++ { + var fs FullStat + fs.decode(b) + r.Stats = append(r.Stats, fs) + } +} + +// encode implements encoder.encode. +func (r *Rmultigetattr) encode(b *buffer) { + b.Write16(uint16(len(r.Stats))) + for i := range r.Stats { + r.Stats[i].encode(b) + } +} + +// Type implements message.Type. +func (*Rmultigetattr) Type() MsgType { + return MsgRmultigetattr +} + +// String implements fmt.Stringer. +func (r *Rmultigetattr) String() string { + return fmt.Sprintf("Rmultigetattr{Stats: %v}", r.Stats) +} + const maxCacheSize = 3 // msgFactory is used to reduce allocations by caching messages for reuse. @@ -2717,6 +2791,8 @@ func init() { msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} }) msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} }) + msgRegistry.register(MsgTmultigetattr, func() message { return &Tmultigetattr{} }) + msgRegistry.register(MsgRmultigetattr, func() message { return &Rmultigetattr{} }) msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 648cf4b49..3d452a0bd 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -402,6 +402,8 @@ const ( MsgRallocate MsgType = 139 MsgTsetattrclunk MsgType = 140 MsgRsetattrclunk MsgType = 141 + MsgTmultigetattr MsgType = 142 + MsgRmultigetattr MsgType = 143 MsgTchannel MsgType = 250 MsgRchannel MsgType = 251 ) @@ -1178,3 +1180,29 @@ func (a *AllocateMode) encode(b *buffer) { } b.Write32(mask) } + +// FullStat is used in the result of a MultiGetAttr call. +type FullStat struct { + QID QID + Valid AttrMask + Attr Attr +} + +// String implements fmt.Stringer. +func (f *FullStat) String() string { + return fmt.Sprintf("FullStat{QID: %v, Valid: %v, Attr: %v}", f.QID, f.Valid, f.Attr) +} + +// decode implements encoder.decode. +func (f *FullStat) decode(b *buffer) { + f.QID.decode(b) + f.Valid.decode(b) + f.Attr.decode(b) +} + +// encode implements encoder.encode. +func (f *FullStat) encode(b *buffer) { + f.QID.encode(b) + f.Valid.encode(b) + f.Attr.encode(b) +} diff --git a/pkg/p9/version.go b/pkg/p9/version.go index 8d7168ef5..950236162 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 12 + highestSupportedVersion uint32 = 13 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -179,3 +179,9 @@ func versionSupportsListRemoveXattr(v uint32) bool { func versionSupportsTsetattrclunk(v uint32) bool { return v >= 12 } + +// versionSupportsTmultiGetAttr returns true if version v supports +// the TmultiGetAttr message. +func versionSupportsTmultiGetAttr(v uint32) bool { + return v >= 13 +} diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 6992e1de8..4aecb8007 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -30,6 +30,9 @@ import ( // RefCounter is the interface to be implemented by objects that are reference // counted. +// +// TODO(gvisor.dev/issue/1624): Get rid of most of this package and replace it +// with refsvfs2. type RefCounter interface { // IncRef increments the reference counter on the object. IncRef() @@ -181,6 +184,9 @@ func (w *WeakRef) zap() { // AtomicRefCount keeps a reference count using atomic operations and calls the // destructor when the count reaches zero. // +// Do not use AtomicRefCount for new ref-counted objects! It is deprecated in +// favor of the refsvfs2 package. +// // N.B. To allow the zero-object to be initialized, the count is offset by // 1, that is, when refCount is n, there are really n+1 references. // @@ -215,8 +221,8 @@ type AtomicRefCount struct { // LeakMode configures the leak checker. type LeakMode uint32 -// TODO(gvisor.dev/issue/1624): Simplify down to two modes once vfs1 ref -// counting is gone. +// TODO(gvisor.dev/issue/1624): Simplify down to two modes (on/off) once vfs1 +// ref counting is gone. const ( // UninitializedLeakChecking indicates that the leak checker has not yet been initialized. UninitializedLeakChecking LeakMode = iota diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD index 0377c0876..7c1a8c792 100644 --- a/pkg/refsvfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -1,3 +1,5 @@ +# TODO(gvisor.dev/issue/1624): rename this directory/package to "refs" once VFS1 +# is gone and the current refs package can be deleted. load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template") diff --git a/pkg/refsvfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index 3fbc91aa5..1102c8adc 100644 --- a/pkg/refsvfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -13,7 +13,7 @@ // limitations under the License. // Package refs_template defines a template that can be used by reference -// counted objects. +// counted objects. The template comes with leak checking capabilities. package refs_template import ( @@ -40,6 +40,14 @@ var obj *T // Refs implements refs.RefCounter. It keeps a reference count using atomic // operations and calls the destructor when the count reaches zero. // +// NOTE: Do not introduce additional fields to the Refs struct. It is used by +// many filesystem objects, and we want to keep it as small as possible (i.e., +// the same size as using an int64 directly) to avoid taking up extra cache +// space. In general, this template should not be extended at the cost of +// performance. If it does not offer enough flexibility for a particular object +// (example: b/187877947), we should implement the RefCounter/CheckedObject +// interfaces manually. +// // +stateify savable type Refs struct { // refCount is composed of two fields: diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 3f17fba49..9dac53c80 100644 --- a/pkg/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -322,3 +322,12 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc func (p *PageTables) MarkReadOnlyShared() { p.readOnlyShared = true } + +// PrefaultRootTable touches the root table page to be sure that its physical +// pages are mapped. +// +//go:nosplit +//go:noinline +func (p *PageTables) PrefaultRootTable() PTE { + return p.root[0] +} diff --git a/pkg/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s index 290579e53..d513f16c9 100644 --- a/pkg/safecopy/atomic_amd64.s +++ b/pkg/safecopy/atomic_amd64.s @@ -24,12 +24,12 @@ TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24 MOVL DI, sig+20(FP) RET -// swapUint32 atomically stores new into *addr and returns (the previous *addr +// swapUint32 atomically stores new into *ptr and returns (the previous ptr* // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the // value of old is unspecified, and sig is the number of the signal that was // received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) TEXT ·swapUint32(SB), NOSPLIT, $0-24 @@ -38,7 +38,7 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24 // handleSwapUint32Fault will store a different value in this address. MOVL $0, sig+20(FP) - MOVQ addr+0(FP), DI + MOVQ ptr+0(FP), DI MOVL new+8(FP), AX XCHGL AX, 0(DI) MOVL AX, old+16(FP) @@ -60,12 +60,12 @@ TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28 MOVL DI, sig+24(FP) RET -// swapUint64 atomically stores new into *addr and returns (the previous *addr +// swapUint64 atomically stores new into *ptr and returns (the previous *ptr // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the // value of old is unspecified, and sig is the number of the signal that was // received. // -// Preconditions: addr must be aligned to a 8-byte boundary. +// Preconditions: ptr must be aligned to a 8-byte boundary. // //func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) TEXT ·swapUint64(SB), NOSPLIT, $0-28 @@ -74,7 +74,7 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28 // handleSwapUint64Fault will store a different value in this address. MOVL $0, sig+24(FP) - MOVQ addr+0(FP), DI + MOVQ ptr+0(FP), DI MOVQ new+8(FP), AX XCHGQ AX, 0(DI) MOVQ AX, old+16(FP) @@ -97,11 +97,11 @@ TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24 RET // compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns -// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is +// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is // received during the operation, the value of prev is unspecified, and sig is // the number of the signal that was received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 @@ -111,7 +111,7 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 // address. MOVL $0, sig+20(FP) - MOVQ addr+0(FP), DI + MOVQ ptr+0(FP), DI MOVL old+8(FP), AX MOVL new+12(FP), DX LOCK @@ -135,11 +135,11 @@ TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 MOVL DI, sig+12(FP) RET -// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS +// loadUint32 atomically loads *ptr and returns it. If a SIGSEGV or SIGBUS // signal is received, the value returned is unspecified, and sig is the number // of the signal that was received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) TEXT ·loadUint32(SB), NOSPLIT, $0-16 @@ -148,7 +148,7 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16 // handleLoadUint32Fault will store a different value in this address. MOVL $0, sig+12(FP) - MOVQ addr+0(FP), AX + MOVQ ptr+0(FP), AX MOVL (AX), BX MOVL BX, val+8(FP) RET diff --git a/pkg/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s index 55c031a3c..246a049ba 100644 --- a/pkg/safecopy/atomic_arm64.s +++ b/pkg/safecopy/atomic_arm64.s @@ -25,7 +25,7 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24 // handleSwapUint32Fault will store a different value in this address. MOVW $0, sig+20(FP) again: - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 MOVW new+8(FP), R1 LDAXRW (R0), R2 STLXRW R1, (R0), R3 @@ -60,7 +60,7 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28 // handleSwapUint64Fault will store a different value in this address. MOVW $0, sig+24(FP) again: - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 MOVD new+8(FP), R1 LDAXR (R0), R2 STLXR R1, (R0), R3 @@ -96,7 +96,7 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 // address. MOVW $0, sig+20(FP) - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 MOVW old+8(FP), R1 MOVW new+12(FP), R2 again: @@ -125,11 +125,11 @@ TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 MOVW R1, sig+12(FP) RET -// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS +// loadUint32 atomically loads *ptr and returns it. If a SIGSEGV or SIGBUS // signal is received, the value returned is unspecified, and sig is the number // of the signal that was received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) TEXT ·loadUint32(SB), NOSPLIT, $0-16 @@ -138,7 +138,7 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16 // handleLoadUint32Fault will store a different value in this address. MOVW $0, sig+12(FP) - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 LDARW (R0), R1 MOVW R1, val+8(FP) RET diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s index 1d63ca1fd..37316b2f5 100644 --- a/pkg/safecopy/memcpy_amd64.s +++ b/pkg/safecopy/memcpy_amd64.s @@ -51,8 +51,8 @@ TEXT ·memcpy(SB), NOSPLIT, $0-36 // handleMemcpyFault will store a different value in this address. MOVL $0, sig+32(FP) - MOVQ to+0(FP), DI - MOVQ from+8(FP), SI + MOVQ dst+0(FP), DI + MOVQ src+8(FP), SI MOVQ n+16(FP), BX tail: diff --git a/pkg/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s index 7b3f50aa5..50f5b754b 100644 --- a/pkg/safecopy/memcpy_arm64.s +++ b/pkg/safecopy/memcpy_arm64.s @@ -33,8 +33,8 @@ TEXT ·memcpy(SB), NOSPLIT, $-8-36 // handleMemcpyFault will store a different value in this address. MOVW $0, sig+32(FP) - MOVD to+0(FP), R3 - MOVD from+8(FP), R4 + MOVD dst+0(FP), R3 + MOVD src+8(FP), R4 MOVD n+16(FP), R5 CMP $0, R5 BNE check diff --git a/pkg/sentry/devices/quotedev/BUILD b/pkg/sentry/devices/quotedev/BUILD new file mode 100644 index 000000000..d09214e3e --- /dev/null +++ b/pkg/sentry/devices/quotedev/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "quotedev", + srcs = ["quotedev.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/devices/quotedev/quotedev.go b/pkg/sentry/devices/quotedev/quotedev.go new file mode 100644 index 000000000..6114cb724 --- /dev/null +++ b/pkg/sentry/devices/quotedev/quotedev.go @@ -0,0 +1,52 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package quotedev implements a vfs.Device for /dev/gvisor_quote. +package quotedev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + quoteDevMinor = 0 +) + +// quoteDevice implements vfs.Device for /dev/gvisor_quote +// +// +stateify savable +type quoteDevice struct{} + +// Open implements vfs.Device.Open. +// TODO(b/157161182): Add support for attestation ioctls. +func (quoteDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + return nil, syserror.EIO +} + +// Register registers all devices implemented by this package in vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + return vfsObj.RegisterDevice(vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, quoteDevice{}, &vfs.RegisterDeviceOptions{ + GroupName: "gvisor_quote", + }) +} + +// CreateDevtmpfsFiles creates device special files in dev representing all +// devices implemented by this package. +func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { + return dev.CreateDeviceFile(ctx, "gvisor_quote", vfs.CharDevice, linux.UNNAMED_MAJOR, quoteDevMinor, 0666 /* mode */) +} diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index c4a069832..94cb05246 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -29,6 +29,7 @@ go_library( "//pkg/fd", "//pkg/hostarch", "//pkg/log", + "//pkg/metric", "//pkg/p9", "//pkg/refs", "//pkg/safemem", diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 8f5a87120..819e140bc 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -91,7 +92,7 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF } if flags.Write { if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil { - fsmetric.GoferOpensWX.Increment() + metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") log.Warningf("Opened a writable executable: %q", name) } } diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 1d09afdd7..4893af56b 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -403,7 +403,7 @@ type ipForwarding struct { // enabled stores the IPv4 forwarding state on save. // We must save/restore this here, since a netstack instance // is created on restore. - enabled *bool + enabled bool } func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { @@ -461,13 +461,8 @@ func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return 0, io.EOF } - if f.ipf.enabled == nil { - enabled := f.stack.Forwarding(ipv4.ProtocolNumber) - f.ipf.enabled = &enabled - } - val := "0\n" - if *f.ipf.enabled { + if f.ipf.enabled { // Technically, this is not quite compatible with Linux. Linux // stores these as an integer, so if you write "2" into // ip_forward, you should get 2 back. @@ -494,11 +489,8 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO if err != nil { return n, err } - if f.ipf.enabled == nil { - f.ipf.enabled = new(bool) - } - *f.ipf.enabled = v != 0 - return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) + f.ipf.enabled = v != 0 + return n, f.stack.SetForwarding(ipv4.ProtocolNumber, f.ipf.enabled) } // portRangeInode implements fs.InodeOperations. It provides and allows diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go index 4cb4741af..51d2be647 100644 --- a/pkg/sentry/fs/proc/sys_net_state.go +++ b/pkg/sentry/fs/proc/sys_net_state.go @@ -47,9 +47,7 @@ func (s *tcpSack) afterLoad() { // afterLoad is invoked by stateify. func (ipf *ipForwarding) afterLoad() { - if ipf.enabled != nil { - if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { - panic(fmt.Sprintf("failed to set IPv4 forwarding [%v]: %v", *ipf.enabled, err)) - } + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, ipf.enabled); err != nil { + panic(fmt.Sprintf("ipf.stack.SetForwarding(%d, %t): %s", ipv4.ProtocolNumber, ipf.enabled, err)) } } diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go index 0f54888d8..6512e9cdb 100644 --- a/pkg/sentry/fsimpl/cgroupfs/base.go +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -68,11 +67,6 @@ func (c *controllerCommon) Enabled() bool { return true } -// Filesystem implements kernel.CgroupController.Filesystem. -func (c *controllerCommon) Filesystem() *vfs.Filesystem { - return c.fs.VFSFilesystem() -} - // RootCgroup implements kernel.CgroupController.RootCgroup. func (c *controllerCommon) RootCgroup() kernel.Cgroup { return c.fs.rootCgroup() diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go index bd3e69757..54050de3c 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -109,7 +109,7 @@ type InternalData struct { DefaultControlValues map[string]int64 } -// filesystem implements vfs.FilesystemImpl. +// filesystem implements vfs.FilesystemImpl and kernel.cgroupFS. // // +stateify savable type filesystem struct { @@ -139,6 +139,11 @@ type filesystem struct { tasksMu sync.RWMutex `state:"nosave"` } +// InitializeHierarchyID implements kernel.cgroupFS.InitializeHierarchyID. +func (fs *filesystem) InitializeHierarchyID(hid uint32) { + fs.hierarchyID = hid +} + // Name implements vfs.FilesystemType.Name. func (FilesystemType) Name() string { return Name @@ -284,14 +289,12 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Register controllers. The registry may be modified concurrently, so if we // get an error, we raced with someone else who registered the same // controllers first. - hid, err := r.Register(fs.kcontrollers) - if err != nil { + if err := r.Register(fs.kcontrollers, fs); err != nil { ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err) rootD.DecRef(ctx) fs.VFSFilesystem().DecRef(ctx) return nil, nil, syserror.EBUSY } - fs.hierarchyID = hid // Move all existing tasks to the root of the new hierarchy. k.PopulateNewCgroupHierarchy(fs.rootCgroup()) diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go index e6fe0fc0d..daff40cd5 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go @@ -36,7 +36,7 @@ const Name = "devtmpfs" // // +stateify savable type FilesystemType struct { - initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1664): not yet supported. + initOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. initErr error // fs is the tmpfs filesystem that backs all mounts of this FilesystemType. diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 7b1eec3da..2dbc6bfd5 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -46,7 +46,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fd", "//pkg/fspath", diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 6d5258a9b..368272f12 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "host_named_pipe.go", "p9file.go", "regular_file.go", + "revalidate.go", "save_restore.go", "socket.go", "special_file.go", @@ -53,6 +54,7 @@ go_library( "//pkg/fspath", "//pkg/hostarch", "//pkg/log", + "//pkg/metric", "//pkg/p9", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 4b5621043..91ec4a142 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -117,6 +117,17 @@ func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry { return ds } +// Precondition: !parent.isSynthetic() && !child.isSynthetic(). +func appendNewChildDentry(ds **[]*dentry, parent *dentry, child *dentry) { + // The new child was added to parent and took a ref on the parent (hence + // parent can be removed from cache). A new child has 0 refs for now. So + // checkCachingLocked() should be called on both. Call it first on the parent + // as it may create space in the cache for child to be inserted - hence + // avoiding a cache eviction. + *ds = appendDentry(*ds, parent) + *ds = appendDentry(*ds, child) +} + // Preconditions: ds != nil. func putDentrySlice(ds *[]*dentry) { // Allow dentries to be GC'd. @@ -169,167 +180,96 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[] // * fs.renameMu must be locked. // * d.dirMu must be locked. // * !rp.Done(). -// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up -// to date. +// * If !d.cachedMetadataAuthoritative(), then d and all children that are +// part of rp must have been revalidated. // // Postconditions: The returned dentry's cached metadata is up to date. -func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, bool, error) { if !d.isDir() { - return nil, syserror.ENOTDIR + return nil, false, syserror.ENOTDIR } if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, err + return nil, false, err } + followedSymlink := false afterSymlink: name := rp.Component() if name == "." { rp.Advance() - return d, nil + return d, followedSymlink, nil } if name == ".." { if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { - return nil, err + return nil, false, err } else if isRoot || d.parent == nil { rp.Advance() - return d, nil - } - // We must assume that d.parent is correct, because if d has been moved - // elsewhere in the remote filesystem so that its parent has changed, - // we have no way of determining its new parent's location in the - // filesystem. - // - // Call rp.CheckMount() before updating d.parent's metadata, since if - // we traverse to another mount then d.parent's metadata is irrelevant. - if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { - return nil, err + return d, followedSymlink, nil } - if d != d.parent && !d.cachedMetadataAuthoritative() { - if err := d.parent.updateFromGetattr(ctx); err != nil { - return nil, err - } + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, false, err } rp.Advance() - return d.parent, nil + return d.parent, followedSymlink, nil } - child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), d, name, ds) + child, err := fs.getChildLocked(ctx, d, name, ds) if err != nil { - return nil, err - } - if child == nil { - return nil, syserror.ENOENT + return nil, false, err } if err := rp.CheckMount(ctx, &child.vfsd); err != nil { - return nil, err + return nil, false, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx, rp.Mount()) if err != nil { - return nil, err + return nil, false, err } if err := rp.HandleSymlink(target); err != nil { - return nil, err + return nil, false, err } + followedSymlink = true goto afterSymlink // don't check the current directory again } rp.Advance() - return child, nil + return child, followedSymlink, nil } // getChildLocked returns a dentry representing the child of parent with the -// given name. If no such child exists, getChildLocked returns (nil, nil). +// given name. Returns ENOENT if the child doesn't exist. // // Preconditions: // * fs.renameMu must be locked. // * parent.dirMu must be locked. // * parent.isDir(). // * name is not "." or "..". -// -// Postconditions: If getChildLocked returns a non-nil dentry, its cached -// metadata is up to date. -func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { +// * dentry at name has been revalidated +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if len(name) > maxFilenameLen { return nil, syserror.ENAMETOOLONG } - child, ok := parent.children[name] - if (ok && fs.opts.interop != InteropModeShared) || parent.isSynthetic() { - // Whether child is nil or not, it is cached information that is - // assumed to be correct. + if child, ok := parent.children[name]; ok || parent.isSynthetic() { + if child == nil { + return nil, syserror.ENOENT + } return child, nil } - // We either don't have cached information or need to verify that it's - // still correct, either of which requires a remote lookup. Check if this - // name is valid before performing the lookup. - return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds) -} -// Preconditions: Same as getChildLocked, plus: -// * !parent.isSynthetic(). -func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { - if child != nil { - // Need to lock child.metadataMu because we might be updating child - // metadata. We need to hold the lock *before* getting metadata from the - // server and release it after updating local metadata. - child.metadataMu.Lock() - } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) - if err != nil && err != syserror.ENOENT { - if child != nil { - child.metadataMu.Unlock() + if err != nil { + if err == syserror.ENOENT { + parent.cacheNegativeLookupLocked(name) } return nil, err } - if child != nil { - if !file.isNil() && qid.Path == child.qidPath { - // The file at this path hasn't changed. Just update cached metadata. - file.close(ctx) - child.updateFromP9AttrsLocked(attrMask, &attr) - child.metadataMu.Unlock() - return child, nil - } - child.metadataMu.Unlock() - if file.isNil() && child.isSynthetic() { - // We have a synthetic file, and no remote file has arisen to - // replace it. - return child, nil - } - // The file at this path has changed or no longer exists. Mark the - // dentry invalidated, and re-evaluate its caching status (i.e. if it - // has 0 references, drop it). Wait to update parent.children until we - // know what to replace the existing dentry with (i.e. one of the - // returns below), to avoid a redundant map access. - vfsObj.InvalidateDentry(ctx, &child.vfsd) - if child.isSynthetic() { - // Normally we don't mark invalidated dentries as deleted since - // they may still exist (but at a different path), and also for - // consistency with Linux. However, synthetic files are guaranteed - // to become unreachable if their dentries are invalidated, so - // treat their invalidation as deletion. - child.setDeleted() - parent.syntheticChildren-- - child.decRefNoCaching() - parent.dirents = nil - } - *ds = appendDentry(*ds, child) - } - if file.isNil() { - // No file exists at this path now. Cache the negative lookup if - // allowed. - parent.cacheNegativeLookupLocked(name) - return nil, nil - } + // Create a new dentry representing the file. - child, err = fs.newDentry(ctx, file, qid, attrMask, &attr) + child, err := fs.newDentry(ctx, file, qid, attrMask, &attr) if err != nil { file.close(ctx) delete(parent.children, name) return nil, err } parent.cacheNewChildLocked(child, name) - // For now, child has 0 references, so our caller should call - // child.checkCachingLocked(). parent gained a ref so we should also call - // parent.checkCachingLocked() so it can be removed from the cache if needed. - *ds = appendDentry(*ds, child) - *ds = appendDentry(*ds, parent) + appendNewChildDentry(ds, parent, child) return child, nil } @@ -344,14 +284,22 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up // to date. func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { + if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil { + return nil, err + } for !rp.Final() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err } d = next + if followedSymlink { + if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil { + return nil, err + } + } } if !d.isDir() { return nil, syserror.ENOTDIR @@ -364,20 +312,22 @@ func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // Preconditions: fs.renameMu must be locked. func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { d := rp.Start().Impl().(*dentry) - if !d.cachedMetadataAuthoritative() { - // Get updated metadata for rp.Start() as required by fs.stepLocked(). - if err := d.updateFromGetattr(ctx); err != nil { - return nil, err - } + if err := fs.revalidatePath(ctx, rp, d, ds); err != nil { + return nil, err } for !rp.Done() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err } d = next + if followedSymlink { + if err := fs.revalidatePath(ctx, rp, d, ds); err != nil { + return nil, err + } + } } if rp.MustBeDir() && !d.isDir() { return nil, syserror.ENOTDIR @@ -397,13 +347,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return err @@ -421,25 +364,47 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if parent.isDeleted() { return syserror.ENOENT } + if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, name, &ds); err != nil { + return err + } parent.dirMu.Lock() defer parent.dirMu.Unlock() - child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), parent, name, &ds) - switch { - case err != nil && err != syserror.ENOENT: - return err - case child != nil: + if len(name) > maxFilenameLen { + return syserror.ENAMETOOLONG + } + // Check for existence only if caching information is available. Otherwise, + // don't check for existence just yet. We will check for existence if the + // checks for writability fail below. Existence check is done by the creation + // RPCs themselves. + if child, ok := parent.children[name]; ok && child != nil { return syserror.EEXIST } + checkExistence := func() error { + if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && err != syserror.ENOENT { + return err + } else if child != nil { + return syserror.EEXIST + } + return nil + } mnt := rp.Mount() if err := mnt.CheckBeginWrite(); err != nil { + // Existence check takes precedence. + if existenceErr := checkExistence(); existenceErr != nil { + return existenceErr + } return err } defer mnt.EndWrite() if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + // Existence check takes precedence. + if existenceErr := checkExistence(); existenceErr != nil { + return existenceErr + } return err } if !dir && rp.MustBeDir() { @@ -489,13 +454,6 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return err @@ -521,33 +479,32 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return syserror.EISDIR } } + vfsObj := rp.VirtualFilesystem() + if err := fs.revalidateOne(ctx, vfsObj, parent, rp.Component(), &ds); err != nil { + return err + } + mntns := vfs.MountNamespaceFromContext(ctx) defer mntns.DecRef(ctx) + parent.dirMu.Lock() defer parent.dirMu.Unlock() - child, ok := parent.children[name] - if ok && child == nil { - return syserror.ENOENT - } - - sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0 - if sticky { - if !ok { - // If the sticky bit is set, we need to retrieve the child to determine - // whether removing it is allowed. - child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) - if err != nil { - return err - } - } else if child != nil && !child.cachedMetadataAuthoritative() { - // Make sure the dentry representing the file at name is up to date - // before examining its metadata. - child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) - if err != nil { - return err - } + // Load child if sticky bit is set because we need to determine whether + // deletion is allowed. + var child *dentry + if atomic.LoadUint32(&parent.mode)&linux.ModeSticky == 0 { + var ok bool + child, ok = parent.children[name] + if ok && child == nil { + // Hit a negative cached entry, child doesn't exist. + return syserror.ENOENT + } + } else { + child, _, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + if err != nil { + return err } if err := parent.mayDelete(rp.Credentials(), child); err != nil { return err @@ -556,11 +513,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b // If a child dentry exists, prepare to delete it. This should fail if it is // a mount point. We detect mount points by speculatively calling - // PrepareDeleteDentry, which fails if child is a mount point. However, we - // may need to revalidate the file in this case to make sure that it has not - // been deleted or replaced on the remote fs, in which case the mount point - // will have disappeared. If calling PrepareDeleteDentry fails again on the - // up-to-date dentry, we can be sure that it is a mount point. + // PrepareDeleteDentry, which fails if child is a mount point. // // Also note that if child is nil, then it can't be a mount point. if child != nil { @@ -575,23 +528,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b child.dirMu.Lock() defer child.dirMu.Unlock() if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - // We can skip revalidation in several cases: - // - We are not in InteropModeShared - // - The parent directory is synthetic, in which case the child must also - // be synthetic - // - We already updated the child during the sticky bit check above - if parent.cachedMetadataAuthoritative() || sticky { - return err - } - child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) - if err != nil { - return err - } - if child != nil { - if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - return err - } - } + return err } } flags := uint32(0) @@ -723,13 +660,6 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } d, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return nil, err @@ -830,7 +760,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // to creating a synthetic one, i.e. one that is kept entirely in memory. // Check that we're not overriding an existing file with a synthetic one. - _, err = fs.stepLocked(ctx, rp, parent, true, ds) + _, _, err = fs.stepLocked(ctx, rp, parent, true, ds) switch { case err == nil: // Step succeeded, another file exists. @@ -891,12 +821,6 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf defer unlock() start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by fs.stepLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } if rp.Done() { // Reject attempts to open mount root directory with O_CREAT. if mayCreate && rp.MustBeDir() { @@ -905,6 +829,12 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } + if !start.cachedMetadataAuthoritative() { + // Refresh dentry's attributes before opening. + if err := start.updateFromGetattr(ctx); err != nil { + return nil, err + } + } start.IncRef() defer start.DecRef(ctx) unlock() @@ -926,9 +856,12 @@ afterTrailingSymlink: if mayCreate && rp.MustBeDir() { return nil, syserror.EISDIR } + if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, rp.Component(), &ds); err != nil { + return nil, err + } // Determine whether or not we need to create a file. parent.dirMu.Lock() - child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + child, _, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) if err == syserror.ENOENT && mayCreate { if parent.isSynthetic() { parent.dirMu.Unlock() @@ -1028,7 +961,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open } return &fd.vfsfd, nil case linux.S_IFLNK: - // Can't open symlinks without O_PATH (which is unimplemented). + // Can't open symlinks without O_PATH, which is handled at the VFS layer. return nil, syserror.ELOOP case linux.S_IFSOCK: if d.isSynthetic() { @@ -1188,7 +1121,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } - *ds = appendDentry(*ds, child) // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { @@ -1212,7 +1144,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } // Insert the dentry into the tree. d.cacheNewChildLocked(child, name) - *ds = appendDentry(*ds, d) + appendNewChildDentry(ds, d, child) if d.cachedMetadataAuthoritative() { d.touchCMtime() d.dirents = nil @@ -1297,18 +1229,23 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { return err } + vfsObj := rp.VirtualFilesystem() + if err := fs.revalidateOne(ctx, vfsObj, newParent, newName, &ds); err != nil { + return err + } + if err := fs.revalidateOne(ctx, vfsObj, oldParent, oldName, &ds); err != nil { + return err + } + // We need a dentry representing the renamed file since, if it's a // directory, we need to check for write permission on it. oldParent.dirMu.Lock() defer oldParent.dirMu.Unlock() - renamed, err := fs.getChildLocked(ctx, vfsObj, oldParent, oldName, &ds) + renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) if err != nil { return err } - if renamed == nil { - return syserror.ENOENT - } if err := oldParent.mayDelete(creds, renamed); err != nil { return err } @@ -1337,8 +1274,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.isDeleted() { return syserror.ENOENT } - replaced, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds) - if err != nil { + replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil && err != syserror.ENOENT { return err } var replacedVFSD *vfs.Dentry @@ -1402,9 +1339,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // parent isn't actually changing. if oldParent != newParent { oldParent.decRefNoCaching() - ds = appendDentry(ds, oldParent) newParent.IncRef() ds = appendDentry(ds, newParent) + ds = appendDentry(ds, oldParent) if renamed.isSynthetic() { oldParent.syntheticChildren-- newParent.syntheticChildren++ diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index fb42c5f62..21692d2ac 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -32,9 +32,9 @@ // specialFileFD.mu // specialFileFD.bufMu // -// Locking dentry.dirMu in multiple dentries requires that either ancestor -// dentries are locked before descendant dentries, or that filesystem.renameMu -// is locked for writing. +// Locking dentry.dirMu and dentry.metadataMu in multiple dentries requires that +// either ancestor dentries are locked before descendant dentries, or that +// filesystem.renameMu is locked for writing. package gofer import ( diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 21b4a96fe..b0a429d42 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -238,3 +238,10 @@ func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, err ctx.UninterruptibleSleepFinish(false) return fdobj, err } + +func (f p9file) multiGetAttr(ctx context.Context, names []string) ([]p9.FullStat, error) { + ctx.UninterruptibleSleepStart(false) + stats, err := f.file.MultiGetAttr(names) + ctx.UninterruptibleSleepFinish(false) + return stats, err +} diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index f0e7bbaf7..eed05e369 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -59,7 +60,7 @@ func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD, return nil, err } if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { - fsmetric.GoferOpensWX.Increment() + metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") } if atomic.LoadInt32(&d.mmapFD) >= 0 { fsmetric.GoferOpensHost.Increment() diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go new file mode 100644 index 000000000..8f81f0822 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/revalidate.go @@ -0,0 +1,386 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" +) + +type errPartialRevalidation struct{} + +// Error implements error.Error. +func (errPartialRevalidation) Error() string { + return "partial revalidation" +} + +type errRevalidationStepDone struct{} + +// Error implements error.Error. +func (errRevalidationStepDone) Error() string { + return "stop revalidation" +} + +// revalidatePath checks cached dentries for external modification. File +// attributes are refreshed and cache is invalidated in case the dentry has been +// deleted, or a new file/directory created in its place. +// +// Revalidation stops at symlinks and mount points. The caller is responsible +// for revalidating again after symlinks are resolved and after changing to +// different mounts. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidatePath(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error { + // Revalidation is done even if start is synthetic in case the path is + // something like: ../non_synthetic_file. + if fs.opts.interop != InteropModeShared { + return nil + } + + // Copy resolving path to walk the path for revalidation. + rp := rpOrig.Copy() + err := fs.revalidate(ctx, rp, start, rp.Done, ds) + rp.Release(ctx) + return err +} + +// revalidateParentDir does the same as revalidatePath, but stops at the parent. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidateParentDir(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error { + // Revalidation is done even if start is synthetic in case the path is + // something like: ../non_synthetic_file and parent is non synthetic. + if fs.opts.interop != InteropModeShared { + return nil + } + + // Copy resolving path to walk the path for revalidation. + rp := rpOrig.Copy() + err := fs.revalidate(ctx, rp, start, rp.Final, ds) + rp.Release(ctx) + return err +} + +// revalidateOne does the same as revalidatePath, but checks a single dentry. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidateOne(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) error { + // Skip revalidation for interop mode different than InteropModeShared or + // if the parent is synthetic (child must be synthetic too, but it cannot be + // replaced without first replacing the parent). + if parent.cachedMetadataAuthoritative() { + return nil + } + + parent.dirMu.Lock() + child, ok := parent.children[name] + parent.dirMu.Unlock() + if !ok { + return nil + } + + state := makeRevalidateState(parent) + defer state.release() + + state.add(name, child) + return fs.revalidateHelper(ctx, vfsObj, state, ds) +} + +// revalidate revalidates path components in rp until done returns true, or +// until a mount point or symlink is reached. It may send multiple MultiGetAttr +// calls to the gofer to handle ".." in the path. +// +// Preconditions: +// * fs.renameMu must be locked. +// * InteropModeShared is in effect. +func (fs *filesystem) revalidate(ctx context.Context, rp *vfs.ResolvingPath, start *dentry, done func() bool, ds **[]*dentry) error { + state := makeRevalidateState(start) + defer state.release() + + // Skip synthetic dentries because the start dentry cannot be replaced in case + // it has been created in the remote file system. + if !start.isSynthetic() { + state.add("", start) + } + +done: + for cur := start; !done(); { + var err error + cur, err = fs.revalidateStep(ctx, rp, cur, state) + if err != nil { + switch err.(type) { + case errPartialRevalidation: + if err := fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds); err != nil { + return err + } + + // Reset state to release any remaining locks and restart from where + // stepping stopped. + state.reset() + state.start = cur + + // Skip synthetic dentries because the start dentry cannot be replaced in + // case it has been created in the remote file system. + if !cur.isSynthetic() { + state.add("", cur) + } + + case errRevalidationStepDone: + break done + + default: + return err + } + } + } + return fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds) +} + +// revalidateStep walks one element of the path and updates revalidationState +// with the dentry if needed. It may also stop the stepping or ask for a +// partial revalidation. Partial revalidation requires the caller to revalidate +// the current revalidationState, release all locks, and resume stepping. +// In case a symlink is hit, revalidation stops and the caller is responsible +// for calling revalidate again after the symlink is resolved. Revalidation may +// also stop for other reasons, like hitting a child not in the cache. +// +// Returns: +// * (dentry, nil): step worked, continue stepping.` +// * (dentry, errPartialRevalidation): revalidation should be done with the +// state gathered so far. Then continue stepping with the remainder of the +// path, starting at `dentry`. +// * (nil, errRevalidationStepDone): revalidation doesn't need to step any +// further. It hit a symlink, a mount point, or an uncached dentry. +// +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). +// * InteropModeShared is in effect (assumes no negative dentries). +func (fs *filesystem) revalidateStep(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, state *revalidateState) (*dentry, error) { + switch name := rp.Component(); name { + case ".": + // Do nothing. + + case "..": + // Partial revalidation is required when ".." is hit because metadata locks + // can only be acquired from parent to child to avoid deadlocks. + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { + return nil, errRevalidationStepDone{} + } else if isRoot || d.parent == nil { + rp.Advance() + return d, errPartialRevalidation{} + } + // We must assume that d.parent is correct, because if d has been moved + // elsewhere in the remote filesystem so that its parent has changed, + // we have no way of determining its new parent's location in the + // filesystem. + // + // Call rp.CheckMount() before updating d.parent's metadata, since if + // we traverse to another mount then d.parent's metadata is irrelevant. + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, errRevalidationStepDone{} + } + rp.Advance() + return d.parent, errPartialRevalidation{} + + default: + d.dirMu.Lock() + child, ok := d.children[name] + d.dirMu.Unlock() + if !ok { + // child is not cached, no need to validate any further. + return nil, errRevalidationStepDone{} + } + + state.add(name, child) + + // Symlink must be resolved before continuing with revalidation. + if child.isSymlink() { + return nil, errRevalidationStepDone{} + } + + d = child + } + + rp.Advance() + return d, nil +} + +// revalidateHelper calls the gofer to stat all dentries in `state`. It will +// update or invalidate dentries in the cache based on the result. +// +// Preconditions: +// * fs.renameMu must be locked. +// * InteropModeShared is in effect. +func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualFilesystem, state *revalidateState, ds **[]*dentry) error { + if len(state.names) == 0 { + return nil + } + // Lock metadata on all dentries *before* getting attributes for them. + state.lockAllMetadata() + stats, err := state.start.file.multiGetAttr(ctx, state.names) + if err != nil { + return err + } + + i := -1 + for d := state.popFront(); d != nil; d = state.popFront() { + i++ + found := i < len(stats) + if i == 0 && len(state.names[0]) == 0 { + if found && !d.isSynthetic() { + // First dentry is where the search is starting, just update attributes + // since it cannot be replaced. + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) + } + d.metadataMu.Unlock() + continue + } + + // Note that synthetic dentries will always fails the comparison check + // below. + if !found || d.qidPath != stats[i].QID.Path { + d.metadataMu.Unlock() + if !found && d.isSynthetic() { + // We have a synthetic file, and no remote file has arisen to replace + // it. + return nil + } + // The file at this path has changed or no longer exists. Mark the + // dentry invalidated, and re-evaluate its caching status (i.e. if it + // has 0 references, drop it). The dentry will be reloaded next time it's + // accessed. + vfsObj.InvalidateDentry(ctx, &d.vfsd) + + name := state.names[i] + d.parent.dirMu.Lock() + + if d.isSynthetic() { + // Normally we don't mark invalidated dentries as deleted since + // they may still exist (but at a different path), and also for + // consistency with Linux. However, synthetic files are guaranteed + // to become unreachable if their dentries are invalidated, so + // treat their invalidation as deletion. + d.setDeleted() + d.decRefNoCaching() + *ds = appendDentry(*ds, d) + + d.parent.syntheticChildren-- + d.parent.dirents = nil + } + + // Since the dirMu was released and reacquired, re-check that the + // parent's child with this name is still the same. Do not touch it if + // it has been replaced with a different one. + if child := d.parent.children[name]; child == d { + // Invalidate dentry so it gets reloaded next time it's accessed. + delete(d.parent.children, name) + } + d.parent.dirMu.Unlock() + + return nil + } + + // The file at this path hasn't changed. Just update cached metadata. + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) + d.metadataMu.Unlock() + } + + return nil +} + +// revalidateStatePool caches revalidateState instances to save array +// allocations for dentries and names. +var revalidateStatePool = sync.Pool{ + New: func() interface{} { + return &revalidateState{} + }, +} + +// revalidateState keeps state related to a revalidation request. It keeps track +// of {name, dentry} list being revalidated, as well as metadata locks on the +// dentries. The list must be in ancestry order, in other words `n` must be +// `n-1` child. +type revalidateState struct { + // start is the dentry where to start the attributes search. + start *dentry + + // List of names of entries to refresh attributes. Names length must be the + // same as detries length. They are kept in separate slices because names is + // used to call File.MultiGetAttr(). + names []string + + // dentries is the list of dentries that correspond to the names above. + // dentry.metadataMu is acquired as each dentry is added to this list. + dentries []*dentry + + // locked indicates if metadata lock has been acquired on dentries. + locked bool +} + +func makeRevalidateState(start *dentry) *revalidateState { + r := revalidateStatePool.Get().(*revalidateState) + r.start = start + return r +} + +// release must be called after the caller is done with this object. It releases +// all metadata locks and resources. +func (r *revalidateState) release() { + r.reset() + revalidateStatePool.Put(r) +} + +// Preconditions: +// * d is a descendant of all dentries in r.dentries. +func (r *revalidateState) add(name string, d *dentry) { + r.names = append(r.names, name) + r.dentries = append(r.dentries, d) +} + +func (r *revalidateState) lockAllMetadata() { + for _, d := range r.dentries { + d.metadataMu.Lock() + } + r.locked = true +} + +func (r *revalidateState) popFront() *dentry { + if len(r.dentries) == 0 { + return nil + } + d := r.dentries[0] + r.dentries = r.dentries[1:] + return d +} + +// reset releases all metadata locks and resets all fields to allow this +// instance to be reused. +func (r *revalidateState) reset() { + if r.locked { + // Unlock any remaining dentries. + for _, d := range r.dentries { + d.metadataMu.Unlock() + } + r.locked = false + } + r.start = nil + r.names = r.names[:0] + r.dentries = r.dentries[:0] +} diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index ac3b5b621..c12444b7e 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fsmetric" @@ -100,7 +101,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*speci d.fs.specialFileFDs[fd] = struct{}{} d.fs.syncMu.Unlock() if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { - fsmetric.GoferOpensWX.Increment() + metric.SuspiciousOperationsMetric.Increment("opened_write_execute_file") } if h.fd >= 0 { fsmetric.GoferOpensHost.Increment() diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index badca4d9f..f50b0fb08 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -612,16 +612,24 @@ afterTrailingSymlink: // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - fs.mu.RLock() defer fs.processDeferredDecRefs(ctx) - defer fs.mu.RUnlock() + + fs.mu.RLock() d, err := fs.walkExistingLocked(ctx, rp) if err != nil { + fs.mu.RUnlock() return "", err } if !d.isSymlink() { + fs.mu.RUnlock() return "", syserror.EINVAL } + + // Inode.Readlink() cannot be called holding fs locks. + d.IncRef() + defer d.DecRef(ctx) + fs.mu.RUnlock() + return d.inode.Readlink(ctx, rp.Mount()) } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 16486eeae..6f699c9cd 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -534,6 +534,9 @@ func (d *Dentry) FSLocalPath() string { // - Checking that dentries passed to methods are of the appropriate file type. // - Checking permissions. // +// Inode functions may be called holding filesystem wide locks and are not +// allowed to call vfs functions that may reenter, unless otherwise noted. +// // Specific responsibilities of implementations are documented below. type Inode interface { // Methods related to reference counting. A generic implementation is @@ -680,6 +683,9 @@ type inodeDirectory interface { type inodeSymlink interface { // Readlink returns the target of a symbolic link. If an inode is not a // symlink, the implementation should return EINVAL. + // + // Readlink is called with no kernfs locks held, so it may reenter if needed + // to resolve symlink targets. Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) // Getlink returns the target of a symbolic link, as used by path diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 02bf74dbc..4718fac7a 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -221,6 +221,8 @@ func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) defer file.DecRef(ctx) root := vfs.RootFromContext(ctx) defer root.DecRef(ctx) + + // Note: it's safe to reenter kernfs from Readlink if needed to resolve path. return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry()) } diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 7c7543f14..cf905fae4 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -65,6 +65,7 @@ var _ kernfs.Inode = (*tasksInode)(nil) func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ + "cmdline": fs.newInode(ctx, root, 0444, &cmdLineData{}), "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index e1a8b4409..045ed7a2d 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -336,15 +336,6 @@ var _ dynamicInode = (*versionData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { - k := kernel.KernelFromContext(ctx) - init := k.GlobalInit() - if init == nil { - // Attempted to read before the init Task is created. This can - // only occur during startup, which should never need to read - // this file. - panic("Attempted to read version before initial Task is available") - } - // /proc/version takes the form: // // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST) @@ -364,7 +355,7 @@ func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { // FIXME(mpratt): Using Version from the init task SyscallTable // disregards the different version a task may have (e.g., in a uts // namespace). - ver := init.Leader().SyscallTable().Version + ver := kernelVersion(ctx) fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version) return nil } @@ -400,3 +391,31 @@ func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error { r.GenerateProcCgroups(buf) return nil } + +// cmdLineData backs /proc/cmdline. +// +// +stateify savable +type cmdLineData struct { + dynamicBytesFileSetAttr +} + +var _ dynamicInode = (*cmdLineData)(nil) + +// Generate implements vfs.DynamicByteSource.Generate. +func (*cmdLineData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "BOOT_IMAGE=/vmlinuz-%s-gvisor quiet\n", kernelVersion(ctx).Release) + return nil +} + +// kernelVersion returns the kernel version. +func kernelVersion(ctx context.Context) kernel.Version { + k := kernel.KernelFromContext(ctx) + init := k.GlobalInit() + if init == nil { + // Attempted to read before the init Task is created. This can + // only occur during startup, which should never need to read + // this file. + panic("Attempted to read version before initial Task is available") + } + return init.Leader().SyscallTable().Version +} diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 9b14dd6b9..88ab49048 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -365,27 +365,22 @@ func (d *tcpMemData) writeSizeLocked(size inet.TCPBufferSize) error { } // ipForwarding implements vfs.WritableDynamicBytesSource for -// /proc/sys/net/ipv4/ip_forwarding. +// /proc/sys/net/ipv4/ip_forward. // // +stateify savable type ipForwarding struct { kernfs.DynamicBytesFile stack inet.Stack `state:"wait"` - enabled *bool + enabled bool } var _ vfs.WritableDynamicBytesSource = (*ipForwarding)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error { - if ipf.enabled == nil { - enabled := ipf.stack.Forwarding(ipv4.ProtocolNumber) - ipf.enabled = &enabled - } - val := "0\n" - if *ipf.enabled { + if ipf.enabled { // Technically, this is not quite compatible with Linux. Linux stores these // as an integer, so if you write "2" into tcp_sack, you should get 2 back. // Tough luck. @@ -414,11 +409,8 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs if err != nil { return 0, err } - if ipf.enabled == nil { - ipf.enabled = new(bool) - } - *ipf.enabled = v != 0 - if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + ipf.enabled = v != 0 + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, ipf.enabled); err != nil { return 0, err } return n, nil diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go index 6cee22823..19b012f7d 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go @@ -132,7 +132,7 @@ func TestConfigureIPForwarding(t *testing.T) { t.Run(c.comment, func(t *testing.T) { s.IPForwarding = c.initial - file := &ipForwarding{stack: s, enabled: &c.initial} + file := &ipForwarding{stack: s, enabled: c.initial} // Write the values. src := usermem.BytesIOSequence([]byte(c.str)) diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index d6f076cd6..e534fbca8 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -47,6 +47,7 @@ var ( var ( tasksStaticFiles = map[string]testutil.DirentType{ + "cmdline": linux.DT_REG, "cpuinfo": linux.DT_REG, "filesystems": linux.DT_REG, "loadavg": linux.DT_REG, diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 5fdca1d46..766289e60 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -465,7 +465,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open } return &fd.vfsfd, nil case *symlink: - // TODO(gvisor.dev/issue/2782): Can't open symlinks without O_PATH. + // Can't open symlinks without O_PATH, which is handled at the VFS layer. return nil, syserror.ELOOP case *namedPipe: return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags, &d.inode.locks) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index ca8090bbf..3582d14c9 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -168,10 +168,6 @@ afterSymlink: // Preconditions: // * fs.renameMu must be locked. // * d.dirMu must be locked. -// -// TODO(b/166474175): Investigate all possible errors returned in this -// function, and make sure we differentiate all errors that indicate unexpected -// modifications to the file system from the ones that are not harmful. func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() @@ -278,16 +274,15 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi var buf bytes.Buffer parent.hashMu.RLock() _, err = merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, - Children: parent.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + Children: parent.childrenNames, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), @@ -409,15 +404,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry var buf bytes.Buffer d.hashMu.RLock() params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, - Children: d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + Children: d.childrenNames, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: 0, // Set read size to 0 so only the metadata is verified. @@ -991,8 +985,6 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts } // StatAt implements vfs.FilesystemImpl.StatAt. -// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should -// be verified. func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 458c7fcb6..fa7696ad6 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -840,7 +840,6 @@ func (fd *fileDescription) Release(ctx context.Context) { // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - // TODO(b/162788573): Add integrity check for metadata. stat, err := fd.lowerFD.Stat(ctx, opts) if err != nil { return linux.Statx{}, err @@ -960,10 +959,9 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui } params := &merkletree.GenerateParams{ - TreeReader: &merkleReader, - TreeWriter: &merkleWriter, - Children: fd.d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + TreeReader: &merkleReader, + TreeWriter: &merkleWriter, + Children: fd.d.childrenNames, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), Name: fd.d.name, Mode: uint32(stat.Mode), @@ -1192,8 +1190,6 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. case linux.FS_IOC_GETFLAGS: return fd.verityFlags(ctx, args[2].Pointer()) default: - // TODO(b/169682228): Investigate which ioctl commands should - // be allowed. return 0, syserror.ENOSYS } } @@ -1253,16 +1249,15 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of fd.d.hashMu.RLock() n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, - Children: fd.d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + Children: fd.d.childrenNames, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, ReadSize: dst.NumBytes(), @@ -1304,6 +1299,11 @@ func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapO return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts) } +// SupportsLocks implements vfs.FileDescriptionImpl.SupportsLocks. +func (fd *fileDescription) SupportsLocks() bool { + return fd.lowerFD.SupportsLocks() +} + // LockBSD implements vfs.FileDescriptionImpl.LockBSD. func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return fd.lowerFD.LockBSD(ctx, ownerPID, t, block) @@ -1333,7 +1333,7 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { ts, err := fd.lowerMappable.Translate(ctx, required, optional, at) if err != nil { - return ts, err + return nil, err } // dataSize is the size of the whole file. @@ -1346,17 +1346,17 @@ func (fd *fileDescription) Translate(ctx context.Context, required, optional mem // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { - return ts, err + return nil, err } // The dataSize xattr should be an integer. If it's not, it indicates // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } merkleReader := FileReadWriteSeeker{ @@ -1389,7 +1389,7 @@ func (fd *fileDescription) Translate(ctx context.Context, required, optional mem DataAndTreeInSameFile: false, }) if err != nil { - return ts, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } } return ts, err diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go index 7e535b527..17d0d5025 100644 --- a/pkg/sentry/fsmetric/fsmetric.go +++ b/pkg/sentry/fsmetric/fsmetric.go @@ -42,7 +42,6 @@ var ( // Metrics that only apply to fs/gofer and fsimpl/gofer. var ( - GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.") GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.") GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.") GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.") diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 6b71bd3a9..80dda1559 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -88,9 +88,6 @@ type Stack interface { // for restoring a stack after a save. RestoreCleanupEndpoints([]stack.TransportEndpoint) - // Forwarding returns if packet forwarding between NICs is enabled. - Forwarding(protocol tcpip.NetworkProtocolNumber) bool - // SetForwarding enables or disables packet forwarding between NICs. SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 03e2608c2..218d9dafc 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -154,11 +154,6 @@ func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} -// Forwarding implements inet.Stack.Forwarding. -func (s *TestStack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { - return s.IPForwarding -} - // SetForwarding implements inet.Stack.SetForwarding. func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { s.IPForwarding = enable diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go index 1f1c63f37..c93ef6ac1 100644 --- a/pkg/sentry/kernel/cgroup.go +++ b/pkg/sentry/kernel/cgroup.go @@ -48,10 +48,6 @@ type CgroupController interface { // attached to. Returned value is valid for the lifetime of the controller. HierarchyID() uint32 - // Filesystem returns the filesystem this controller is attached to. - // Returned value is valid for the lifetime of the controller. - Filesystem() *vfs.Filesystem - // RootCgroup returns the root cgroup for this controller. Returned value is // valid for the lifetime of the controller. RootCgroup() Cgroup @@ -124,6 +120,19 @@ func (h *hierarchy) match(ctypes []CgroupControllerType) bool { return true } +// cgroupFS is the public interface to cgroupfs. This lets the kernel package +// refer to cgroupfs.filesystem methods without directly depending on the +// cgroupfs package, which would lead to a circular dependency. +type cgroupFS interface { + // Returns the vfs.Filesystem for the cgroupfs. + VFSFilesystem() *vfs.Filesystem + + // InitializeHierarchyID sets the hierarchy ID for this filesystem during + // filesystem creation. May only be called before the filesystem is visible + // to the vfs layer. + InitializeHierarchyID(hid uint32) +} + // CgroupRegistry tracks the active set of cgroup controllers on the system. // // +stateify savable @@ -172,7 +181,23 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files for _, h := range r.hierarchies { if h.match(ctypes) { - h.fs.IncRef() + if !h.fs.TryIncRef() { + // Racing with filesystem destruction, namely h.fs.Release. + // Since we hold r.mu, we know the hierarchy hasn't been + // unregistered yet, but its associated filesystem is tearing + // down. + // + // If we simply indicate the hierarchy wasn't found without + // cleaning up the registry, the caller can race with the + // unregister and find itself temporarily unable to create a new + // hierarchy with a subset of the relevant controllers. + // + // To keep the result of FindHierarchy consistent with the + // uniqueness of controllers enforced by Register, drop the + // dying hierarchy now. The eventual unregister by the FS + // teardown will become a no-op. + return nil + } return h.fs } } @@ -182,31 +207,35 @@ func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Files // Register registers the provided set of controllers with the registry as a new // hierarchy. If any controller is already registered, the function returns an -// error without modifying the registry. The hierarchy can be later referenced -// by the returned id. -func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) { +// error without modifying the registry. Register sets the hierarchy ID for the +// filesystem on success. +func (r *CgroupRegistry) Register(cs []CgroupController, fs cgroupFS) error { r.mu.Lock() defer r.mu.Unlock() if len(cs) == 0 { - return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers") + return fmt.Errorf("can't register hierarchy with no controllers") } for _, c := range cs { if _, ok := r.controllers[c.Type()]; ok { - return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy") + return fmt.Errorf("controllers may only be mounted on a single hierarchy") } } hid, err := r.nextHierarchyID() if err != nil { - return hid, err + return err } + // Must not fail below here, once we publish the hierarchy ID. + + fs.InitializeHierarchyID(hid) + h := hierarchy{ id: hid, controllers: make(map[CgroupControllerType]CgroupController), - fs: cs[0].Filesystem(), + fs: fs.VFSFilesystem(), } for _, c := range cs { n := c.Type() @@ -214,15 +243,20 @@ func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) { h.controllers[n] = c } r.hierarchies[hid] = h - return hid, nil + return nil } -// Unregister removes a previously registered hierarchy from the registry. If -// the controller was not previously registered, Unregister is a no-op. +// Unregister removes a previously registered hierarchy from the registry. If no +// such hierarchy is registered, Unregister is a no-op. func (r *CgroupRegistry) Unregister(hid uint32) { r.mu.Lock() - defer r.mu.Unlock() + r.unregisterLocked(hid) + r.mu.Unlock() +} +// Precondition: Caller must hold r.mu. +// +checklocks:r.mu +func (r *CgroupRegistry) unregisterLocked(hid uint32) { if h, ok := r.hierarchies[hid]; ok { for name, _ := range h.controllers { delete(r.controllers, name) @@ -253,6 +287,11 @@ func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[C for name, ctl := range r.controllers { if _, ok := ctlSet[name]; !ok { cg := ctl.RootCgroup() + // Multiple controllers may share the same hierarchy, so may have + // the same root cgroup. Grab a single ref per hierarchy root. + if _, ok := cgset[cg]; ok { + continue + } cg.IncRef() // Ref transferred to caller. cgset[cg] = struct{}{} } diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 10885688c..62777faa8 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -154,9 +154,11 @@ func (f *FDTable) drop(ctx context.Context, file *fs.File) { // dropVFS2 drops the table reference. func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) { // Release any POSIX lock possibly held by the FDTable. - err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF}) - if err != nil && err != syserror.ENOLCK { - panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) + if file.SupportsLocks() { + err := file.UnlockPOSIX(ctx, f, lock.LockRange{0, lock.LockEOF}) + if err != nil && err != syserror.ENOLCK { + panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) + } } // Drop the table's reference. diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 2d89b9ccd..24e467e93 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -86,6 +86,12 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) if n > 0 { p.Notify(waiter.ReadableEvents) } + if err == unix.EPIPE { + // If we are returning EPIPE send SIGPIPE to the task. + if sendSig := linux.SignalNoInfoFuncFromContext(ctx); sendSig != nil { + sendSig(linux.SIGPIPE) + } + } return n, err } diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index fe2ab1662..3c5bd8ff7 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -702,7 +702,9 @@ func (s *Set) checkPerms(creds *auth.Credentials, reqPerms fs.PermMask) bool { return s.checkCapability(creds) } -// destroy destroys the set. Caller must hold 's.mu'. +// destroy destroys the set. +// +// Preconditions: Caller must hold 's.mu'. func (s *Set) destroy() { // Notify all waiters. They will fail on the next attempt to execute // operations and return error. diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 70b0699dc..c82d9e82b 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -17,6 +17,7 @@ package kernel import ( "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -113,6 +114,10 @@ func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { return t.k.RealtimeClock() case limits.CtxLimits: return t.tg.limits + case linux.CtxSignalNoInfoFunc: + return func(sig linux.Signal) error { + return t.SendSignal(SignalInfoNoInfo(sig, t, t)) + } case pgalloc.CtxMemoryFile: return t.k.mf case pgalloc.CtxMemoryFileProvider: diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index ecb6603a1..4c65215fa 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -11,11 +11,12 @@ go_library( "vdso.go", "vdso_state.go", ], + marshal = True, + marshal_debug = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi", "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/cpuid", "//pkg/hostarch", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index e92d9fdc3..8fc3e2a79 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/hostarch" @@ -47,10 +46,10 @@ const ( var ( // header64Size is the size of elf.Header64. - header64Size = int(binary.Size(elf.Header64{})) + header64Size = (*linux.ElfHeader64)(nil).SizeBytes() // Prog64Size is the size of elf.Prog64. - prog64Size = int(binary.Size(elf.Prog64{})) + prog64Size = (*linux.ElfProg64)(nil).SizeBytes() ) func progFlagsAsPerms(f elf.ProgFlag) hostarch.AccessType { @@ -136,7 +135,6 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { log.Infof("Unsupported ELF endianness: %v", endian) return elfInfo{}, syserror.ENOEXEC } - byteOrder := binary.LittleEndian if version := elf.Version(ident[elf.EI_VERSION]); version != elf.EV_CURRENT { log.Infof("Unsupported ELF version: %v", version) @@ -145,7 +143,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // EI_OSABI is ignored by Linux, which is the only OS supported. os := abi.Linux - var hdr elf.Header64 + var hdr linux.ElfHeader64 hdrBuf := make([]byte, header64Size) _, err = f.ReadFull(ctx, usermem.BytesIOSequence(hdrBuf), 0) if err != nil { @@ -156,7 +154,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { } return elfInfo{}, err } - binary.Unmarshal(hdrBuf, byteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBuf) // We support amd64 and arm64. var a arch.Arch @@ -213,8 +211,8 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { phdrs := make([]elf.ProgHeader, hdr.Phnum) for i := range phdrs { - var prog64 elf.Prog64 - binary.Unmarshal(phdrBuf[:prog64Size], byteOrder, &prog64) + var prog64 linux.ElfProg64 + prog64.UnmarshalUnsafe(phdrBuf[:prog64Size]) phdrBuf = phdrBuf[prog64Size:] phdrs[i] = elf.ProgHeader{ Type: elf.ProgType(prog64.Type), diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 99f036bba..1b5d5f66e 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -75,6 +75,9 @@ type machine struct { // nextID is the next vCPU ID. nextID uint32 + + // machineArchState is the architecture-specific state. + machineArchState } const ( @@ -196,12 +199,7 @@ func newMachine(vm int) (*machine, error) { m.available.L = &m.mu // Pull the maximum vCPUs. - maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) - if errno != 0 { - m.maxVCPUs = _KVM_NR_VCPUS - } else { - m.maxVCPUs = int(maxVCPUs) - } + m.getMaxVCPU() log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) @@ -427,9 +425,8 @@ func (m *machine) Get() *vCPU { } } - // Create a new vCPU (maybe). - if int(m.nextID) < m.maxVCPUs { - c := m.newVCPU() + // Get a new vCPU (maybe). + if c := m.getNewVCPU(); c != nil { c.lock() m.vCPUsByTID[tid] = c m.mu.Unlock() diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index d7abfefb4..9a2337654 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -63,6 +63,9 @@ func (m *machine) initArchState() error { return nil } +type machineArchState struct { +} + type vCPUArchState struct { // PCIDs is the set of PCIDs for this vCPU. // @@ -351,6 +354,10 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) // allocations occur. entersyscall() bluepill(c) + // The root table physical page has to be mapped to not fault in iret + // or sysret after switching into a user address space. sysret and + // iret are in the upper half that is global and already mapped. + switchOpts.PageTables.PrefaultRootTable() prefaultFloatingPointState(switchOpts.FloatingPointState) vector = c.CPU.SwitchToUser(switchOpts) exitsyscall() @@ -495,3 +502,22 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { physical) } } + +// getMaxVCPU get max vCPU number +func (m *machine) getMaxVCPU() { + maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) + if errno != 0 { + m.maxVCPUs = _KVM_NR_VCPUS + } else { + m.maxVCPUs = int(maxVCPUs) + } +} + +// getNewVCPU create a new vCPU (maybe) +func (m *machine) getNewVCPU() *vCPU { + if int(m.nextID) < m.maxVCPUs { + c := m.newVCPU() + return c + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index cd912f922..8926b1d9f 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -17,6 +17,10 @@ package kvm import ( + "runtime" + "sync/atomic" + + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" @@ -25,6 +29,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" ) +type machineArchState struct { + //initialvCPUs is the machine vCPUs which has initialized but not used + initialvCPUs map[int]*vCPU +} + type vCPUArchState struct { // PCIDs is the set of PCIDs for this vCPU. // @@ -182,3 +191,30 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, return accessType, platform.ErrContextSignal } + +// getMaxVCPU get max vCPU number +func (m *machine) getMaxVCPU() { + rmaxVCPUs := runtime.NumCPU() + smaxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) + // compare the max vcpu number from runtime and syscall, use smaller one. + if errno != 0 { + m.maxVCPUs = rmaxVCPUs + } else { + if rmaxVCPUs < int(smaxVCPUs) { + m.maxVCPUs = rmaxVCPUs + } else { + m.maxVCPUs = int(smaxVCPUs) + } + } +} + +// getNewVCPU() scan for an available vCPU from initialvCPUs +func (m *machine) getNewVCPU() *vCPU { + for CID, c := range m.initialvCPUs { + if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) { + delete(m.initialvCPUs, CID) + return c + } + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 634e55ec0..92edc992b 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" + ktime "gvisor.dev/gvisor/pkg/sentry/time" ) type kvmVcpuInit struct { @@ -47,6 +48,19 @@ func (m *machine) initArchState() error { uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno)) } + + // Initialize all vCPUs on ARM64, while this does not happen on x86_64. + // The reason for the difference is that ARM64 and x86_64 have different KVM timer mechanisms. + // If we create vCPU dynamically on ARM64, the timer for vCPU would mess up for a short time. + // For more detail, please refer to https://github.com/google/gvisor/issues/5739 + m.initialvCPUs = make(map[int]*vCPU) + m.mu.Lock() + for int(m.nextID) < m.maxVCPUs-1 { + c := m.newVCPU() + c.state = 0 + m.initialvCPUs[c.id] = c + } + m.mu.Unlock() return nil } @@ -174,9 +188,58 @@ func (c *vCPU) setTSC(value uint64) error { return nil } +// getTSC gets the counter Physical Counter minus Virtual Offset. +func (c *vCPU) getTSC() error { + var ( + reg kvmOneReg + data uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + reg.id = _KVM_ARM64_REGS_TIMER_CNT + + if err := c.getOneRegister(®); err != nil { + return err + } + + return nil +} + // setSystemTime sets the vCPU to the system time. func (c *vCPU) setSystemTime() error { - return c.setSystemTimeLegacy() + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Use get the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + // replace getTSC to another setOneRegister syscall can get more accurate value? + start := uint64(ktime.Rdtsc()) + if err := c.getTSC(); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && (current <= upperThreshold || minimum < 50) { + // Try to set the TSC + if err := c.setTSC(end + (minimum / 2)); err != nil { + return err + } + return nil + } + } } //go:nosplit @@ -203,7 +266,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error { uintptr(c.fd), _KVM_GET_ONE_REG, uintptr(unsafe.Pointer(reg))); errno != 0 { - return fmt.Errorf("error setting one register: %v", errno) + return fmt.Errorf("error getting one register: %v", errno) } return nil } diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 080859125..7ee89a735 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/hostarch", "//pkg/marshal", diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 0e0e82365..2029e7cf4 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -14,9 +14,11 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/context", "//pkg/hostarch", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 45a05cd63..235b9c306 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -18,9 +18,11 @@ package control import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -193,7 +195,7 @@ func putUint32(buf []byte, n uint32) []byte { // putCmsg writes a control message header and as much data as will fit into // the unused capacity of a buffer. func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) { - space := binary.AlignDown(cap(buf)-len(buf), 4) + space := bits.AlignDown(cap(buf)-len(buf), 4) // We can't write to space that doesn't exist, so if we are going to align // the available space, we must align down. @@ -230,7 +232,7 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ return alignSlice(buf, align), flags } -func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte { +func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data marshal.Marshallable) []byte { if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader { return buf } @@ -241,8 +243,7 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = putUint32(buf, msgType) hdrBuf := buf - - buf = binary.Marshal(buf, hostarch.ByteOrder, data) + buf = append(buf, marshal.Marshal(data)...) // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { @@ -288,7 +289,7 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int // alignSlice extends a slice's length (up to the capacity) to align it. func alignSlice(buf []byte, align uint) []byte { - aligned := binary.AlignUp(len(buf), align) + aligned := bits.AlignUp(len(buf), align) if aligned > cap(buf) { // Linux allows unaligned data if there isn't room for alignment. // Since there isn't room for alignment, there isn't room for any @@ -300,12 +301,13 @@ func alignSlice(buf []byte, align uint) []byte { // PackTimestamp packs a SO_TIMESTAMP socket control message. func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { + timestampP := linux.NsecToTimeval(timestamp) return putCmsgStruct( buf, linux.SOL_SOCKET, linux.SO_TIMESTAMP, t.Arch().Width(), - linux.NsecToTimeval(timestamp), + ×tampP, ) } @@ -316,7 +318,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { linux.SOL_TCP, linux.TCP_INQ, t.Arch().Width(), - inq, + primitive.AllocateInt32(inq), ) } @@ -327,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { linux.SOL_IP, linux.IP_TOS, t.Arch().Width(), - tos, + primitive.AllocateUint8(tos), ) } @@ -338,7 +340,7 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { linux.SOL_IPV6, linux.IPV6_TCLASS, t.Arch().Width(), - tClass, + primitive.AllocateUint32(tClass), ) } @@ -423,7 +425,7 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt // cmsgSpace is equivalent to CMSG_SPACE in Linux. func cmsgSpace(t *kernel.Task, dataLen int) int { - return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width()) + return linux.SizeOfControlMessageHeader + bits.AlignUp(dataLen, t.Arch().Width()) } // CmsgsSpace returns the number of bytes needed to fit the control messages @@ -475,7 +477,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) + h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, syserror.EINVAL @@ -491,7 +493,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.SOL_SOCKET: switch h.Type { case linux.SCM_RIGHTS: - rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) numRights := rightsSize / linux.SizeOfControlMessageRight if len(fds)+numRights > linux.SCM_MAX_FD { @@ -502,7 +504,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) } - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -510,23 +512,23 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) + creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { return socket.ControlMessages{}, syserror.EINVAL } var ts linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &ts) + ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) cmsgs.IP.Timestamp = ts.ToNsecCapped() cmsgs.IP.HasTimestamp = true - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: // Unknown message type. @@ -539,8 +541,10 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTOS = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &cmsgs.IP.TOS) - i += binary.AlignUp(length, width) + var tos primitive.Uint8 + tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS]) + cmsgs.IP.TOS = uint8(tos) + i += bits.AlignUp(length, width) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { @@ -549,19 +553,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) + packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo]) cmsgs.IP.PacketInfo = packetInfo - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 @@ -571,7 +575,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) cmsgs.IP.SockErr = &errCmsg - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL @@ -583,17 +587,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTClass = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &cmsgs.IP.TClass) - i += binary.AlignUp(length, width) + var tclass primitive.Uint32 + tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass]) + cmsgs.IP.TClass = uint32(tclass) + i += bits.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 @@ -603,7 +609,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) cmsgs.IP.SockErr = &errCmsg - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a5c2155a2..3c6511ead 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -17,7 +17,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fdnotifier", "//pkg/hostarch", @@ -40,8 +39,6 @@ go_library( "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", - "//pkg/tcpip/network/ipv4", - "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 0d3b23643..52ae4bc9c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -19,7 +19,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" @@ -529,7 +528,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.SO_TIMESTAMP: controlMessages.IP.HasTimestamp = true ts := linux.Timeval{} - ts.UnmarshalBytes(unixCmsg.Data[:linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) controlMessages.IP.Timestamp = ts.ToNsecCapped() } @@ -537,17 +536,19 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.IP_TOS: controlMessages.IP.HasTOS = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &controlMessages.IP.TOS) + var tos primitive.Uint8 + tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()]) + controlMessages.IP.TOS = uint8(tos) case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) + packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()]) controlMessages.IP.PacketInfo = packetInfo case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) controlMessages.IP.OriginalDstAddress = &addr case unix.IP_RECVERR: @@ -560,11 +561,13 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &controlMessages.IP.TClass) + var tclass primitive.Uint32 + tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()]) + controlMessages.IP.TClass = uint32(tclass) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) controlMessages.IP.OriginalDstAddress = &addr case unix.IPV6_RECVERR: @@ -577,7 +580,9 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.TCP_INQ: controlMessages.IP.HasInq = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], hostarch.ByteOrder, &controlMessages.IP.Inq) + var inq primitive.Int32 + inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq]) + controlMessages.IP.Inq = int32(inq) } } } @@ -691,7 +696,7 @@ func (s *socketOpsCommon) State() uint32 { return 0 } - binary.Unmarshal(buf, hostarch.ByteOrder, &info) + info.UnmarshalUnsafe(buf[:info.SizeBytes()]) return uint32(info.State) } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 26e8ae17a..cbb1e905d 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -15,6 +15,7 @@ package hostinet import ( + "encoding/binary" "fmt" "io" "io/ioutil" @@ -26,16 +27,14 @@ import ( "syscall" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -65,8 +64,6 @@ type Stack struct { tcpSACKEnabled bool netDevFile *os.File netSNMPFile *os.File - ipv4Forwarding bool - ipv6Forwarding bool } // NewStack returns an empty Stack containing no configuration. @@ -126,13 +123,6 @@ func (s *Stack) Configure() error { s.netSNMPFile = f } - s.ipv6Forwarding = false - if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding"); err == nil { - s.ipv6Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" - } else { - log.Warningf("Failed to read if ipv6 forwarding is enabled, setting to false") - } - return nil } @@ -147,8 +137,8 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli if len(link.Data) < unix.SizeofIfInfomsg { return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg) } - var ifinfo unix.IfInfomsg - binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], hostarch.ByteOrder, &ifinfo) + var ifinfo linux.InterfaceInfoMessage + ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()]) inetIF := inet.Interface{ DeviceType: ifinfo.Type, Flags: ifinfo.Flags, @@ -178,11 +168,11 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli if len(addr.Data) < unix.SizeofIfAddrmsg { return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg) } - var ifaddr unix.IfAddrmsg - binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], hostarch.ByteOrder, &ifaddr) + var ifaddr linux.InterfaceAddrMessage + ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()]) inetAddr := inet.InterfaceAddr{ Family: ifaddr.Family, - PrefixLen: ifaddr.Prefixlen, + PrefixLen: ifaddr.PrefixLen, Flags: ifaddr.Flags, } attrs, err := syscall.ParseNetlinkRouteAttr(&addr) @@ -210,13 +200,13 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) continue } - var ifRoute unix.RtMsg - binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], hostarch.ByteOrder, &ifRoute) + var ifRoute linux.RouteMessage + ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()]) inetRoute := inet.Route{ Family: ifRoute.Family, - DstLen: ifRoute.Dst_len, - SrcLen: ifRoute.Src_len, - TOS: ifRoute.Tos, + DstLen: ifRoute.DstLen, + SrcLen: ifRoute.SrcLen, + TOS: ifRoute.TOS, Table: ifRoute.Table, Protocol: ifRoute.Protocol, Scope: ifRoute.Scope, @@ -245,7 +235,9 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) if len(attr.Value) != expected { return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected) } - binary.Unmarshal(attr.Value, hostarch.ByteOrder, &inetRoute.OutputInterface) + var outputIF primitive.Int32 + outputIF.UnmarshalUnsafe(attr.Value) + inetRoute.OutputInterface = int32(outputIF) } } @@ -489,19 +481,6 @@ func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} -// Forwarding implements inet.Stack.Forwarding. -func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { - switch protocol { - case ipv4.ProtocolNumber: - return s.ipv4Forwarding - case ipv6.ProtocolNumber: - return s.ipv6Forwarding - default: - log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol) - return false - } -} - // SetForwarding implements inet.Stack.SetForwarding. func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 4381dfa06..61b2c9755 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -14,14 +14,16 @@ go_library( "tcp_matcher.go", "udp_matcher.go", ], + marshal = True, # This target depends on netstack and should only be used by epsocket, # which is allowed to depend on netstack. visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/hostarch", "//pkg/log", + "//pkg/marshal", "//pkg/sentry/kernel", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 4bd305a44..6fc7781ad 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -79,7 +78,7 @@ func marshalEntryMatch(name string, data []byte) []byte { nflog("marshaling matcher %q", name) // We have to pad this struct size to a multiple of 8 bytes. - size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) + size := bits.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) matcher := linux.KernelXTEntryMatch{ XTEntryMatch: linux.XTEntryMatch{ MatchSize: uint16(size), @@ -88,9 +87,11 @@ func marshalEntryMatch(name string, data []byte) []byte { } copy(matcher.Name[:], name) - buf := make([]byte, 0, size) - buf = binary.Marshal(buf, hostarch.ByteOrder, matcher) - return append(buf, make([]byte, size-len(buf))...) + buf := make([]byte, size) + entryLen := matcher.XTEntryMatch.SizeBytes() + matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen]) + copy(buf[entryLen:], matcher.Data) + return buf } func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 1fc4cb651..cb78ef60b 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -18,8 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -141,10 +139,9 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } var entry linux.IPTEntry - buf := optVal[:linux.SizeOfIPTEntry] - binary.Unmarshal(buf, hostarch.ByteOrder, &entry) + entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIPTEntry:] + optVal = optVal[entry.SizeBytes():] if entry.TargetOffset < linux.SizeOfIPTEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 67a52b628..5cb7fe4aa 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -18,8 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -144,10 +142,9 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } var entry linux.IP6TEntry - buf := optVal[:linux.SizeOfIP6TEntry] - binary.Unmarshal(buf, hostarch.ByteOrder, &entry) + entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIP6TEntry:] + optVal = optVal[entry.SizeBytes():] if entry.TargetOffset < linux.SizeOfIP6TEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index c6fa3fd16..f42d73178 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -22,7 +22,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -121,7 +120,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } - if binary.Size(entries) > uintptr(outLen) { + if entries.SizeBytes() > outLen { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } @@ -146,7 +145,7 @@ func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe nflog("couldn't read entries: %v", err) return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } - if binary.Size(entries) > uintptr(outLen) { + if entries.SizeBytes() > outLen { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } @@ -179,7 +178,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] - binary.Unmarshal(replaceBuf, hostarch.ByteOrder, &replace) + replace.UnmarshalBytes(replaceBuf) // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table @@ -309,8 +308,8 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) } var match linux.XTEntryMatch - buf := optVal[:linux.SizeOfXTEntryMatch] - binary.Unmarshal(buf, hostarch.ByteOrder, &match) + buf := optVal[:match.SizeBytes()] + match.UnmarshalUnsafe(buf) nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index b2cc6be20..60845cab3 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -59,8 +58,8 @@ func (ownerMarshaler) marshal(mr matcher) []byte { } } - buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo) - return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, hostarch.ByteOrder, iptOwnerInfo)) + buf := marshal.Marshal(&iptOwnerInfo) + return marshalEntryMatch(matcherNameOwner, buf) } // unmarshal implements matchMaker.unmarshal. @@ -72,7 +71,7 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack. // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo - binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 4ae1592b2..fa5456eee 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,11 +15,12 @@ package netfilter import ( + "encoding/binary" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -189,8 +190,7 @@ func (*standardTargetMaker) marshal(target target) []byte { Verdict: verdict, } - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -199,8 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } var standardTarget linux.XTStandardTarget - buf = buf[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &standardTarget) + standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()]) if standardTarget.Verdict < 0 { // A Verdict < 0 indicates a non-jump verdict. @@ -245,8 +244,7 @@ func (*errorTargetMaker) marshal(target target) []byte { copy(xt.Name[:], errorName) copy(xt.Target.Name[:], ErrorTargetName) - ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -256,7 +254,7 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var errTgt linux.XTErrorTarget buf = buf[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &errTgt) + errTgt.UnmarshalUnsafe(buf) // Error targets are used in 2 cases: // * An actual error case. These rules have an error named @@ -299,12 +297,11 @@ func (*redirectTargetMaker) marshal(target target) []byte { } copy(xt.Target.Name[:], RedirectTargetName) - ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) xt.NfRange.RangeSize = 1 xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED xt.NfRange.RangeIPV4.MinPort = htons(rt.Port) xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -320,7 +317,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( var rt linux.XTRedirectTarget buf = buf[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &rt) + rt.UnmarshalUnsafe(buf) // Copy linux.XTRedirectTarget to stack.RedirectTarget. target := redirectTarget{RedirectTarget: stack.RedirectTarget{ @@ -359,6 +356,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return &target, nil } +// +marshal type nfNATTarget struct { Target linux.XTEntryTarget Range linux.NFNATRange @@ -394,8 +392,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte { nt.Range.MinProto = htons(rt.Port) nt.Range.MaxProto = nt.Range.MinProto - ret := make([]byte, 0, nfNATMarshalledSize) - return binary.Marshal(ret, hostarch.ByteOrder, nt) + return marshal.Marshal(&nt) } func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -411,7 +408,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar var natRange linux.NFNATRange buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - binary.Unmarshal(buf, hostarch.ByteOrder, &natRange) + natRange.UnmarshalUnsafe(buf) // We don't support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -468,8 +465,7 @@ func (*snatTargetMakerV4) marshal(target target) []byte { xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr) copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr) - ret := make([]byte, 0, linux.SizeOfXTSNATTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -485,7 +481,7 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta var st linux.XTSNATTarget buf = buf[:linux.SizeOfXTSNATTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &st) + st.UnmarshalUnsafe(buf) // Copy linux.XTSNATTarget to stack.SNATTarget. target := snatTarget{SNATTarget: stack.SNATTarget{ @@ -550,8 +546,7 @@ func (*snatTargetMakerV6) marshal(target target) []byte { nt.Range.MinProto = htons(st.Port) nt.Range.MaxProto = nt.Range.MinProto - ret := make([]byte, 0, nfNATMarshalledSize) - return binary.Marshal(ret, hostarch.ByteOrder, nt) + return marshal.Marshal(&nt) } func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -567,9 +562,9 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta var natRange linux.NFNATRange buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - binary.Unmarshal(buf, hostarch.ByteOrder, &natRange) + natRange.UnmarshalUnsafe(buf) - // TODO(gvisor.dev/issue/5689): Support port or address ranges. + // TODO(gvisor.dev/issue/5697): Support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { nflog("snatTargetMakerV6: MinAddr and MaxAddr are different") return nil, syserr.ErrInvalidArgument @@ -631,8 +626,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget - buf := optVal[:linux.SizeOfXTEntryTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &target) + target.UnmarshalUnsafe(optVal[:target.SizeBytes()]) return unmarshalTarget(target, filter, optVal) } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 69557f515..95bb9826e 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,8 +46,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte { DestinationPortStart: matcher.destinationPortStart, DestinationPortEnd: matcher.destinationPortEnd, } - buf := make([]byte, 0, linux.SizeOfXTTCP) - return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, hostarch.ByteOrder, xttcp)) + return marshalEntryMatch(matcherNameTCP, marshal.Marshal(&xttcp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +58,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.XTTCP - binary.Unmarshal(buf[:linux.SizeOfXTTCP], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) nflog("parseMatchers: parsed XTTCP: %+v", matchData) if matchData.Option != 0 || diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 6a60e6bd6..fb8be27e6 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,8 +46,7 @@ func (udpMarshaler) marshal(mr matcher) []byte { DestinationPortStart: matcher.destinationPortStart, DestinationPortEnd: matcher.destinationPortEnd, } - buf := make([]byte, 0, linux.SizeOfXTUDP) - return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, hostarch.ByteOrder, xtudp)) + return marshalEntryMatch(matcherNameUDP, marshal.Marshal(&xtudp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +58,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may exceed what's // strictly necessary to hold matchData. var matchData linux.XTUDP - binary.Unmarshal(buf[:linux.SizeOfXTUDP], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) nflog("parseMatchers: parsed XTUDP: %+v", matchData) if matchData.InverseFlags != 0 { diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 171b95c63..64cd263da 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -14,7 +14,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/context", "//pkg/hostarch", "//pkg/marshal", @@ -50,5 +50,7 @@ go_test( deps = [ ":netlink", "//pkg/abi/linux", + "//pkg/marshal", + "//pkg/marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index ab0e68af7..80385bfdc 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -19,15 +19,17 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" ) // alignPad returns the length of padding required for alignment. // // Preconditions: align is a power of two. func alignPad(length int, align uint) int { - return binary.AlignUp(length, align) - length + return bits.AlignUp(length, align) - length } // Message contains a complete serialized netlink message. @@ -42,7 +44,7 @@ type Message struct { func NewMessage(hdr linux.NetlinkMessageHeader) *Message { return &Message{ hdr: hdr, - buf: binary.Marshal(nil, hostarch.ByteOrder, hdr), + buf: marshal.Marshal(&hdr), } } @@ -58,7 +60,7 @@ func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) { return } var hdr linux.NetlinkMessageHeader - binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBytes) // Msg portion. totalMsgLen := int(hdr.Length) @@ -92,7 +94,7 @@ func (m *Message) Header() linux.NetlinkMessageHeader { // GetData unmarshals the payload message header from this netlink message, and // returns the attributes portion. -func (m *Message) GetData(msg interface{}) (AttrsView, bool) { +func (m *Message) GetData(msg marshal.Marshallable) (AttrsView, bool) { b := BytesView(m.buf) _, ok := b.Extract(linux.NetlinkMessageHeaderSize) @@ -100,12 +102,12 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) { return nil, false } - size := int(binary.Size(msg)) + size := msg.SizeBytes() msgBytes, ok := b.Extract(size) if !ok { return nil, false } - binary.Unmarshal(msgBytes, hostarch.ByteOrder, msg) + msg.UnmarshalUnsafe(msgBytes) numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO) // Linux permits the last message not being aligned, just consume all of it. @@ -131,7 +133,7 @@ func (m *Message) Finalize() []byte { // Align the message. Note that the message length in the header (set // above) is the useful length of the message, not the total aligned // length. See net/netlink/af_netlink.c:__nlmsg_put. - aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) + aligned := bits.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) m.putZeros(aligned - len(m.buf)) return m.buf } @@ -145,45 +147,45 @@ func (m *Message) putZeros(n int) { } // Put serializes v into the message. -func (m *Message) Put(v interface{}) { - m.buf = binary.Marshal(m.buf, hostarch.ByteOrder, v) +func (m *Message) Put(v marshal.Marshallable) { + m.buf = append(m.buf, marshal.Marshal(v)...) } // PutAttr adds v to the message as a netlink attribute. // // Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize + -// binary.Size(v) fits in math.MaxUint16 bytes. -func (m *Message) PutAttr(atype uint16, v interface{}) { - l := linux.NetlinkAttrHeaderSize + int(binary.Size(v)) +// v.SizeBytes()) fits in math.MaxUint16 bytes. +func (m *Message) PutAttr(atype uint16, v marshal.Marshallable) { + l := linux.NetlinkAttrHeaderSize + v.SizeBytes() if l > math.MaxUint16 { panic(fmt.Sprintf("attribute too large: %d", l)) } - m.Put(linux.NetlinkAttrHeader{ + m.Put(&linux.NetlinkAttrHeader{ Type: atype, Length: uint16(l), }) m.Put(v) // Align the attribute. - aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) + aligned := bits.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } // PutAttrString adds s to the message as a netlink attribute. func (m *Message) PutAttrString(atype uint16, s string) { l := linux.NetlinkAttrHeaderSize + len(s) + 1 - m.Put(linux.NetlinkAttrHeader{ + m.Put(&linux.NetlinkAttrHeader{ Type: atype, Length: uint16(l), }) // String + NUL-termination. - m.Put([]byte(s)) + m.Put(primitive.AsByteSlice([]byte(s))) m.putZeros(1) // Align the attribute. - aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) + aligned := bits.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } @@ -251,7 +253,7 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest if !ok { return } - binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBytes) value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) if !ok { diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go index ef13d9386..968968469 100644 --- a/pkg/sentry/socket/netlink/message_test.go +++ b/pkg/sentry/socket/netlink/message_test.go @@ -20,13 +20,31 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" ) type dummyNetlinkMsg struct { + marshal.StubMarshallable Foo uint16 } +func (*dummyNetlinkMsg) SizeBytes() int { + return 2 +} + +func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { + p := primitive.Uint16(m.Foo) + p.MarshalUnsafe(dst) +} + +func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { + var p primitive.Uint16 + p.UnmarshalUnsafe(src) + m.Foo = uint16(p) +} + func TestParseMessage(t *testing.T) { tests := []struct { desc string diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 744fc74f4..c6c04b4e3 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -11,6 +11,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 5a2255db3..86f6419dc 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -167,7 +168,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { Type: linux.RTM_NEWLINK, }) - m.Put(linux.InterfaceInfoMessage{ + m.Put(&linux.InterfaceInfoMessage{ Family: linux.AF_UNSPEC, Type: i.DeviceType, Index: idx, @@ -175,7 +176,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { }) m.PutAttrString(linux.IFLA_IFNAME, i.Name) - m.PutAttr(linux.IFLA_MTU, i.MTU) + m.PutAttr(linux.IFLA_MTU, primitive.AllocateUint32(i.MTU)) mac := make([]byte, 6) brd := mac @@ -183,8 +184,8 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { mac = i.Addr brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) } - m.PutAttr(linux.IFLA_ADDRESS, mac) - m.PutAttr(linux.IFLA_BROADCAST, brd) + m.PutAttr(linux.IFLA_ADDRESS, primitive.AsByteSlice(mac)) + m.PutAttr(linux.IFLA_BROADCAST, primitive.AsByteSlice(brd)) // TODO(gvisor.dev/issue/578): There are many more attributes. } @@ -216,14 +217,15 @@ func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netl Type: linux.RTM_NEWADDR, }) - m.Put(linux.InterfaceAddrMessage{ + m.Put(&linux.InterfaceAddrMessage{ Family: a.Family, PrefixLen: a.PrefixLen, Index: uint32(id), }) - m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr)) - m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr)) + addr := primitive.ByteSlice([]byte(a.Addr)) + m.PutAttr(linux.IFA_LOCAL, &addr) + m.PutAttr(linux.IFA_ADDRESS, &addr) // TODO(gvisor.dev/issue/578): There are many more attributes. } @@ -366,7 +368,7 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net Type: linux.RTM_NEWROUTE, }) - m.Put(linux.RouteMessage{ + m.Put(&linux.RouteMessage{ Family: rt.Family, DstLen: rt.DstLen, SrcLen: rt.SrcLen, @@ -382,18 +384,18 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net Flags: rt.Flags, }) - m.PutAttr(254, []byte{123}) + m.PutAttr(254, primitive.AsByteSlice([]byte{123})) if rt.DstLen > 0 { - m.PutAttr(linux.RTA_DST, rt.DstAddr) + m.PutAttr(linux.RTA_DST, primitive.AsByteSlice(rt.DstAddr)) } if rt.SrcLen > 0 { - m.PutAttr(linux.RTA_SRC, rt.SrcAddr) + m.PutAttr(linux.RTA_SRC, primitive.AsByteSlice(rt.SrcAddr)) } if rt.OutputInterface != 0 { - m.PutAttr(linux.RTA_OIF, rt.OutputInterface) + m.PutAttr(linux.RTA_OIF, primitive.AllocateInt32(rt.OutputInterface)) } if len(rt.GatewayAddr) > 0 { - m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr) + m.PutAttr(linux.RTA_GATEWAY, primitive.AsByteSlice(rt.GatewayAddr)) } // TODO(gvisor.dev/issue/578): There are many more attributes. @@ -503,7 +505,7 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms hdr := msg.Header() // All messages start with a 1 byte protocol family. - var family uint8 + var family primitive.Uint8 if _, ok := msg.GetData(&family); !ok { // Linux ignores messages missing the protocol family. See // net/core/rtnetlink.c:rtnetlink_rcv_msg. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 30c297149..280563d09 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -20,7 +20,6 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -223,7 +222,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - binary.Unmarshal(b[:linux.SockAddrNetlinkSize], hostarch.ByteOrder, &sa) + sa.UnmarshalUnsafe(b[:sa.SizeBytes()]) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument @@ -338,16 +337,14 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - sendBufferSizeP := primitive.Int32(s.sendBufferSize) - return &sendBufferSizeP, nil + return primitive.AllocateInt32(int32(s.sendBufferSize)), nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - recvBufferSizeP := primitive.Int32(math.MaxInt32) - return &recvBufferSizeP, nil + return primitive.AllocateInt32(math.MaxInt32), nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -484,7 +481,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * Family: linux.AF_NETLINK, PortID: uint32(s.portID), } - return sa, uint32(binary.Size(sa)), nil + return sa, uint32(sa.SizeBytes()), nil } // GetPeerName implements socket.Socket.GetPeerName. @@ -495,7 +492,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * // must be the kernel. PortID: 0, } - return sa, uint32(binary.Size(sa)), nil + return sa, uint32(sa.SizeBytes()), nil } // RecvMsg implements socket.Socket.RecvMsg. @@ -504,7 +501,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags Family: linux.AF_NETLINK, PortID: 0, } - fromLen := uint32(binary.Size(from)) + fromLen := uint32(from.SizeBytes()) trunc := flags&linux.MSG_TRUNC != 0 @@ -640,7 +637,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys }) // Add the dump_done_errno payload. - m.Put(int64(0)) + m.Put(primitive.AllocateInt64(0)) _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { @@ -658,8 +655,8 @@ func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ - Error: int32(-err.ToLinux().Number()), + m.Put(&linux.NetlinkErrorMessage{ + Error: int32(-err.ToLinux()), Header: hdr, }) } @@ -668,7 +665,7 @@ func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ + m.Put(&linux.NetlinkErrorMessage{ Error: 0, Header: hdr, }) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 0b39a5b67..9561b7c25 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -19,7 +19,6 @@ go_library( ], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 312f5f85a..037ccfec8 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "encoding/binary" "fmt" "io" "io/ioutil" @@ -35,7 +36,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" @@ -77,41 +77,59 @@ func mustCreateGauge(name, description string) *tcpip.StatCounter { // Metrics contains metrics exported by netstack. var Metrics = tcpip.Stats{ - UnknownProtocolRcvdPackets: mustCreateMetric("/netstack/unknown_protocol_received_packets", "Number of packets received by netstack that were for an unknown or unsupported protocol."), - MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), - DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), + DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped at the transport layer."), + NICs: tcpip.NICStats{ + UnknownL3ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l3_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L3 protocol."), + UnknownL4ProtocolRcvdPackets: mustCreateMetric("/netstack/nic/unknown_l4_protocol_received_packets", "Number of packets received that were for an unknown or unsupported L4 protocol."), + MalformedL4RcvdPackets: mustCreateMetric("/netstack/nic/malformed_l4_received_packets", "Number of packets received that failed L4 header parsing."), + Tx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/tx/packets", "Number of packets transmitted."), + Bytes: mustCreateMetric("/netstack/nic/tx/bytes", "Number of bytes transmitted."), + }, + Rx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/rx/packets", "Number of packets received."), + Bytes: mustCreateMetric("/netstack/nic/rx/bytes", "Number of bytes received."), + }, + DisabledRx: tcpip.NICPacketStats{ + Packets: mustCreateMetric("/netstack/nic/disabled_rx/packets", "Number of packets received on disabled NICs."), + Bytes: mustCreateMetric("/netstack/nic/disabled_rx/bytes", "Number of bytes received on disabled NICs."), + }, + Neighbor: tcpip.NICNeighborStats{ + UnreachableEntryLookups: mustCreateMetric("/netstack/nic/neighbor/unreachable_entry_loopups", "Number of lookups performed on a neighbor entry in Unreachable state."), + }, + }, ICMP: tcpip.ICMPStats{ V4: tcpip.ICMPv4Stats{ PacketsSent: tcpip.ICMPv4SentPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_request", "Number of ICMPv4 echo request packets sent."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Number of ICMPv4 echo reply packets sent."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Number of ICMPv4 destination unreachable packets sent."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Number of ICMPv4 source quench packets sent."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Number of ICMPv4 redirect packets sent."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Number of ICMPv4 time exceeded packets sent."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Number of ICMPv4 parameter problem packets sent."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Number of ICMPv4 timestamp packets sent."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Number of ICMPv4 timestamp reply packets sent."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Number of ICMPv4 information request packets sent."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Number of ICMPv4 information reply packets sent."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped by netstack due to link layer errors."), - RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped by netstack due to rate limit being exceeded."), + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Number of ICMPv4 packets dropped due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v4/packets_sent/rate_limited", "Number of ICMPv4 packets dropped due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_request", "Number of ICMPv4 echo request packets received."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Number of ICMPv4 echo reply packets received."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Number of ICMPv4 destination unreachable packets received."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Number of ICMPv4 source quench packets received."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Number of ICMPv4 redirect packets received."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Number of ICMPv4 time exceeded packets received."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Number of ICMPv4 parameter problem packets received."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Number of ICMPv4 timestamp packets received."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Number of ICMPv4 timestamp reply packets received."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Number of ICMPv4 information request packets received."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Number of ICMPv4 information reply packets received."), }, Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Number of ICMPv4 packets received that the transport layer could not parse."), }, @@ -119,40 +137,40 @@ var Metrics = tcpip.Stats{ V6: tcpip.ICMPv6Stats{ PacketsSent: tcpip.ICMPv6SentPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent by netstack."), - MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent by netstack."), - MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), - MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Number of ICMPv6 echo request packets sent."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Number of ICMPv6 echo reply packets sent."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Number of ICMPv6 destination unreachable packets sent."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Number of ICMPv6 packet too big packets sent."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Number of ICMPv6 time exceeded packets sent."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Number of ICMPv6 parameter problem packets sent."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Number of ICMPv6 router solicit packets sent."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Number of ICMPv6 router advert packets sent."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets sent."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Number of ICMPv6 neighbor advert packets sent."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Number of ICMPv6 redirect message packets sent."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_query", "Number of ICMPv6 multicast listener query packets sent."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_sent/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped by netstack due to link layer errors."), - RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped by netstack due to rate limit being exceeded."), + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Number of ICMPv6 packets dropped due to link layer errors."), + RateLimited: mustCreateMetric("/netstack/icmp/v6/packets_sent/rate_limited", "Number of ICMPv6 packets dropped due to rate limit being exceeded."), }, PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received by netstack."), - MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received by netstack."), - MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent by netstack."), - MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent by netstack."), + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Number of ICMPv6 echo request packets received."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Number of ICMPv6 echo reply packets received."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Number of ICMPv6 destination unreachable packets received."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Number of ICMPv6 packet too big packets received."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Number of ICMPv6 time exceeded packets received."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Number of ICMPv6 parameter problem packets received."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Number of ICMPv6 router solicit packets received."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Number of ICMPv6 router advert packets received."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Number of ICMPv6 neighbor solicit packets received."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Number of ICMPv6 neighbor advert packets received."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Number of ICMPv6 redirect message packets received."), + MulticastListenerQuery: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_query", "Number of ICMPv6 multicast listener query packets received."), + MulticastListenerReport: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_report", "Number of ICMPv6 multicast listener report packets sent."), + MulticastListenerDone: mustCreateMetric("/netstack/icmp/v6/packets_received/multicast_listener_done", "Number of ICMPv6 multicast listener done packets sent."), }, Unrecognized: mustCreateMetric("/netstack/icmp/v6/packets_received/unrecognized", "Number of ICMPv6 packets received that the transport layer does not know how to parse."), Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Number of ICMPv6 packets received that the transport layer could not parse."), @@ -163,23 +181,23 @@ var Metrics = tcpip.Stats{ IGMP: tcpip.IGMPStats{ PacketsSent: tcpip.IGMPSentPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Number of IGMP Membership Query messages sent."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Number of IGMPv1 Membership Report messages sent."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Number of IGMPv2 Membership Report messages sent."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Number of IGMP Leave Group messages sent."), }, - Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Number of IGMP packets dropped due to link layer errors."), }, PacketsReceived: tcpip.IGMPReceivedPacketStats{ IGMPPacketStats: tcpip.IGMPPacketStats{ - MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received by netstack."), - V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received by netstack."), - V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received by netstack."), - LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received by netstack."), + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Number of IGMP Membership Query messages received."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Number of IGMPv1 Membership Report messages received."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Number of IGMPv2 Membership Report messages received."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Number of IGMP Leave Group messages received."), }, - Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received by netstack that could not be parsed."), + Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Number of IGMP packets received that could not be parsed."), ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Number of received IGMP packets with bad checksums."), - Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received by netstack."), + Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Number of unrecognized IGMP packets received."), }, }, IP: tcpip.IPStats{ @@ -199,6 +217,16 @@ var Metrics = tcpip.Stats{ OptionRecordRouteReceived: mustCreateMetric("/netstack/ip/options/record_route_received", "Number of record route options found in received IP packets."), OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Number of router alert options found in received IP packets."), OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Number of unknown options found in received IP packets."), + Forwarding: tcpip.IPForwardingStats{ + Unrouteable: mustCreateMetric("/netstack/ip/forwarding/unrouteable", "Number of IP packets received which couldn't be routed and thus were not forwarded."), + ExhaustedTTL: mustCreateMetric("/netstack/ip/forwarding/exhausted_ttl", "Number of IP packets received which could not be forwarded due to an exhausted TTL."), + LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."), + LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."), + ExtensionHeaderProblem: mustCreateMetric("/netstack/ip/forwarding/extension_header_problem", "Number of IP packets received which could not be forwarded due to a problem processing their IPv6 extension headers."), + PacketTooBig: mustCreateMetric("/netstack/ip/forwarding/packet_too_big", "Number of IP packets received which could not be forwarded because they could not fit within the outgoing MTU."), + HostUnreachable: mustCreateMetric("/netstack/ip/forwarding/host_unreachable", "Number of IP packets received which could not be forwarded due to unresolvable next hop."), + Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."), + }, }, ARP: tcpip.ARPStats{ PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."), @@ -375,9 +403,9 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue }), nil } -var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{})) -var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{})) -var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{})) +var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes() +var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes() +var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes() // bytesToIPAddress converts an IPv4 or IPv6 address from the user to the // netstack representation taking any addresses into account. @@ -613,7 +641,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < sockAddrLinkSize { return syserr.ErrInvalidArgument } - binary.Unmarshal(sockaddr[:sockAddrLinkSize], hostarch.ByteOrder, &a) + a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) if a.Protocol != uint16(s.protocol) { return syserr.ErrInvalidArgument @@ -843,7 +871,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return &optP, nil } - optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux()) return &optP, nil case linux.SO_PEERCRED: @@ -1117,7 +1145,14 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, // TODO(b/64800844): Translate fields once they are added to // tcpip.TCPInfoOption. - info := linux.TCPInfo{} + info := linux.TCPInfo{ + State: uint8(v.State), + RTO: uint32(v.RTO / time.Microsecond), + RTT: uint32(v.RTT / time.Microsecond), + RTTVar: uint32(v.RTTVar / time.Microsecond), + SndSsthresh: v.SndSsthresh, + SndCwnd: v.SndCwnd, + } switch v.CcState { case tcpip.RTORecovery: info.CaState = linux.TCP_CA_Loss @@ -1128,11 +1163,6 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, case tcpip.Open: info.CaState = linux.TCP_CA_Open } - info.RTO = uint32(v.RTO / time.Microsecond) - info.RTT = uint32(v.RTT / time.Microsecond) - info.RTTVar = uint32(v.RTTVar / time.Microsecond) - info.SndSsthresh = v.SndSsthresh - info.SndCwnd = v.SndCwnd // In netstack reorderSeen is updated only when RACK is enabled. // We only track whether the reordering is seen, which is @@ -1312,7 +1342,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return &v, nil case linux.IP6T_ORIGINAL_DST: - if outLen < int(binary.Size(linux.SockAddrInet6{})) { + if outLen < sockAddrInet6Size { return nil, syserr.ErrInvalidArgument } @@ -1509,7 +1539,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return &v, nil case linux.SO_ORIGINAL_DST: - if outLen < int(binary.Size(linux.SockAddrInet{})) { + if outLen < sockAddrInetSize { return nil, syserr.ErrInvalidArgument } @@ -1742,7 +1772,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1755,7 +1785,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1791,7 +1821,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Linger - binary.Unmarshal(optVal[:linux.SizeOfLinger], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfLinger]) + + if v != (linux.Linger{}) { + socket.SetSockOptEmitUnimplementedEvent(t, name) + } ep.SocketOptions().SetLinger(tcpip.LingerOption{ Enabled: v.OnOff != 0, @@ -2090,9 +2124,9 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } var ( - inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{})) - inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{})) - inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{})) + inetMulticastRequestSize = (*linux.InetMulticastRequest)(nil).SizeBytes() + inetMulticastRequestWithNICSize = (*linux.InetMulticastRequestWithNIC)(nil).SizeBytes() + inet6MulticastRequestSize = (*linux.Inet6MulticastRequest)(nil).SizeBytes() ) // copyInMulticastRequest copies in a variable-size multicast request. The @@ -2117,12 +2151,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR if len(optVal) >= inetMulticastRequestWithNICSize { var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], hostarch.ByteOrder, &req) + req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize]) return req, nil } var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestSize], hostarch.ByteOrder, &req.InetMulticastRequest) + req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize]) return req, nil } @@ -2132,7 +2166,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse } var req linux.Inet6MulticastRequest - binary.Unmarshal(optVal[:inet6MulticastRequestSize], hostarch.ByteOrder, &req) + req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize]) return req, nil } @@ -3101,8 +3135,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe continue } // Populate ifr.ifr_netmask (type sockaddr). - hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET)) - hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0) + hostarch.ByteOrder.PutUint16(ifr.Data[0:], uint16(linux.AF_INET)) + hostarch.ByteOrder.PutUint16(ifr.Data[2:], 0) var mask uint32 = 0xffffffff << (32 - addr.PrefixLen) // Netmask is expected to be returned as a big endian // value. diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index b215067cf..eef5e6519 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -458,23 +458,10 @@ func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { s.Stack.RestoreCleanupEndpoints(es) } -// Forwarding implements inet.Stack.Forwarding. -func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { - switch protocol { - case ipv4.ProtocolNumber, ipv6.ProtocolNumber: - return s.Stack.Forwarding(protocol) - default: - panic(fmt.Sprintf("Forwarding(%v) failed: unsupported protocol", protocol)) - } -} - // SetForwarding implements inet.Stack.SetForwarding. func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { - switch protocol { - case ipv4.ProtocolNumber, ipv6.ProtocolNumber: - s.Stack.SetForwarding(protocol, enable) - default: - panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol)) + if err := s.Stack.SetForwardingDefaultAndAllNICs(protocol, enable); err != nil { + return fmt.Errorf("SetForwardingDefaultAndAllNICs(%d, %t): %s", protocol, enable, err) } return nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 4c3d48096..353f4ade0 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -24,7 +24,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -81,7 +80,7 @@ func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { } ee := linux.SockExtendedErr{ - Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), + Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux()), Origin: errOriginToLinux(sockErr.Cause.Origin()), Type: sockErr.Cause.Type(), Code: sockErr.Cause.Code(), @@ -572,19 +571,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { switch family { case unix.AF_INET: var addr linux.SockAddrInet - binary.Unmarshal(data[:unix.SizeofSockaddrInet4], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_INET6: var addr linux.SockAddrInet6 - binary.Unmarshal(data[:unix.SizeofSockaddrInet6], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_UNIX: var addr linux.SockAddrUnix - binary.Unmarshal(data[:unix.SizeofSockaddrUnix], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_NETLINK: var addr linux.SockAddrNetlink - binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr default: panic(fmt.Sprintf("Unsupported socket family %v", family)) @@ -716,7 +715,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInetSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInetSize], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrInetSize]) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -729,7 +728,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInet6Size { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInet6Size], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrInet6Size]) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -745,7 +744,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrLinkSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrLinkSize], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 2ebd77f82..1fbbd133c 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -25,7 +25,6 @@ go_library( ":strace_go_proto", "//pkg/abi", "//pkg/abi/linux", - "//pkg/binary", "//pkg/bits", "//pkg/eventchannel", "//pkg/hostarch", diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index 71b92eaee..d66befe81 100644 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go @@ -371,6 +371,7 @@ var linuxAMD64 = SyscallMap{ 433: makeSyscallInfo("fspick", FD, Path, Hex), 434: makeSyscallInfo("pidfd_open", Hex, Hex), 435: makeSyscallInfo("clone3", Hex, Hex), + 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet), } func init() { diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go index bd7361a52..1a2d7d75f 100644 --- a/pkg/sentry/strace/linux64_arm64.go +++ b/pkg/sentry/strace/linux64_arm64.go @@ -312,6 +312,7 @@ var linuxARM64 = SyscallMap{ 433: makeSyscallInfo("fspick", FD, Path, Hex), 434: makeSyscallInfo("pidfd_open", Hex, Hex), 435: makeSyscallInfo("clone3", Hex, Hex), + 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet), } func init() { diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index e5b7f9b96..f4aab25b0 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -20,14 +20,13 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - - "gvisor.dev/gvisor/pkg/hostarch" ) // SocketFamily are the possible socket(2) families. @@ -162,6 +161,15 @@ var controlMessageType = map[int32]string{ linux.SO_TIMESTAMP: "SO_TIMESTAMP", } +func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights { + count := len(src) / linux.SizeOfControlMessageRight + cmr := make(linux.ControlMessageRights, count) + for i, _ := range cmr { + cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:])) + } + return cmr +} + func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) string { if length > maxBytes { return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length) @@ -181,7 +189,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) + h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) var skipData bool level := "SOL_SOCKET" @@ -221,18 +229,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) continue } switch h.Type { case linux.SCM_RIGHTS: - rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) - - numRights := rightsSize / linux.SizeOfControlMessageRight - fds := make(linux.ControlMessageRights, numRights) - binary.Unmarshal(buf[i:i+rightsSize], hostarch.ByteOrder, &fds) - + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) + fds := unmarshalControlMessageRights(buf[i : i+rightsSize]) rights := make([]string, 0, len(fds)) for _, fd := range fds { rights = append(rights, fmt.Sprint(fd)) @@ -258,7 +262,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) + creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", @@ -282,7 +286,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var tv linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &tv) + tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", @@ -296,7 +300,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) default: panic("unreachable") } - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) } return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go index e115683f8..3b4d79889 100644 --- a/pkg/sentry/syscalls/epoll.go +++ b/pkg/sentry/syscalls/epoll.go @@ -119,7 +119,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error { } // WaitEpoll implements the epoll_wait(2) linux syscall. -func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEvent, error) { +func WaitEpoll(t *kernel.Task, fd int32, max int, timeoutInNanos int64) ([]linux.EpollEvent, error) { // Get epoll from the file descriptor. epollfile := t.GetFile(fd) if epollfile == nil { @@ -136,7 +136,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve // Try to read events and return right away if we got them or if the // caller requested a non-blocking "wait". r := e.ReadEvents(max) - if len(r) != 0 || timeout == 0 { + if len(r) != 0 || timeoutInNanos == 0 { return r, nil } @@ -144,8 +144,8 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve // and register with the epoll object for readability events. var haveDeadline bool var deadline ktime.Time - if timeout > 0 { - timeoutDur := time.Duration(timeout) * time.Millisecond + if timeoutInNanos > 0 { + timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) haveDeadline = true } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 2d2212605..090c5ffcb 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -404,6 +404,7 @@ var AMD64 = &kernel.SyscallTable{ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{ 0xffffffffff600000: 96, // vsyscall gettimeofday(2) @@ -722,6 +723,7 @@ var ARM64 = &kernel.SyscallTable{ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{}, Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index 7f460d30b..69cbc98d0 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/epoll" @@ -104,14 +105,8 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } -// EpollWait implements the epoll_wait(2) linux syscall. -func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - epfd := args[0].Int() - eventsAddr := args[1].Pointer() - maxEvents := int(args[2].Int()) - timeout := int(args[3].Int()) - - r, err := syscalls.WaitEpoll(t, epfd, maxEvents, timeout) +func waitEpoll(t *kernel.Task, fd int32, eventsAddr hostarch.Addr, max int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { + r, err := syscalls.WaitEpoll(t, fd, max, timeoutInNanos) if err != nil { return 0, nil, syserror.ConvertIntr(err, syserror.EINTR) } @@ -123,6 +118,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } return uintptr(len(r)), nil, nil + +} + +// EpollWait implements the epoll_wait(2) linux syscall. +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + // Convert milliseconds to nanoseconds. + timeoutInNanos := int64(args[3].Int()) * 1000000 + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) } // EpollPwait implements the epoll_pwait(2) linux syscall. @@ -144,4 +150,38 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } +// EpollPwait2 implements the epoll_pwait(2) linux syscall. +func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutPtr := args[3].Pointer() + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + haveTimeout := timeoutPtr != 0 + + var timeoutInNanos int64 = -1 + if haveTimeout { + timeout, err := copyTimespecIn(t, timeoutPtr) + if err != nil { + return 0, nil, err + } + timeoutInNanos = timeout.ToNsec() + + } + + if maskAddr != 0 { + mask, err := CopyInSigSet(t, maskAddr, maskSize) + if err != nil { + return 0, nil, err + } + + oldmask := t.SignalMask() + t.SetSignalMask(mask) + t.SetSavedSignalMask(oldmask) + } + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) +} + // LINT.ThenChange(vfs2/epoll.go) diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 5e9e940df..e07917613 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -463,8 +463,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(v.SizeBytes()) - if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index b980aa43e..047d955b6 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -118,13 +119,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } -// EpollWait implements Linux syscall epoll_wait(2). -func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - epfd := args[0].Int() - eventsAddr := args[1].Pointer() - maxEvents := int(args[2].Int()) - timeout := int(args[3].Int()) - +func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS { return 0, nil, syserror.EINVAL @@ -158,7 +153,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } return 0, nil, err } - if timeout == 0 { + if timeoutInNanos == 0 { return 0, nil, nil } // In the first iteration of this loop, register with the epoll @@ -173,8 +168,8 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys defer epfile.EventUnregister(&w) } else { // Set up the timer if a timeout was specified. - if timeout > 0 && !haveDeadline { - timeoutDur := time.Duration(timeout) * time.Millisecond + if timeoutInNanos > 0 && !haveDeadline { + timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) haveDeadline = true } @@ -186,6 +181,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } } } + +} + +// EpollWait implements Linux syscall epoll_wait(2). +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutInNanos := int64(args[3].Int()) * 1000000 + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) } // EpollPwait implements Linux syscall epoll_pwait(2). @@ -199,3 +205,29 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } + +// EpollPwait2 implements Linux syscall epoll_pwait(2). +func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutPtr := args[3].Pointer() + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + haveTimeout := timeoutPtr != 0 + + var timeoutInNanos int64 = -1 + if haveTimeout { + var timeout linux.Timespec + if _, err := timeout.CopyIn(t, timeoutPtr); err != nil { + return 0, nil, err + } + timeoutInNanos = timeout.ToNsec() + } + + if err := setTempSignalSet(t, maskAddr, maskSize); err != nil { + return 0, nil, err + } + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 6edde0ed1..69f69e3af 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -467,8 +467,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(v.SizeBytes()) - if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index c50fd97eb..0fc81e694 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -159,6 +159,7 @@ func Override() { s.Table[327] = syscalls.Supported("preadv2", Preadv2) s.Table[328] = syscalls.Supported("pwritev2", Pwritev2) s.Table[332] = syscalls.Supported("statx", Statx) + s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2) s.Init() // Override ARM64. @@ -269,6 +270,7 @@ func Override() { s.Table[286] = syscalls.Supported("preadv2", Preadv2) s.Table[287] = syscalls.Supported("pwritev2", Pwritev2) s.Table[291] = syscalls.Supported("statx", Statx) + s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2) s.Init() } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index f612a71b2..ef8d8a813 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -454,6 +454,9 @@ type FileDescriptionImpl interface { // RemoveXattr removes the given extended attribute from the file. RemoveXattr(ctx context.Context, name string) error + // SupportsLocks indicates whether file locks are supported. + SupportsLocks() bool + // LockBSD tries to acquire a BSD-style advisory file lock. LockBSD(ctx context.Context, uid lock.UniqueID, ownerPID int32, t lock.LockType, block lock.Blocker) error @@ -524,7 +527,7 @@ func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.St Start: fd.vd, }) stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, err } return fd.impl.Stat(ctx, opts) @@ -539,7 +542,7 @@ func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) err Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.SetStat(ctx, opts) @@ -555,7 +558,7 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vd, }) statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, err } return fd.impl.StatFS(ctx) @@ -701,7 +704,7 @@ func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string Start: fd.vd, }) names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return names, err } names, err := fd.impl.ListXattr(ctx, size) @@ -730,7 +733,7 @@ func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions) Start: fd.vd, }) val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return val, err } return fd.impl.GetXattr(ctx, *opts) @@ -746,7 +749,7 @@ func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions) Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.SetXattr(ctx, *opts) @@ -762,7 +765,7 @@ func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error { Start: fd.vd, }) err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.RemoveXattr(ctx, name) @@ -818,6 +821,11 @@ func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) e return fd.Sync(ctx) } +// SupportsLocks indicates whether file locks are supported. +func (fd *FileDescription) SupportsLocks() bool { + return fd.impl.SupportsLocks() +} + // LockBSD tries to acquire a BSD-style advisory file lock. func (fd *FileDescription) LockBSD(ctx context.Context, ownerPID int32, lockType lock.LockType, blocker lock.Blocker) error { atomic.StoreUint32(&fd.usedLockBSD, 1) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index b87d9690a..2b6f47b4b 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -413,6 +413,11 @@ type LockFD struct { locks *FileLocks } +// SupportsLocks implements FileDescriptionImpl.SupportsLocks. +func (LockFD) SupportsLocks() bool { + return true +} + // Init initializes fd with FileLocks to use. func (fd *LockFD) Init(locks *FileLocks) { fd.locks = locks @@ -423,28 +428,28 @@ func (fd *LockFD) Locks() *FileLocks { return fd.locks } -// LockBSD implements vfs.FileDescriptionImpl.LockBSD. +// LockBSD implements FileDescriptionImpl.LockBSD. func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return fd.locks.LockBSD(ctx, uid, ownerPID, t, block) } -// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. +// UnlockBSD implements FileDescriptionImpl.UnlockBSD. func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { fd.locks.UnlockBSD(uid) return nil } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +// LockPOSIX implements FileDescriptionImpl.LockPOSIX. func (fd *LockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { return fd.locks.LockPOSIX(ctx, uid, ownerPID, t, r, block) } -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX. func (fd *LockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { return fd.locks.UnlockPOSIX(ctx, uid, r) } -// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +// TestPOSIX implements FileDescriptionImpl.TestPOSIX. func (fd *LockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { return fd.locks.TestPOSIX(ctx, uid, t, r) } @@ -455,27 +460,68 @@ func (fd *LockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.L // +stateify savable type NoLockFD struct{} -// LockBSD implements vfs.FileDescriptionImpl.LockBSD. +// SupportsLocks implements FileDescriptionImpl.SupportsLocks. +func (NoLockFD) SupportsLocks() bool { + return false +} + +// LockBSD implements FileDescriptionImpl.LockBSD. func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return syserror.ENOLCK } -// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD. +// UnlockBSD implements FileDescriptionImpl.UnlockBSD. func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { return syserror.ENOLCK } -// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +// LockPOSIX implements FileDescriptionImpl.LockPOSIX. func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { return syserror.ENOLCK } -// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX. func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { return syserror.ENOLCK } -// TestPOSIX implements vfs.FileDescriptionImpl.TestPOSIX. +// TestPOSIX implements FileDescriptionImpl.TestPOSIX. func (NoLockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { return linux.Flock{}, syserror.ENOLCK } + +// BadLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface +// returning EBADF. +// +// +stateify savable +type BadLockFD struct{} + +// SupportsLocks implements FileDescriptionImpl.SupportsLocks. +func (BadLockFD) SupportsLocks() bool { + return false +} + +// LockBSD implements FileDescriptionImpl.LockBSD. +func (BadLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { + return syserror.EBADF +} + +// UnlockBSD implements FileDescriptionImpl.UnlockBSD. +func (BadLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error { + return syserror.EBADF +} + +// LockPOSIX implements FileDescriptionImpl.LockPOSIX. +func (BadLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, r fslock.LockRange, block fslock.Blocker) error { + return syserror.EBADF +} + +// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX. +func (BadLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, r fslock.LockRange) error { + return syserror.EBADF +} + +// TestPOSIX implements FileDescriptionImpl.TestPOSIX. +func (BadLockFD) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, r fslock.LockRange) (linux.Flock, error) { + return linux.Flock{}, syserror.EBADF +} diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go index 39fbac987..e9651b631 100644 --- a/pkg/sentry/vfs/opath.go +++ b/pkg/sentry/vfs/opath.go @@ -24,96 +24,96 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// opathFD implements vfs.FileDescriptionImpl for a file description opened with O_PATH. +// opathFD implements FileDescriptionImpl for a file description opened with O_PATH. // // +stateify savable type opathFD struct { vfsfd FileDescription FileDescriptionDefaultImpl - NoLockFD + BadLockFD } -// Release implements vfs.FileDescriptionImpl.Release. +// Release implements FileDescriptionImpl.Release. func (fd *opathFD) Release(context.Context) { // noop } -// Allocate implements vfs.FileDescriptionImpl.Allocate. +// Allocate implements FileDescriptionImpl.Allocate. func (fd *opathFD) Allocate(ctx context.Context, mode, offset, length uint64) error { return syserror.EBADF } -// PRead implements vfs.FileDescriptionImpl.PRead. +// PRead implements FileDescriptionImpl.PRead. func (fd *opathFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { return 0, syserror.EBADF } -// Read implements vfs.FileDescriptionImpl.Read. +// Read implements FileDescriptionImpl.Read. func (fd *opathFD) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { return 0, syserror.EBADF } -// PWrite implements vfs.FileDescriptionImpl.PWrite. +// PWrite implements FileDescriptionImpl.PWrite. func (fd *opathFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { return 0, syserror.EBADF } -// Write implements vfs.FileDescriptionImpl.Write. +// Write implements FileDescriptionImpl.Write. func (fd *opathFD) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { return 0, syserror.EBADF } -// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +// Ioctl implements FileDescriptionImpl.Ioctl. func (fd *opathFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { return 0, syserror.EBADF } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +// IterDirents implements FileDescriptionImpl.IterDirents. func (fd *opathFD) IterDirents(ctx context.Context, cb IterDirentsCallback) error { return syserror.EBADF } -// Seek implements vfs.FileDescriptionImpl.Seek. +// Seek implements FileDescriptionImpl.Seek. func (fd *opathFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { return 0, syserror.EBADF } -// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +// ConfigureMMap implements FileDescriptionImpl.ConfigureMMap. func (fd *opathFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { return syserror.EBADF } -// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +// ListXattr implements FileDescriptionImpl.ListXattr. func (fd *opathFD) ListXattr(ctx context.Context, size uint64) ([]string, error) { return nil, syserror.EBADF } -// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +// GetXattr implements FileDescriptionImpl.GetXattr. func (fd *opathFD) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) { return "", syserror.EBADF } -// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +// SetXattr implements FileDescriptionImpl.SetXattr. func (fd *opathFD) SetXattr(ctx context.Context, opts SetXattrOptions) error { return syserror.EBADF } -// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +// RemoveXattr implements FileDescriptionImpl.RemoveXattr. func (fd *opathFD) RemoveXattr(ctx context.Context, name string) error { return syserror.EBADF } -// Sync implements vfs.FileDescriptionImpl.Sync. +// Sync implements FileDescriptionImpl.Sync. func (fd *opathFD) Sync(ctx context.Context) error { return syserror.EBADF } -// SetStat implements vfs.FileDescriptionImpl.SetStat. +// SetStat implements FileDescriptionImpl.SetStat. func (fd *opathFD) SetStat(ctx context.Context, opts SetStatOptions) error { return syserror.EBADF } -// Stat implements vfs.FileDescriptionImpl.Stat. +// Stat implements FileDescriptionImpl.Stat. func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { vfsObj := fd.vfsfd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ @@ -121,7 +121,7 @@ func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, err Start: fd.vfsfd.vd, }) stat, err := fd.vfsfd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, err } @@ -134,6 +134,6 @@ func (fd *opathFD) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vfsfd.vd, }) statfs, err := fd.vfsfd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, err } diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index e4fd55012..97b898aba 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -44,13 +44,10 @@ type ResolvingPath struct { start *Dentry pit fspath.Iterator - flags uint16 - mustBeDir bool // final file must be a directory? - mustBeDirOrig bool - symlinks uint8 // number of symlinks traversed - symlinksOrig uint8 - curPart uint8 // index into parts - numOrigParts uint8 + flags uint16 + mustBeDir bool // final file must be a directory? + symlinks uint8 // number of symlinks traversed + curPart uint8 // index into parts creds *auth.Credentials @@ -60,14 +57,9 @@ type ResolvingPath struct { nextStart *Dentry // ref held if not nil absSymlinkTarget fspath.Path - // ResolvingPath must track up to two relative paths: the "current" - // relative path, which is updated whenever a relative symlink is - // encountered, and the "original" relative path, which is updated from the - // current relative path by handleError() when resolution must change - // filesystems (due to reaching a mount boundary or absolute symlink) and - // overwrites the current relative path when Restart() is called. - parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator - origParts [1 + linux.MaxSymlinkTraversals]fspath.Iterator + // ResolvingPath tracks relative paths, which is updated whenever a relative + // symlink is encountered. + parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator } const ( @@ -120,6 +112,8 @@ var resolvingPathPool = sync.Pool{ }, } +// getResolvingPath gets a new ResolvingPath from the pool. Caller must call +// ResolvingPath.Release() when done. func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath { rp := resolvingPathPool.Get().(*ResolvingPath) rp.vfs = vfs @@ -132,17 +126,37 @@ func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *Pat rp.flags |= rpflagsFollowFinalSymlink } rp.mustBeDir = pop.Path.Dir - rp.mustBeDirOrig = pop.Path.Dir rp.symlinks = 0 rp.curPart = 0 - rp.numOrigParts = 1 rp.creds = creds rp.parts[0] = pop.Path.Begin - rp.origParts[0] = pop.Path.Begin return rp } -func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) { +// Copy creates another ResolvingPath with the same state as the original. +// Copies are independent, using the copy does not change the original and +// vice-versa. +// +// Caller must call Resease() when done. +func (rp *ResolvingPath) Copy() *ResolvingPath { + copy := resolvingPathPool.Get().(*ResolvingPath) + *copy = *rp // All fields all shallow copiable. + + // Take extra reference for the copy if the original had them. + if copy.flags&rpflagsHaveStartRef != 0 { + copy.start.IncRef() + } + if copy.flags&rpflagsHaveMountRef != 0 { + copy.mount.IncRef() + } + // Reset error state. + copy.nextStart = nil + copy.nextMount = nil + return copy +} + +// Release decrements references if needed and returns the object to the pool. +func (rp *ResolvingPath) Release(ctx context.Context) { rp.root = VirtualDentry{} rp.decRefStartAndMount(ctx) rp.mount = nil @@ -240,25 +254,6 @@ func (rp *ResolvingPath) Advance() { } } -// Restart resets the stream of path components represented by rp to its state -// on entry to the current FilesystemImpl method. -func (rp *ResolvingPath) Restart(ctx context.Context) { - rp.pit = rp.origParts[rp.numOrigParts-1] - rp.mustBeDir = rp.mustBeDirOrig - rp.symlinks = rp.symlinksOrig - rp.curPart = rp.numOrigParts - 1 - copy(rp.parts[:], rp.origParts[:rp.numOrigParts]) - rp.releaseErrorState(ctx) -} - -func (rp *ResolvingPath) relpathCommit() { - rp.mustBeDirOrig = rp.mustBeDir - rp.symlinksOrig = rp.symlinks - rp.numOrigParts = rp.curPart + 1 - copy(rp.origParts[:rp.curPart], rp.parts[:]) - rp.origParts[rp.curPart] = rp.pit -} - // CheckRoot is called before resolving the parent of the Dentry d. If the // Dentry is contextually a VFS root, such that path resolution should treat // d's parent as itself, CheckRoot returns (true, nil). If the Dentry is the @@ -405,11 +400,10 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef rp.nextMount = nil rp.nextStart = nil - // Commit the previous FileystemImpl's progress through the relative - // path. (Don't consume the path component that caused us to traverse + // Don't consume the path component that caused us to traverse // through the mount root - i.e. the ".." - because we still need to - // resolve the mount point's parent in the new FilesystemImpl.) - rp.relpathCommit() + // resolve the mount point's parent in the new FilesystemImpl. + // // Restart path resolution on the new Mount. Don't bother calling // rp.releaseErrorState() since we already set nextMount and nextStart // to nil above. @@ -425,9 +419,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.nextMount = nil // Consume the path component that represented the mount point. rp.Advance() - // Commit the previous FilesystemImpl's progress through the relative - // path. - rp.relpathCommit() // Restart path resolution on the new Mount. rp.releaseErrorState(ctx) return true @@ -442,9 +433,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.Advance() // Prepend the symlink target to the relative path. rp.relpathPrepend(rp.absSymlinkTarget) - // Commit the previous FilesystemImpl's progress through the relative - // path, including the symlink target we just prepended. - rp.relpathCommit() // Restart path resolution on the new Mount. rp.releaseErrorState(ctx) return true diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 00f1847d8..87fdcf403 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -208,11 +208,11 @@ func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -230,11 +230,11 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede dentry: d, } rp.mount.IncRef() - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return vd, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return VirtualDentry{}, err } } @@ -252,7 +252,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } rp.mount.IncRef() name := rp.Component() - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return parentVD, name, nil } if checkInvariants { @@ -261,7 +261,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return VirtualDentry{}, "", err } } @@ -292,7 +292,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential for { err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldVD.DecRef(ctx) return nil } @@ -302,7 +302,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldVD.DecRef(ctx) return err } @@ -331,7 +331,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -340,7 +340,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -366,7 +366,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -375,7 +375,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -425,7 +425,6 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential rp := vfs.getResolvingPath(creds, pop) if opts.Flags&linux.O_DIRECTORY != 0 { rp.mustBeDir = true - rp.mustBeDirOrig = true } // Ignore O_PATH for verity, as verity performs extra operations on the fd for verification. // The underlying filesystem that verity wraps opens the fd with O_PATH. @@ -444,7 +443,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential for { fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) if opts.FileExec { if fd.Mount().Flags.NoExec { @@ -468,7 +467,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential return fd, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -480,11 +479,11 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden for { target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return target, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return "", err } } @@ -533,7 +532,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldParentVD.DecRef(ctx) return nil } @@ -543,7 +542,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldParentVD.DecRef(ctx) return err } @@ -569,7 +568,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.RmdirAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -578,7 +577,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -590,11 +589,11 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -606,11 +605,11 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential for { stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return linux.Statx{}, err } } @@ -623,11 +622,11 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti for { statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return linux.Statfs{}, err } } @@ -652,7 +651,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -661,7 +660,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -686,7 +685,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.UnlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -695,7 +694,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -707,7 +706,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C for { bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return bep, nil } if checkInvariants { @@ -716,7 +715,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -729,7 +728,7 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede for { names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return names, nil } if err == syserror.ENOTSUP { @@ -737,11 +736,11 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by // default don't exist. - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -754,11 +753,11 @@ func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Creden for { val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return val, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return "", err } } @@ -771,11 +770,11 @@ func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Creden for { err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -787,11 +786,11 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre for { err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 8e3146d8d..dfe85f31d 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -243,6 +243,7 @@ func (w *Watchdog) waitForStart() { } stuckStartup.Increment() + metric.WeirdnessMetric.Increment("watchdog_stuck_startup") var buf bytes.Buffer buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout)) @@ -312,10 +313,11 @@ func (w *Watchdog) runTurn() { // New stuck task detected. // // Note that tasks blocked doing IO may be considered stuck in kernel, - // unless they are surrounded b + // unless they are surrounded by // Task.UninterruptibleSleepStart/Finish. tc = &offender{lastUpdateTime: lastUpdateTime} stuckTasks.Increment() + metric.WeirdnessMetric.Increment("watchdog_stuck_tasks") newTaskFound = true } newOffenders[t] = tc diff --git a/pkg/shim/BUILD b/pkg/shim/BUILD index 4f7c02f5d..fd6127b97 100644 --- a/pkg/shim/BUILD +++ b/pkg/shim/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -41,7 +41,19 @@ go_library( "@com_github_containerd_fifo//:go_default_library", "@com_github_containerd_typeurl//:go_default_library", "@com_github_gogo_protobuf//types:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_sirupsen_logrus//:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "shim_test", + size = "small", + srcs = ["service_test.go"], + library = ":shim", + deps = [ + "//pkg/shim/utils", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) diff --git a/pkg/shim/service.go b/pkg/shim/service.go index 9d9fa8ef6..1f9adcb65 100644 --- a/pkg/shim/service.go +++ b/pkg/shim/service.go @@ -22,6 +22,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync" "time" @@ -44,6 +45,7 @@ import ( "github.com/containerd/containerd/sys/reaper" "github.com/containerd/typeurl" "github.com/gogo/protobuf/types" + specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cleanup" @@ -944,9 +946,19 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C if err != nil { return nil, fmt.Errorf("read oci spec: %w", err) } - if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + + updated, err := utils.UpdateVolumeAnnotations(spec) + if err != nil { return nil, fmt.Errorf("update volume annotations: %w", err) } + updated = updateCgroup(spec) || updated + + if updated { + if err := utils.WriteSpec(r.Bundle, spec); err != nil { + return nil, err + } + } + runsc.FormatRunscLogPath(r.ID, options.RunscConfig) runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig) p := proc.New(r.ID, runtime, stdio.Stdio{ @@ -966,3 +978,39 @@ func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.C p.Monitor = reaper.Default return p, nil } + +// updateCgroup updates cgroup path for the sandbox to make the sandbox join the +// pod cgroup and not the pause container cgroup. Returns true if the spec was +// modified. Ex.: +// /kubepods/burstable/pod123/abc => kubepods/burstable/pod123 +// +func updateCgroup(spec *specs.Spec) bool { + if !utils.IsSandbox(spec) { + return false + } + if spec.Linux == nil || len(spec.Linux.CgroupsPath) == 0 { + return false + } + + // Search backwards for the pod cgroup path to make the sandbox use it, + // instead of the pause container's cgroup. + parts := strings.Split(spec.Linux.CgroupsPath, string(filepath.Separator)) + for i := len(parts) - 1; i >= 0; i-- { + if strings.HasPrefix(parts[i], "pod") { + var path string + for j := 0; j <= i; j++ { + path = filepath.Join(path, parts[j]) + } + // Add back the initial '/' that may have been lost above. + if filepath.IsAbs(spec.Linux.CgroupsPath) { + path = string(filepath.Separator) + path + } + if spec.Linux.CgroupsPath == path { + return false + } + spec.Linux.CgroupsPath = path + return true + } + } + return false +} diff --git a/pkg/shim/service_test.go b/pkg/shim/service_test.go new file mode 100644 index 000000000..2d9f07e02 --- /dev/null +++ b/pkg/shim/service_test.go @@ -0,0 +1,121 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/shim/utils" +) + +func TestCgroupPath(t *testing.T) { + for _, tc := range []struct { + name string + path string + want string + }{ + { + name: "simple", + path: "foo/pod123/container", + want: "foo/pod123", + }, + { + name: "absolute", + path: "/foo/pod123/container", + want: "/foo/pod123", + }, + { + name: "no-container", + path: "foo/pod123", + want: "foo/pod123", + }, + { + name: "no-container-absolute", + path: "/foo/pod123", + want: "/foo/pod123", + }, + { + name: "double-pod", + path: "/foo/podium/pod123/container", + want: "/foo/podium/pod123", + }, + { + name: "start-pod", + path: "pod123/container", + want: "pod123", + }, + { + name: "start-pod-absolute", + path: "/pod123/container", + want: "/pod123", + }, + { + name: "slashes", + path: "///foo/////pod123//////container", + want: "/foo/pod123", + }, + { + name: "no-pod", + path: "/foo/nopod123/container", + want: "/foo/nopod123/container", + }, + } { + t.Run(tc.name, func(t *testing.T) { + spec := specs.Spec{ + Linux: &specs.Linux{ + CgroupsPath: tc.path, + }, + } + updated := updateCgroup(&spec) + if spec.Linux.CgroupsPath != tc.want { + t.Errorf("updateCgroup(%q), want: %q, got: %q", tc.path, tc.want, spec.Linux.CgroupsPath) + } + if shouldUpdate := tc.path != tc.want; shouldUpdate != updated { + t.Errorf("updateCgroup(%q)=%v, want: %v", tc.path, updated, shouldUpdate) + } + }) + } +} + +// Test cases that cgroup path should not be updated. +func TestCgroupNoUpdate(t *testing.T) { + for _, tc := range []struct { + name string + spec *specs.Spec + }{ + { + name: "empty", + spec: &specs.Spec{}, + }, + { + name: "subcontainer", + spec: &specs.Spec{ + Linux: &specs.Linux{ + CgroupsPath: "foo/pod123/container", + }, + Annotations: map[string]string{ + utils.ContainerTypeAnnotation: utils.ContainerTypeContainer, + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + if updated := updateCgroup(tc.spec); updated { + t.Errorf("updateCgroup(%+v), got: %v, want: false", tc.spec.Linux, updated) + } + }) + } +} diff --git a/pkg/shim/utils/annotations.go b/pkg/shim/utils/annotations.go index 1e9d3f365..c744800bb 100644 --- a/pkg/shim/utils/annotations.go +++ b/pkg/shim/utils/annotations.go @@ -19,7 +19,9 @@ package utils // These are vendor due to import conflicts. const ( sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory" - containerTypeAnnotation = "io.kubernetes.cri.container-type" + // ContainerTypeAnnotation is they key that defines sandbox or container. + ContainerTypeAnnotation = "io.kubernetes.cri.container-type" containerTypeSandbox = "sandbox" - containerTypeContainer = "container" + // ContainerTypeContainer is the value for container. + ContainerTypeContainer = "container" ) diff --git a/pkg/shim/utils/utils.go b/pkg/shim/utils/utils.go index 7b1cd983e..f183b1bbc 100644 --- a/pkg/shim/utils/utils.go +++ b/pkg/shim/utils/utils.go @@ -18,19 +18,16 @@ package utils import ( "encoding/json" "io/ioutil" - "os" "path/filepath" specs "github.com/opencontainers/runtime-spec/specs-go" ) +const configFilename = "config.json" + // ReadSpec reads OCI spec from the bundle directory. func ReadSpec(bundle string) (*specs.Spec, error) { - f, err := os.Open(filepath.Join(bundle, "config.json")) - if err != nil { - return nil, err - } - b, err := ioutil.ReadAll(f) + b, err := ioutil.ReadFile(filepath.Join(bundle, configFilename)) if err != nil { return nil, err } @@ -41,9 +38,18 @@ func ReadSpec(bundle string) (*specs.Spec, error) { return &spec, nil } +// WriteSpec writes OCI spec to the bundle directory. +func WriteSpec(bundle string, spec *specs.Spec) error { + b, err := json.Marshal(spec) + if err != nil { + return err + } + return ioutil.WriteFile(filepath.Join(bundle, configFilename), b, 0666) +} + // IsSandbox checks whether a container is a sandbox container. func IsSandbox(spec *specs.Spec) bool { - t, ok := spec.Annotations[containerTypeAnnotation] + t, ok := spec.Annotations[ContainerTypeAnnotation] return !ok || t == containerTypeSandbox } diff --git a/pkg/shim/utils/volumes.go b/pkg/shim/utils/volumes.go index 52a428179..6bc75139d 100644 --- a/pkg/shim/utils/volumes.go +++ b/pkg/shim/utils/volumes.go @@ -15,9 +15,7 @@ package utils import ( - "encoding/json" "fmt" - "io/ioutil" "path/filepath" "strings" @@ -89,18 +87,16 @@ func isVolumePath(volume, path string) (bool, error) { } // UpdateVolumeAnnotations add necessary OCI annotations for gvisor -// volume optimization. -func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { - var ( - uid string - err error - ) +// volume optimization. Returns true if the spec was modified. +func UpdateVolumeAnnotations(s *specs.Spec) (bool, error) { + var uid string if IsSandbox(s) { + var err error uid, err = podUID(s) if err != nil { // Skip if we can't get pod UID, because this doesn't work // for containerd 1.1. - return nil + return false, nil } } var updated bool @@ -116,40 +112,48 @@ func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { // This is a sandbox. path, err := volumePath(volume, uid) if err != nil { - return fmt.Errorf("get volume path for %q: %w", volume, err) + return false, fmt.Errorf("get volume path for %q: %w", volume, err) } s.Annotations[volumeSourceKey(volume)] = path updated = true } else { // This is a container. for i := range s.Mounts { - // An error is returned for sandbox if source - // annotation is not successfully applied, so - // it is guaranteed that the source annotation - // for sandbox has already been successfully - // applied at this point. + // An error is returned for sandbox if source annotation is not + // successfully applied, so it is guaranteed that the source annotation + // for sandbox has already been successfully applied at this point. // - // The volume name is unique inside a pod, so - // matching without podUID is fine here. + // The volume name is unique inside a pod, so matching without podUID + // is fine here. // - // TODO: Pass podUID down to shim for containers to do - // more accurate matching. + // TODO: Pass podUID down to shim for containers to do more accurate + // matching. if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes { - // gVisor requires the container mount type to match - // sandbox mount type. - s.Mounts[i].Type = v + // Container mount type must match the sandbox's mount type. + changeMountType(&s.Mounts[i], v) updated = true } } } } - if !updated { - return nil - } - // Update bundle. - b, err := json.Marshal(s) - if err != nil { - return err + return updated, nil +} + +func changeMountType(m *specs.Mount, newType string) { + m.Type = newType + + // OCI spec allows bind mounts to be specified in options only. So if new type + // is not bind, remove bind/rbind from options. + // + // "For bind mounts (when options include either bind or rbind), the type is + // a dummy, often "none" (not listed in /proc/filesystems)." + if newType != "bind" { + newOpts := make([]string, 0, len(m.Options)) + for _, opt := range m.Options { + if opt != "rbind" && opt != "bind" { + newOpts = append(newOpts, opt) + } + } + m.Options = newOpts } - return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666) } diff --git a/pkg/shim/utils/volumes_test.go b/pkg/shim/utils/volumes_test.go index 3e02c6151..5db43cdf1 100644 --- a/pkg/shim/utils/volumes_test.go +++ b/pkg/shim/utils/volumes_test.go @@ -15,11 +15,9 @@ package utils import ( - "encoding/json" "fmt" "io/ioutil" "os" - "path/filepath" "reflect" "testing" @@ -47,60 +45,60 @@ func TestUpdateVolumeAnnotations(t *testing.T) { } for _, test := range []struct { - desc string + name string spec *specs.Spec expected *specs.Spec expectErr bool expectUpdate bool }{ { - desc: "volume annotations for sandbox", + name: "volume annotations for sandbox", spec: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + sandboxLogDirAnnotation: testLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", - "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + sandboxLogDirAnnotation: testLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, }, }, expectUpdate: true, }, { - desc: "volume annotations for sandbox with legacy log path", + name: "volume annotations for sandbox with legacy log path", spec: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLegacyLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + sandboxLogDirAnnotation: testLegacyLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLegacyLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", - "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + sandboxLogDirAnnotation: testLegacyLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, }, }, expectUpdate: true, }, { - desc: "tmpfs: volume annotations for container", + name: "tmpfs: volume annotations for container", spec: &specs.Spec{ Mounts: []specs.Mount{ { @@ -117,10 +115,10 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ @@ -139,16 +137,16 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expectUpdate: true, }, { - desc: "bind: volume annotations for container", + name: "bind: volume annotations for container", spec: &specs.Spec{ Mounts: []specs.Mount{ { @@ -159,10 +157,10 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "container", + volumeKeyPrefix + testVolumeName + ".type": "bind", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ @@ -175,63 +173,63 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "container", + volumeKeyPrefix + testVolumeName + ".type": "bind", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expectUpdate: true, }, { - desc: "should not return error without pod log directory", + name: "should not return error without pod log directory", spec: &specs.Spec{ Annotations: map[string]string{ - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ Annotations: map[string]string{ - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, }, { - desc: "should return error if volume path does not exist", + name: "should return error if volume path does not exist", spec: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount.notexist.share": "pod", - "dev.gvisor.spec.mount.notexist.type": "tmpfs", - "dev.gvisor.spec.mount.notexist.options": "ro", + sandboxLogDirAnnotation: testLogDirPath, + ContainerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + "notexist.share": "pod", + volumeKeyPrefix + "notexist.type": "tmpfs", + volumeKeyPrefix + "notexist.options": "ro", }, }, expectErr: true, }, { - desc: "no volume annotations for sandbox", + name: "no volume annotations for sandbox", spec: &specs.Spec{ Annotations: map[string]string{ sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, + ContainerTypeAnnotation: containerTypeSandbox, }, }, expected: &specs.Spec{ Annotations: map[string]string{ sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, + ContainerTypeAnnotation: containerTypeSandbox, }, }, }, { - desc: "no volume annotations for container", + name: "no volume annotations for container", spec: &specs.Spec{ Mounts: []specs.Mount{ { @@ -248,7 +246,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, + ContainerTypeAnnotation: ContainerTypeContainer, }, }, expected: &specs.Spec{ @@ -267,17 +265,51 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, + ContainerTypeAnnotation: ContainerTypeContainer, }, }, }, + { + name: "bind options removed", + spec: &specs.Spec{ + Annotations: map[string]string{ + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, + }, + Mounts: []specs.Mount{ + { + Destination: "/dst", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro", "bind", "rbind"}, + }, + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + ContainerTypeAnnotation: ContainerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, + }, + Mounts: []specs.Mount{ + { + Destination: "/dst", + Type: "tmpfs", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + }, + expectUpdate: true, + }, } { - t.Run(test.desc, func(t *testing.T) { - bundle, err := ioutil.TempDir(dir, "test-bundle") - if err != nil { - t.Fatalf("Create test bundle: %v", err) - } - err = UpdateVolumeAnnotations(bundle, test.spec) + t.Run(test.name, func(t *testing.T) { + updated, err := UpdateVolumeAnnotations(test.spec) if test.expectErr { if err == nil { t.Fatal("Expected error, but got nil") @@ -290,18 +322,8 @@ func TestUpdateVolumeAnnotations(t *testing.T) { if !reflect.DeepEqual(test.expected, test.spec) { t.Fatalf("Expected %+v, got %+v", test.expected, test.spec) } - if test.expectUpdate { - b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json")) - if err != nil { - t.Fatalf("Read spec from bundle: %v", err) - } - var spec specs.Spec - if err := json.Unmarshal(b, &spec); err != nil { - t.Fatalf("Unmarshal spec: %v", err) - } - if !reflect.DeepEqual(test.expected, &spec) { - t.Fatalf("Expected %+v, got %+v", test.expected, &spec) - } + if test.expectUpdate != updated { + t.Errorf("Expected %v, got %v", test.expected, updated) } }) } diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index d6c89c7e9..08d06e37b 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -7,7 +7,6 @@ go_library( srcs = ["statefile.go"], visibility = ["//:sandbox"], deps = [ - "//pkg/binary", "//pkg/compressio", "//pkg/state/wire", ], diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go index bdfb800fb..d27c8c8a8 100644 --- a/pkg/state/statefile/statefile.go +++ b/pkg/state/statefile/statefile.go @@ -48,6 +48,7 @@ import ( "compress/flate" "crypto/hmac" "crypto/sha256" + "encoding/binary" "encoding/json" "fmt" "hash" @@ -55,7 +56,6 @@ import ( "strings" "time" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/compressio" "gvisor.dev/gvisor/pkg/state/wire" ) @@ -90,6 +90,13 @@ type WriteCloser interface { io.Closer } +func writeMetadataLen(w io.Writer, val uint64) error { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], val) + _, err := w.Write(buf[:]) + return err +} + // NewWriter returns a state data writer for a statefile. // // Note that the returned WriteCloser must be closed. @@ -127,7 +134,7 @@ func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser } // Metadata length. - if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil { + if err := writeMetadataLen(mw, uint64(len(b))); err != nil { return nil, err } // Metadata bytes; io.MultiWriter will return a short write error if @@ -158,6 +165,14 @@ func MetadataUnsafe(r io.Reader) (map[string]string, error) { return metadata(r, nil) } +func readMetadataLen(r io.Reader) (uint64, error) { + var buf [8]byte + if _, err := io.ReadFull(r, buf[:]); err != nil { + return 0, err + } + return binary.BigEndian.Uint64(buf[:]), nil +} + // metadata validates the magic header and reads out the metadata from a state // data stream. func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { @@ -183,7 +198,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { } }() - metadataLen, err := binary.ReadUint64(r, binary.BigEndian) + metadataLen, err := readMetadataLen(r) if err != nil { return nil, err } diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 79e564de6..90be24e15 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -38,7 +38,7 @@ var ( ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), linux.EADDRINUSE) ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), linux.EADDRNOTAVAIL) ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), linux.EPIPE) - ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), nil) + ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), linux.NOERRNO) ErrTimeout = New((&tcpip.ErrTimeout{}).String(), linux.ETIMEDOUT) ErrAborted = New((&tcpip.ErrAborted{}).String(), linux.EPIPE) ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), linux.EINPROGRESS) diff --git a/pkg/syserr/syserr.go b/pkg/syserr/syserr.go index b5881ea3c..d70521f32 100644 --- a/pkg/syserr/syserr.go +++ b/pkg/syserr/syserr.go @@ -34,24 +34,19 @@ type Error struct { // linux.Errno. noTranslation bool - // errno is the linux.Errno this Error should be translated to. nil means - // that this Error should be translated to a nil linux.Errno. - errno *linux.Errno + // errno is the linux.Errno this Error should be translated to. + errno linux.Errno } // New creates a new Error and adds a translation for it. // // New must only be called at init. -func New(message string, linuxTranslation *linux.Errno) *Error { +func New(message string, linuxTranslation linux.Errno) *Error { err := &Error{message: message, errno: linuxTranslation} - if linuxTranslation == nil { - return err - } - // TODO(b/34162363): Remove this. - errno := linuxTranslation.Number() - if errno <= 0 || errno >= len(linuxBackwardsTranslations) { + errno := linuxTranslation + if errno < 0 || int(errno) >= len(linuxBackwardsTranslations) { panic(fmt.Sprint("invalid errno: ", errno)) } @@ -74,7 +69,7 @@ func New(message string, linuxTranslation *linux.Errno) *Error { // NewDynamic should only be used sparingly and not be used for static error // messages. Errors with static error messages should be declared with New as // global variables. -func NewDynamic(message string, linuxTranslation *linux.Errno) *Error { +func NewDynamic(message string, linuxTranslation linux.Errno) *Error { return &Error{message: message, errno: linuxTranslation} } @@ -87,7 +82,7 @@ func NewWithoutTranslation(message string) *Error { return &Error{message: message, noTranslation: true} } -func newWithHost(message string, linuxTranslation *linux.Errno, hostErrno unix.Errno) *Error { +func newWithHost(message string, linuxTranslation linux.Errno, hostErrno unix.Errno) *Error { e := New(message, linuxTranslation) addLinuxHostTranslation(hostErrno, e) return e @@ -119,10 +114,10 @@ func (e *Error) ToError() error { if e.noTranslation { panic(fmt.Sprintf("error %q does not support translation", e.message)) } - if e.errno == nil { + errno := int(e.errno) + if errno == linux.NOERRNO { return nil } - errno := e.errno.Number() if errno <= 0 || errno >= len(linuxBackwardsTranslations) || !linuxBackwardsTranslations[errno].ok { panic(fmt.Sprintf("unknown error %q (%d)", e.message, errno)) } @@ -131,7 +126,7 @@ func (e *Error) ToError() error { // ToLinux converts the Error to a Linux ABI error that can be returned to the // application. -func (e *Error) ToLinux() *linux.Errno { +func (e *Error) ToLinux() linux.Errno { if e.noTranslation { panic(fmt.Sprintf("No Linux ABI translation available for %q", e.message)) } diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index 7d2f5adf6..76bee5a64 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,12 +8,3 @@ go_library( visibility = ["//visibility:public"], deps = ["@org_golang_x_sys//unix:go_default_library"], ) - -go_test( - name = "syserror_test", - srcs = ["syserror_test.go"], - deps = [ - ":syserror", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index aa30cfc85..ed4d7e958 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -22,12 +22,14 @@ go_library( "errors.go", "sock_err_list.go", "socketops.go", + "stdclock.go", + "stdclock_state.go", "tcpip.go", - "time_unsafe.go", "timer.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/atomicbitops", "//pkg/sync", "//pkg/tcpip/buffer", "//pkg/waiter", @@ -37,12 +39,30 @@ go_library( deps_test( name = "netstack_deps_test", allowed = [ + # gVisor deps. + "//pkg/atomicbitops", + "//pkg/buffer", + "//pkg/context", + "//pkg/gohacks", + "//pkg/goid", + "//pkg/ilist", + "//pkg/iovec", + "//pkg/linewriter", + "//pkg/log", + "//pkg/rand", + "//pkg/sleep", + "//pkg/state", + "//pkg/state/wire", + "//pkg/sync", + "//pkg/waiter", + + # Other deps. "@com_github_google_btree//:go_default_library", "@org_golang_x_sys//unix:go_default_library", "@org_golang_x_time//rate:go_default_library", ], allowed_prefixes = [ - "//", + "//pkg/tcpip", "@org_golang_x_sys//internal/unsafeheader", ], targets = [ diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 12c39dfa3..bab640faf 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -701,7 +701,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp if !ok { return } - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) foundTS := false tsVal := uint32(0) @@ -748,12 +748,6 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does -// not contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - // TCPSACKBlockChecker creates a checker that verifies that the segment does // contain the specified SACK blocks in the TCP options. func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { @@ -765,7 +759,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } var gotSACKBlocks []header.SACKBlock - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) for i := 0; i < limit; { switch opts[i] { @@ -1607,6 +1601,17 @@ func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { } } +// IPv6UnknownOption validates that an extension header option is the +// unknown header option. +func IPv6UnknownOption() IPv6ExtHdrOptionChecker { + return func(t *testing.T, opt header.IPv6ExtHdrOption) { + _, ok := opt.(*header.IPv6UnknownExtHdrOption) + if !ok { + t.Errorf("got = %T, want = header.IPv6UnknownExtHdrOption", opt) + } + } +} + // IgnoreCmpPath returns a cmp.Option that ignores listed field paths. func IgnoreCmpPath(paths ...string) cmp.Option { ignores := map[string]struct{}{} diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index fb819d7a8..c4dbe8440 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -218,6 +218,12 @@ func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { } } +// RunImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func (mc *ManualClock) RunImmediatelyScheduledJobs() { + mc.Advance(0) +} + // Advance executes all work that have been scheduled to execute within d from // the current time. Blocks until all work has completed execution. func (mc *ManualClock) Advance(d time.Duration) { diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 3d1bccd15..d6cad3a94 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -77,12 +77,12 @@ const ( // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag // field in the flags byte within an NDPPrefixInformation. - ndpPrefixInformationOnLinkFlagMask = (1 << 7) + ndpPrefixInformationOnLinkFlagMask = 1 << 7 // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the // Autonomous Address-Configuration flag field in the flags byte within // an NDPPrefixInformation. - ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + ndpPrefixInformationAutoAddrConfFlagMask = 1 << 6 // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 // field in the flags byte within an NDPPrefixInformation. @@ -451,7 +451,7 @@ func (o NDPNonceOption) String() string { // Nonce returns the nonce value this option holds. func (o NDPNonceOption) Nonce() []byte { - return []byte(o) + return o } // NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index ebb4b2c1d..1c913b5e1 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -60,9 +60,13 @@ func IPv4(pkt *stack.PacketBuffer) bool { return false } ipHdr = header.IPv4(hdr) + length := int(ipHdr.TotalLength()) - len(hdr) + if length < 0 { + return false + } pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr)) + pkt.Data().CapLength(length) return true } diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index f75ee34ab..f26c857eb 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -123,6 +123,9 @@ func (q *queue) RemoveNotify(handle *NotificationHandle) { q.notify = notify } +var _ stack.LinkEndpoint = (*Endpoint)(nil) +var _ stack.GSOEndpoint = (*Endpoint)(nil) + // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { @@ -130,6 +133,7 @@ type Endpoint struct { mtu uint32 linkAddr tcpip.LinkAddress LinkEPCapabilities stack.LinkEndpointCapabilities + SupportedGSOKind stack.SupportedGSO // Outbound packet queue. q *queue @@ -211,11 +215,16 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.LinkEPCapabilities } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (*Endpoint) GSOMaxSize() uint32 { return 1 << 15 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *Endpoint) SupportedGSO() stack.SupportedGSO { + return e.SupportedGSOKind +} + // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*Endpoint) MaxHeaderLength() uint16 { @@ -279,5 +288,5 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index f042df82e..d971194e6 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,7 +14,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/binary", "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index feb79fe0e..bddb1d0a2 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -45,7 +45,6 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -98,6 +97,9 @@ func (p PacketDispatchMode) String() string { } } +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.GSOEndpoint = (*endpoint)(nil) + type endpoint struct { // fds is the set of file descriptors each identifying one inbound/outbound // channel. The endpoint will dispatch from all inbound channels as well as @@ -134,6 +136,9 @@ type endpoint struct { // wg keeps track of running goroutines. wg sync.WaitGroup + + // gsoKind is the supported kind of GSO. + gsoKind stack.SupportedGSO } // Options specify the details about the fd-based endpoint to be created. @@ -255,9 +260,9 @@ func New(opts *Options) (stack.LinkEndpoint, error) { if isSocket { if opts.GSOMaxSize != 0 { if opts.SoftwareGSOEnabled { - e.caps |= stack.CapabilitySoftwareGSO + e.gsoKind = stack.SWGSOSupported } else { - e.caps |= stack.CapabilityHardwareGSO + e.gsoKind = stack.HWGSOSupported } e.gsoMaxSize = opts.GSOMaxSize } @@ -403,6 +408,35 @@ type virtioNetHdr struct { csumOffset uint16 } +// marshal serializes h to a newly-allocated byte slice, in little-endian byte +// order. +// +// Note: Virtio v1.0 onwards specifies little-endian as the byte ordering used +// for general serialization. This makes it difficult to use go-marshal for +// virtio types, as go-marshal implicitly uses the native byte ordering. +func (h *virtioNetHdr) marshal() []byte { + buf := [virtioNetHdrSize]byte{ + 0: byte(h.flags), + 1: byte(h.gsoType), + + // Manually lay out the fields in little-endian byte order. Little endian => + // least significant bit goes to the lower address. + + 2: byte(h.hdrLen), + 3: byte(h.hdrLen >> 8), + + 4: byte(h.gsoSize), + 5: byte(h.gsoSize >> 8), + + 6: byte(h.csumStart), + 7: byte(h.csumStart >> 8), + + 8: byte(h.csumOffset), + 9: byte(h.csumOffset >> 8), + } + return buf[:] +} + // These constants are declared in linux/virtio_net.h. const ( _VIRTIO_NET_HDR_F_NEEDS_CSUM = 1 @@ -441,7 +475,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol var builder iovec.Builder fd := e.fds[pkt.Hash%uint32(len(e.fds))] - if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} if pkt.GSOOptions.Type != stack.GSONone { vnetHdr.hdrLen = uint16(pkt.HeaderSize()) @@ -463,7 +497,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol } } - vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) + vnetHdrBuf := vnetHdr.marshal() builder.Add(vnetHdrBuf) } @@ -482,7 +516,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp } var vnetHdrBuf []byte - if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} if pkt.GSOOptions.Type != stack.GSONone { vnetHdr.hdrLen = uint16(pkt.HeaderSize()) @@ -503,7 +537,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp vnetHdr.gsoSize = pkt.GSOOptions.MSS } } - vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) + vnetHdrBuf = vnetHdr.marshal() } var builder iovec.Builder @@ -602,11 +636,16 @@ func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { } } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// SupportsHWGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + return e.gsoKind +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (e *endpoint) ARPHardwareType() header.ARPHardwareType { if e.hdrSize > 0 { diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index a7adf822b..4b7ef3aac 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -128,7 +128,7 @@ type readVDispatcher struct { func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &readVDispatcher{fd: fd, e: e} - skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) return d, nil } @@ -212,7 +212,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { bufs: make([]*iovecBuffer, MaxMsgsPerRecv), msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), } - skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported for i := range d.bufs { d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr) } diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 89df35822..3e816b0c7 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -135,6 +135,14 @@ func (e *Endpoint) GSOMaxSize() uint32 { return 0 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *Endpoint) SupportedGSO() stack.SupportedGSO { + if e, ok := e.child.(stack.GSOEndpoint); ok { + return e.SupportedGSO() + } + return stack.GSONotSupported +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { return e.child.ARPHardwareType() diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index bba6a6973..b1a28491d 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -25,6 +25,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.GSOEndpoint = (*endpoint)(nil) + // endpoint represents a LinkEndpoint which implements a FIFO queue for all // outgoing packets. endpoint can have 1 or more underlying queueDispatchers. // All outgoing packets are consistenly hashed to a single underlying queue @@ -141,7 +144,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.lower.LinkAddress() } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (e *endpoint) GSOMaxSize() uint32 { if gso, ok := e.lower.(stack.GSOEndpoint); ok { return gso.GSOMaxSize() @@ -149,6 +152,14 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + if gso, ok := e.lower.(stack.GSOEndpoint); ok { + return gso.SupportedGSO() + } + return stack.GSONotSupported +} + // WritePacket implements stack.LinkEndpoint.WritePacket. // // The packet must have the following fields populated: diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index 6905b9ccb..a72eb1aad 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -47,7 +47,7 @@ go_test( library = ":arp", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 94209b026..f89b55561 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -396,7 +396,7 @@ func TestDirectRequest(t *testing.T) { nicID: nicID, entry: stack.NeighborEntry{ Addr: test.senderAddr, - LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), + LinkAddr: test.senderLinkAddr, State: stack.Stale, }, } diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index e867b3c3f..0df39ae81 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -19,8 +19,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ stack.NetworkInterface = (*testInterface)(nil) diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 7daf64b4a..dadfc28cc 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -275,15 +275,23 @@ func TestMemoryLimits(t *testing.T) { highLimit := 3 * lowLimit // Allow at most 3 such packets. f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) + if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) + if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) + if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil { + t.Fatal(err) + } if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -300,9 +308,13 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { memSize := pkt(1, "0").MemSize() f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } if got, want := f.memSize, memSize; got != want { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 90075a70c..56b76a284 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -167,8 +167,7 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s resPkt := r.holes[0].pkt for i := 1; i < len(r.holes); i++ { - fragData := r.holes[i].pkt.Data() - resPkt.Data().ReadFromData(fragData, fragData.Size()) + stack.MergeFragment(resPkt, r.holes[i].pkt) } return resPkt, r.proto, true, memConsumed, nil } diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD index d21b4c7ef..fd944ce99 100644 --- a/pkg/tcpip/network/internal/ip/BUILD +++ b/pkg/tcpip/network/internal/ip/BUILD @@ -6,6 +6,7 @@ go_library( name = "ip", srcs = [ "duplicate_address_detection.go", + "errors.go", "generic_multicast_protocol.go", "stats.go", ], diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go index eed49f5d2..5123b7d6a 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -83,6 +83,8 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize)) } + configs.Validate() + *d = DAD{ opts: opts, configs: configs, diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index a22b712c6..24687cf06 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -133,7 +133,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) } // Wait for any initially fired timers to complete. - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -147,7 +147,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -156,7 +156,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -170,7 +170,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -208,7 +208,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -247,7 +247,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -272,7 +272,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -347,7 +347,7 @@ func TestNonce(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() for i, want := range test.expectedResults { if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go new file mode 100644 index 000000000..94f1cd1cb --- /dev/null +++ b/pkg/tcpip/network/internal/ip/errors.go @@ -0,0 +1,85 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// ForwardingError represents an error that occured while trying to forward +// a packet. +type ForwardingError interface { + isForwardingError() + fmt.Stringer +} + +// ErrTTLExceeded indicates that the received packet's TTL has been exceeded. +type ErrTTLExceeded struct{} + +func (*ErrTTLExceeded) isForwardingError() {} + +func (*ErrTTLExceeded) String() string { return "ttl exceeded" } + +// ErrParameterProblem indicates the received packet had a problem with an IP +// parameter. +type ErrParameterProblem struct{} + +func (*ErrParameterProblem) isForwardingError() {} + +func (*ErrParameterProblem) String() string { return "parameter problem" } + +// ErrLinkLocalSourceAddress indicates the received packet had a link-local +// source address. +type ErrLinkLocalSourceAddress struct{} + +func (*ErrLinkLocalSourceAddress) isForwardingError() {} + +func (*ErrLinkLocalSourceAddress) String() string { return "link local destination address" } + +// ErrLinkLocalDestinationAddress indicates the received packet had a link-local +// destination address. +type ErrLinkLocalDestinationAddress struct{} + +func (*ErrLinkLocalDestinationAddress) isForwardingError() {} + +func (*ErrLinkLocalDestinationAddress) String() string { return "link local destination address" } + +// ErrNoRoute indicates that a route for the received packet couldn't be found. +type ErrNoRoute struct{} + +func (*ErrNoRoute) isForwardingError() {} + +func (*ErrNoRoute) String() string { return "no route" } + +// ErrMessageTooLong indicates the packet was too big for the outgoing MTU. +// +// +stateify savable +type ErrMessageTooLong struct{} + +func (*ErrMessageTooLong) isForwardingError() {} + +func (*ErrMessageTooLong) String() string { return "message too long" } + +// ErrOther indicates the packet coould not be forwarded for a reason +// captured by the contained error. +type ErrOther struct { + Err tcpip.Error +} + +func (*ErrOther) isForwardingError() {} + +func (e *ErrOther) String() string { return fmt.Sprintf("other tcpip error: %s", e.Err) } diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index ac35d81e7..d22974b12 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ip holds IPv4/IPv6 common utilities. package ip import ( diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index d06b26309..40ab21cb6 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -16,80 +16,150 @@ package ip import "gvisor.dev/gvisor/pkg/tcpip" +// LINT.IfChange(MultiCounterIPForwardingStats) + +// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter +// may have several versions. +type MultiCounterIPForwardingStats struct { + // Unrouteable is the number of IP packets received which were dropped + // because the netstack could not construct a route to their + // destination. + Unrouteable tcpip.MultiCounterStat + + // ExhaustedTTL is the number of IP packets received which were dropped + // because their TTL was exhausted. + ExhaustedTTL tcpip.MultiCounterStat + + // LinkLocalSource is the number of IP packets which were dropped + // because they contained a link-local source address. + LinkLocalSource tcpip.MultiCounterStat + + // LinkLocalDestination is the number of IP packets which were dropped + // because they contained a link-local destination address. + LinkLocalDestination tcpip.MultiCounterStat + + // PacketTooBig is the number of IP packets which were dropped because they + // were too big for the outgoing MTU. + PacketTooBig tcpip.MultiCounterStat + + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable tcpip.MultiCounterStat + + // ExtensionHeaderProblem is the number of IP packets which were dropped + // because of a problem encountered when processing an IPv6 extension + // header. + ExtensionHeaderProblem tcpip.MultiCounterStat + + // Errors is the number of IP packets received which could not be + // successfully forwarded. + Errors tcpip.MultiCounterStat +} + +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { + m.Unrouteable.Init(a.Unrouteable, b.Unrouteable) + m.Errors.Init(a.Errors, b.Errors) + m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource) + m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination) + m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem) + m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig) + m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) + m.HostUnreachable.Init(a.HostUnreachable, b.HostUnreachable) +} + +// LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats) + // LINT.IfChange(MultiCounterIPStats) // MultiCounterIPStats holds IP statistics, each counter may have several // versions. type MultiCounterIPStats struct { - // PacketsReceived is the number of IP packets received from the link layer. + // PacketsReceived is the number of IP packets received from the link + // layer. PacketsReceived tcpip.MultiCounterStat - // DisabledPacketsReceived is the number of IP packets received from the link - // layer when the IP layer is disabled. + // ValidPacketsReceived is the number of valid IP packets that reached the IP + // layer. + ValidPacketsReceived tcpip.MultiCounterStat + + // DisabledPacketsReceived is the number of IP packets received from + // the link layer when the IP layer is disabled. DisabledPacketsReceived tcpip.MultiCounterStat - // InvalidDestinationAddressesReceived is the number of IP packets received - // with an unknown or invalid destination address. + // InvalidDestinationAddressesReceived is the number of IP packets + // received with an unknown or invalid destination address. InvalidDestinationAddressesReceived tcpip.MultiCounterStat - // InvalidSourceAddressesReceived is the number of IP packets received with a - // source address that should never have been received on the wire. + // InvalidSourceAddressesReceived is the number of IP packets received + // with a source address that should never have been received on the + // wire. InvalidSourceAddressesReceived tcpip.MultiCounterStat - // PacketsDelivered is the number of incoming IP packets that are successfully + // PacketsDelivered is the number of incoming IP packets successfully // delivered to the transport layer. PacketsDelivered tcpip.MultiCounterStat // PacketsSent is the number of IP packets sent via WritePacket. PacketsSent tcpip.MultiCounterStat - // OutgoingPacketErrors is the number of IP packets which failed to write to a - // link-layer endpoint. + // OutgoingPacketErrors is the number of IP packets which failed to + // write to a link-layer endpoint. OutgoingPacketErrors tcpip.MultiCounterStat - // MalformedPacketsReceived is the number of IP Packets that were dropped due - // to the IP packet header failing validation checks. + // MalformedPacketsReceived is the number of IP Packets that were + // dropped due to the IP packet header failing validation checks. MalformedPacketsReceived tcpip.MultiCounterStat - // MalformedFragmentsReceived is the number of IP Fragments that were dropped - // due to the fragment failing validation checks. + // MalformedFragmentsReceived is the number of IP Fragments that were + // dropped due to the fragment failing validation checks. MalformedFragmentsReceived tcpip.MultiCounterStat // IPTablesPreroutingDropped is the number of IP packets dropped in the // Prerouting chain. IPTablesPreroutingDropped tcpip.MultiCounterStat - // IPTablesInputDropped is the number of IP packets dropped in the Input - // chain. + // IPTablesInputDropped is the number of IP packets dropped in the + // Input chain. IPTablesInputDropped tcpip.MultiCounterStat - // IPTablesOutputDropped is the number of IP packets dropped in the Output - // chain. + // IPTablesForwardDropped is the number of IP packets dropped in the + // Forward chain. + IPTablesForwardDropped tcpip.MultiCounterStat + + // IPTablesOutputDropped is the number of IP packets dropped in the + // Output chain. IPTablesOutputDropped tcpip.MultiCounterStat - // IPTablesPostroutingDropped is the number of IP packets dropped in the - // Postrouting chain. + // IPTablesPostroutingDropped is the number of IP packets dropped in + // the Postrouting chain. IPTablesPostroutingDropped tcpip.MultiCounterStat - // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out - // of IPStats. + // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option + // stats out of IPStats. // OptionTimestampReceived is the number of Timestamp options seen. OptionTimestampReceived tcpip.MultiCounterStat - // OptionRecordRouteReceived is the number of Record Route options seen. + // OptionRecordRouteReceived is the number of Record Route options + // seen. OptionRecordRouteReceived tcpip.MultiCounterStat - // OptionRouterAlertReceived is the number of Router Alert options seen. + // OptionRouterAlertReceived is the number of Router Alert options + // seen. OptionRouterAlertReceived tcpip.MultiCounterStat // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived tcpip.MultiCounterStat + + // Forwarding collects stats related to IP forwarding. + Forwarding MultiCounterIPForwardingStats } // Init sets internal counters to track a and b counters. func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.ValidPacketsReceived.Init(a.ValidPacketsReceived, b.ValidPacketsReceived) m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived) m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived) @@ -100,12 +170,14 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived) m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) + m.IPTablesForwardDropped.Init(a.IPTablesForwardDropped, b.IPTablesForwardDropped) m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped) m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived) m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived) m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived) m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) + m.Forwarding.Init(&a.Forwarding, &b.Forwarding) } // LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index 1c4f583c7..a180e5c75 100644 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -4,15 +4,13 @@ package(licenses = ["notice"]) go_library( name = "testutil", - srcs = [ - "testutil.go", - "testutil_unsafe.go", - ], + srcs = ["testutil.go"], visibility = [ "//pkg/tcpip/network/arp:__pkg__", "//pkg/tcpip/network/internal/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", + "//pkg/tcpip/tests/integration:__pkg__", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go index e2cf24b67..605e9ef8d 100644 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ b/pkg/tcpip/network/internal/testutil/testutil.go @@ -19,8 +19,6 @@ package testutil import ( "fmt" "math/rand" - "reflect" - "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -129,69 +127,3 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi } return pkt } - -func checkFieldCounts(ref, multi reflect.Value) error { - refTypeName := ref.Type().Name() - multiTypeName := multi.Type().Name() - refNumField := ref.NumField() - multiNumField := multi.NumField() - - if refNumField != multiNumField { - return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) - } - - return nil -} - -func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { - s, ok := ref.Addr().Interface().(**tcpip.StatCounter) - if !ok { - return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) - } - - // The field names are expected to match (case insensitive). - if !strings.EqualFold(refName, multiName) { - return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) - } - - base := (*s).Value() - m.Increment() - if (*s).Value() != base+1 { - return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) - } - - return nil -} - -// ValidateMultiCounterStats verifies that every counter stored in multi is -// correctly tracking its counterpart in the given counters. -func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { - for _, c := range counters { - if err := checkFieldCounts(c, multi); err != nil { - return err - } - } - - for i := 0; i < multi.NumField(); i++ { - multiName := multi.Type().Field(i).Name - multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) - - if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { - for _, c := range counters { - if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { - return err - } - } - } else { - var countersNextField []reflect.Value - for _, c := range counters { - countersNextField = append(countersNextField, c.Field(i)) - } - if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { - return err - } - } - } - - return nil -} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 74aad126c..bd63e0289 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -1996,8 +1996,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) @@ -2005,8 +2005,8 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) } - if err := s.SetForwarding(test.netProto, false); err != nil { - t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netProto, false); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, false): %s", test.netProto, err) } if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7ee0495d9..c90974693 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -62,7 +62,7 @@ go_test( library = ":ipv4", deps = [ "//pkg/tcpip", - "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index f663fdc0b..5f6b0c6af 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -163,10 +163,12 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet return } - // Skip the ip header, then deliver the error. - pkt.Data().TrimFront(hlen) + // Keep needed information before trimming header. p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) + dstAddr := hdr.DestinationAddress() + // Skip the ip header, then deliver the error. + pkt.Data().DeleteFront(hlen) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { @@ -220,7 +222,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -336,14 +337,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4DstUnreachable: received.dstUnreachable.Increment() - pkt.Data().TrimFront(header.ICMPv4MinimumSize) - switch h.Code() { + mtu := h.MTU() + code := h.Code() + pkt.Data().DeleteFront(header.ICMPv4MinimumSize) + switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) case header.ICMPv4PortUnreachable: e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt) case header.ICMPv4FragmentationNeeded: - networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize) if err != nil { networkMTU = 0 } @@ -383,6 +386,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // icmpReason is a marker interface for IPv4 specific ICMP errors. type icmpReason interface { isICMPReason() + // isForwarding indicates whether or not the error arose while attempting to + // forward a packet. isForwarding() bool } @@ -442,6 +447,55 @@ func (r *icmpReasonParamProblem) isForwarding() bool { return r.forwarding } +// icmpReasonNetworkUnreachable is an error in which the network specified in +// the internet destination field of the datagram is unreachable. +type icmpReasonNetworkUnreachable struct{} + +func (*icmpReasonNetworkUnreachable) isICMPReason() {} +func (*icmpReasonNetworkUnreachable) isForwarding() bool { + // If we hit a Net Unreachable error, then we know we are operating as + // a router. As per RFC 792 page 5, Destination Unreachable Message, + // + // If, according to the information in the gateway's routing tables, + // the network specified in the internet destination field of a + // datagram is unreachable, e.g., the distance to the network is + // infinity, the gateway may send a destination unreachable message to + // the internet source host of the datagram. + return true +} + +// icmpReasonFragmentationNeeded is an error where a packet requires +// fragmentation while also having the Don't Fragment flag set, as per RFC 792 +// page 3, Destination Unreachable Message. +type icmpReasonFragmentationNeeded struct{} + +func (*icmpReasonFragmentationNeeded) isICMPReason() {} +func (*icmpReasonFragmentationNeeded) isForwarding() bool { + // If we hit a Don't Fragment error, then we know we are operating as a router. + // As per RFC 792 page 4, Destination Unreachable Message, + // + // Another case is when a datagram must be fragmented to be forwarded by a + // gateway yet the Don't Fragment flag is on. In this case the gateway must + // discard the datagram and may return a destination unreachable message. + return true +} + +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 792 page 5, Destination Unreachable Message, + // + // In addition, in some networks, the gateway may be able to determine + // if the internet destination host is unreachable. Gateways in these + // networks may send destination unreachable messages to the source host + // when the destination host is unreachable. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -498,7 +552,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -610,6 +669,18 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetworkUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4NetUnreachable) + counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4HostUnreachable) + counter = sent.dstUnreachable + case *icmpReasonFragmentationNeeded: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) + counter = sent.dstUnreachable case *icmpReasonTTLExceeded: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4TTLExceeded) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a0bc06465..bb8d53c12 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -62,9 +63,15 @@ const ( fragmentblockSize = 8 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -81,6 +88,12 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -91,6 +104,16 @@ type endpoint struct { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, return an ICMP error to the original + // packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -150,14 +173,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { e.mu.Lock() defer e.mu.Unlock() + if !e.setForwarding(forwarding) { + return + } + if forwarding { // There does not seem to be an RFC requirement for a node to join the all // routers multicast address but @@ -433,6 +474,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn } if packetMustBeFragmented(pkt, networkMTU) { + h := header.IPv4(pkt.NetworkHeader().View()) + if h.Flags()&header.IPv4FlagDontFragment != 0 && pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/5919): Handle error condition in which DontFragment + // is set but the packet must be fragmented for the non-forwarding case. + return &tcpip.ErrMessageTooLong{} + } sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we @@ -599,22 +646,25 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv4(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() - if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { - // As per RFC 3927 section 7, - // - // A router MUST NOT forward a packet with an IPv4 Link-Local source or - // destination address, irrespective of the router's default route - // configuration or routes obtained from dynamic routing protocols. - // - // A router which receives a packet with an IPv4 Link-Local source or - // destination address MUST NOT forward the packet. This prevents - // forwarding of packets back onto the network segment from which they - // originated, or to any other segment. - return nil + // As per RFC 3927 section 7, + // + // A router MUST NOT forward a packet with an IPv4 Link-Local source or + // destination address, irrespective of the router's default route + // configuration or routes obtained from dynamic routing protocols. + // + // A router which receives a packet with an IPv4 Link-Local source or + // destination address MUST NOT forward the packet. This prevents + // forwarding of packets back onto the network segment from which they + // originated, or to any other segment. + if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} } ttl := h.TTL() @@ -624,7 +674,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // If the gateway processing a datagram finds the time to live field // is zero it must discard the datagram. The gateway may also notify // the source host via the time exceeded message. - return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } if opts := h.Options(); len(opts) != 0 { @@ -635,10 +690,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { pointer: optProblem.Pointer, forwarding: true, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() - e.stats.ip.MalformedPacketsReceived.Increment() } - return nil // option problems are not reported locally. + return &ip.ErrParameterProblem{} } copied := copy(opts, newOpts) if copied != len(newOpts) { @@ -655,18 +708,44 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } } + stk := e.protocol.stack + // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(ep.nic.ID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + ep.handleValidatedPacket(h, pkt) return nil } - r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonNetworkUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(r.NICID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. @@ -680,10 +759,28 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // spent, the field must be decremented by 1. newHdr.SetTTL(ttl - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + IsForwardedPacket: true, + })); err.(type) { + case nil: + return nil + case *tcpip.ErrMessageTooLong: + // As per RFC 792, page 4, Destination Unreachable: + // + // Another case is when a datagram must be fragmented to be forwarded by a + // gateway yet the Don't Fragment flag is on. In this case the gateway must + // discard the datagram and may return a destination unreachable message. + // + // WriteHeaderIncludedPacket checks for the presence of the Don't Fragment bit + // while sending the packet and returns this error iff fragmentation is + // necessary and the bit is also set. + _ = e.protocol.returnError(&icmpReasonFragmentationNeeded{}, pkt) + return &ip.ErrMessageTooLong{} + default: + return &ip.ErrOther{Err: err} + } } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -764,6 +861,7 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats + stats.ip.ValidPacketsReceived.Increment() srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -794,11 +892,29 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) addressEndpoint.DecRef() pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.ip.InvalidDestinationAddressesReceived.Increment() return } - _ = e.forwardPacket(pkt) + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + stats.ip.Forwarding.Unrouteable.Increment() + case *ip.ErrParameterProblem: + stats.ip.MalformedPacketsReceived.Increment() + case *ip.ErrMessageTooLong: + stats.ip.Forwarding.PacketTooBig.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + stats.ip.Forwarding.Errors.Increment() return } @@ -828,7 +944,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -901,7 +1016,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() } return @@ -955,8 +1069,8 @@ func (e *endpoint) Close() { // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) if err == nil { @@ -967,8 +1081,8 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.mu.addressableEndpointState.RemovePermanentAddress(addr) } @@ -981,8 +1095,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { // AcquireAssignedAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() loopback := e.nic.IsLoopback() return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool { @@ -1067,7 +1181,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1088,12 +1201,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - ids []uint32 hashIV uint32 @@ -1206,35 +1313,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 7d413c455..7a8e0aa24 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -36,10 +36,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" - tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -114,65 +114,97 @@ func TestExcludeBroadcast(t *testing.T) { func TestForwarding(t *testing.T) { const ( - nicID1 = 1 - nicID2 = 2 + incomingNICID = 1 + outgoingNICID = 2 randomSequence = 123 randomIdent = 42 randomTimeOffset = 0x10203040 ) - ipv4Addr1 := tcpip.AddressWithPrefix{ + incomingIPv4Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, } - ipv4Addr2 := tcpip.AddressWithPrefix{ + outgoingIPv4Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), PrefixLen: 8, } - remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) - remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) + outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2") + remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2") + unreachableIPv4Addr := testutil.MustParse4("12.0.0.2") + multicastIPv4Addr := testutil.MustParse4("225.0.0.0") + linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0") tests := []struct { - name string - TTL uint8 - expectErrorICMP bool - options header.IPv4Options - forwardedOptions header.IPv4Options - icmpType header.ICMPv4Type - icmpCode header.ICMPv4Code + name string + TTL uint8 + sourceAddr tcpip.Address + destAddr tcpip.Address + expectErrorICMP bool + ipFlags uint8 + mtu uint32 + payloadLength int + options header.IPv4Options + forwardedOptions header.IPv4Options + icmpType header.ICMPv4Type + icmpCode header.ICMPv4Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool + expectPacketForwarded bool + expectedFragmentsForwarded []fragmentInfo }{ { name: "TTL of zero", TTL: 0, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, expectErrorICMP: true, icmpType: header.ICMPv4TimeExceeded, icmpCode: header.ICMPv4TTLExceeded, + mtu: ipv4.MaxTotalSize, }, { - name: "TTL of one", - TTL: 1, - expectErrorICMP: false, + name: "TTL of one", + TTL: 1, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "TTL of two", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Max TTL", + TTL: math.MaxUint8, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, }, { - name: "four EOL options", - TTL: 2, - expectErrorICMP: false, - options: header.IPv4Options{0, 0, 0, 0}, - forwardedOptions: header.IPv4Options{0, 0, 0, 0}, + name: "four EOL options", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + mtu: ipv4.MaxTotalSize, + options: header.IPv4Options{0, 0, 0, 0}, + forwardedOptions: header.IPv4Options{0, 0, 0, 0}, }, { - name: "TS type 1 full", - TTL: 2, + name: "TS type 1 full", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: ipv4.MaxTotalSize, options: header.IPv4Options{ 68, 12, 13, 0xF1, 192, 168, 1, 12, @@ -183,8 +215,11 @@ func TestForwarding(t *testing.T) { icmpCode: header.ICMPv4UnusedCode, }, { - name: "TS type 0", - TTL: 2, + name: "TS type 0", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: ipv4.MaxTotalSize, options: header.IPv4Options{ 68, 24, 21, 0x00, 1, 2, 3, 4, @@ -201,10 +236,14 @@ func TestForwarding(t *testing.T) { 13, 14, 15, 16, 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock }, + expectPacketForwarded: true, }, { - name: "end of options list", - TTL: 2, + name: "end of options list", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: ipv4.MaxTotalSize, options: header.IPv4Options{ 68, 12, 13, 0x11, 192, 168, 1, 12, @@ -220,11 +259,89 @@ func TestForwarding(t *testing.T) { 0, 0, 0, // 7 bytes unknown option removed. 0, 0, 0, 0, }, + expectPacketForwarded: true, + }, + { + name: "Network unreachable", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: unreachableIPv4Addr, + expectErrorICMP: true, + mtu: ipv4.MaxTotalSize, + icmpType: header.ICMPv4DstUnreachable, + icmpCode: header.ICMPv4NetUnreachable, + expectPacketUnrouteableError: true, + }, + { + name: "Multicast destination", + TTL: 2, + destAddr: multicastIPv4Addr, + expectPacketUnrouteableError: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: linkLocalIPv4Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv4Addr, + destAddr: remoteIPv4Addr2, + expectLinkLocalSourceError: true, + }, + { + name: "Fragmentation needed and DF set", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + ipFlags: header.IPv4FlagDontFragment, + // We've picked this MTU because it is: + // + // 1) Greater than the minimum MTU that IPv4 hosts are required to process + // (576 bytes). As per RFC 1812, Section 4.3.2.3: + // + // The ICMP datagram SHOULD contain as much of the original datagram as + // possible without the length of the ICMP datagram exceeding 576 bytes. + // + // Therefore, setting an MTU greater than 576 bytes ensures that we can fit a + // complete ICMP packet on the incoming endpoint (and make assertions about + // it). + // + // 2) Less than `ipv4.MaxTotalSize`, which lets us build an IPv4 packet whose + // size exceeds the MTU. + mtu: 1000, + payloadLength: 1004, + expectErrorICMP: true, + icmpType: header.ICMPv4DstUnreachable, + icmpCode: header.ICMPv4FragmentationNeeded, + }, + { + name: "Fragmentation needed and DF not set", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + mtu: 1000, + payloadLength: 1004, + expectPacketForwarded: true, + // Combined, these fragments have length of 1012 octets, which is equal to + // the length of the payload (1004 octets), plus the length of the ICMP + // header (8 octets). + expectedFragmentsForwarded: []fragmentInfo{ + // The first fragment has a length of the greatest multiple of 8 which is + // less than or equal to to `mtu - header.IPv4MinimumSize`. + {offset: 0, payloadSize: uint16(976), more: true}, + // The next fragment holds the rest of the packet. + {offset: uint16(976), payloadSize: 36, more: false}, + }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, @@ -236,46 +353,52 @@ func TestForwarding(t *testing.T) { clock.Advance(time.Millisecond * randomTimeOffset) // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } - ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} - if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) + incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} + if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err) } - e2 := channel.New(1, ipv4.MaxTotalSize, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + expectedEmittedPacketCount := 1 + if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount { + expectedEmittedPacketCount = len(test.expectedFragmentsForwarded) } - ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} - if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) + outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr) + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) + } + outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ { - Destination: ipv4Addr1.Subnet(), - NIC: nicID1, + Destination: incomingIPv4Addr.Subnet(), + NIC: incomingNICID, }, { - Destination: ipv4Addr2.Subnet(), - NIC: nicID2, + Destination: outgoingIPv4Addr.Subnet(), + NIC: outgoingNICID, }, }) - if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", header.IPv4ProtocolNumber, err) } ipHeaderLength := header.IPv4MinimumSize + len(test.options) if ipHeaderLength > header.IPv4MaximumHeaderSize { t.Fatalf("got ipHeaderLength = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } - totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpHeaderLength := header.ICMPv4MinimumSize + totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + hdr := buffer.NewPrependable(totalLength) + hdr.Prepend(test.payloadLength) + icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) icmp.SetIdent(randomIdent) icmp.SetSequence(randomSequence) icmp.SetType(header.ICMPv4Echo) @@ -284,11 +407,12 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(^header.Checksum(icmp, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv4Fields{ - TotalLength: totalLen, + TotalLength: uint16(totalLength), Protocol: uint8(header.ICMPv4ProtocolNumber), TTL: test.TTL, - SrcAddr: remoteIPv4Addr1, - DstAddr: remoteIPv4Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, + Flags: test.ipFlags, }) if len(test.options) != 0 { ip.SetHeaderLength(uint8(ipHeaderLength)) @@ -303,51 +427,122 @@ func TestForwarding(t *testing.T) { requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) - e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + + reply, ok := incomingEndpoint.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv4Addr1.Address), - checker.DstAddr(remoteIPv4Addr1), + // We expect the ICMP packet to contain as much of the original packet as + // possible up to a limit of 576 bytes, split between payload, IP header, + // and ICMP header. + expectedICMPPayloadLength := func() int { + maxICMPPacketLength := header.IPv4MinimumProcessableDatagramSize + maxICMPPayloadLength := maxICMPPacketLength - icmpHeaderLength - ipHeaderLength + if len(hdr.View()) > maxICMPPayloadLength { + return maxICMPPayloadLength + } + return len(hdr.View()) + } + + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(incomingIPv4Addr.Address), + checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), checker.ICMPv4( checker.ICMPv4Checksum(), checker.ICMPv4Type(test.icmpType), checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]), ), ) + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } + + if test.expectPacketForwarded { + if len(test.expectedFragmentsForwarded) != 0 { + var fragmentedPackets []*stack.PacketBuffer + for i := 0; i < len(test.expectedFragmentsForwarded); i++ { + reply, ok = outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP Echo fragment through outgoing NIC") + } + fragmentedPackets = append(fragmentedPackets, reply.Pkt) + } + + // The forwarded packet's TTL will have been decremented. + ipHeader := header.IPv4(requestPkt.NetworkHeader().View()) + ipHeader.SetTTL(ipHeader.TTL() - 1) + + // Forwarded packets have available header bytes equalling the sum of the + // maximum IP header size and the maximum size allocated for link layer + // headers. In this case, no size is allocated for link layer headers. + expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize + if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { + t.Error(err) + } + } else { + reply, ok = outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP Echo packet through outgoing NIC") + } - if n := e2.Drain(); n != 0 { - t.Fatalf("got e2.Drain() = %d, want = 0", n) + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), + checker.TTL(test.TTL-1), + checker.IPv4Options(test.forwardedOptions), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Payload(nil), + ), + ) } } else { - reply, ok := e2.Read() - if !ok { - t.Fatal("expected ICMP Echo packet through outgoing NIC") + if reply, ok = outgoingEndpoint.Read(); ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) } + } + boolToInt := func(val bool) uint64 { + if val { + return 1 + } + return 0 + } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv4Addr1), - checker.DstAddr(remoteIPv4Addr2), - checker.TTL(test.TTL-1), - checker.IPv4Options(test.forwardedOptions), - checker.ICMPv4( - checker.ICMPv4Checksum(), - checker.ICMPv4Type(header.ICMPv4Echo), - checker.ICMPv4Code(header.ICMPv4UnusedCode), - checker.ICMPv4Payload(nil), - ), - ) + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } - if n := e1.Drain(); n != 0 { - t.Fatalf("got e1.Drain() = %d, want = 0", n) - } + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want { + t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want { + t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpCode == header.ICMPv4FragmentationNeeded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want) } }) } @@ -1116,7 +1311,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1135,7 +1330,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4( checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1170,13 +1365,25 @@ func TestIPv4Sanity(t *testing.T) { } } -// comparePayloads compared the contents of all the packets against the contents -// of the source packet. -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { +// compareFragments compares the contents of a set of fragmented packets against +// the contents of a source packet. +// +// If withIPHeader is set to true, we will validate the fragmented packets' IP +// headers against the source packet's IP header. If set to false, we validate +// the fragmented packets' IP headers against each other. +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber, withIPHeader bool, expectedAvailableHeaderBytes int) error { // Make a complete array of the sourcePacket packet. - source := header.IPv4(packets[0].NetworkHeader().View()) + var source header.IPv4 vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) - source = append(source, vv.ToView()...) + + // If the packet to be fragmented contains an IPv4 header, use that header for + // validating fragment headers. Else, use the header of the first fragment. + if withIPHeader { + source = header.IPv4(vv.ToView()) + } else { + source = header.IPv4(packets[0].NetworkHeader().View()) + source = append(source, vv.ToView()...) + } // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -1199,12 +1406,12 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB if got := fragmentIPHeader.TransportProtocol(); got != proto { return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) } - if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { - return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) - } if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) } + if got := packet.AvailableHeaderBytes(); got != expectedAvailableHeaderBytes { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, expectedAvailableHeaderBytes) + } if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) } @@ -1220,6 +1427,14 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) + + // If we are validating against the original IP header, we should exclude the + // ID field, which will only be set fo fragmented packets. + if withIPHeader { + fragmentIPHeader.SetID(0) + fragmentIPHeader.SetChecksum(0) + fragmentIPHeader.SetChecksum(^fragmentIPHeader.CalculateChecksum()) + } if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } @@ -1327,9 +1542,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -1348,7 +1563,7 @@ func TestFragmentationWritePacket(t *testing.T) { if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { t.Error(err) } }) @@ -1383,7 +1598,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) for _, test := range writePacketsTests { t.Run(test.description, func(t *testing.T) { @@ -1393,13 +1608,13 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -1429,7 +1644,7 @@ func TestFragmentationWritePackets(t *testing.T) { } fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] - if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber, false /* withIPHeader */, extraHeaderReserve); err != nil { t.Error(err) } }) @@ -1507,8 +1722,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2082,7 +2297,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { checker.ICMPv4Type(header.ICMPv4TimeExceeded), checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), checker.ICMPv4Checksum(), - checker.ICMPv4Payload([]byte(firstFragmentSent)), + checker.ICMPv4Payload(firstFragmentSent), ), ) }) @@ -2486,7 +2701,7 @@ func TestReceiveFragments(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, RawFactory: raw.EndpointFactory{}, }) - e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) + e := channel.New(0, 1280, "\xf0\x00") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -2729,7 +2944,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -2855,7 +3070,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ @@ -3025,8 +3240,8 @@ func TestCloseLocking(t *testing.T) { ) var ( - src = tcptestutil.MustParse4("16.0.0.1") - dst = tcptestutil.MustParse4("16.0.0.2") + src = testutil.MustParse4("16.0.0.1") + dst = testutil.MustParse4("16.0.0.2") ) s := stack.New(stack.Options{ @@ -3034,7 +3249,7 @@ func TestCloseLocking(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - // Perform NAT so that the endoint tries to search for a sibling endpoint + // Perform NAT so that the endpoint tries to search for a sibling endpoint // which ends up taking the protocol and endpoint lock (in that order). table := stack.Table{ Rules: []stack.Rule{ diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index a637f9d50..d1f9e3cf5 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -19,8 +19,8 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ stack.NetworkInterface = (*testInterface)(nil) diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index db998e83e..f99cbf8f3 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -45,6 +45,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 1319db32b..23fc94303 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -181,10 +181,13 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe return } + // Keep needed information before trimming header. + p := hdr.TransportProtocol() + dstAddr := hdr.DestinationAddress() + // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data().TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6MinimumSize) if p == header.IPv6FragmentHeader { f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { @@ -196,14 +199,14 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // because they don't have the transport headers. return } + p = fragHdr.TransportProtocol() // Skip fragmentation header and find out the actual protocol // number. - pkt.Data().TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize) } - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -327,11 +330,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize) networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 } + pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize) e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: @@ -341,8 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + code := header.ICMPv6(hdr).Code() + pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize) + switch code { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: @@ -741,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - stack := e.protocol.stack - - // Is the networking stack operating as a router? - if !stack.Forwarding(ProtocolNumber) { - // ... No, silently drop the packet. + if !e.Forwarding() { received.routerOnlyPacketsDroppedByHost.Increment() return } @@ -951,6 +951,19 @@ func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo // icmpReason is a marker interface for IPv6 specific ICMP errors. type icmpReason interface { isICMPReason() + // isForwarding indicates whether or not the error arose while attempting to + // forward a packet. + isForwarding() bool + // respondToMulticast indicates whether this error falls under the exception + // outlined by RFC 4443 section 2.4 point e.3 exception 2: + // + // (e.3) A packet destined to an IPv6 multicast address. (There are two + // exceptions to this rule: (1) the Packet Too Big Message (Section 3.2) to + // allow Path MTU discovery to work for IPv6 multicast, and (2) the Parameter + // Problem Message, Code 2 (Section 3.4) reporting an unrecognized IPv6 + // option (see Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + respondsToMulticast() bool } // icmpReasonParameterProblem is an error during processing of extension headers @@ -958,18 +971,6 @@ type icmpReason interface { type icmpReasonParameterProblem struct { code header.ICMPv6Code - // respondToMulticast indicates that we are sending a packet that falls under - // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2: - // - // (e.3) A packet destined to an IPv6 multicast address. (There are - // two exceptions to this rule: (1) the Packet Too Big Message - // (Section 3.2) to allow Path MTU discovery to work for IPv6 - // multicast, and (2) the Parameter Problem Message, Code 2 - // (Section 3.4) reporting an unrecognized IPv6 option (see - // Section 4.2 of [IPv6]) that has the Option Type highest- - // order two bits set to 10). - respondToMulticast bool - // pointer is defined in the RFC 4443 setion 3.4 which reads: // // Pointer Identifies the octet offset within the invoking packet @@ -979,9 +980,20 @@ type icmpReasonParameterProblem struct { // packet if the field in error is beyond what can fit // in the maximum size of an ICMPv6 error message. pointer uint32 + + forwarding bool + + respondToMulticast bool } func (*icmpReasonParameterProblem) isICMPReason() {} +func (p *icmpReasonParameterProblem) isForwarding() bool { + return p.forwarding +} + +func (p *icmpReasonParameterProblem) respondsToMulticast() bool { + return p.respondToMulticast +} // icmpReasonPortUnreachable is an error where the transport protocol has no // listener and no alternative means to inform the sender. @@ -989,12 +1001,96 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +func (*icmpReasonPortUnreachable) isForwarding() bool { + return false +} + +func (*icmpReasonPortUnreachable) respondsToMulticast() bool { + return false +} + +// icmpReasonNetUnreachable is an error where no route can be found to the +// network of the final destination. +type icmpReasonNetUnreachable struct{} + +func (*icmpReasonNetUnreachable) isICMPReason() {} + +func (*icmpReasonNetUnreachable) isForwarding() bool { + // If we hit a Network Unreachable error, then we also know we are + // operating as a router. As per RFC 4443 section 3.1: + // + // If the reason for the failure to deliver is lack of a matching + // entry in the forwarding node's routing table, the Code field is + // set to 0 (Network Unreachable). + return true +} + +func (*icmpReasonNetUnreachable) respondsToMulticast() bool { + return false +} + +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 4443 page 8, Destination Unreachable Message, + // + // If the reason for the failure to deliver cannot be mapped to any of + // other codes, the Code field is set to 3. Example of such cases are + // an inability to resolve the IPv6 destination address into a + // corresponding link address, or a link-specific problem of some sort. + return true +} + +func (*icmpReasonHostUnreachable) respondsToMulticast() bool { + return false +} + +// icmpReasonFragmentationNeeded is an error where a packet is to big to be sent +// out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message. +type icmpReasonPacketTooBig struct{} + +func (*icmpReasonPacketTooBig) isICMPReason() {} + +func (*icmpReasonPacketTooBig) isForwarding() bool { + // If we hit a Packet Too Big error, then we know we are operating as a router. + // As per RFC 4443 section 3.2: + // + // A Packet Too Big MUST be sent by a router in response to a packet that it + // cannot forward because the packet is larger than the MTU of the outgoing + // link. + return true +} + +func (*icmpReasonPacketTooBig) respondsToMulticast() bool { + return true +} + // icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in // transit to its final destination, as per RFC 4443 section 3.3. type icmpReasonHopLimitExceeded struct{} func (*icmpReasonHopLimitExceeded) isICMPReason() {} +func (*icmpReasonHopLimitExceeded) isForwarding() bool { + // If we hit a Hop Limit Exceeded error, then we know we are operating + // as a router. As per RFC 4443 section 3.3: + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard + // the packet and originate an ICMPv6 Time Exceeded message with Code + // 0 to the source of the packet. This indicates either a routing + // loop or too small an initial Hop Limit value. + return true +} + +func (*icmpReasonHopLimitExceeded) respondsToMulticast() bool { + return false +} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -1002,6 +1098,14 @@ type icmpReasonReassemblyTimeout struct{} func (*icmpReasonReassemblyTimeout) isICMPReason() {} +func (*icmpReasonReassemblyTimeout) isForwarding() bool { + return false +} + +func (*icmpReasonReassemblyTimeout) respondsToMulticast() bool { + return false +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { @@ -1030,25 +1134,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip // Section 4.2 of [IPv6]) that has the Option Type highest- // order two bits set to 10). // - var allowResponseToMulticast bool - if reason, ok := reason.(*icmpReasonParameterProblem); ok { - allowResponseToMulticast = reason.respondToMulticast - } - + allowResponseToMulticast := reason.respondsToMulticast() isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst) if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any { return nil } - // If we hit a Hop Limit Exceeded error, then we know we are operating as a - // router. As per RFC 4443 section 3.3: - // - // If a router receives a packet with a Hop Limit of zero, or if a - // router decrements a packet's Hop Limit to zero, it MUST discard the - // packet and originate an ICMPv6 Time Exceeded message with Code 0 to - // the source of the packet. This indicates either a routing loop or - // too small an initial Hop Limit value. - // // If we are operating as a router, do not use the packet's destination // address as the response's source address as we should not own the // destination address of a packet we are forwarding. @@ -1058,7 +1149,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip // packet as "multicast addresses must not be used as source addresses in IPv6 // packets", as per RFC 4291 section 2.7. localAddr := origIPHdrDst - if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast { + if reason.isForwarding() || isOrigDstMulticast { localAddr = "" } // Even if we were able to receive a packet from some remote, we may not have @@ -1072,7 +1163,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -1147,6 +1243,18 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) + counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6AddressUnreachable) + counter = sent.dstUnreachable + case *icmpReasonPacketTooBig: + icmpHdr.SetType(header.ICMPv6PacketTooBig) + icmpHdr.SetCode(header.ICMPv6UnusedCode) + counter = sent.packetTooBig case *icmpReasonHopLimitExceeded: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index e457be3cf..d7b04554e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -21,7 +21,6 @@ import ( "reflect" "strings" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -46,16 +45,12 @@ const ( defaultChannelSize = 1 defaultMTU = 65536 - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 30 * time.Second - arbitraryHopLimit = 42 ) var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) - lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -673,8 +668,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(_, _) = %s", err) @@ -1308,7 +1304,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index f7510c243..68f8308f2 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -63,6 +63,11 @@ const ( buckets = 2048 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + // policyTable is the default policy table defined in RFC 6724 section 2.1. // // A more human-readable version: @@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 { var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -187,6 +193,12 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -270,6 +282,16 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, we should return an ICMP error to the + // original packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -405,27 +427,39 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t } } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.setForwarding(forwarding) { + return + } + allRoutersGroups := [...]tcpip.Address{ header.IPv6AllRoutersInterfaceLocalMulticastAddress, header.IPv6AllRoutersLinkLocalMulticastAddress, header.IPv6AllRoutersSiteLocalMulticastAddress, } - e.mu.Lock() - defer e.mu.Unlock() - if forwarding { - // When transitioning into an IPv6 router, host-only state (NDP discovered - // routers, discovered on-link prefixes, and auto-generated addresses) is - // cleaned up/invalidated and NDP router solicitations are stopped. - e.mu.ndp.stopSolicitingRouters() - e.mu.ndp.cleanupState(true /* hostOnly */) - // As per RFC 4291 section 2.8: // // A router is required to recognize all addresses that a host is @@ -449,28 +483,19 @@ func (e *endpoint) transitionForwarding(forwarding bool) { panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err)) } } - - return - } - - for _, g := range allRoutersGroups { - switch err := e.leaveGroupLocked(g).(type) { - case nil: - case *tcpip.ErrBadLocalAddress: - // The endpoint may have already left the multicast group. - default: - panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } else { + for _, g := range allRoutersGroups { + switch err := e.leaveGroupLocked(g).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } } } - // When transitioning into an IPv6 host, NDP router solicitations are - // started if the endpoint is enabled. - // - // If the endpoint is not currently enabled, routers will be solicited when - // the endpoint becomes enabled (if it is still a host). - if e.Enabled() { - e.mu.ndp.startSolicitingRouters() - } + e.mu.ndp.forwardingChanged(forwarding) } // Enable implements stack.NetworkEndpoint. @@ -552,17 +577,7 @@ func (e *endpoint) Enable() tcpip.Error { e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) } - // If we are operating as a router, then do not solicit routers since we - // won't process the RAs anyway. - // - // Routers do not process Router Advertisements (RA) the same way a host - // does. That is, routers do not learn from RAs (e.g. on-link prefixes - // and default routers). Therefore, soliciting RAs from other routers on - // a link is unnecessary for routers. - if !e.protocol.Forwarding() { - e.mu.ndp.startSolicitingRouters() - } - + e.mu.ndp.startSolicitingRouters() return nil } @@ -613,7 +628,7 @@ func (e *endpoint) disableLocked() { return true }) - e.mu.ndp.cleanupState(false /* hostOnly */) + e.mu.ndp.cleanupState() // The endpoint may have already left the multicast group. switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) { @@ -786,6 +801,12 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol } if packetMustBeFragmented(pkt, networkMTU) { + if pkt.NetworkPacketInfo.IsForwardedPacket { + // As per RFC 2460, section 4.5: + // Unlike IPv4, fragmentation in IPv6 is performed only by source nodes, + // not by routers along a packet's delivery path. + return &tcpip.ErrMessageTooLong{} + } sent, remain, err := e.handleFragments(r, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we @@ -928,16 +949,19 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv6(pkt.NetworkHeader().View()) dstAddr := h.DestinationAddress() - if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { - // As per RFC 4291 section 2.5.6, - // - // Routers must not forward any packets with Link-Local source or - // destination addresses to other links. - return nil + // As per RFC 4291 section 2.5.6, + // + // Routers must not forward any packets with Link-Local source or + // destination addresses to other links. + if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} } hopLimit := h.HopLimit() @@ -949,21 +973,56 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // packet and originate an ICMPv6 Time Exceeded message with Code 0 to // the source of the packet. This indicates either a routing loop or // too small an initial Hop Limit value. - return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } + stk := e.protocol.stack + // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(ep.nic.ID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + ep.handleValidatedPacket(h, pkt) return nil } - r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + // Check extension headers for any errors requiring action during forwarding. + if err := e.processExtensionHeaders(h, pkt, true /* forwarding */); err != nil { + return &ip.ErrParameterProblem{} + } + + r, err := stk.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning the + // ICMP packet because the original error is more relevant to the caller. + _ = e.protocol.returnError(&icmpReasonNetUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() + inNicName := stk.FindNICNameFromID(e.nic.ID()) + outNicName := stk.FindNICNameFromID(r.NICID()) + if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesForwardDropped.Increment() + return nil + } + // We need to do a deep copy of the IP packet because // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do // not own it. @@ -975,10 +1034,23 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // each node that forwards the packet. newHdr.SetHopLimit(hopLimit - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + switch err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + IsForwardedPacket: true, + })); err.(type) { + case nil: + return nil + case *tcpip.ErrMessageTooLong: + // As per RFC 4443, section 3.2: + // A Packet Too Big MUST be sent by a router in response to a packet that + // it cannot forward because the packet is larger than the MTU of the + // outgoing link. + _ = e.protocol.returnError(&icmpReasonPacketTooBig{}, pkt) + return &ip.ErrMessageTooLong{} + default: + return &ip.ErrOther{Err: err} + } } // HandlePacket is called by the link layer when new ipv6 packets arrive for @@ -1059,6 +1131,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.stats.ip + stats.ValidPacketsReceived.Increment() + srcAddr := h.SourceAddress() dstAddr := h.DestinationAddress() @@ -1075,15 +1149,54 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { addressEndpoint.DecRef() } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.InvalidDestinationAddressesReceived.Increment() return } + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + e.stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + e.stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + e.stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + e.stats.ip.Forwarding.Unrouteable.Increment() + case *ip.ErrParameterProblem: + e.stats.ip.Forwarding.ExtensionHeaderProblem.Increment() + case *ip.ErrMessageTooLong: + e.stats.ip.Forwarding.PacketTooBig.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + e.stats.ip.Forwarding.Errors.Increment() + return + } - _ = e.forwardPacket(pkt) + // iptables filtering. All packets that reach here are intended for + // this machine and need not be forwarded. + inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + // iptables is telling us to drop the packet. + stats.IPTablesInputDropped.Increment() return } + // Any returned error is only useful for terminating execution early, but + // we have nothing left to do, so we can drop it. + _ = e.processExtensionHeaders(h, pkt, false /* forwarding */) +} + +// processExtensionHeaders processes the extension headers in the given packet. +// Returns an error if the processing of a header failed or if the packet should +// be discarded. +func (e *endpoint) processExtensionHeaders(h header.IPv6, pkt *stack.PacketBuffer, forwarding bool) error { + stats := e.stats.ip + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + // Create a VV to parse the packet. We don't plan to modify anything here. // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). @@ -1094,15 +1207,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) vv.AppendViews(pkt.Data().Views()) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) - // iptables filtering. All packets that reach here are intended for - // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { - // iptables is telling us to drop the packet. - stats.IPTablesInputDropped.Increment() - return - } - var ( hasFragmentHeader bool routerAlert *header.IPv6RouterAlertOption @@ -1115,22 +1219,41 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) extHdr, done, err := it.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break } + // As per RFC 8200, section 4: + // + // Extension headers (except for the Hop-by-Hop Options header) are + // not processed, inserted, or deleted by any node along a packet's + // delivery path until the packet reaches the node identified in the + // Destination Address field of the IPv6 header. + // + // Furthermore, as per RFC 8200 section 4.1, the Hop By Hop extension + // header is restricted to appear first in the list of extension headers. + // + // Therefore, we can immediately return once we hit any header other + // than the Hop-by-Hop header while forwarding a packet. + if forwarding { + if _, ok := extHdr.(header.IPv6HopByHopOptionsExtHdr); !ok { + return nil + } + } + switch extHdr := extHdr.(type) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { _ = e.protocol.returnError(&icmpReasonParameterProblem{ - code: header.ICMPv6UnknownHeader, - pointer: previousHeaderStart, + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found Hop-by-Hop header = %#v with non-zero previous header offset = %d", extHdr, previousHeaderStart) } optsIt := extHdr.Iter() @@ -1139,7 +1262,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) opt, done, err := optsIt.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break @@ -1154,7 +1277,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // There MUST only be one option of this type, regardless of // value, per Hop-by-Hop header. stats.MalformedPacketsReceived.Increment() - return + return fmt.Errorf("found multiple Router Alert options (%#v, %#v)", opt, routerAlert) } routerAlert = opt stats.OptionRouterAlertReceived.Increment() @@ -1162,10 +1285,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) switch opt.UnknownAction() { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: - return + return fmt.Errorf("found unknown Hop-by-Hop header option = %#v with discard action", opt) case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: if header.IsV6MulticastAddress(dstAddr) { - return + return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt) } fallthrough case header.IPv6OptionUnknownActionDiscardSendICMP: @@ -1180,10 +1303,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found unknown hop-by-hop header option = %#v with discard action", opt) default: - panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) + panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %#v", opt)) } } } @@ -1205,8 +1329,13 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), + // For the sake of consistency, we're using the value of `forwarding` + // here, even though it should always be false if we've reached this + // point. If `forwarding` is true here, we're executing undefined + // behavior no matter what. + forwarding: forwarding, }, pkt) - return + return fmt.Errorf("found unrecognized routing type with non-zero segments left in header = %#v", extHdr) } case header.IPv6FragmentExtHdr: @@ -1241,7 +1370,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if err != nil { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return err } if done { break @@ -1269,7 +1398,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) default: stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return fmt.Errorf("known extension header = %#v present after fragment header in a non-initial fragment", lastHdr) } } @@ -1278,7 +1407,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // Drop the packet as it's marked as a fragment but has no payload. stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return fmt.Errorf("fragment has no payload") } // As per RFC 2460 Section 4.5: @@ -1296,7 +1425,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6ErroneousHeader, pointer: header.IPv6PayloadLenOffset, }, pkt) - return + return fmt.Errorf("found fragment length = %d that is not a multiple of 8 octets", fragmentPayloadLen) } // The packet is a fragment, let's try to reassemble it. @@ -1310,14 +1439,15 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // Parameter Problem, Code 0, message should be sent to the source of // the fragment, pointing to the Fragment Offset field of the fragment // packet. - if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { + lengthAfterReassembly := int(start) + fragmentPayloadLen + if lengthAfterReassembly > header.IPv6MaximumPayloadSize { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: fragmentFieldOffset, }, pkt) - return + return fmt.Errorf("determined that reassembled packet length = %d would exceed allowed length = %d", lengthAfterReassembly, header.IPv6MaximumPayloadSize) } // Note that pkt doesn't have its transport header set after reassembly, @@ -1339,7 +1469,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if err != nil { stats.MalformedPacketsReceived.Increment() stats.MalformedFragmentsReceived.Increment() - return + return err } if ready { @@ -1361,7 +1491,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) opt, done, err := optsIt.Next() if err != nil { stats.MalformedPacketsReceived.Increment() - return + return err } if done { break @@ -1372,10 +1502,10 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) switch opt.UnknownAction() { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: - return + return fmt.Errorf("found unknown destination header option = %#v with discard action", opt) case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: if header.IsV6MulticastAddress(dstAddr) { - return + return fmt.Errorf("found unknown destination header option %#v with discard action", opt) } fallthrough case header.IPv6OptionUnknownActionDiscardSendICMP: @@ -1392,9 +1522,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, }, pkt) - return + return fmt.Errorf("found unknown destination header option %#v with discard action", opt) default: - panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) + panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %#v", opt)) } } @@ -1402,13 +1532,19 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + // Calculate the number of octets parsed from data. We want to remove all + // the data except the unparsed portion located at the end, which its size + // is extHdr.Buf.Size(). + trim := pkt.Data().Size() - extHdr.Buf.Size() + // For unfragmented packets, extHdr still contains the transport header. // Get rid of it. // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) - pkt.Data().Replace(extHdr.Buf) + trim += pkt.TransportHeader().View().Size() + + pkt.Data().DeleteFront(trim) stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { @@ -1425,6 +1561,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) + return fmt.Errorf("destination port unreachable") case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -1456,6 +1593,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) code: header.ICMPv6UnknownHeader, pointer: prevHdrIDOffset, }, pkt) + return fmt.Errorf("transport protocol unreachable") default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -1469,6 +1607,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) } } + return nil } // Close cleans up resources associated with the endpoint. @@ -1490,8 +1629,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1532,8 +1671,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { @@ -1610,8 +1749,8 @@ func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { // AcquireAssignedAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { - e.mu.Lock() - defer e.mu.Unlock() + e.mu.RLock() + defer e.mu.RUnlock() return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB) } @@ -1833,7 +1972,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1858,12 +1996,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - fragmentation *fragmentation.Fragmentation } @@ -2038,35 +2170,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 40a793d6b..afc6c3547 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -31,8 +31,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -2603,7 +2604,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -2802,9 +2803,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt.Clone() - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2858,7 +2859,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -2868,14 +2869,14 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -2980,8 +2981,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -3003,52 +3004,289 @@ func TestFragmentationErrors(t *testing.T) { func TestForwarding(t *testing.T) { const ( - nicID1 = 1 - nicID2 = 2 + incomingNICID = 1 + outgoingNICID = 2 randomSequence = 123 randomIdent = 42 ) - ipv6Addr1 := tcpip.AddressWithPrefix{ + incomingIPv6Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10::1").To16()), PrefixLen: 64, } - ipv6Addr2 := tcpip.AddressWithPrefix{ + outgoingIPv6Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("11::1").To16()), PrefixLen: 64, } + multicastIPv6Addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("ff00::").To16()), + PrefixLen: 64, + } + remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) + unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16()) + linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16()) tests := []struct { - name string - TTL uint8 - expectErrorICMP bool + name string + extHdr func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) + TTL uint8 + expectErrorICMP bool + expectPacketForwarded bool + payloadLength int + countUnrouteablePackets uint64 + sourceAddr tcpip.Address + destAddr tcpip.Address + icmpType header.ICMPv6Type + icmpCode header.ICMPv6Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool + expectExtensionHeaderError bool }{ { name: "TTL of zero", TTL: 0, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, }, { name: "TTL of one", TTL: 1, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "TTL of two", + TTL: 2, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "TTL of three", + TTL: 3, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "Network unreachable", + TTL: 2, + expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: unreachableIPv6Addr, + icmpType: header.ICMPv6DstUnreachable, + icmpCode: header.ICMPv6NetworkUnreachable, + expectPacketUnrouteableError: true, + }, + { + name: "Multicast destination", + TTL: 2, + countUnrouteablePackets: 1, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + expectPacketForwarded: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: linkLocalIPv6Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv6Addr, + destAddr: remoteIPv6Addr2, + expectLinkLocalSourceError: true, + }, + { + name: "Hopbyhop with unknown option skippable action", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 62, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6UnknownOption(), checker.IPv6UnknownOption())) + }, + expectPacketForwarded: true, + }, + { + name: "Hopbyhop with unknown option discard action", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard unknown. + 127, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action (unicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action (multicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectErrorICMP: true, + icmpType: header.ICMPv6ParamProblem, + icmpCode: header.ICMPv6UnknownOption, + expectExtensionHeaderError: true, + }, + { + name: "Hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, }, { - name: "TTL of three", - TTL: 3, - expectErrorICMP: false, + name: "Hopbyhop with router alert option", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 0, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + }, hopByHopExtHdrID, checker.IPv6ExtHdr(checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD))) + }, + expectPacketForwarded: true, + }, + { + name: "Hopbyhop with two router alert options", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + extHdr: func(nextHdr uint8) ([]byte, uint8, checker.NetworkChecker) { + return []byte{ + nextHdr, 1, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + + // Router Alert option. + 5, 2, 0, 0, 0, 0, + }, hopByHopExtHdrID, nil + }, + expectExtensionHeaderError: true, + }, + { + name: "Can't fragment", + TTL: 2, + payloadLength: header.IPv6MinimumMTU + 1, + expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6PacketTooBig, + icmpCode: header.ICMPv6UnusedCode, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Can't fragment multicast", + TTL: 2, + payloadLength: header.IPv6MinimumMTU + 1, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr.Address, + expectErrorICMP: true, + icmpType: header.ICMPv6PacketTooBig, + icmpCode: header.ICMPv6UnusedCode, }, } @@ -3059,41 +3297,60 @@ func TestForwarding(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }) // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + incomingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } - ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} - if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) + incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr} + if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err) } - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } - ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} - if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) + outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr} + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ { - Destination: ipv6Addr1.Subnet(), - NIC: nicID1, + Destination: incomingIPv6Addr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: outgoingIPv6Addr.Subnet(), + NIC: outgoingNICID, }, { - Destination: ipv6Addr2.Subnet(), - NIC: nicID2, + Destination: multicastIPv6Addr.Subnet(), + NIC: outgoingNICID, }, }) - if err := s.SetForwarding(ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) } - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) - icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + transportProtocol := header.ICMPv6ProtocolNumber + extHdrBytes := []byte{} + extHdrChecker := checker.IPv6ExtHdr() + if test.extHdr != nil { + nextHdrID := hopByHopExtHdrID + extHdrBytes, nextHdrID, extHdrChecker = test.extHdr(uint8(header.ICMPv6ProtocolNumber)) + transportProtocol = tcpip.TransportProtocolNumber(nextHdrID) + } + extHdrLen := len(extHdrBytes) + + ipHeaderLength := header.IPv6MinimumSize + icmpHeaderLength := header.ICMPv6MinimumSize + totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen + hdr := buffer.NewPrependable(totalLength) + hdr.Prepend(test.payloadLength) + icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) + icmp.SetIdent(randomIdent) icmp.SetSequence(randomSequence) icmp.SetType(header.ICMPv6EchoRequest) @@ -3101,52 +3358,72 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(0) icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmp, - Src: remoteIPv6Addr1, - Dst: remoteIPv6Addr2, + Src: test.sourceAddr, + Dst: test.destAddr, })) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + copy(hdr.Prepend(extHdrLen), extHdrBytes) + ip := header.IPv6(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - TransportProtocol: header.ICMPv6ProtocolNumber, + PayloadLength: uint16(header.ICMPv6MinimumSize + test.payloadLength), + TransportProtocol: transportProtocol, HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, }) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) - e1.InjectInbound(ProtocolNumber, requestPkt) + incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt) + + reply, ok := incomingEndpoint.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { - t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") + t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) + } + + // As per RFC 4443, page 9: + // + // The returned ICMP packet will contain as much of invoking packet + // as possible without the ICMPv6 packet exceeding the minimum IPv6 + // MTU. + expectedICMPPayloadLength := func() int { + maxICMPPayloadLength := header.IPv6MinimumMTU - ipHeaderLength - icmpHeaderLength + if len(hdr.View()) > maxICMPPayloadLength { + return maxICMPPayloadLength + } + return len(hdr.View()) } checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv6Addr1.Address), - checker.DstAddr(remoteIPv6Addr1), + checker.SrcAddr(incomingIPv6Addr.Address), + checker.DstAddr(test.sourceAddr), checker.TTL(DefaultTTL), checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), - checker.ICMPv6Payload([]byte(hdr.View())), + checker.ICMPv6Type(test.icmpType), + checker.ICMPv6Code(test.icmpCode), + checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), ), ) - if n := e2.Drain(); n != 0 { + if n := outgoingEndpoint.Drain(); n != 0 { t.Fatalf("got e2.Drain() = %d, want = 0", n) } - } else { - reply, ok := e2.Read() + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } + + reply, ok = outgoingEndpoint.Read() + if test.expectPacketForwarded { if !ok { t.Fatal("expected ICMP Echo Request packet through outgoing NIC") } - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv6Addr1), - checker.DstAddr(remoteIPv6Addr2), + checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), + extHdrChecker, checker.ICMPv6( checker.ICMPv6Type(header.ICMPv6EchoRequest), checker.ICMPv6Code(header.ICMPv6UnusedCode), @@ -3154,9 +3431,46 @@ func TestForwarding(t *testing.T) { ), ) - if n := e1.Drain(); n != 0 { + if n := incomingEndpoint.Drain(); n != 0 { t.Fatalf("got e1.Drain() = %d, want = 0", n) } + } else if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) + } + + boolToInt := func(val bool) uint64 { + if val { + return 1 + } + return 0 + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want { + t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value(), boolToInt(test.expectExtensionHeaderError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.ExtensionHeaderProblem.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.PacketTooBig.Value(), boolToInt(test.icmpType == header.ICMPv6PacketTooBig); got != want { + t.Errorf("got s.Stats().IP.Forwarding.PacketTooBig.Value() = %d, want = %d", got, want) } }) } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index d6e0a81a6..11ff36561 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -48,7 +48,7 @@ const ( // defaultHandleRAs is the default configuration for whether or not to // handle incoming Router Advertisements as a host. - defaultHandleRAs = true + defaultHandleRAs = HandlingRAsEnabledWhenForwardingDisabled // defaultDiscoverDefaultRouters is the default configuration for // whether or not to discover default routers from incoming Router @@ -301,10 +301,60 @@ type NDPDispatcher interface { OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } +var _ fmt.Stringer = HandleRAsConfiguration(0) + +// HandleRAsConfiguration enumerates when RAs may be handled. +type HandleRAsConfiguration int + +const ( + // HandlingRAsDisabled indicates that Router Advertisements will not be + // handled. + HandlingRAsDisabled HandleRAsConfiguration = iota + + // HandlingRAsEnabledWhenForwardingDisabled indicates that router + // advertisements will only be handled when forwarding is disabled. + HandlingRAsEnabledWhenForwardingDisabled + + // HandlingRAsAlwaysEnabled indicates that Router Advertisements will always + // be handled, even when forwarding is enabled. + HandlingRAsAlwaysEnabled +) + +// String implements fmt.Stringer. +func (c HandleRAsConfiguration) String() string { + switch c { + case HandlingRAsDisabled: + return "HandlingRAsDisabled" + case HandlingRAsEnabledWhenForwardingDisabled: + return "HandlingRAsEnabledWhenForwardingDisabled" + case HandlingRAsAlwaysEnabled: + return "HandlingRAsAlwaysEnabled" + default: + return fmt.Sprintf("HandleRAsConfiguration(%d)", c) + } +} + +// enabled returns true iff Router Advertisements may be handled given the +// specified forwarding status. +func (c HandleRAsConfiguration) enabled(forwarding bool) bool { + switch c { + case HandlingRAsDisabled: + return false + case HandlingRAsEnabledWhenForwardingDisabled: + return !forwarding + case HandlingRAsAlwaysEnabled: + return true + default: + panic(fmt.Sprintf("unhandled HandleRAsConfiguration = %d", c)) + } +} + // NDPConfigurations is the NDP configurations for the netstack. type NDPConfigurations struct { // The number of Router Solicitation messages to send when the IPv6 endpoint // becomes enabled. + // + // Ignored unless configured to handle Router Advertisements. MaxRtrSolicitations uint8 // The amount of time between transmitting Router Solicitation messages. @@ -318,8 +368,9 @@ type NDPConfigurations struct { // Must be greater than or equal to 0s. MaxRtrSolicitationDelay time.Duration - // HandleRAs determines whether or not Router Advertisements are processed. - HandleRAs bool + // HandleRAs is the configuration for when Router Advertisements should be + // handled. + HandleRAs HandleRAsConfiguration // DiscoverDefaultRouters determines whether or not default routers are // discovered from Router Advertisements, as per RFC 4861 section 6. This @@ -654,7 +705,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a protocol-wide configuration, so we check the // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding // packets. - if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() { + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { + ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment() return } @@ -1609,44 +1661,16 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs map[t delete(tempAddrs, tempAddr) } -// removeSLAACAddresses removes all SLAAC addresses. -// -// If keepLinkLocal is false, the SLAAC generated link-local address is removed. -// -// The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) { - linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - var linkLocalPrefixes int - for prefix, state := range ndp.slaacPrefixes { - // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them if we are cleaning up - // host-only state. - if keepLinkLocal && prefix == linkLocalSubnet { - linkLocalPrefixes++ - continue - } - - ndp.invalidateSLAACPrefix(prefix, state) - } - - if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { - panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) - } -} - // cleanupState cleans up ndp's state. // -// If hostOnly is true, then only host-specific state is cleaned up. -// // This function invalidates all discovered on-link prefixes, discovered // routers, and auto-generated addresses. // -// If hostOnly is true, then the link-local auto-generated address aren't -// invalidated as routers are also expected to generate a link-local address. -// // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupState(hostOnly bool) { - ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */) +func (ndp *ndpState) cleanupState() { + for prefix, state := range ndp.slaacPrefixes { + ndp.invalidateSLAACPrefix(prefix, state) + } for prefix := range ndp.onLinkPrefixes { ndp.invalidateOnLinkPrefix(prefix) @@ -1670,6 +1694,10 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // startSolicitingRouters starts soliciting routers, as per RFC 4861 section // 6.3.7. If routers are already being solicited, this function does nothing. // +// If ndp is not configured to handle Router Advertisements, routers will not +// be solicited as there is no point soliciting routers if we don't handle their +// advertisements. +// // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { if ndp.rtrSolicitTimer.timer != nil { @@ -1682,6 +1710,10 @@ func (ndp *ndpState) startSolicitingRouters() { return } + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { + return + } + // Calculate the random delay before sending our first RS, as per RFC // 4861 section 6.3.7. var delay time.Duration @@ -1722,7 +1754,7 @@ func (ndp *ndpState) startSolicitingRouters() { header.NDPSourceLinkLayerAddressOption(linkAddress), } } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + optsSerializer.Length() icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) rs := header.NDPRouterSolicit(icmpData.MessageBody()) @@ -1774,6 +1806,32 @@ func (ndp *ndpState) startSolicitingRouters() { } } +// forwardingChanged handles a change in forwarding configuration. +// +// If transitioning to a host, router solicitation will be started. Otherwise, +// router solicitation will be stopped if NDP is not configured to handle RAs +// as a router. +// +// Precondition: ndp.ep.mu must be locked. +func (ndp *ndpState) forwardingChanged(forwarding bool) { + if forwarding { + if ndp.configs.HandleRAs.enabled(forwarding) { + return + } + + ndp.stopSolicitingRouters() + return + } + + // Solicit routers when transitioning to a host. + // + // If the endpoint is not currently enabled, routers will be solicited when + // the endpoint becomes enabled (if it is still a host). + if ndp.ep.Enabled() { + ndp.startSolicitingRouters() + } +} + // stopSolicitingRouters stops soliciting routers. If routers are not currently // being solicited, this function does nothing. // diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 52b9a200c..b300ed894 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -32,58 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - t.Cleanup(ep.Close) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := llladdr.WithPrefix() - if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - addressEP.DecRef() - } - - return s, ep -} - var _ NDPDispatcher = (*testNDPDispatcher)(nil) // testNDPDispatcher is an NDPDispatcher only allows default router discovery. @@ -163,11 +111,6 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } -type linkResolutionResult struct { - linkAddr tcpip.LinkAddress - ok bool -} - // TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -732,15 +675,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { } func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - return s, ep - } + const nicID = 1 handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { var extHdrs header.IPv6ExtHdrSerializer @@ -865,6 +800,11 @@ func TestNDPValidation(t *testing.T) { }, } + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) + if err != nil { + t.Fatal(err) + } + for _, typ := range types { for _, isRouter := range []bool{false, true} { name := typ.name @@ -875,13 +815,35 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + }) if isRouter { - // Enabling forwarding makes the stack act as a router. - s.SetForwarding(ProtocolNumber, true) + if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ProtocolNumber, err) + } } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) + } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatal("cannot find network endpoint instance for IPv6") + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}) + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost @@ -906,12 +868,12 @@ func TestNDPValidation(t *testing.T) { // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) + t.Errorf("got invalid.Value() = %d, want = 0", got) } - // RouterOnlyPacketsReceivedByHost count should initially be 0. + // Should initially not have dropped any packets. if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + t.Errorf("got routerOnly.Value() = %d, want = 0", got) } if t.Failed() { @@ -931,18 +893,18 @@ func TestNDPValidation(t *testing.T) { want = 1 } if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) + t.Errorf("got invalid.Value() = %d, want = %d", got, want) } want = 0 if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. + // Router only packets are expected to be dropped when operating + // as a host. want = 1 } if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + t.Errorf("got routerOnly.Value() = %d, want = %d", got, want) } - }) } }) diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go index c2758352f..2f18f60e8 100644 --- a/pkg/tcpip/network/ipv6/stats.go +++ b/pkg/tcpip/network/ipv6/stats.go @@ -29,6 +29,10 @@ type Stats struct { // ICMP holds ICMPv6 statistics. ICMP tcpip.ICMPv6Stats + + // UnhandledRouterAdvertisements is the number of Router Advertisements that + // were observed but not handled. + UnhandledRouterAdvertisements *tcpip.StatCounter } // IsNetworkEndpointStats implements stack.NetworkEndpointStats. diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index a6c877158..b7c2de652 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -18,6 +18,7 @@ import ( "math" "sync/atomic" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sync" ) @@ -213,7 +214,7 @@ type SocketOptions struct { getSendBufferLimits GetSendBufferLimits `state:"manual"` // sendBufferSize determines the send buffer size for this socket. - sendBufferSize int64 + sendBufferSize atomicbitops.AlignedAtomicInt64 // getReceiveBufferLimits provides the handler to get the min, default and // max size for receive buffer. It is initialized at the creation time and @@ -221,7 +222,7 @@ type SocketOptions struct { getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"` // receiveBufferSize determines the receive buffer size for this socket. - receiveBufferSize int64 + receiveBufferSize atomicbitops.AlignedAtomicInt64 // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -612,7 +613,7 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { // GetSendBufferSize gets value for SO_SNDBUF option. func (so *SocketOptions) GetSendBufferSize() int64 { - return atomic.LoadInt64(&so.sendBufferSize) + return so.sendBufferSize.Load() } // SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the @@ -621,7 +622,7 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { v := sendBufferSize if !notify { - atomic.StoreInt64(&so.sendBufferSize, v) + so.sendBufferSize.Store(v) return } @@ -647,18 +648,18 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { // Notify endpoint about change in buffer size. newSz := so.handler.OnSetSendBufferSize(v) - atomic.StoreInt64(&so.sendBufferSize, newSz) + so.sendBufferSize.Store(newSz) } // GetReceiveBufferSize gets value for SO_RCVBUF option. func (so *SocketOptions) GetReceiveBufferSize() int64 { - return atomic.LoadInt64(&so.receiveBufferSize) + return so.receiveBufferSize.Load() } // SetReceiveBufferSize sets value for SO_RCVBUF option. func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) { if !notify { - atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize) + so.receiveBufferSize.Store(receiveBufferSize) return } @@ -683,8 +684,8 @@ func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bo v = math.MaxInt32 } - oldSz := atomic.LoadInt64(&so.receiveBufferSize) + oldSz := so.receiveBufferSize.Load() // Notify endpoint about change in buffer size. newSz := so.handler.OnSetReceiveBufferSize(v, oldSz) - atomic.StoreInt64(&so.receiveBufferSize, newSz) + so.receiveBufferSize.Store(newSz) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 2bd6a67f5..395ff9a07 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( "neighbor_entry_list.go", "neighborstate_string.go", "nic.go", + "nic_stats.go", "nud.go", "packet_buffer.go", "packet_buffer_list.go", @@ -73,6 +74,8 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/atomicbitops", + "//pkg/buffer", "//pkg/ilist", "//pkg/log", "//pkg/rand", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index e5590ecc0..ce9cebdaa 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -440,33 +440,54 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad // Regardless how the address was obtained, it will be acquired before it is // returned. func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { - a.mu.Lock() - defer a.mu.Unlock() + lookup := func() *addressState { + if addrState, ok := a.mu.endpoints[localAddr]; ok { + if !addrState.IsAssigned(allowTemp) { + return nil + } - if addrState, ok := a.mu.endpoints[localAddr]; ok { - if !addrState.IsAssigned(allowTemp) { - return nil - } + if !addrState.IncRef() { + panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + } - if !addrState.IncRef() { - panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + return addrState } - return addrState - } - - if f != nil { - for _, addrState := range a.mu.endpoints { - if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { - return addrState + if f != nil { + for _, addrState := range a.mu.endpoints { + if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { + return addrState + } } } + return nil + } + // Avoid exclusive lock on mu unless we need to add a new address. + a.mu.RLock() + ep := lookup() + a.mu.RUnlock() + + if ep != nil { + return ep } if !allowTemp { return nil } + // Acquire state lock in exclusive mode as we need to add a new temporary + // endpoint. + a.mu.Lock() + defer a.mu.Unlock() + + // Do the lookup again in case another goroutine added the address in the time + // we released and acquired the lock. + ep = lookup() + if ep != nil { + return ep + } + + // Proceed to add a new temporary endpoint. addr := localAddr.WithPrefix() ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) if err != nil { @@ -475,6 +496,7 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc // expect no error. panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) } + // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 2d74e0abc..d971db010 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -54,6 +54,11 @@ type fwdTestNetworkEndpoint struct { nic NetworkInterface proto *fwdTestNetworkProtocol dispatcher TransportDispatcher + + mu struct { + sync.RWMutex + forwarding bool + } } func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { @@ -101,7 +106,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: vv.ToView().ToVectorisedView(), }) - // TODO(b/143425874) Decrease the TTL field in forwarded packets. + // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets. _ = r.WriteHeaderIncludedPacket(pkt) } @@ -109,10 +114,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -129,7 +130,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -169,11 +170,6 @@ type fwdTestNetworkProtocol struct { addrResolveDelay time.Duration onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -242,16 +238,16 @@ func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber return fwdTestNetNumber } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v @@ -264,6 +260,8 @@ type fwdTestPacketInfo struct { Pkt *PacketBuffer } +var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil) + type fwdTestLinkEndpoint struct { dispatcher NetworkDispatcher mtu uint32 @@ -306,11 +304,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { return caps | CapabilityResolutionRequired } -// GSOMaxSize returns the maximum GSO packet size. -func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { @@ -322,7 +315,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -357,7 +350,7 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { panic("not implemented") } @@ -370,8 +363,10 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f }}, }) - // Enable forwarding. - s.SetForwarding(proto.Number(), true) + protoNum := proto.Number() + if err := s.SetForwardingDefaultAndAllNICs(protoNum, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", protoNum, err) + } // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index e2894c548..3670d5995 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -177,6 +177,7 @@ func DefaultTables() *IPTables { priorities: [NumHooks][]TableID{ Prerouting: {MangleID, NATID}, Input: {NATID, FilterID}, + Forward: {FilterID}, Output: {MangleID, NATID, FilterID}, Postrouting: {MangleID, NATID}, }, diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4631ab93f..93592e7f5 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -280,9 +280,18 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) case Output: return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) - case Forward, Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. + case Forward: + if !matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert) { + return false + } + + if !matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert) { + return false + } + + return true + case Postrouting: + // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b6cf24739..ac2fa777e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -481,13 +481,9 @@ func TestDADResolve(t *testing.T) { } for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } e := channelLinkWithHeaderLength{ @@ -499,7 +495,9 @@ func TestDADResolve(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ + Clock: clock, SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, @@ -529,14 +527,10 @@ func TestDADResolve(t *testing.T) { t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - // Make sure the address does not resolve before the resolution time has // passed. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) + const delta = time.Nanosecond + clock.Advance(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Error(err) } @@ -566,13 +560,14 @@ func TestDADResolve(t *testing.T) { } // Wait for DAD to resolve. + clock.Advance(delta) select { - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatalf("expected DAD event for %s on NIC(%d)", addr1, nicID) } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { t.Error(err) @@ -1146,57 +1141,198 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on }) } -// TestNoRouterDiscovery tests that router discovery will not be performed if -// configured not to. -func TestNoRouterDiscovery(t *testing.T) { - // Being configured to discover routers means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // router discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) +func TestDynamicConfigurationsDisabled(t *testing.T) { + const ( + nicID = 1 + maxRtrSolicitDelay = time.Second + ) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + prefix := tcpip.AddressWithPrefix{ + Address: testutil.MustParse6("102:304:506:708::"), + PrefixLen: 64, + } - // Rx an RA with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router when configured not to") - default: + tests := []struct { + name string + config func(bool) ipv6.NDPConfigurations + ra *stack.PacketBuffer + }{ + { + name: "No Router Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} + }, + ra: raBuf(llAddr2, 1000), + }, + { + name: "No Prefix Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0), + }, + { + name: "No Autogenerate Addresses", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Being configured to discover routers/prefixes or auto-generate + // addresses means RAs must be handled, and router/prefix discovery or + // SLAAC must be enabled. + // + // This tests all possible combinations of the configurations where + // router/prefix discovery or SLAAC are disabled. + for i := 0; i < 7; i++ { + handle := ipv6.HandlingRAsDisabled + if i&1 != 0 { + handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled + } + enable := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + ndpConfigs := test.config(enable) + ndpConfigs.HandleRAs = handle + ndpConfigs.MaxRtrSolicitations = 1 + ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay + ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, + }) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding + ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err) + } + stats := ep.Stats() + v6Stats, ok := stats.(*ipv6.Stats) + if !ok { + t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats) + } + + // Make sure that when handling RAs are enabled, we solicit routers. + clock.Advance(maxRtrSolicitDelay) + if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want { + t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want) + } + if handleRAsDisabled { + if p, ok := e.Read(); ok { + t.Errorf("unexpectedly got a packet = %#v", p) + } + } else if p, ok := e.Read(); !ok { + t.Error("expected router solicitation packet") + } else if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } else { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(nil)), + ) + } + + // Make sure we do not discover any routers or prefixes, or perform + // SLAAC on reception of an RA. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) + // Make sure that the unhandled RA stat is only incremented when + // handling RAs is disabled. + if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want { + t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e) + default: + } + }) } }) } } +func boolToUint64(v bool) uint64 { + if v { + return 1 + } + return 0 +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // discovered flag set to discovered. func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) } +func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { + tests := [...]struct { + name string + handleRAs ipv6.HandleRAsConfiguration + forwarding bool + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding disabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding enabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f(t, test.handleRAs, test.forwarding) + }) + } +} + // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { @@ -1207,7 +1343,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1241,103 +1377,109 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { - t.Helper() + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") } - default: - t.Fatal("expected router discovery event") } - } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") } - case <-time.After(timeout): - t.Fatal("timed out waiting for router discovery event") } - } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - default: - } - - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Rx an RA from another router (lladdr3) with non-zero lifetime. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Rx an RA from lladdr2 with lesser lifetime. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - default: - } + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from another router (lladdr3) with non-zero lifetime. + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) + expectRouterEvent(llAddr3, true) - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + // Rx an RA from lladdr2 with lesser lifetime. + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Wait for lladdr2's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + }) } // TestRouterDiscoveryMaxRouters tests that only @@ -1351,7 +1493,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1390,57 +1532,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } } -// TestNoPrefixDiscovery tests that prefix discovery will not be performed if -// configured not to. -func TestNoPrefixDiscovery(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: testutil.MustParse6("102:304:506:708::"), - PrefixLen: 64, - } - - // Being configured to discover prefixes means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // prefix discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) - - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for prefix on nic with ID 1, and the // discovered flag set to discovered. func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { @@ -1459,8 +1550,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1498,87 +1588,93 @@ func TestPrefixDiscovery(t *testing.T) { prefix2, subnet2, _ := prefixSubnetAddr(1, "") prefix3, subnet3, _ := prefixSubnetAddr(2, "") - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") } - default: - t.Fatal("expected prefix discovery event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) - // Wait for prefix2's most recent invalidation job plus some buffer to - // expire. - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: + } + + // Wait for prefix2's most recent invalidation job plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for prefix discovery event") } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for prefix discovery event") - } - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) + }) } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { @@ -1607,7 +1703,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1692,7 +1788,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: false, DiscoverOnLinkPrefixes: true, }, @@ -1757,53 +1853,6 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) return containsAddr(list, protocolAddress) } -// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. -func TestNoAutoGenAddr(t *testing.T) { - prefix, _, _ := prefixSubnetAddr(0, "") - - // Being configured to auto-generate addresses means handle and - // autogen are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, autogen = - // true and forwarding = false (the required configuration to do - // SLAAC) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - autogen := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) - - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for addr on nic with ID 1, and the // event type is set to eventType. func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { @@ -1812,7 +1861,7 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr2(t *testing.T) { +func TestAutoGenAddr(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second saved := ipv6.MinPrefixInformationValidLifetimeForUpdate @@ -1824,96 +1873,102 @@ func TestAutoGenAddr2(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } - default: - t.Fatal("expected addr auto gen event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + }) } func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { @@ -2001,7 +2056,7 @@ func TestAutoGenTempAddr(t *testing.T) { RetransmitTimer: test.retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2302,7 +2357,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2389,7 +2444,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2538,7 +2593,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2739,7 +2794,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, @@ -2884,7 +2939,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: ndpDisp, @@ -3351,7 +3406,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3494,7 +3549,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3561,7 +3616,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3727,7 +3782,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3809,7 +3864,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3973,7 +4028,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { @@ -4000,7 +4055,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Temporary address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -4150,7 +4205,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4278,7 +4333,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4484,7 +4539,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4535,7 +4590,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4629,8 +4684,110 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } } -// TestCleanupNDPState tests that all discovered routers and prefixes, and -// auto-generated addresses are invalidated when a NIC becomes a router. +func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { + const ( + lifetimeSeconds = 999 + nicID = 1 + ) + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenLinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) + + e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1) + if err := s.CreateNIC(nicID, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen} + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID) + } + + prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1) + e1.InjectInbound( + header.IPv6ProtocolNumber, + raBufWithPI( + llAddr3, + lifetimeSeconds, + prefix, + true, /* onLink */ + true, /* auto */ + lifetimeSeconds, + lifetimeSeconds, + ), + ) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + } + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID) + } + + // Enabling or disabling forwarding should not invalidate discovered prefixes + // or routers, or auto-generated address. + for _, forwarding := range [...]bool{true, false} { + t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpected router event = %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpected prefix event = %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpected auto-gen addr event = %#v", e) + default: + } + }) + } +} + func TestCleanupNDPState(t *testing.T) { const ( lifetimeSeconds = 5 @@ -4659,18 +4816,6 @@ func TestCleanupNDPState(t *testing.T) { maxAutoGenAddrEvents int skipFinalAddrCheck bool }{ - // A NIC should still keep its auto-generated link-local address when - // becoming a router. - { - name: "Enable forwarding", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) - }, - keepAutoGenLinkLocal: true, - maxAutoGenAddrEvents: 4, - }, - // A NIC should cleanup all NDP state when it is disabled. { name: "Disable NIC", @@ -4722,7 +4867,7 @@ func TestCleanupNDPState(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, DiscoverOnLinkPrefixes: true, AutoGenGlobalAddresses: true, @@ -4995,7 +5140,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -5186,96 +5331,127 @@ func TestRouterSolicitation(t *testing.T) { }, } + subTests := []struct { + name string + handleRAs ipv6.HandleRAsConfiguration + afterFirstRS func(*testing.T, *stack.Stack) + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + afterFirstRS: func(*testing.T, *stack.Stack) {}, + }, + + // Enabling forwarding when RAs are always configured to be handled + // should not stop router solicitations. + { + name: "Handle RAs always", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + afterFirstRS: func(t *testing.T, s *stack.Stack) { + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + }, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() + + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") + } - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: subTest.handleRAs, + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) - remaining-- - } + subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(test.effectiveRtrSolicitInt) - } - } + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) + } else { + waitForPkt(test.effectiveRtrSolicitInt) + } + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) } }) } @@ -5300,11 +5476,17 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, false) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", ipv6.ProtocolNumber, err) + } }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) + + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } }, }, @@ -5373,6 +5555,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: interval, MaxRtrSolicitationDelay: delay, diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 48bb75e2f..90881169d 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -15,8 +15,6 @@ package stack import ( - "bytes" - "encoding/binary" "fmt" "math" "math/rand" @@ -48,9 +46,6 @@ const ( // be sent to all nodes. testEntryBroadcastAddr = tcpip.Address("broadcast") - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - // testEntryBroadcastLinkAddr is a special link address sent back to // multicast neighbor probes. testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") @@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl randomGenerator: rng, }, id: 1, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), }, linkRes) return linkRes } @@ -106,20 +101,24 @@ type testEntryStore struct { entriesMap map[tcpip.Address]NeighborEntry } -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) +func toAddress(i uint16) tcpip.Address { + return tcpip.Address([]byte{ + 1, + 0, + byte(i >> 8), + byte(i), + }) } -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) +func toLinkAddress(i uint16) tcpip.LinkAddress { + return tcpip.LinkAddress([]byte{ + 1, + 0, + 0, + 0, + byte(i >> 8), + byte(i), + }) } // newTestEntryStore returns a testEntryStore pre-populated with entries. @@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore { store := &testEntryStore{ entriesMap: make(map[tcpip.Address]NeighborEntry), } - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) linkAddr := toLinkAddress(i) @@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore { } // size returns the number of entries in the store. -func (s *testEntryStore) size() int { +func (s *testEntryStore) size() uint16 { s.mu.RLock() defer s.mu.RUnlock() - return len(s.entriesMap) + return uint16(len(s.entriesMap)) } // entry returns the entry at index i. Returns an empty entry and false if i is // out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { +func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { return s.entryByAddr(toAddress(i)) } @@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry { entries := make([]NeighborEntry, 0, len(s.entriesMap)) s.mu.RLock() defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) if entry, ok := s.entriesMap[addr]; ok { entries = append(entries, entry) @@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry { } // set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { +func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { addr := toAddress(i) s.mu.Lock() defer s.mu.Unlock() @@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return 0 } -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() @@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext { } type overflowOptions struct { - startAtEntryIndex int + startAtEntryIndex uint16 wantStaticEntries []NeighborEntry } @@ -1068,7 +1060,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // periodically refreshes the frequently used entry. // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { + for i := uint16(0); i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) @@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { + for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { @@ -1556,14 +1548,14 @@ func TestNeighborCacheRetryResolution(t *testing.T) { func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() - clock := &tcpip.StdClock{} + clock := tcpip.NewStdClock() linkRes := newTestNeighborResolver(nil, config, clock) linkRes.delay = 0 // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { + for i := uint16(0); i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 6d95e1664..463d017fc 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -307,7 +307,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) @@ -361,7 +361,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { e.dispatchAddEventLocked() case Unreachable: e.dispatchChangeEventLocked() - e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment() } config := e.nudState.Config() @@ -378,7 +378,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { // As per RFC 4861 section 7.2.2: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 1d39ee73d..c2a291244 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -36,11 +36,6 @@ const ( entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 ) var ( @@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A // ResolveStaticAddress attempts to resolve address without sending requests. // It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { return "", false } // LinkAddressProtocol returns the network protocol of the addresses this // resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return entryTestNetNumber } @@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudConfigs: c, randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8d615500f..378389db2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,7 +51,7 @@ type nic struct { name string context NICContext - stats NICStats + stats sharedStats // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -78,26 +78,13 @@ type nic struct { } } -// NICStats hold statistics for a NIC. -type NICStats struct { - Tx DirectionStats - Rx DirectionStats - - DisabledRx DirectionStats - - Neighbor NeighborStats -} - -func makeNICStats() NICStats { - var s NICStats - tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) - return s -} - -// DirectionStats includes packet and byte counts. -type DirectionStats struct { - Packets *tcpip.StatCounter - Bytes *tcpip.StatCounter +// makeNICStats initializes the NIC statistics and associates them to the global +// NIC statistics. +func makeNICStats(global tcpip.NICStats) sharedStats { + var stats sharedStats + tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem()) + stats.init(&stats.local, &global) + return stats } type packetEndpointList struct { @@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC id: id, name: name, context: ctx, - stats: makeNICStats(), + stats: makeNICStats(stack.Stats().NICs), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), @@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt return err } - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + n.stats.tx.packets.Increment() + n.stats.tx.bytes.IncrementBy(uint64(numBytes)) return nil } @@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + n.stats.tx.packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + n.stats.tx.bytes.IncrementBy(uint64(writtenBytes)) return writtenPackets, err } @@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if !enabled { n.mu.RUnlock() - n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.disabledRx.packets.Increment() + n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return } - n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.rx.packets.Increment() + n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -786,7 +773,7 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL4ProtocolRcvdPackets.Increment() return TransportPacketProtocolUnreachable } @@ -807,20 +794,20 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() // We consider a malformed transport packet handled because there is // nothing the caller can do. return TransportPacketHandled } } else if !transProto.Parse(pkt) { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } @@ -852,7 +839,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // If it doesn't handle it then we should do so. switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled case UnknownDestinationPacketUnhandled: return TransportPacketDestinationPortUnreachable @@ -1000,3 +987,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t return d.CheckDuplicateAddress(addr, h), nil } + +func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return &tcpip.ErrNotSupported{} + } + + forwardingEP.SetForwarding(enable) + return nil +} + +func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return false, &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return false, &tcpip.ErrNotSupported{} + } + + return forwardingEP.Forwarding(), nil +} diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go new file mode 100644 index 000000000..1773d5e8d --- /dev/null +++ b/pkg/tcpip/stack/nic_stats.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +type sharedStats struct { + local tcpip.NICStats + multiCounterNICStats +} + +// LINT.IfChange(multiCounterNICPacketStats) + +type multiCounterNICPacketStats struct { + packets tcpip.MultiCounterStat + bytes tcpip.MultiCounterStat +} + +func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) { + m.packets.Init(a.Packets, b.Packets) + m.bytes.Init(a.Bytes, b.Bytes) +} + +// LINT.ThenChange(../../tcpip.go:NICPacketStats) + +// LINT.IfChange(multiCounterNICNeighborStats) + +type multiCounterNICNeighborStats struct { + unreachableEntryLookups tcpip.MultiCounterStat +} + +func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) { + m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups) +} + +// LINT.ThenChange(../../tcpip.go:NICNeighborStats) + +// LINT.IfChange(multiCounterNICStats) + +type multiCounterNICStats struct { + unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat + unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat + malformedL4RcvdPackets tcpip.MultiCounterStat + tx multiCounterNICPacketStats + rx multiCounterNICPacketStats + disabledRx multiCounterNICPacketStats + neighbor multiCounterNICNeighborStats +} + +func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) { + m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets) + m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets) + m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets) + m.tx.init(&a.Tx, &b.Tx) + m.rx.init(&a.Rx, &b.Rx) + m.disabledRx.init(&a.DisabledRx, &b.DisabledRx) + m.neighbor.init(&a.Neighbor, &b.Neighbor) +} + +// LINT.ThenChange(../../tcpip.go:NICStats) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 8a3005295..5cb342f78 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,13 @@ package stack import ( + "reflect" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) @@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. nic := nic{ - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { t.Errorf("got DisabledRx.Packets = %d, want = 0", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } @@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), })) - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + global := tcpip.NICStats{}.FillIn() + nic := nic{ + stats: makeNICStats(global), + } + multi := nic.stats.multiCounterNICStats + local := nic.stats.local + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..02f905351 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -323,26 +323,21 @@ type Rand interface { type NUDState struct { rng Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex - - config NUDConfigurations - - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration - - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + mu struct { + sync.RWMutex + + config NUDConfigurations + + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration + + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified @@ -351,7 +346,7 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { s := &NUDState{ rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +354,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +372,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if time.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +403,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = time.Now().Add(2 * time.Hour) } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e1253f310..6ba97d626 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -28,17 +28,15 @@ import ( ) const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 defaultFakeRandomNum = 0.5 ) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 646979d1e..4ca702121 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -16,9 +16,10 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + tcpipbuffer "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -39,7 +40,11 @@ type PacketBufferOptions struct { // Data is the initial unparsed data for the new packet. If set, it will be // owned by the new packet. - Data buffer.VectorisedView + Data tcpipbuffer.VectorisedView + + // IsForwardedPacket identifies that the PacketBuffer being created is for a + // forwarded packet. + IsForwardedPacket bool } // A PacketBuffer contains all the data of a network packet. @@ -52,6 +57,34 @@ type PacketBufferOptions struct { // empty. Use of PacketBuffer in any other order is unsupported. // // PacketBuffer must be created with NewPacketBuffer. +// +// Internal structure: A PacketBuffer holds a pointer to buffer.Buffer, which +// exposes a logically-contiguous byte storage. The underlying storage structure +// is abstracted out, and should not be a concern here for most of the time. +// +// |- reserved ->| +// |--->| consumed (incoming) +// 0 V V +// +--------+----+----+--------------------+ +// | | | | current data ... | (buf) +// +--------+----+----+--------------------+ +// ^ | +// |<---| pushed (outgoing) +// +// When a PacketBuffer is created, a `reserved` header region can be specified, +// which stack pushes headers in this region for an outgoing packet. There could +// be no such region for an incoming packet, and `reserved` is 0. The value of +// `reserved` never changes in the entire lifetime of the packet. +// +// Outgoing Packet: When a header is pushed, `pushed` gets incremented by the +// pushed length, and the current value is stored for each header. PacketBuffer +// substracts this value from `reserved` to compute the starting offset of each +// header in `buf`. +// +// Incoming Packet: When a header is consumed (a.k.a. parsed), the current +// `consumed` value is stored for each header, and it gets incremented by the +// consumed length. PacketBuffer adds this value to `reserved` to compute the +// starting offset of each header in `buf`. type PacketBuffer struct { _ sync.NoCopy @@ -59,28 +92,16 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // data holds the payload of the packet. - // - // For inbound packets, Data is initially the whole packet. Then gets moved to - // headers via PacketHeader.Consume, when the packet is being parsed. - // - // For outbound packets, Data is the innermost layer, defined by the protocol. - // Headers are pushed in front of it via PacketHeader.Push. - // - // The bytes backing Data are immutable, a.k.a. users shouldn't write to its - // backing storage. - data buffer.VectorisedView + // buf is the underlying buffer for the packet. See struct level docs for + // details. + buf *buffer.Buffer + reserved int + pushed int + consumed int // headers stores metadata about each header. headers [numHeaderType]headerInfo - // header is the internal storage for outbound packets. Headers will be pushed - // (prepended) on this storage as the packet is being constructed. - // - // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and - // data are held in the same underlying buffer storage. - header buffer.Prependable - // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() // returns false. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol @@ -127,10 +148,17 @@ type PacketBuffer struct { // NewPacketBuffer creates a new PacketBuffer with opts. func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { pk := &PacketBuffer{ - data: opts.Data, + buf: &buffer.Buffer{}, } if opts.ReserveHeaderBytes != 0 { - pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + pk.buf.AppendOwned(make([]byte, opts.ReserveHeaderBytes)) + pk.reserved = opts.ReserveHeaderBytes + } + for _, v := range opts.Data.Views() { + pk.buf.AppendOwned(v) + } + if opts.IsForwardedPacket { + pk.NetworkPacketInfo.IsForwardedPacket = opts.IsForwardedPacket } return pk } @@ -138,13 +166,13 @@ func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { // ReservedHeaderBytes returns the number of bytes initially reserved for // headers. func (pk *PacketBuffer) ReservedHeaderBytes() int { - return pk.header.UsedLength() + pk.header.AvailableLength() + return pk.reserved } // AvailableHeaderBytes returns the number of bytes currently available for // headers. This is relevant to PacketHeader.Push method only. func (pk *PacketBuffer) AvailableHeaderBytes() int { - return pk.header.AvailableLength() + return pk.reserved - pk.pushed } // LinkHeader returns the handle to link-layer header. @@ -173,24 +201,18 @@ func (pk *PacketBuffer) TransportHeader() PacketHeader { // HeaderSize returns the total size of all headers in bytes. func (pk *PacketBuffer) HeaderSize() int { - // Note for inbound packets (Consume called), headers are not stored in - // pk.header. Thus, calculation of size of each header is needed. - var size int - for i := range pk.headers { - size += len(pk.headers[i].buf) - } - return size + return pk.pushed + pk.consumed } // Size returns the size of packet in bytes. func (pk *PacketBuffer) Size() int { - return pk.HeaderSize() + pk.data.Size() + return int(pk.buf.Size()) - pk.headerOffset() } // MemSize returns the estimation size of the pk in memory, including backing // buffer data. func (pk *PacketBuffer) MemSize() int { - return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize + return int(pk.buf.Size()) + packetBufferStructSize } // Data returns the handle to data portion of pk. @@ -199,61 +221,65 @@ func (pk *PacketBuffer) Data() PacketData { } // Views returns the underlying storage of the whole packet. -func (pk *PacketBuffer) Views() []buffer.View { - // Optimization for outbound packets that headers are in pk.header. - useHeader := true - for i := range pk.headers { - if !canUseHeader(&pk.headers[i]) { - useHeader = false - break - } - } +func (pk *PacketBuffer) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := pk.headerOffset() + pk.buf.SubApply(offset, int(pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views +} - dataViews := pk.data.Views() - - var vs []buffer.View - if useHeader { - vs = make([]buffer.View, 0, 1+len(dataViews)) - vs = append(vs, pk.header.View()) - } else { - vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) - for i := range pk.headers { - if v := pk.headers[i].buf; len(v) > 0 { - vs = append(vs, v) - } - } - } - return append(vs, dataViews...) +func (pk *PacketBuffer) headerOffset() int { + return pk.reserved - pk.pushed +} + +func (pk *PacketBuffer) headerOffsetOf(typ headerType) int { + return pk.reserved + pk.headers[typ].offset } -func canUseHeader(h *headerInfo) bool { - // h.offset will be negative if the header was pushed in to prependable - // portion, or doesn't matter when it's empty. - return len(h.buf) == 0 || h.offset < 0 +func (pk *PacketBuffer) dataOffset() int { + return pk.reserved + pk.consumed } -func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { +func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] - if h.buf != nil { - panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + if h.length > 0 { + panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size)) + } + if pk.pushed+size > pk.reserved { + panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved)) } - h.buf = buffer.View(pk.header.Prepend(size)) - h.offset = -pk.header.UsedLength() - return h.buf + pk.pushed += size + h.offset = -pk.pushed + h.length = size + return pk.headerView(typ) } -func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { +func (pk *PacketBuffer) consume(typ headerType, size int) (v tcpipbuffer.View, consumed bool) { h := &pk.headers[typ] - if h.buf != nil { + if h.length > 0 { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - v, ok := pk.data.PullUp(size) + if pk.reserved+pk.consumed+size > int(pk.buf.Size()) { + return nil, false + } + h.offset = pk.consumed + h.length = size + pk.consumed += size + return pk.headerView(typ), true +} + +func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { + h := &pk.headers[typ] + if h.length == 0 { + return nil + } + v, ok := pk.buf.PullUp(pk.headerOffsetOf(typ), h.length) if !ok { - return + panic("PullUp failed") } - pk.data.TrimFront(size) - h.buf = v - return h.buf, true + return v } // Clone makes a shallow copy of pk. @@ -263,9 +289,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - data: pk.data.Clone(nil), + buf: pk.buf, + reserved: pk.reserved, + pushed: pk.pushed, + consumed: pk.consumed, headers: pk.headers, - header: pk.header, Hash: pk.Hash, Owner: pk.Owner, GSOOptions: pk.GSOOptions, @@ -299,9 +327,11 @@ func (pk *PacketBuffer) Network() header.Network { // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { - newPk := NewPacketBuffer(PacketBufferOptions{ - Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), - }) + newPk := &PacketBuffer{ + buf: pk.buf, + // Treat unfilled header portion as reserved. + reserved: pk.AvailableHeaderBytes(), + } // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to // maintain this flag in the packet. Currently conntrack needs this flag to // tell if a noop connection should be inserted at Input hook. Once conntrack @@ -315,15 +345,12 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // headerInfo stores metadata about a header in a packet. type headerInfo struct { - // buf is the memorized slice for both prepended and consumed header. - // When header is prepended, buf serves as memorized value, which is a slice - // of pk.header. When header is consumed, buf is the slice pulled out from - // pk.Data, which is the only place to hold this header. - buf buffer.View - - // offset will be a negative number denoting the offset where this header is - // from the end of pk.header, if it is prepended. Otherwise, zero. + // offset is the offset of the header in pk.buf relative to + // pk.buf[pk.reserved]. See the PacketBuffer struct for details. offset int + + // length is the length of this header. + length int } // PacketHeader is a handle object to a header in the underlying packet. @@ -333,14 +360,14 @@ type PacketHeader struct { } // View returns the underlying storage of h. -func (h PacketHeader) View() buffer.View { - return h.pk.headers[h.typ].buf +func (h PacketHeader) View() tcpipbuffer.View { + return h.pk.headerView(h.typ) } // Push pushes size bytes in the front of its residing packet, and returns the // backing storage. Callers may only call one of Push or Consume once on each // header in the lifetime of the underlying packet. -func (h PacketHeader) Push(size int) buffer.View { +func (h PacketHeader) Push(size int) tcpipbuffer.View { return h.pk.push(h.typ, size) } @@ -349,7 +376,7 @@ func (h PacketHeader) Push(size int) buffer.View { // size, consumed will be false, and the state of h will not be affected. // Callers may only call one of Push or Consume once on each header in the // lifetime of the underlying packet. -func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { +func (h PacketHeader) Consume(size int) (v tcpipbuffer.View, consumed bool) { return h.pk.consume(h.typ, size) } @@ -360,54 +387,84 @@ type PacketData struct { // PullUp returns a contiguous view of size bytes from the beginning of d. // Callers should not write to or keep the view for later use. -func (d PacketData) PullUp(size int) (buffer.View, bool) { - return d.pk.data.PullUp(size) +func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { + return d.pk.buf.PullUp(d.pk.dataOffset(), size) } -// TrimFront removes count from the beginning of d. It panics if count > -// d.Size(). -func (d PacketData) TrimFront(count int) { - d.pk.data.TrimFront(count) +// DeleteFront removes count from the beginning of d. It panics if count > +// d.Size(). All backing storage references after the front of the d are +// invalidated. +func (d PacketData) DeleteFront(count int) { + if !d.pk.buf.Remove(d.pk.dataOffset(), count) { + panic("count > d.Size()") + } } // CapLength reduces d to at most length bytes. func (d PacketData) CapLength(length int) { - d.pk.data.CapLength(length) + if length < 0 { + panic("length < 0") + } + if currLength := d.Size(); currLength > length { + trim := currLength - length + d.pk.buf.Remove(int(d.pk.buf.Size())-trim, trim) + } } // Views returns the underlying storage of d in a slice of Views. Caller should // not modify the returned slice. -func (d PacketData) Views() []buffer.View { - return d.pk.data.Views() +func (d PacketData) Views() []tcpipbuffer.View { + var views []tcpipbuffer.View + offset := d.pk.dataOffset() + d.pk.buf.SubApply(offset, int(d.pk.buf.Size())-offset, func(v []byte) { + views = append(views, v) + }) + return views } // AppendView appends v into d, taking the ownership of v. -func (d PacketData) AppendView(v buffer.View) { - d.pk.data.AppendView(v) +func (d PacketData) AppendView(v tcpipbuffer.View) { + d.pk.buf.AppendOwned(v) } -// ReadFromData moves at most count bytes from the beginning of srcData to the -// end of d and returns the number of bytes moved. -func (d PacketData) ReadFromData(srcData PacketData, count int) int { - return srcData.pk.data.ReadToVV(&d.pk.data, count) +// MergeFragment appends the data portion of frag to dst. It takes ownership of +// frag and frag should not be used again. +func MergeFragment(dst, frag *PacketBuffer) { + frag.buf.TrimFront(int64(frag.dataOffset())) + dst.buf.Merge(frag.buf) } // ReadFromVV moves at most count bytes from the beginning of srcVV to the end // of d and returns the number of bytes moved. -func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int { - return srcVV.ReadToVV(&d.pk.data, count) +func (d PacketData) ReadFromVV(srcVV *tcpipbuffer.VectorisedView, count int) int { + done := 0 + for _, v := range srcVV.Views() { + if len(v) < count { + count -= len(v) + done += len(v) + d.pk.buf.AppendOwned(v) + } else { + v = v[:count] + count -= len(v) + done += len(v) + d.pk.buf.Append(v) + break + } + } + srcVV.TrimFront(done) + return done } // Size returns the number of bytes in the data payload of the packet. func (d PacketData) Size() int { - return d.pk.data.Size() + return int(d.pk.buf.Size()) - d.pk.dataOffset() } // AsRange returns a Range representing the current data payload of the packet. func (d PacketData) AsRange() Range { return Range{ pk: d.pk, - offset: d.pk.HeaderSize(), + offset: d.pk.dataOffset(), length: d.Size(), } } @@ -417,17 +474,12 @@ func (d PacketData) AsRange() Range { // // This method exists for compatibility between PacketBuffer and VectorisedView. // It may be removed later and should be used with care. -func (d PacketData) ExtractVV() buffer.VectorisedView { - return d.pk.data -} - -// Replace replaces the data portion of the packet with vv, taking the ownership -// of vv. -// -// This method exists for compatibility between PacketBuffer and VectorisedView. -// It may be removed later and should be used with care. -func (d PacketData) Replace(vv buffer.VectorisedView) { - d.pk.data = vv +func (d PacketData) ExtractVV() tcpipbuffer.VectorisedView { + var vv tcpipbuffer.VectorisedView + d.pk.buf.SubApply(d.pk.dataOffset(), d.pk.Size(), func(v []byte) { + vv.AppendView(v) + }) + return vv } // Range represents a contiguous subportion of a PacketBuffer. @@ -471,9 +523,9 @@ func (r Range) Capped(max int) Range { // AsView returns the backing storage of r if possible. It will allocate a new // View if r spans multiple pieces internally. Caller should not write to the // returned View in any way. -func (r Range) AsView() buffer.View { +func (r Range) AsView() tcpipbuffer.View { var allocated bool - var v buffer.View + var v tcpipbuffer.View r.iterate(func(b []byte) { if v == nil { // v has not been assigned, allowing first view to be returned. @@ -494,7 +546,7 @@ func (r Range) AsView() buffer.View { } // ToOwnedView returns a owned copy of data in r. -func (r Range) ToOwnedView() buffer.View { +func (r Range) ToOwnedView() tcpipbuffer.View { if r.length == 0 { return nil } @@ -515,63 +567,7 @@ func (r Range) Checksum() uint16 { // iterate calls fn for each piece in r. fn is always called with a non-empty // slice. func (r Range) iterate(fn func([]byte)) { - w := window{ - offset: r.offset, - length: r.length, - } - // Header portion. - for i := range r.pk.headers { - if b := w.process(r.pk.headers[i].buf); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - // Data portion. - if !w.isDone() { - for _, v := range r.pk.data.Views() { - if b := w.process(v); len(b) > 0 { - fn(b) - } - if w.isDone() { - break - } - } - } -} - -// window represents contiguous region of byte stream. User would call process() -// to input bytes, and obtain a subslice that is inside the window. -type window struct { - offset int - length int -} - -// isDone returns true if the window has passed and further process() calls will -// always return an empty slice. This can be used to end processing early. -func (w *window) isDone() bool { - return w.length == 0 -} - -// process feeds b in and returns a subslice that is inside the window. The -// returned slice will be a subslice of b, and it does not keep b after method -// returns. This method may return an empty slice if nothing in b is inside the -// window. -func (w *window) process(b []byte) (inWindow []byte) { - if w.offset >= len(b) { - w.offset -= len(b) - return nil - } - if w.offset > 0 { - b = b[w.offset:] - w.offset = 0 - } - if w.length < len(b) { - b = b[:w.length] - } - w.length -= len(b) - return b + r.pk.buf.SubApply(r.offset, r.length, fn) } // PayloadSince returns packet payload starting from and including a particular @@ -579,21 +575,14 @@ func (w *window) process(b []byte) (inWindow []byte) { // // The returned View is owned by the caller - its backing buffer is separate // from the packet header's underlying packet buffer. -func PayloadSince(h PacketHeader) buffer.View { - size := h.pk.data.Size() - for _, hinfo := range h.pk.headers[h.typ:] { - size += len(hinfo.buf) +func PayloadSince(h PacketHeader) tcpipbuffer.View { + offset := h.pk.headerOffset() + for i := headerType(0); i < h.typ; i++ { + offset += h.pk.headers[i].length } - - v := make(buffer.View, 0, size) - - for _, hinfo := range h.pk.headers[h.typ:] { - v = append(v, hinfo.buf...) - } - - for _, view := range h.pk.data.Views() { - v = append(v, view...) - } - - return v + return Range{ + pk: h.pk, + offset: offset, + length: int(h.pk.buf.Size()) - offset, + }.ToOwnedView() } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 6728370c3..a8da34992 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkData(t, pk, test.data) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), - concatViews(test.link, test.network, test.transport, test.data)) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(test.link, test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(test.transport, test.data)) + // Check the after state. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.link, + network: test.network, + transport: test.transport, + data: test.data, + }) }) } } @@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) { if got, want := pk.Size(), len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - // After state of pk. - var ( - link = test.data[:test.link] - network = test.data[test.link:][:test.network] - transport = test.data[test.link+test.network:][:test.transport] - payload = test.data[allHdrSize:] - ) - checkData(t, pk, payload) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(link, network, transport, payload)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(network, transport, payload)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(transport, payload)) + // Check the after state of pk. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.data[:test.link], + network: test.data[test.link:][:test.network], + transport: test.data[test.link+test.network:][:test.transport], + data: test.data[allHdrSize:], + }) }) } } @@ -252,6 +226,70 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) { }) } +// This is a very obscure use-case seen in the code that verifies packets +// before sending them out. It tries to parse the headers to verify. +// PacketHeader was initially not designed to mix Push() and Consume(), but it +// works and it's been relied upon. Include a test here. +func TestPacketHeaderPushConsumeMixed(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := append([]byte(nil), network...) + initData = append(initData, data...) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Consume network header + gotNetwork, ok := pk.NetworkHeader().Consume(len(network)) + if !ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network)) + } + checkViewEqual(t, "gotNetwork", gotNetwork, network) + + // 2. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + network: network, + data: data, + }) +} + +func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := concatViews(network, data) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + data: initData, + }) + + // 2. Consume network header, with a number of bytes too large. + gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1) + if ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork) + } + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + data: initData, + }) +} + func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { const headerSize = 10 @@ -397,11 +435,11 @@ func TestPacketBufferData(t *testing.T) { } }) - // TrimFront + // DeleteFront for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().TrimFront(n) + pkt.Data().DeleteFront(n) checkData(t, pkt, []byte(tc.data)[n:]) }) @@ -437,23 +475,8 @@ func TestPacketBufferData(t *testing.T) { checkData(t, pkt, []byte(tc.data+s)) }) - // ReadFromData/VV + // ReadFromVV for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { - t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { - s := "TO READ" - otherPkt := NewPacketBuffer(PacketBufferOptions{ - Data: vv(s, s), - }) - s += s - - pkt := tc.makePkt(t) - pkt.Data().ReadFromData(otherPkt.Data(), n) - - if n < len(s) { - s = s[:n] - } - checkData(t, pkt, []byte(tc.data+s)) - }) t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { s := "TO READ" srcVV := vv(s, s) @@ -480,20 +503,41 @@ func TestPacketBufferData(t *testing.T) { t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) } }) - - // Replace - t.Run("Replace", func(t *testing.T) { - s := "REPLACED" - - pkt := tc.makePkt(t) - pkt.Data().Replace(vv(s)) - - checkData(t, pkt, []byte(s)) - }) }) } } +type packetContents struct { + link buffer.View + network buffer.View + transport buffer.View + data buffer.View +} + +func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) { + t.Helper() + // Headers. + checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link) + checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network) + checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport) + // Data. + checkData(t, pk, want.data) + // Whole packet. + checkViewEqual(t, prefix+"pk.Views()", + concatViews(pk.Views()...), + concatViews(want.link, want.network, want.transport, want.data)) + // PayloadSince. + checkViewEqual(t, prefix+"PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(want.link, want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(want.transport, want.data)) +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -510,19 +554,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkData(t, pk, data) - checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) - // Check the initial values for each header. - checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) - checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) - checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) - // Check the initial valies for PayloadSince. - checkViewEqual(t, "Initial PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), data) + checkPacketContents(t, "Initial ", pk, packetContents{ + data: data, + }) } func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { @@ -540,7 +574,7 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) { func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { t.Helper() if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { - t.Errorf("pkt.Data().Views() = %x, want %x", got, want) + t.Errorf("pkt.Data().Views() = 0x%x, want 0x%x", got, want) } if got := pkt.Data().Size(); got != len(want) { t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 7ad206f6d..85bb87b4b 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -55,6 +55,9 @@ type NetworkPacketInfo struct { // LocalAddressBroadcast is true if the packet's local address is a broadcast // address. LocalAddressBroadcast bool + + // IsForwardedPacket is true if the packet is being forwarded. + IsForwardedPacket bool } // TransportErrorKind enumerates error types that are handled by the transport @@ -655,9 +658,9 @@ type IPNetworkEndpointStats interface { IPStats() *tcpip.IPStats } -// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. -type ForwardingNetworkProtocol interface { - NetworkProtocol +// ForwardingNetworkEndpoint is a network endpoint that may forward packets. +type ForwardingNetworkEndpoint interface { + NetworkEndpoint // Forwarding returns the forwarding configuration. Forwarding() bool @@ -756,11 +759,6 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityHardwareGSO - - // CapabilitySoftwareGSO indicates the link endpoint supports of sending - // multiple packets using a single call (LinkEndpoint.WritePackets). - CapabilitySoftwareGSO ) // NetworkLinkEndpoint is a data-link layer that supports sending network @@ -1047,10 +1045,29 @@ type GSO struct { MaxSize uint32 } +// SupportedGSO returns the type of segmentation offloading supported. +type SupportedGSO int + +const ( + // GSONotSupported indicates that segmentation offloading is not supported. + GSONotSupported SupportedGSO = iota + + // HWGSOSupported indicates that segmentation offloading may be performed by + // the hardware. + HWGSOSupported + + // SWGSOSupported indicates that segmentation offloading may be performed in + // software. + SWGSOSupported +) + // GSOEndpoint provides access to GSO properties. type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 + + // SupportedGSO returns the supported segmentation offloading. + SupportedGSO() SupportedGSO } // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 4ecde5995..f17c04277 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool { // HasSoftwareGSOCapability returns true if the route supports software GSO. func (r *Route) HasSoftwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == SWGSOSupported + } + return false } // HasHardwareGSOCapability returns true if the route supports hardware GSO. func (r *Route) HasHardwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == HWGSOSupported + } + return false } // HasSaveRestoreCapability returns true if the route supports save/restore. @@ -440,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool { // If the source NIC and outgoing NIC are different, make sure the stack has // forwarding enabled, or the packet will be handled locally. - if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { + if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { return false } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 843118b13..72760a4a7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -29,6 +29,7 @@ import ( "time" "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -39,13 +40,6 @@ import ( ) const ( - // ageLimit is set to the same cache stale time used in Linux. - ageLimit = 1 * time.Minute - // resolutionTimeout is set to the same ARP timeout used in Linux. - resolutionTimeout = 1 * time.Second - // resolutionAttempts is set to the same ARP retries used in Linux. - resolutionAttempts = 3 - // DefaultTOS is the default type of service value for network endpoints. DefaultTOS = 0 ) @@ -65,10 +59,10 @@ type ResumableEndpoint interface { } // uniqueIDGenerator is a default unique ID generator. -type uniqueIDGenerator uint64 +type uniqueIDGenerator atomicbitops.AlignedAtomicUint64 func (u *uniqueIDGenerator) UniqueID() uint64 { - return atomic.AddUint64((*uint64)(u), 1) + return ((*atomicbitops.AlignedAtomicUint64)(u)).Add(1) } // Stack is a networking stack, with all supported protocols, NICs, and route @@ -94,8 +88,9 @@ type Stack struct { } } - mu sync.RWMutex - nics map[tcpip.NICID]*nic + mu sync.RWMutex + nics map[tcpip.NICID]*nic + defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{} // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -322,7 +317,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {} func New(opts Options) *Stack { clock := opts.Clock if clock == nil { - clock = &tcpip.StdClock{} + clock = tcpip.NewStdClock() } if opts.UniqueID == nil { @@ -347,22 +342,23 @@ func New(opts Options) *Stack { } s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*nic), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), - nudConfigs: opts.NUDConfigs, - uniqueIDGenerator: opts.UniqueID, - nudDisp: opts.NUDDisp, - randomGenerator: mathrand.New(randSrc), - secureRNG: opts.SecureRNG, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + nics: make(map[tcpip.NICID]*nic), + defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: opts.IPTables, + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + nudConfigs: opts.NUDConfigs, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + randomGenerator: mathrand.New(randSrc), + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -491,37 +487,61 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables packet forwarding between NICs for the -// passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { - protocol, ok := s.networkProtocols[protocolNum] +// SetNICForwarding enables or disables packet forwarding on the specified NIC +// for the passed protocol. +func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrUnknownProtocol{} + return &tcpip.ErrUnknownNICID{} } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + return nic.setForwarding(protocol, enable) +} + +// NICForwarding returns the forwarding configuration for the specified NIC. +func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrNotSupported{} + return false, &tcpip.ErrUnknownNICID{} } - forwardingProtocol.SetForwarding(enable) - return nil + return nic.forwarding(protocol) } -// Forwarding returns true if packet forwarding between NICs is enabled for the -// passed protocol. -func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { - protocol, ok := s.networkProtocols[protocolNum] - if !ok { - return false +// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the +// passed protocol and sets the default setting for newly created NICs. +func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + doneOnce := false + for id, nic := range s.nics { + if err := nic.setForwarding(protocol, enable); err != nil { + // Expect forwarding to be settable on all interfaces if it was set on + // one. + if doneOnce { + panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err)) + } + + return err + } + + doneOnce = true } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) - if !ok { - return false + if enable { + s.defaultForwardingEnabled[protocol] = struct{}{} + } else { + delete(s.defaultForwardingEnabled, protocol) } - return forwardingProtocol.Forwarding() + return nil } // PortRange returns the UDP and TCP inclusive range of ephemeral ports used in @@ -658,6 +678,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } n := newNIC(s, id, opts.Name, ep, opts.Context) + for proto := range s.defaultForwardingEnabled { + if err := n.setForwarding(proto, true); err != nil { + panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err)) + } + } s.nics[id] = n if !opts.Disabled { return n.enable() @@ -772,7 +797,7 @@ type NICInfo struct { // MTU is the maximum transmission unit. MTU uint32 - Stats NICStats + Stats tcpip.NICStats // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats @@ -785,6 +810,10 @@ type NICInfo struct { // value sent in haType field of an ARP Request sent by this NIC and the // value expected in the haType field of an ARP response. ARPHardwareType header.ARPHardwareType + + // Forwarding holds the forwarding status for each network endpoint that + // supports forwarding. + Forwarding map[tcpip.NetworkProtocolNumber]bool } // HasNIC returns true if the NICID is defined in the stack. @@ -814,17 +843,33 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { netStats[proto] = netEP.Stats() } - nics[id] = NICInfo{ + info := NICInfo{ Name: nic.name, LinkAddress: nic.LinkEndpoint.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.LinkEndpoint.MTU(), - Stats: nic.stats, + Stats: nic.stats.local, NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), + Forwarding: make(map[tcpip.NetworkProtocolNumber]bool), } + + for proto := range s.networkProtocols { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + info.Forwarding[proto] = forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } + } + + nics[id] = info } return nics } @@ -1028,6 +1073,20 @@ func (s *Stack) HandleLocal() bool { return s.handleLocal } +func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + return forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + return false + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -1080,7 +1139,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal + onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route @@ -1119,7 +1178,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n // requirement to do this from any RFC but simply a choice made to better // follow a strong host model which the netstack follows at the time of // writing. - if canForward && chosenRoute == (tcpip.Route{}) { + if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) { chosenRoute = route } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8ead3b8df..73e0f0d58 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -84,7 +84,8 @@ type fakeNetworkEndpoint struct { mu struct { sync.RWMutex - enabled bool + enabled bool + forwarding bool } nic stack.NetworkInterface @@ -138,11 +139,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data().TrimFront(fakeNetHeaderLen) + // DeleteFront invalidates slices. Make a copy before trimming. + nb := append([]byte(nil), hdr...) + pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -163,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -194,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -225,11 +224,6 @@ type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int defaultTTL uint8 - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -298,15 +292,15 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v @@ -465,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -922,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -943,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1068,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. err := s.RemoveAddress(1, localAddr) @@ -1120,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. { @@ -1142,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A // No address given, verify that there is no address assigned to the NIC. for _, a := range info.ProtocolAddresses { if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) } } return @@ -1222,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1250,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1289,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1326,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1576,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1617,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } } - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } @@ -1643,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. @@ -1662,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err := s.CreateNIC(2, ep); err != nil { t.Fatalf("CreateNIC failed: %s", err) } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } @@ -1711,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, }, rt..., ) @@ -2051,7 +2045,7 @@ func TestAddAddress(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } @@ -2115,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } } @@ -2236,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") for _, call := range test.calls { if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) @@ -2250,46 +2244,87 @@ func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + + nics := []struct { + addr tcpip.Address + txByteCount int + rxByteCount int + }{ + { + addr: "\x01", + txByteCount: 30, + rxByteCount: 10, + }, + { + addr: "\x02", + txByteCount: 50, + rxByteCount: 20, + }, } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) + + var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int + for i, nic := range nics { + nicid := tcpip.NICID(i) + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed: ", err) + } + if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { + t.Fatal("AddAddress failed:", err) } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } + { + subnet, err := tcpip.NewSubnet(nic.addr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) + } + + nicStats := s.NICInfo()[nicid].Stats - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + // Inbound packet. + rxBuffer := buffer.NewView(nic.rxByteCount) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: rxBuffer.ToVectorisedView(), + })) + if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) + } + if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { + t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + } + rxPacketsTotal++ + rxBytesTotal += nic.rxByteCount + + // Outbound packet. + txBuffer := buffer.NewView(nic.txByteCount) + actualTxLength := nic.txByteCount + fakeNetHeaderLen + if err := sendTo(s, nic.addr, txBuffer); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := ep.Drain() + if got := nicStats.Tx.Packets.Value(); got != uint64(want) { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + txPacketsTotal += want + txBytesTotal += actualTxLength } - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) + // Now verify that each NIC stats was correctly aggregated at the stack level. + if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { + t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) + if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { + t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -2318,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{}) id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) } @@ -3020,7 +3055,7 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -3839,8 +3874,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { // TestAddRoute tests Stack.AddRoute func TestAddRoute(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) subnet1, err := tcpip.NewSubnet("\x00", "\x00") @@ -3877,8 +3910,6 @@ func TestAddRoute(t *testing.T) { // TestRemoveRoutes tests Stack.RemoveRoutes func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) addressToRemove := tcpip.Address("\x01") @@ -4218,14 +4249,14 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) } - if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { - t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) } s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { + if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { @@ -4273,8 +4304,8 @@ func TestFindRouteWithForwarding(t *testing.T) { // Disabling forwarding when the route is dependent on forwarding being // enabled should make the route invalid. - if err := s.SetForwarding(test.netCfg.proto, false); err != nil { - t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, false): %s", test.netCfg.proto, err) } { err := send(r, data) diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go new file mode 100644 index 000000000..7ce43a68e --- /dev/null +++ b/pkg/tcpip/stdclock.go @@ -0,0 +1,130 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +// stdClock implements Clock with the time package. +// +// +stateify savable +type stdClock struct { + // baseTime holds the time when the clock was constructed. + // + // This value is used to calculate the monotonic time from the time package. + // As per https://golang.org/pkg/time/#hdr-Monotonic_Clocks, + // + // Operating systems provide both a “wall clock,” which is subject to + // changes for clock synchronization, and a “monotonic clock,” which is not. + // The general rule is that the wall clock is for telling time and the + // monotonic clock is for measuring time. Rather than split the API, in this + // package the Time returned by time.Now contains both a wall clock reading + // and a monotonic clock reading; later time-telling operations use the wall + // clock reading, but later time-measuring operations, specifically + // comparisons and subtractions, use the monotonic clock reading. + // + // ... + // + // If Times t and u both contain monotonic clock readings, the operations + // t.After(u), t.Before(u), t.Equal(u), and t.Sub(u) are carried out using + // the monotonic clock readings alone, ignoring the wall clock readings. If + // either t or u contains no monotonic clock reading, these operations fall + // back to using the wall clock readings. + // + // Given the above, we can safely conclude that time.Since(baseTime) will + // return monotonically increasing values if we use time.Now() to set baseTime + // at the time of clock construction. + // + // Note that time.Since(t) is shorthand for time.Now().Sub(t), as per + // https://golang.org/pkg/time/#Since. + baseTime time.Time `state:"nosave"` + + // monotonicOffset is the offset applied to the calculated monotonic time. + // + // monotonicOffset is assigned maxMonotonic after restore so that the + // monotonic time will continue from where it "left off" before saving as part + // of S/R. + monotonicOffset int64 `state:"nosave"` + + // monotonicMU protects maxMonotonic. + monotonicMU sync.Mutex `state:"nosave"` + maxMonotonic int64 +} + +// NewStdClock returns an instance of a clock that uses the time package. +func NewStdClock() Clock { + return &stdClock{ + baseTime: time.Now(), + } +} + +var _ Clock = (*stdClock)(nil) + +// NowNanoseconds implements Clock.NowNanoseconds. +func (*stdClock) NowNanoseconds() int64 { + return time.Now().UnixNano() +} + +// NowMonotonic implements Clock.NowMonotonic. +func (s *stdClock) NowMonotonic() int64 { + sinceBase := time.Since(s.baseTime) + if sinceBase < 0 { + panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime)) + } + + monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset + + s.monotonicMU.Lock() + defer s.monotonicMU.Unlock() + + // Monotonic time values must never decrease. + if monotonicValue > s.maxMonotonic { + s.maxMonotonic = monotonicValue + } + + return s.maxMonotonic +} + +// AfterFunc implements Clock.AfterFunc. +func (*stdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/stdclock_state.go b/pkg/tcpip/stdclock_state.go new file mode 100644 index 000000000..795db9181 --- /dev/null +++ b/pkg/tcpip/stdclock_state.go @@ -0,0 +1,26 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "time" + +// afterLoad is invoked by stateify. +func (s *stdClock) afterLoad() { + s.baseTime = time.Now() + + s.monotonicMU.Lock() + defer s.monotonicMU.Unlock() + s.monotonicOffset = s.maxMonotonic +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 0ba71b62e..34f820053 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -37,9 +37,9 @@ import ( "reflect" "strconv" "strings" - "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) @@ -73,7 +73,7 @@ type Clock interface { // nanoseconds since the Unix epoch. NowNanoseconds() int64 - // NowMonotonic returns a monotonic time value. + // NowMonotonic returns a monotonic time value at nanosecond resolution. NowMonotonic() int64 // AfterFunc waits for the duration to elapse and then calls f in its own @@ -861,6 +861,9 @@ type SettableSocketOption interface { isSettableSocketOption() } +// EndpointState represents the state of an endpoint. +type EndpointState uint8 + // CongestionControlState indicates the current congestion control state for // TCP sender. type CongestionControlState int @@ -897,6 +900,9 @@ type TCPInfoOption struct { // RTO is the retransmission timeout for the endpoint. RTO time.Duration + // State is the current endpoint protocol state. + State EndpointState + // CcState is the congestion control state. CcState CongestionControlState @@ -1107,6 +1113,7 @@ const ( // LingerOption is used by SetSockOpt/GetSockOpt to set/get the // duration for which a socket lingers before returning from Close. // +// +marshal // +stateify savable type LingerOption struct { Enabled bool @@ -1219,7 +1226,7 @@ type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. type StatCounter struct { - count uint64 + count atomicbitops.AlignedAtomicUint64 } // Increment adds one to the counter. @@ -1234,12 +1241,12 @@ func (s *StatCounter) Decrement() { // Value returns the current value of the counter. func (s *StatCounter) Value(name ...string) uint64 { - return atomic.LoadUint64(&s.count) + return s.count.Load() } // IncrementBy increments the counter by v. func (s *StatCounter) IncrementBy(v uint64) { - atomic.AddUint64(&s.count, v) + s.count.Add(v) } func (s *StatCounter) String() string { @@ -1527,6 +1534,46 @@ type IGMPStats struct { // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats) } +// IPForwardingStats collects stats related to IP forwarding (both v4 and v6). +type IPForwardingStats struct { + // LINT.IfChange(IPForwardingStats) + + // Unrouteable is the number of IP packets received which were dropped + // because a route to their destination could not be constructed. + Unrouteable *StatCounter + + // ExhaustedTTL is the number of IP packets received which were dropped + // because their TTL was exhausted. + ExhaustedTTL *StatCounter + + // LinkLocalSource is the number of IP packets which were dropped + // because they contained a link-local source address. + LinkLocalSource *StatCounter + + // LinkLocalDestination is the number of IP packets which were dropped + // because they contained a link-local destination address. + LinkLocalDestination *StatCounter + + // PacketTooBig is the number of IP packets which were dropped because they + // were too big for the outgoing MTU. + PacketTooBig *StatCounter + + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable *StatCounter + + // ExtensionHeaderProblem is the number of IP packets which were dropped + // because of a problem encountered when processing an IPv6 extension + // header. + ExtensionHeaderProblem *StatCounter + + // Errors is the number of IP packets received which could not be + // successfully forwarded. + Errors *StatCounter + + // LINT.ThenChange(network/internal/ip/stats.go:multiCounterIPForwardingStats) +} + // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // LINT.IfChange(IPStats) @@ -1534,6 +1581,10 @@ type IPStats struct { // PacketsReceived is the number of IP packets received from the link layer. PacketsReceived *StatCounter + // ValidPacketsReceived is the number of valid IP packets that reached the IP + // layer. + ValidPacketsReceived *StatCounter + // DisabledPacketsReceived is the number of IP packets received from the link // layer when the IP layer is disabled. DisabledPacketsReceived *StatCounter @@ -1573,6 +1624,10 @@ type IPStats struct { // chain. IPTablesInputDropped *StatCounter + // IPTablesForwardDropped is the number of IP packets dropped in the Forward + // chain. + IPTablesForwardDropped *StatCounter + // IPTablesOutputDropped is the number of IP packets dropped in the Output // chain. IPTablesOutputDropped *StatCounter @@ -1595,6 +1650,9 @@ type IPStats struct { // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived *StatCounter + // Forwarding collects stats related to IP forwarding. + Forwarding IPForwardingStats + // LINT.ThenChange(network/internal/ip/stats.go:MultiCounterIPStats) } @@ -1787,37 +1845,104 @@ type UDPStats struct { ChecksumErrors *StatCounter } +// NICNeighborStats holds metrics for the neighbor table. +type NICNeighborStats struct { + // LINT.IfChange(NICNeighborStats) + + // UnreachableEntryLookups counts the number of lookups performed on an + // entry in Unreachable state. + UnreachableEntryLookups *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICNeighborStats) +} + +// NICPacketStats holds basic packet statistics. +type NICPacketStats struct { + // LINT.IfChange(NICPacketStats) + + // Packets is the number of packets counted. + Packets *StatCounter + + // Bytes is the number of bytes counted. + Bytes *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICPacketStats) +} + +// NICStats holds NIC statistics. +type NICStats struct { + // LINT.IfChange(NICStats) + + // UnknownL3ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported network protocol. + UnknownL3ProtocolRcvdPackets *StatCounter + + // UnknownL4ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported transport protocol. + UnknownL4ProtocolRcvdPackets *StatCounter + + // MalformedL4RcvdPackets is the number of packets received by a NIC that + // could not be delivered to a transport endpoint because the L4 header could + // not be parsed. + MalformedL4RcvdPackets *StatCounter + + // Tx contains statistics about transmitted packets. + Tx NICPacketStats + + // Rx contains statistics about received packets. + Rx NICPacketStats + + // DisabledRx contains statistics about received packets on disabled NICs. + DisabledRx NICPacketStats + + // Neighbor contains statistics about neighbor entries. + Neighbor NICNeighborStats + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICStats) +} + +// FillIn returns a copy of s with nil fields initialized to new StatCounters. +func (s NICStats) FillIn() NICStats { + InitStatCounters(reflect.ValueOf(&s).Elem()) + return s +} + // Stats holds statistics about the networking stack. -// -// All fields are optional. type Stats struct { - // UnknownProtocolRcvdPackets is the number of packets received by the - // stack that were for an unknown or unsupported protocol. - UnknownProtocolRcvdPackets *StatCounter + // TODO(https://gvisor.dev/issues/5986): Make the DroppedPackets stat less + // ambiguous. - // MalformedRcvdPackets is the number of packets received by the stack - // that were deemed malformed. - MalformedRcvdPackets *StatCounter - - // DroppedPackets is the number of packets dropped due to full queues. + // DroppedPackets is the number of packets dropped at the transport layer. DroppedPackets *StatCounter - // ICMP breaks out ICMP-specific stats (both v4 and v6). + // NICs is an aggregation of every NIC's statistics. These should not be + // incremented using this field, but using the relevant NIC multicounters. + NICs NICStats + + // ICMP is an aggregation of every NetworkEndpoint's ICMP statistics (both v4 + // and v6). These should not be incremented using this field, but using the + // relevant NetworkEndpoint ICMP multicounters. ICMP ICMPStats - // IGMP breaks out IGMP-specific stats. + // IGMP is an aggregation of every NetworkEndpoint's IGMP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint IGMP multicounters. IGMP IGMPStats - // IP breaks out IP-specific stats (both v4 and v6). + // IP is an aggregation of every NetworkEndpoint's IP statistics. These should + // not be incremented using this field, but using the relevant NetworkEndpoint + // IP multicounters. IP IPStats - // ARP breaks out ARP-specific stats. + // ARP is an aggregation of every NetworkEndpoint's ARP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint ARP multicounters. ARP ARPStats - // TCP breaks out TCP-specific stats. + // TCP holds TCP-specific stats. TCP TCPStats - // UDP breaks out UDP-specific stats. + // UDP holds UDP-specific stats. UDP UDPStats } diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 269081ff8..c96ae2f02 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "net" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -210,26 +209,6 @@ func TestAddressString(t *testing.T) { } } -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} - func TestAddressWithPrefixSubnet(t *testing.T) { tests := []struct { addr Address diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index d4f7bb5ff..8802f36b2 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -31,12 +31,14 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/udp", ], ) @@ -46,17 +48,20 @@ go_test( size = "small", srcs = ["link_resolution_test.go"], deps = [ + "//pkg/context", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index dbd279c94..92fa6257d 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -16,6 +16,7 @@ package forward_test import ( "bytes" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -34,6 +35,39 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +const ttl = 64 + +var ( + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") +) + +func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) +} + +func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) +} + +func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo))) +} + +func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest))) +} + func TestForwarding(t *testing.T) { const listenPort = 8080 @@ -320,45 +354,16 @@ func TestMulticastForwarding(t *testing.T) { const ( nicID1 = 1 nicID2 = 2 - ttl = 64 ) var ( ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10") ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10") - ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a") ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a") - ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") ) - rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoRequest(e, src, dst, ttl) - } - - rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoRequest(e, src, dst, ttl) - } - - v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Echo))) - } - - v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoRequest))) - } - tests := []struct { name string srcAddr, dstAddr tcpip.Address @@ -394,7 +399,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv4EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) }, }, { @@ -404,7 +409,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv4EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) }, }, @@ -436,7 +441,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv6EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) }, }, { @@ -446,7 +451,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv6EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) }, }, } @@ -475,11 +480,11 @@ func TestMulticastForwarding(t *testing.T) { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) } - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } s.SetRouteTable([]tcpip.Route{ @@ -506,3 +511,180 @@ func TestMulticastForwarding(t *testing.T) { }) } } + +func TestPerInterfaceForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + ) + + tests := []struct { + name string + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 unicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 multicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4GlobalMulticastAddr, + rx: rxICMPv4EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + }, + }, + + { + name: "IPv6 unicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 multicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6GlobalMulticastAddr, + rx: rxICMPv6EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + }, + }, + } + + netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + // ARP is not used in this test but it is a network protocol that does + // not support forwarding. We install the protocol to make sure that + // forwarding information for a NIC is only reported for network + // protocols that support forwarding. + arp.NewProtocol, + + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) + } + + for _, add := range [...]struct { + nicID tcpip.NICID + addr tcpip.ProtocolAddress + }{ + { + nicID: nicID1, + addr: utils.RouterNIC1IPv4Addr, + }, + { + nicID: nicID1, + addr: utils.RouterNIC1IPv6Addr, + }, + { + nicID: nicID2, + addr: utils.RouterNIC2IPv4Addr, + }, + { + nicID: nicID2, + addr: utils.RouterNIC2IPv6Addr, + }, + } { + if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err) + } + } + + // Only enable forwarding on NIC1 and make sure that only packets arriving + // on NIC1 are forwarded. + for _, netProto := range netProtos { + if err := s.SetNICForwarding(nicID1, netProto, true); err != nil { + t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err) + } + } + + nicsInfo := s.NICInfo() + for _, subTest := range [...]struct { + nicID tcpip.NICID + nicEP *channel.Endpoint + otherNICID tcpip.NICID + otherNICEP *channel.Endpoint + expectForwarding bool + }{ + { + nicID: nicID1, + nicEP: e1, + otherNICID: nicID2, + otherNICEP: e2, + expectForwarding: true, + }, + { + nicID: nicID2, + nicEP: e2, + otherNICID: nicID2, + otherNICEP: e1, + expectForwarding: false, + }, + } { + t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) { + nicInfo, ok := nicsInfo[subTest.nicID] + if !ok { + t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo) + } else { + forwarding := make(map[tcpip.NetworkProtocolNumber]bool) + for _, netProto := range netProtos { + forwarding[netProto] = subTest.expectForwarding + } + + if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" { + t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: subTest.otherNICID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: subTest.otherNICID, + }, + }) + + test.rx(subTest.nicEP, test.srcAddr, test.dstAddr) + if p, ok := subTest.nicEP.Read(); ok { + t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p) + } + if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding { + t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding) + } else if subTest.expectForwarding { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index c61d4e788..07ba2b837 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -19,12 +19,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) @@ -645,3 +647,297 @@ func TestIPTableWritePackets(t *testing.T) { }) } } + +const ttl = 64 + +var ( + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") +) + +func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoReply(e, src, dst, ttl) +} + +func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoReply(e, src, dst, ttl) +} + +func forwardedICMPv4EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply))) +} + +func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply))) +} + +func TestForwardingHook(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Name = "nic1" + nic2Name = "nic2" + + otherNICName = "otherNIC" + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + local bool + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 remote", + netProto: ipv4.ProtocolNumber, + local: false, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoReply, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoReplyChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 local", + netProto: ipv4.ProtocolNumber, + local: true, + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr.Address, + rx: rxICMPv4EchoReply, + }, + { + name: "IPv6 remote", + netProto: ipv6.ProtocolNumber, + local: false, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoReply, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoReplyChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 local", + netProto: ipv6.ProtocolNumber, + local: true, + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr.Address, + rx: rxICMPv6EchoReply, + }, + } + + setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { + return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, ipv6) + ruleIdx := filter.BuiltinChains[stack.Forward] + filter.Rules[ruleIdx].Filter = f + filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} + if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { + t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) + } + } + } + + boolToInt := func(v bool) uint64 { + if v { + return 1 + } + return 0 + } + + subTests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) + expectForward bool + }{ + { + name: "Accept", + setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, + expectForward: true, + }, + + { + name: "Drop", + setupFilter: setupDropFilter(stack.IPHeaderFilter{}), + expectForward: false, + }, + { + name: "Drop with input NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}), + expectForward: false, + }, + { + name: "Drop with output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}), + expectForward: false, + }, + { + name: "Drop with input and output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), + expectForward: false, + }, + + { + name: "Drop with other input NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}), + expectForward: true, + }, + { + name: "Drop with other output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}), + expectForward: true, + }, + { + name: "Drop with other input and output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), + expectForward: true, + }, + { + name: "Drop with input and other output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), + expectForward: true, + }, + { + name: "Drop with other input and other output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), + expectForward: true, + }, + + { + name: "Drop with inverted input NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), + expectForward: true, + }, + { + name: "Drop with inverted output NIC filtering", + setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), + expectForward: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + + subTest.setupFilter(t, s, test.netProto) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) + } + + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + } + + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID2, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID2, + }, + }) + + test.rx(e1, test.srcAddr, test.dstAddr) + + expectTransmitPacket := subTest.expectForward && !test.local + + ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) + } + ep1Stats := ep1.Stats() + ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) + } + ip1Stats := ipEP1Stats.IPStats() + + if got := ip1Stats.PacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) + } + if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want { + t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want) + } + if got := ip1Stats.PacketsSent.Value(); got != 0 { + t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got) + } + + ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) + } + ep2Stats := ep2.Stats() + ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) + } + ip2Stats := ipEP2Stats.IPStats() + if got := ip2Stats.PacketsReceived.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) + } + if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want { + t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want) + } + if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want { + t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want) + } + + p, ok := e2.Read() + if ok != expectTransmitPacket { + t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, expectTransmitPacket) + } + if expectTransmitPacket { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index c657714ba..9f727eb8f 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -17,22 +17,26 @@ package link_resolution_test import ( "bytes" "fmt" + "net" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -395,6 +399,246 @@ func TestTCPLinkResolutionFailure(t *testing.T) { } } +func TestForwardingWithLinkResolutionFailure(t *testing.T) { + const ( + incomingNICID = 1 + outgoingNICID = 2 + ttl = 2 + expectedHostUnreachableErrorCount = 1 + ) + outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06") + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != arp.ProtocolNumber { + t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber) + } + if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(request.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != src { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst) + } + } + + ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber) + } + + snmc := header.SolicitedNodeAddr(dst) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()), + checker.SrcAddr(src), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(dst), + )) + } + + icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4HostUnreachable), + ), + ) + } + + icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv6.DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6AddressUnreachable), + ), + ) + } + + tests := []struct { + name string + networkProtocolFactory []stack.NetworkProtocolFactory + networkProtocolNumber tcpip.NetworkProtocolNumber + sourceAddr tcpip.Address + destAddr tcpip.Address + incomingAddr tcpip.AddressWithPrefix + outgoingAddr tcpip.AddressWithPrefix + transportProtocol func(*stack.Stack) stack.TransportProtocol + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address) + icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address) + mtu uint32 + }{ + { + name: "IPv4 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + networkProtocolNumber: header.IPv4ProtocolNumber, + sourceAddr: tcptestutil.MustParse4("10.0.0.2"), + destAddr: tcptestutil.MustParse4("11.0.0.2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + }, + transportProtocol: icmp.NewProtocol4, + linkResolutionRequestChecker: arpChecker, + icmpReplyChecker: icmpv4Checker, + rx: rxICMPv4EchoRequest, + mtu: ipv4.MaxTotalSize, + }, + { + name: "IPv6 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + networkProtocolNumber: header.IPv6ProtocolNumber, + sourceAddr: tcptestutil.MustParse6("10::2"), + destAddr: tcptestutil.MustParse6("11::2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + }, + transportProtocol: icmp.NewProtocol6, + linkResolutionRequestChecker: ndpChecker, + icmpReplyChecker: icmpv6Checker, + rx: rxICMPv6EchoRequest, + mtu: header.IPv6MinimumMTU, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + + s := stack.New(stack.Options{ + NetworkProtocols: test.networkProtocolFactory, + TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol}, + Clock: clock, + }) + + // Set up endpoint through which we will receive packets. + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) + } + incomingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.incomingAddr, + } + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + } + + // Set up endpoint through which we will attempt to forward packets. + outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr) + outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) + } + outgoingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.outgoingAddr, + } + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.incomingAddr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: test.outgoingAddr.Subnet(), + NIC: outgoingNICID, + }, + }) + + if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err) + } + + test.rx(incomingEndpoint, test.sourceAddr, test.destAddr) + + var request channel.PacketInfo + var ok bool + nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber) + if err != nil { + t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err) + } + // Trigger the first packet on the endpoint. + clock.RunImmediatelyScheduledJobs() + + for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ { + if request, ok = outgoingEndpoint.Read(); !ok { + t.Fatal("expected ARP packet through outgoing NIC") + } + + test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr) + + // Advance the clock the span of one request timeout. + clock.Advance(nudConfigs.RetransmitTimer) + } + + // Next, we make a blocking read to retrieve the error packet. This is + // necessary because outgoing packets are dequeued asynchronously when + // link resolution fails, and this dequeue is what triggers the ICMP + // error. + // + // TODO(gvisor.dev/issue/6012): Replace with asynchronous read after we + // have integrated the stack clock with the dequeuing code. + reply, ok := incomingEndpoint.ReadContext(context.Background()) + if !ok { + t.Fatal("expected ICMP packet through incoming NIC") + } + + test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr) + + // Since link resolution failed, we don't expect the packet to be + // forwarded. + forwardedPacket, ok := outgoingEndpoint.Read() + if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket) + } + + if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want { + t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want) + } + }) + } +} + func TestGetLinkAddress(t *testing.T) { const ( host1NICID = 1 diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3df1bbd68..87d36e1dd 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -714,11 +714,11 @@ func TestExternalLoopbackTraffic(t *testing.T) { } if test.forwarding { - if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) } - if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) } } diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 8fd9be32b..2e6ae55ea 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -224,11 +224,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) } - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err) } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err) } if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr); err != nil { @@ -316,13 +316,11 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. }) } -// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on -// the provided endpoint. -func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { +func rxICMPv4Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv4Type) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4Echo) + pkt.SetType(ty) pkt.SetCode(header.ICMPv4UnusedCode) pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, 0)) @@ -341,13 +339,23 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) })) } -// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on // the provided endpoint. -func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4Echo) +} + +// RxICMPv4EchoReply constructs and injects an ICMPv4 echo reply packet on +// the provided endpoint. +func RxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv4Echo(e, src, dst, ttl, header.ICMPv4EchoReply) +} + +func rxICMPv6Echo(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8, ty header.ICMPv6Type) { totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) - pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetType(ty) pkt.SetCode(header.ICMPv6UnusedCode) pkt.SetChecksum(0) pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -368,3 +376,15 @@ func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) Data: hdr.View().ToVectorisedView(), })) } + +// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on +// the provided endpoint. +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoRequest) +} + +// RxICMPv6EchoReply constructs and injects an ICMPv6 echo reply packet on +// the provided endpoint. +func RxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { + rxICMPv6Echo(e, src, dst, ttl, header.ICMPv6EchoReply) +} diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD index 472545a5d..02ee86ff1 100644 --- a/pkg/tcpip/testutil/BUILD +++ b/pkg/tcpip/testutil/BUILD @@ -5,7 +5,10 @@ package(licenses = ["notice"]) go_library( name = "testutil", testonly = True, - srcs = ["testutil.go"], + srcs = [ + "testutil.go", + "testutil_unsafe.go", + ], visibility = ["//visibility:public"], deps = ["//pkg/tcpip"], ) diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go index 1aaed590f..94b580a70 100644 --- a/pkg/tcpip/testutil/testutil.go +++ b/pkg/tcpip/testutil/testutil.go @@ -18,6 +18,8 @@ package testutil import ( "fmt" "net" + "reflect" + "strings" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -41,3 +43,81 @@ func MustParse6(addr string) tcpip.Address { } return tcpip.Address(ip) } + +func checkFieldCounts(ref, multi reflect.Value) error { + refTypeName := ref.Type().Name() + multiTypeName := multi.Type().Name() + refNumField := ref.NumField() + multiNumField := multi.NumField() + + if refNumField != multiNumField { + return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) + } + + return nil +} + +func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { + s, ok := ref.Addr().Interface().(**tcpip.StatCounter) + if !ok { + return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) + } + + // The field names are expected to match (case insensitive). + if !strings.EqualFold(refName, multiName) { + return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) + } + + base := (*s).Value() + m.Increment() + if (*s).Value() != base+1 { + return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) + } + + return nil +} + +// ValidateMultiCounterStats verifies that every counter stored in multi is +// correctly tracking its counterpart in the given counters. +func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { + for _, c := range counters { + if err := checkFieldCounts(c, multi); err != nil { + return err + } + } + + for i := 0; i < multi.NumField(); i++ { + multiName := multi.Type().Field(i).Name + multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) + + if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { + for _, c := range counters { + if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { + return err + } + } + } else { + var countersNextField []reflect.Value + for _, c := range counters { + countersNextField = append(countersNextField, c.Field(i)) + } + if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { + return err + } + } + } + + return nil +} + +// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on +// error. +// +// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. +func MustParseLink(addr string) tcpip.LinkAddress { + parsed, err := tcpip.ParseMACAddress(addr) + if err != nil { + panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err)) + } + return parsed +} diff --git a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go b/pkg/tcpip/testutil/testutil_unsafe.go index 5ff764800..5ff764800 100644 --- a/pkg/tcpip/network/internal/testutil/testutil_unsafe.go +++ b/pkg/tcpip/testutil/testutil_unsafe.go diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go deleted file mode 100644 index eeea97b12..000000000 --- a/pkg/tcpip/time_unsafe.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.9 -// +build !go1.18 - -// Check go:linkname function signatures when updating Go version. - -package tcpip - -import ( - "time" // Used with go:linkname. - _ "unsafe" // Required for go:linkname. -) - -// StdClock implements Clock with the time package. -// -// +stateify savable -type StdClock struct{} - -var _ Clock = (*StdClock)(nil) - -//go:linkname now time.now -func now() (sec int64, nsec int32, mono int64) - -// NowNanoseconds implements Clock.NowNanoseconds. -func (*StdClock) NowNanoseconds() int64 { - sec, nsec, _ := now() - return sec*1e9 + int64(nsec) -} - -// NowMonotonic implements Clock.NowMonotonic. -func (*StdClock) NowMonotonic() int64 { - _, _, mono := now() - return mono -} - -// AfterFunc implements Clock.AfterFunc. -func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { - return &stdTimer{ - t: time.AfterFunc(d, f), - } -} - -type stdTimer struct { - t *time.Timer -} - -var _ Timer = (*stdTimer)(nil) - -// Stop implements Timer.Stop. -func (st *stdTimer) Stop() bool { - return st.t.Stop() -} - -// Reset implements Timer.Reset. -func (st *stdTimer) Reset(d time.Duration) { - st.t.Reset(d) -} - -// NewStdTimer returns a Timer implemented with the time package. -func NewStdTimer(t *time.Timer) Timer { - return &stdTimer{t: t} -} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index a82384c49..4ddb7020d 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -25,11 +25,10 @@ import ( const ( shortDuration = 1 * time.Nanosecond middleDuration = 100 * time.Millisecond - longDuration = 1 * time.Second ) func TestJobReschedule(t *testing.T) { - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var wg sync.WaitGroup var lock sync.Mutex @@ -43,7 +42,7 @@ func TestJobReschedule(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - job := tcpip.NewJob(&clock, &lock, func() { + job := tcpip.NewJob(clock, &lock, func() { wg.Done() }) job.Schedule(shortDuration) @@ -56,11 +55,11 @@ func TestJobReschedule(t *testing.T) { func TestJobExecution(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) - job := tcpip.NewJob(&clock, &lock, func() { + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) @@ -83,11 +82,11 @@ func TestJobExecution(t *testing.T) { func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(middleDuration) lock.Lock() @@ -114,12 +113,12 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -151,13 +150,13 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) for i := 0; i < 1000; i++ { lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -174,12 +173,12 @@ func TestJobImmediatelyCancel(t *testing.T) { func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -206,12 +205,12 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. @@ -239,12 +238,12 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) for i := 0; i < 10; i++ { job.Cancel() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 9948f305b..517903ae7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -160,7 +160,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { @@ -349,7 +349,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { +func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { return nil } @@ -390,7 +390,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -606,7 +606,7 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. @@ -747,8 +747,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) - // TODO(b/129292233): Determine if len(h) check is still needed after early - // parsing. + // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed + // after early parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -756,8 +756,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } case header.IPv6ProtocolNumber: h := header.ICMPv6(pkt.TransportHeader().View()) - // TODO(b/129292233): Determine if len(h) check is still needed after early - // parsing. + // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed + // after early parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 496eca581..fa703a0ed 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -159,7 +159,7 @@ func (ep *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { @@ -220,19 +220,19 @@ func (*endpoint) Disconnect() tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes *tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { +func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error { return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -318,7 +318,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -339,7 +339,7 @@ func (ep *endpoint) UpdateLastError(err tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -484,7 +484,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { } // SetOwner implements tcpip.Endpoint.SetOwner. -func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*endpoint) SetOwner(tcpip.PacketOwner) {} // SocketOptions implements tcpip.Endpoint.SocketOptions. func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index bcec3d2e7..07a585444 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -183,7 +183,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.mu.Lock() @@ -402,7 +402,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Find a route to the destination. - route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) + route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false) if err != nil { return err } @@ -428,7 +428,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -439,7 +439,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -513,12 +513,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 48417f192..0f20d3856 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -126,7 +126,15 @@ go_test( go_test( name = "tcp_test", size = "small", - srcs = ["timer_test.go"], + srcs = [ + "segment_test.go", + "timer_test.go", + ], library = ":tcp", - deps = ["//pkg/sleep"], + deps = [ + "//pkg/sleep", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + "@com_github_google_go_cmp//cmp:go_default_library", + ], ) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 524d5cabf..05b41e0f8 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -586,8 +586,14 @@ func (h *handshake) complete() tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + // Check for any ICMP errors notified to us. if n¬ifyError != 0 { - return h.ep.lastErrorLocked() + if err := h.ep.lastErrorLocked(); err != nil { + return err + } + // Flag the handshake failure as aborted if the lastError is + // cleared because of a socket layer call. + return &tcpip.ErrConnectionAborted{} } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -1229,7 +1235,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. - state := e.state + state := e.EndpointState() if state == StateClose { // When we get into StateClose while processing from the queue, // return immediately and let the protocolMainloop handle it. @@ -1362,8 +1368,24 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Reaching this point means that we successfully completed the 3-way - // handshake with our peer. - // + // handshake with our peer. The current endpoint state could be any state + // post ESTABLISHED, including CLOSED or ERROR if the endpoint processes a + // RST from the peer via the dispatcher fast path, before the loop is + // started. + if s := e.EndpointState(); !s.connected() { + switch s { + case StateClose, StateError: + // If the endpoint is in CLOSED/ERROR state, sender state has to be + // initialized if the endpoint was previously established. + if e.snd != nil { + break + } + fallthrough + default: + panic("endpoint was not established, current state " + s.String()) + } + } + // Completing the 3-way handshake is an indication that the route is valid // and the remote is reachable as the only way we can complete a handshake // is if our SYN reached the remote and their ACK reached us. diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 512053a04..0ca986512 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -177,7 +177,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans s := newIncomingSegment(id, pkt) if !s.parse(pkt.RXTransportChecksumValidated) { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() @@ -185,7 +184,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans } if !s.csumValid { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.ChecksumErrors.Increment() ep.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index d6d68f128..5342aacfd 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -37,8 +38,8 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { // Start connection attempt, it must fail. err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -49,8 +50,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -156,8 +157,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -391,7 +392,7 @@ func testV4Accept(t *testing.T, c *context.Context) { defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -420,7 +421,7 @@ func testV4Accept(t *testing.T, c *context.Context) { r.Reset(data) nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) + tcp = header.IPv4(b).Payload() if string(tcp.Payload()) != data { t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) } @@ -525,7 +526,7 @@ func TestV6AcceptOnV6(t *testing.T) { defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress _, _, err := c.EP.Accept(&addr) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -611,7 +612,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.ReadableEvents) defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3a7b2d166..fb7670adb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -38,19 +38,15 @@ import ( ) // EndpointState represents the state of a TCP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - // Endpoint states internal to netstack. These map to the TCP state CLOSED. - StateInitial EndpointState = iota - StateBound - StateConnecting // Connect() called, but the initial SYN hasn't been sent. - StateError - - // TCP protocol states. + _ EndpointState = iota + // TCP protocol states in sync with the definitions in + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/tcp_states.h#L13 StateEstablished StateSynSent StateSynRecv @@ -62,6 +58,12 @@ const ( StateLastAck StateListen StateClosing + + // Endpoint states internal to netstack. + StateInitial + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError ) const ( @@ -97,6 +99,16 @@ func (s EndpointState) connecting() bool { } } +// internal returns true when the state is netstack internal. +func (s EndpointState) internal() bool { + switch s { + case StateInitial, StateBound, StateConnecting, StateError: + return true + default: + return false + } +} + // handshake returns true when s is one of the states representing an endpoint // in the middle of a TCP handshake. func (s EndpointState) handshake() bool { @@ -422,12 +434,12 @@ type endpoint struct { // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState `state:".(EndpointState)"` + state uint32 `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the // endpoint state at restore time as the socket is moved to it's correct // state. - origEndpointState EndpointState `state:"nosave"` + origEndpointState uint32 `state:"nosave"` isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` @@ -747,7 +759,7 @@ func (e *endpoint) ResumeWork() { // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + oldstate := EndpointState(atomic.LoadUint32(&e.state)) switch state { case StateEstablished: e.stack.Stats().TCP.CurrentEstablished.Increment() @@ -764,12 +776,12 @@ func (e *endpoint) setEndpointState(state EndpointState) { e.stack.Stats().TCP.CurrentEstablished.Decrement() } } - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } // setRecentTimestamp sets the recentTS field to the provided value. @@ -806,11 +818,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue }, sndQueueInfo: sndQueueInfo{ TCPSndBufState: stack.TCPSndBufState{ - SndMTU: int(math.MaxInt32), + SndMTU: math.MaxInt32, }, }, waiterQueue: waiterQueue, - state: StateInitial, + state: uint32(StateInitial), keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -1280,6 +1292,12 @@ func (e *endpoint) LastError() tcpip.Error { return e.lastErrorLocked() } +// LastErrorLocked reads and clears lastError with e.mu held. +// Only to be used in tests. +func (e *endpoint) LastErrorLocked() tcpip.Error { + return e.lastErrorLocked() +} + // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. func (e *endpoint) UpdateLastError(err tcpip.Error) { e.LockUser() @@ -1595,7 +1613,7 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) { // // For large receive buffers, the threshold is aMSS - once reader reads more // than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of -// receive buffer size. This is chosen arbitrairly. +// receive buffer size. This is chosen arbitrarily. // crossed will be true if the window size crossed the ACK threshold. // above will be true if the new window is >= ACK threshold and false // otherwise. @@ -1950,6 +1968,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { info := tcpip.TCPInfoOption{} e.LockUser() + if state := e.EndpointState(); state.internal() { + info.State = tcpip.EndpointState(StateClose) + } else { + info.State = tcpip.EndpointState(state) + } snd := e.snd if snd != nil { // We do not calculate RTT before sending the data packets. If @@ -2725,7 +2748,7 @@ func (e *endpoint) updateSndBufferUsage(v int) { // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1 e.sndQueueInfo.sndQueueMu.Unlock() if notify { diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 6e9777fe4..a56d34dc5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -154,7 +154,7 @@ func (e *endpoint) afterLoad() { e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. - e.state = StateInitial + e.state = uint32(StateInitial) // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) @@ -167,7 +167,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() - epState := e.origEndpointState + epState := EndpointState(e.origEndpointState) switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption @@ -281,11 +281,11 @@ func (e *endpoint) Resume(s *stack.Stack) { }() case epState == StateClose: e.isPortReserved = false - e.state = StateClose + e.state = uint32(StateClose) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) case epState == StateError: - e.state = StateError + e.state = uint32(StateError) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index a3d1aa1a3..d43e21426 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -401,7 +401,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er case *tcpip.TCPTimeWaitReuseOption: p.mu.RLock() - *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) + *v = p.timeWaitReuse p.mu.RUnlock() return nil diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 9e332dcf7..813a6dffd 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -279,7 +279,7 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // been observed RACK uses reo_wnd of zero during loss recovery, in order to // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. -func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { +func (rc *rackControl) updateRACKReorderWindow() { dsackSeen := rc.DSACKSeen snd := rc.snd diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index ee2c08cd6..4fd8a0624 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -137,9 +137,9 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + if r.RcvNxt.Add(curWnd).LessThan(r.RcvNxt.Add(newWnd)) && toGrow { // If the new window moves the right edge, then update RcvAcc. - r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) + r.RcvAcc = r.RcvNxt.Add(newWnd) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -148,6 +148,18 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { } newWnd = curWnd } + + // Apply silly-window avoidance when recovering from zero-window situation. + // Keep advertising zero receive window up until the new window reaches a + // threshold. + if r.rcvWnd == 0 && newWnd != 0 { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + if crossed, above := r.ep.windowCrossedACKThresholdLocked(int(newWnd), int(r.ep.ops.GetReceiveBufferSize())); !crossed && !above { + newWnd = 0 + } + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() + } + // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. r.rcvWnd = newWnd diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index c28641be3..61754de29 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -140,6 +140,15 @@ func (s *segment) clone() *segment { return t } +// merge merges data in oth and clears oth. +func (s *segment) merge(oth *segment) { + s.data.Append(oth.data) + s.dataMemSize = s.data.Size() + + oth.data = buffer.VectorisedView{} + oth.dataMemSize = oth.data.Size() +} + // flagIsSet checks if at least one flag in flags is set in s.flags. func (s *segment) flagIsSet(flags header.TCPFlags) bool { return s.flags&flags != 0 @@ -234,7 +243,7 @@ func (s *segment) parse(skipChecksumValidation bool) bool { return false } - s.options = []byte(s.hdr[header.TCPMinimumSize:]) + s.options = s.hdr[header.TCPMinimumSize:] s.parsedOptions = header.ParseTCPOptions(s.options) if skipChecksumValidation { s.csumValid = true @@ -253,5 +262,5 @@ func (s *segment) parse(skipChecksumValidation bool) bool { // sackBlock returns a header.SACKBlock that represents this segment. func (s *segment) sackBlock() header.SACKBlock { - return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())} + return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())} } diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go new file mode 100644 index 000000000..486016fc0 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -0,0 +1,67 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type segmentSizeWants struct { + DataSize int + SegMemSize int +} + +func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeWants) { + t.Helper() + got := segmentSizeWants{ + DataSize: seg.data.Size(), + SegMemSize: seg.segMemSize(), + } + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("%s differs (-want +got):\n%s", name, diff) + } +} + +func TestSegmentMerge(t *testing.T) { + id := stack.TransportEndpointID{} + seg1 := newOutgoingSegment(id, buffer.NewView(10)) + defer seg1.decRef() + seg2 := newOutgoingSegment(id, buffer.NewView(20)) + defer seg2.decRef() + + checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ + DataSize: 10, + SegMemSize: SegSize + 10, + }) + checkSegmentSize(t, "seg2", seg2, segmentSizeWants{ + DataSize: 20, + SegMemSize: SegSize + 20, + }) + + seg1.merge(seg2) + + checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ + DataSize: 30, + SegMemSize: SegSize + 30, + }) + checkSegmentSize(t, "seg2", seg2, segmentSizeWants{ + DataSize: 0, + SegMemSize: SegSize, + }) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 2b32cb7b2..6b7293755 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -616,7 +616,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // 'S2' that meets the following 3 criteria for determinig // loss, the sequence range of one segment of up to SMSS // octects starting with S2 MUST be returned. - if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) { + if !s.ep.scoreboard.IsSACKED(header.SACKBlock{Start: segSeq, End: segSeq.Add(1)}) { // NextSeg(): // // (1.a) S2 is greater than HighRxt @@ -716,15 +716,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // triggering bugs in poorly written DNS // implementations. var nextTooBig bool - for seg.Next() != nil && seg.Next().data.Size() != 0 { - if seg.data.Size()+seg.Next().data.Size() > available { + for nSeg := seg.Next(); nSeg != nil && nSeg.data.Size() != 0; nSeg = seg.Next() { + if seg.data.Size()+nSeg.data.Size() > available { nextTooBig = true break } - seg.data.Append(seg.Next().data) - - // Consume the segment that we just merged in. - s.writeList.Remove(seg.Next()) + seg.merge(nSeg) + s.writeList.Remove(nSeg) + nSeg.decRef() } if !nextTooBig && seg.data.Size() < available { // Segment is not full. @@ -1025,7 +1024,7 @@ func (s *sender) SetPipe() { if segEnd.LessThan(endSeq) { endSeq = segEnd } - sb := header.SACKBlock{startSeq, endSeq} + sb := header.SACKBlock{Start: startSeq, End: endSeq} // SetPipe(): // // After initializing pipe to zero, the following steps are @@ -1456,7 +1455,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: // * Step 4: Update RACK reordering window - s.rc.updateRACKReorderWindow(rcvdSeg) + s.rc.updateRACKReorderWindow() // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index c58361bc1..29af87b7d 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -38,7 +38,7 @@ const ( func setStackRACKPermitted(t *testing.T, c *context.Context) { t.Helper() - opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection) + opt := tcpip.TCPRACKLossDetection if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 3750b0691..e7ede7662 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -87,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } for w.N != 0 { _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for receive to be notified. select { case <-notifyRead: @@ -130,8 +130,8 @@ func TestGiveUpConnect(t *testing.T) { { err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -145,8 +145,8 @@ func TestGiveUpConnect(t *testing.T) { // and stats updates. { err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAborted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{}) + if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -159,6 +159,76 @@ func TestGiveUpConnect(t *testing.T) { } } +// Test for ICMP error handling without completing handshake. +func TestConnectICMPError(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventHUp) + defer wq.EventUnregister(&waitEntry) + + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + } + } + + syn := c.GetPacket() + checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn))) + + wep := ep.(interface { + StopWork() + ResumeWork() + LastErrorLocked() tcpip.Error + }) + + // Stop the protocol loop, ensure that the ICMP error is processed and + // the last ICMP error is read before the loop is resumed. This sanity + // tests the handshake completion logic on ICMP errors. + wep.StopWork() + + c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU) + + for { + if err := wep.LastErrorLocked(); err != nil { + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d) + } + break + } + time.Sleep(time.Millisecond) + } + + wep.ResumeWork() + + <-notifyCh + + // The stack would have unregistered the endpoint because of the ICMP error. + // Expect a RST for any subsequent packets sent to the endpoint. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1, + AckNum: c.IRS + 1, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + func TestConnectIncrementActiveConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -202,8 +272,8 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{}) + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -393,7 +463,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -936,8 +1006,8 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} { err := c.EP.Connect(connectAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Connect(%+v): %s", connectAddr, err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d) } } @@ -1543,8 +1613,8 @@ func TestConnectBindToDevice(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -1604,8 +1674,8 @@ func TestSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} err := c.EP.Connect(addr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -2473,7 +2543,7 @@ func TestScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -2545,7 +2615,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3077,8 +3147,8 @@ func TestSetTTL(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -3137,7 +3207,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3191,7 +3261,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3266,8 +3336,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -3385,8 +3455,8 @@ loop: case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) } break loop case <-time.After(1 * time.Second): @@ -3436,8 +3506,8 @@ func TestSendOnResetConnection(t *testing.T) { var r bytes.Reader r.Reset(make([]byte, 10)) _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d) } } @@ -4189,7 +4259,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(iss), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -4390,8 +4460,8 @@ func TestReadAfterClosedState(t *testing.T) { var buf bytes.Buffer { _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) - if _, ok := err.(*tcpip.ErrClosedForReceive); !ok { - t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{}) + if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" { + t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d) } } } @@ -4435,8 +4505,8 @@ func TestReusePort(t *testing.T) { } { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } c.EP.Close() @@ -4724,8 +4794,8 @@ func TestSelfConnect(t *testing.T) { { err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -5265,7 +5335,7 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next-1)), + checker.TCPSeqNum(next-1), checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), @@ -5290,12 +5360,7 @@ func TestKeepalive(t *testing.T) { }) checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), + checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), ) if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { @@ -5428,7 +5493,7 @@ func TestListenBacklogFull(t *testing.T) { } lastPortOffset := uint16(0) - for ; int(lastPortOffset) < listenBacklog+1; lastPortOffset++ { + for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) } @@ -5437,7 +5502,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + uint16(lastPortOffset), + SrcPort: context.TestPort + lastPortOffset, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(context.TestInitialSequenceNumber), @@ -5452,7 +5517,7 @@ func TestListenBacklogFull(t *testing.T) { for i := 0; i < listenBacklog; i++ { _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5469,7 +5534,7 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5481,7 +5546,7 @@ func TestListenBacklogFull(t *testing.T) { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5794,7 +5859,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { // Try to accept the connections in the backlog. newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5814,7 +5879,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { r.Reset(data) newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) + tcp = header.IPv4(pkt).Payload() if string(tcp.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } @@ -5865,7 +5930,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { defer c.WQ.EventUnregister(&we) _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5881,7 +5946,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -6020,7 +6085,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { t.Fatalf("Accept failed: %s", err) } - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.ReadableEvents) @@ -6048,7 +6113,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { } pkt := c.GetPacket() - tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + tcpHdr = header.IPv4(pkt).Payload() if string(tcpHdr.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) } @@ -6088,7 +6153,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { // Verify that there is only one acceptable connection at this point. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6158,7 +6223,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { // Now check that there is one acceptable connections. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6210,7 +6275,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { defer wq.EventUnregister(&we) aep, _, err := ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6228,8 +6293,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } { err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok { - t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{}) + if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" { + t.Errorf("Connect(...) mismatch (-want +got):\n%s", d) } } // Listening endpoint remains in listen state. @@ -6305,7 +6370,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Allocate a large enough payload for the test. payloadSize := receiveBufferSize * 2 - b := make([]byte, int(payloadSize)) + b := make([]byte, payloadSize) worker := (c.EP).(interface { StopWork() @@ -6349,7 +6414,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // window increases to the full available buffer size. for { _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { break } } @@ -6359,7 +6424,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // ack, 1 for the non-zero window p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(uint32(wantAckNum)), + checker.TCPAckNum(wantAckNum), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { @@ -6414,14 +6479,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := uint32(rawEP.TSVal) + tsVal := rawEP.TSVal rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> uint16(c.WindowScale)) + return uint16(rcvWnd >> c.WindowScale) } // Allocate a large array to send to the endpoint. b := make([]byte, receiveBufferSize*48) @@ -6480,7 +6545,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalCopied := 0 for { res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { break } totalCopied += res.Count @@ -6549,19 +6614,16 @@ func TestDelayEnabled(t *testing.T) { defer c.Cleanup() checkDelayOption(t, c, false, false) // Delay is disabled by default. - for _, v := range []struct { - delayEnabled tcpip.TCPDelayEnabled - wantDelayOption bool - }{ - {delayEnabled: false, wantDelayOption: false}, - {delayEnabled: true, wantDelayOption: true}, - } { - c := context.New(t, defaultMTU) - defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) - } - checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + for _, delayEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + opt := tcpip.TCPDelayEnabled(delayEnabled) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) + } + checkDelayOption(t, c, opt, delayEnabled) + }) } } @@ -6672,7 +6734,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6791,7 +6853,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6898,7 +6960,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6972,7 +7034,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Receive the SYN-ACK reply. b = c.GetPacket() - tcpHdr = header.TCP(header.IPv4(b).Payload()) + tcpHdr = header.IPv4(b).Payload() c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders = &context.Headers{ @@ -6988,7 +7050,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Try to accept the connection. c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7062,7 +7124,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7212,7 +7274,7 @@ func TestTCPCloseWithData(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7397,7 +7459,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), + checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), @@ -7475,7 +7537,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: iss, - AckNum: seqnum.Value(c.IRS + 1), + AckNum: c.IRS + 1, RcvWnd: 30000, }) @@ -7643,8 +7705,8 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) } // Send data. This should result in an acceptable endpoint. @@ -7702,8 +7764,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) } // Sleep for a little of the tcpDeferAccept timeout. diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 16f8c5212..53efecc5a 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -1214,9 +1214,9 @@ func (c *Context) SACKEnabled() bool { // SetGSOEnabled enables or disables generic segmentation offload. func (c *Context) SetGSOEnabled(enable bool) { if enable { - c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO + c.linkEP.SupportedGSOKind = stack.HWGSOSupported } else { - c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO + c.linkEP.SupportedGSOKind = stack.GSONotSupported } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f7dd50d35..623e069a6 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -40,13 +40,14 @@ type udpPacket struct { } // EndpointState represents the state of a UDP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - StateInitial EndpointState = iota + _ EndpointState = iota + StateInitial StateBound StateConnected StateClosed @@ -98,7 +99,7 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState + state uint32 route *stack.Route `state:"manual"` dstPort uint16 ttl uint8 @@ -176,7 +177,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), - state: StateInitial, + state: uint32(StateInitial), uniqueID: s.UniqueID(), } e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) @@ -204,12 +205,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState() returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -290,7 +291,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { @@ -801,8 +802,8 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ - e.multicastNICID, - e.multicastAddr, + NIC: e.multicastNICID, + InterfaceAddr: e.multicastAddr, } e.mu.Unlock() @@ -1301,12 +1302,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, destinationAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.LocalAddress, - Port: header.UDP(hdr).DestinationPort(), + Port: hdr.DestinationPort(), }, data: pkt.Data().ExtractVV(), } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 705ad1f64..7c357cb09 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags - ep.state = StateConnected + ep.state = uint32(StateConnected) ep.rcvMu.Lock() ep.rcvReady = true diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index dc2e3f493..2e283e52b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1462,7 +1462,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { name: "IPv4 unicast", proto: header.IPv4ProtocolNumber, flow: unicastV4, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort}, }, { name: "IPv4 multicast", @@ -1474,7 +1474,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort}, }, { name: "IPv4 broadcast", @@ -1486,13 +1486,13 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort}, }, { name: "IPv6 unicast", proto: header.IPv6ProtocolNumber, flow: unicastV6, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort}, }, { name: "IPv6 multicast", @@ -1504,7 +1504,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort}, }, } @@ -2115,8 +2115,8 @@ func TestShortHeader(t *testing.T) { Data: buf.ToVectorisedView(), })) - if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { - t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) } } @@ -2124,25 +2124,27 @@ func TestShortHeader(t *testing.T) { // global and endpoint stats are incremented. func TestBadChecksumErrors(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV6} { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() + t.Run(flow.String(), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpoint(flow.sockProto()) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } + c.createEndpoint(flow.sockProto()) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } - payload := newPayload() - c.injectPacket(flow, payload, true /* badChecksum */) + payload := newPayload() + c.injectPacket(flow, payload, true /* badChecksum */) - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } + }) } } diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go index 4855a52fc..12fe98b16 100644 --- a/pkg/test/dockerutil/profile.go +++ b/pkg/test/dockerutil/profile.go @@ -82,10 +82,15 @@ func (p *profile) createProcess(c *Container) error { } // The root directory of this container's runtime. - root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) + rootDir := fmt.Sprintf("/var/run/docker/runtime-%s/moby", c.runtime) + if _, err := os.Stat(rootDir); os.IsNotExist(err) { + // In docker v20+, due to https://github.com/moby/moby/issues/42345 the + // rootDir seems to always be the following. + rootDir = "/var/run/docker/runtime-runc/moby" + } - // Format is `runsc --root=rootdir debug --profile-*=file --duration=24h containerID`. - args := []string{root, "debug"} + // Format is `runsc --root=rootDir debug --profile-*=file --duration=24h containerID`. + args := []string{fmt.Sprintf("--root=%s", rootDir), "debug"} for _, profileArg := range p.Types { outputPath := filepath.Join(p.BasePath, fmt.Sprintf("%s.pprof", profileArg)) args = append(args, fmt.Sprintf("--profile-%s=%s", profileArg, outputPath)) diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 05b721b28..9b270cbf2 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -402,7 +402,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { ctx := k.SupervisorContext() mntr := newContainerMounter(&cm.l.root, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled) if kernel.VFS2Enabled { - ctx, err = mntr.configureRestore(ctx, cm.l.root.conf) + ctx, err = mntr.configureRestore(ctx) if err != nil { return fmt.Errorf("configuring filesystem restore: %v", err) } diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 3c0cef6db..bf4a41f77 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -232,7 +232,7 @@ func parseMountOption(opt string, allowedKeys ...string) (bool, error) { // mountDevice returns a device string based on the fs type and target // of the mount. -func mountDevice(m specs.Mount) string { +func mountDevice(m *specs.Mount) string { if m.Type == bind { // Make a device string that includes the target, which is consistent across // S/R and uniquely identifies the connection. @@ -256,6 +256,8 @@ func mountFlags(opts []string) fs.MountSourceFlags { mf.NoAtime = true case "noexec": mf.NoExec = true + case "bind", "rbind": + // These are the same as a mount with type="bind". default: log.Warningf("ignoring unknown mount option %q", o) } @@ -486,9 +488,9 @@ func (m *mountHint) isSupported() bool { // For now enforce that all options are the same. Once bind mount is properly // supported, then we should ensure the master is less restrictive than the // container, e.g. master can be 'rw' while container mounts as 'ro'. -func (m *mountHint) checkCompatible(mount specs.Mount) error { +func (m *mountHint) checkCompatible(mount *specs.Mount) error { // Remove options that don't affect to mount's behavior. - masterOpts := filterUnsupportedOptions(m.mount) + masterOpts := filterUnsupportedOptions(&m.mount) replicaOpts := filterUnsupportedOptions(mount) if len(masterOpts) != len(replicaOpts) { @@ -512,7 +514,7 @@ func (m *mountHint) fileAccessType() config.FileAccessType { return config.FileAccessShared } -func filterUnsupportedOptions(mount specs.Mount) []string { +func filterUnsupportedOptions(mount *specs.Mount) []string { rv := make([]string, 0, len(mount.Options)) for _, o := range mount.Options { if isSupportedMountFlag(mount.Type, o) { @@ -576,7 +578,7 @@ func newPodMountHints(spec *specs.Spec) (*podMountHints, error) { return &podMountHints{mounts: mnts}, nil } -func (p *podMountHints) findMount(mount specs.Mount) *mountHint { +func (p *podMountHints) findMount(mount *specs.Mount) *mountHint { for _, m := range p.mounts { if m.mount.Source == mount.Source { return m @@ -679,7 +681,8 @@ func (c *containerMounter) mountSubmounts(ctx context.Context, conf *config.Conf root := mns.Root() defer root.DecRef(ctx) - for _, m := range c.mounts { + for i := range c.mounts { + m := &c.mounts[i] log.Debugf("Mounting %q to %q, type: %s, options: %s", m.Source, m.Destination, m.Type, m.Options) if hint := c.hints.findMount(m); hint != nil && hint.isSupported() { if err := c.mountSharedSubmount(ctx, mns, root, m, hint); err != nil { @@ -714,7 +717,7 @@ func (c *containerMounter) checkDispenser() error { func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *config.Config, hint *mountHint) (*fs.Inode, error) { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. - fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, hint.mount) + fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, &hint.mount) if err != nil { return nil, err } @@ -734,7 +737,7 @@ func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *config.C mf.ReadOnly = true } - inode, err := filesystem.Mount(ctx, mountDevice(hint.mount), mf, strings.Join(opts, ","), nil) + inode, err := filesystem.Mount(ctx, mountDevice(&hint.mount), mf, strings.Join(opts, ","), nil) if err != nil { return nil, fmt.Errorf("creating mount %q: %v", hint.name, err) } @@ -796,13 +799,14 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *config.Con // getMountNameAndOptions retrieves the fsName, opts, and useOverlay values // used for mounts. -func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.Mount) (string, []string, bool, error) { +func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m *specs.Mount) (string, []string, bool, error) { + specutils.MaybeConvertToBindMount(m) + var ( fsName string opts []string useOverlay bool ) - switch m.Type { case devpts.Name, devtmpfs.Name, procvfs2.Name, sysvfs2.Name: fsName = m.Type @@ -836,7 +840,7 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.M return fsName, opts, useOverlay, nil } -func (c *containerMounter) getMountAccessType(conf *config.Config, mount specs.Mount) config.FileAccessType { +func (c *containerMounter) getMountAccessType(conf *config.Config, mount *specs.Mount) config.FileAccessType { if hint := c.hints.findMount(mount); hint != nil { return hint.fileAccessType() } @@ -847,7 +851,7 @@ func (c *containerMounter) getMountAccessType(conf *config.Config, mount specs.M // be readonly, a lower ramfs overlay is added to create the mount point dir. // Another overlay is added with tmpfs on top if Config.Overlay is true. // 'm.Destination' must be an absolute path with '..' and symlinks resolved. -func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Config, mns *fs.MountNamespace, root *fs.Dirent, m specs.Mount) error { +func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Config, mns *fs.MountNamespace, root *fs.Dirent, m *specs.Mount) error { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m) @@ -921,7 +925,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Confi // mountSharedSubmount binds mount to a previously mounted volume that is shared // among containers in the same pod. -func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.MountNamespace, root *fs.Dirent, mount specs.Mount, source *mountHint) error { +func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.MountNamespace, root *fs.Dirent, mount *specs.Mount, source *mountHint) error { if err := source.checkCompatible(mount); err != nil { return err } @@ -946,7 +950,7 @@ func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.Moun // addRestoreMount adds a mount to the MountSources map used for restoring a // checkpointed container. -func (c *containerMounter) addRestoreMount(conf *config.Config, renv *fs.RestoreEnvironment, m specs.Mount) error { +func (c *containerMounter) addRestoreMount(conf *config.Config, renv *fs.RestoreEnvironment, m *specs.Mount) error { fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m) if err != nil { return err @@ -994,7 +998,8 @@ func (c *containerMounter) createRestoreEnvironment(conf *config.Config) (*fs.Re // Add submounts. var tmpMounted bool - for _, m := range c.mounts { + for i := range c.mounts { + m := &c.mounts[i] if err := c.addRestoreMount(conf, renv, m); err != nil { return nil, err } @@ -1009,7 +1014,7 @@ func (c *containerMounter) createRestoreEnvironment(conf *config.Config) (*fs.Re Type: tmpfsvfs2.Name, Destination: "/tmp", } - if err := c.addRestoreMount(conf, renv, tmpMount); err != nil { + if err := c.addRestoreMount(conf, renv, &tmpMount); err != nil { return nil, err } } @@ -1068,7 +1073,7 @@ func (c *containerMounter) mountTmp(ctx context.Context, conf *config.Config, mn // another user. This is normally done for /tmp. Options: []string{"mode=01777"}, } - return c.mountSubmount(ctx, conf, mns, root, tmpMount) + return c.mountSubmount(ctx, conf, mns, root, &tmpMount) default: return err diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go index b4f12d034..09ffda628 100644 --- a/runsc/boot/fs_test.go +++ b/runsc/boot/fs_test.go @@ -244,7 +244,7 @@ func TestGetMountAccessType(t *testing.T) { } mounter := containerMounter{hints: podHints} conf := &config.Config{FileAccessMounts: config.FileAccessShared} - if got := mounter.getMountAccessType(conf, specs.Mount{Source: source}); got != tst.want { + if got := mounter.getMountAccessType(conf, &specs.Mount{Source: source}); got != tst.want { t.Errorf("getMountAccessType(), want: %v, got: %v", tst.want, got) } }) diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 25f06165f..10f2d3d35 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -230,6 +230,33 @@ func New(args Args) (*Loader, error) { vfs2.Override() } + // Make host FDs stable between invocations. Host FDs must map to the exact + // same number when the sandbox is restored. Otherwise the wrong FD will be + // used. + info := containerInfo{} + newfd := startingStdioFD + + for _, stdioFD := range args.StdioFDs { + // Check that newfd is unused to avoid clobbering over it. + if _, err := unix.FcntlInt(uintptr(newfd), unix.F_GETFD, 0); !errors.Is(err, unix.EBADF) { + if err != nil { + return nil, fmt.Errorf("error checking for FD (%d) conflict: %w", newfd, err) + } + return nil, fmt.Errorf("unable to remap stdios, FD %d is already in use", newfd) + } + + err := unix.Dup3(stdioFD, newfd, unix.O_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("dup3 of stdios failed: %w", err) + } + info.stdioFDs = append(info.stdioFDs, fd.New(newfd)) + _ = unix.Close(stdioFD) + newfd++ + } + for _, goferFD := range args.GoferFDs { + info.goferFDs = append(info.goferFDs, fd.New(goferFD)) + } + // Create kernel and platform. p, err := createPlatform(args.Conf, args.Device) if err != nil { @@ -349,6 +376,7 @@ func New(args Args) (*Loader, error) { if err != nil { return nil, fmt.Errorf("creating init process for root container: %v", err) } + info.procArgs = procArgs if err := initCompatLogs(args.UserLogFD); err != nil { return nil, fmt.Errorf("initializing compat logs: %v", err) @@ -359,6 +387,9 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("creating pod mount hints: %v", err) } + info.conf = args.Conf + info.spec = args.Spec + if kernel.VFS2Enabled { // Set up host mount that will be used for imported fds. hostFilesystem, err := hostvfs2.NewFilesystem(k.VFS()) @@ -373,37 +404,6 @@ func New(args Args) (*Loader, error) { k.SetHostMount(hostMount) } - info := containerInfo{ - conf: args.Conf, - spec: args.Spec, - procArgs: procArgs, - } - - // Make host FDs stable between invocations. Host FDs must map to the exact - // same number when the sandbox is restored. Otherwise the wrong FD will be - // used. - newfd := startingStdioFD - for _, stdioFD := range args.StdioFDs { - // Check that newfd is unused to avoid clobbering over it. - if _, err := unix.FcntlInt(uintptr(newfd), unix.F_GETFD, 0); !errors.Is(err, unix.EBADF) { - if err != nil { - return nil, fmt.Errorf("error checking for FD (%d) conflict: %w", newfd, err) - } - return nil, fmt.Errorf("unable to remap stdios, FD %d is already in use", newfd) - } - - err := unix.Dup3(stdioFD, newfd, unix.O_CLOEXEC) - if err != nil { - return nil, fmt.Errorf("dup3 of stdios failed: %w", err) - } - info.stdioFDs = append(info.stdioFDs, fd.New(newfd)) - _ = unix.Close(stdioFD) - newfd++ - } - for _, goferFD := range args.GoferFDs { - info.goferFDs = append(info.goferFDs, fd.New(goferFD)) - } - eid := execID{cid: args.ID} l := &Loader{ k: k, diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index 7d8fd0483..7be5176b0 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -46,6 +46,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/runsc/config" + "gvisor.dev/gvisor/runsc/specutils" ) func registerFilesystems(k *kernel.Kernel) error { @@ -362,33 +363,33 @@ func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *config. for i := range mounts { submount := &mounts[i] - log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options) + log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.mount.Source, submount.mount.Destination, submount.mount.Type, submount.mount.Options) var ( mnt *vfs.Mount err error ) - if hint := c.hints.findMount(submount.Mount); hint != nil && hint.isSupported() { - mnt, err = c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.Mount, hint) + if hint := c.hints.findMount(submount.mount); hint != nil && hint.isSupported() { + mnt, err = c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.mount, hint) if err != nil { - return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, submount.Destination, err) + return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, submount.mount.Destination, err) } } else { mnt, err = c.mountSubmountVFS2(ctx, conf, mns, creds, submount) if err != nil { - return fmt.Errorf("mount submount %q: %w", submount.Destination, err) + return fmt.Errorf("mount submount %q: %w", submount.mount.Destination, err) } } if mnt != nil && mnt.ReadOnly() { // Switch to ReadWrite while we setup submounts. if err := c.k.VFS().SetMountReadOnly(mnt, false); err != nil { - return fmt.Errorf("failed to set mount at %q readwrite: %w", submount.Destination, err) + return fmt.Errorf("failed to set mount at %q readwrite: %w", submount.mount.Destination, err) } // Restore back to ReadOnly at the end. defer func() { if err := c.k.VFS().SetMountReadOnly(mnt, true); err != nil { - panic(fmt.Sprintf("failed to restore mount at %q back to readonly: %v", submount.Destination, err)) + panic(fmt.Sprintf("failed to restore mount at %q back to readonly: %v", submount.mount.Destination, err)) } }() } @@ -401,8 +402,8 @@ func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *config. } type mountAndFD struct { - specs.Mount - fd int + mount *specs.Mount + fd int } func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) { @@ -410,15 +411,18 @@ func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) { // undocumented assumption that FDs are dispensed in the order in which // they are required by mounts. var mounts []mountAndFD - for _, m := range c.mounts { - fd := -1 + for i := range c.mounts { + m := &c.mounts[i] + specutils.MaybeConvertToBindMount(m) + // Only bind mounts use host FDs; see // containerMounter.getMountNameAndOptionsVFS2. + fd := -1 if m.Type == bind { fd = c.fds.remove() } mounts = append(mounts, mountAndFD{ - Mount: m, + mount: m, fd: fd, }) } @@ -428,7 +432,7 @@ func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) { // Sort the mounts so that we don't place children before parents. sort.Slice(mounts, func(i, j int) bool { - return len(mounts[i].Destination) < len(mounts[j].Destination) + return len(mounts[i].mount.Destination) < len(mounts[j].mount.Destination) }) return mounts, nil @@ -444,16 +448,16 @@ func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *config.C return nil, nil } - if err := c.makeMountPoint(ctx, creds, mns, submount.Destination); err != nil { - return nil, fmt.Errorf("creating mount point %q: %w", submount.Destination, err) + if err := c.makeMountPoint(ctx, creds, mns, submount.mount.Destination); err != nil { + return nil, fmt.Errorf("creating mount point %q: %w", submount.mount.Destination, err) } if useOverlay { - log.Infof("Adding overlay on top of mount %q", submount.Destination) + log.Infof("Adding overlay on top of mount %q", submount.mount.Destination) var cleanup func() opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) if err != nil { - return nil, fmt.Errorf("mounting volume with overlay at %q: %w", submount.Destination, err) + return nil, fmt.Errorf("mounting volume with overlay at %q: %w", submount.mount.Destination, err) } defer cleanup() fsName = overlay.Name @@ -465,32 +469,34 @@ func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *config.C target := &vfs.PathOperation{ Root: root, Start: root, - Path: fspath.Parse(submount.Destination), + Path: fspath.Parse(submount.mount.Destination), } mnt, err := c.k.VFS().MountAt(ctx, creds, "", target, fsName, opts) if err != nil { - return nil, fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts) + return nil, fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.mount.Destination, submount.mount.Type, err, opts) } - log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.Source, submount.Destination, submount.Type, opts.GetFilesystemOptions.Data) + log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.mount.Source, submount.mount.Destination, submount.mount.Type, opts.GetFilesystemOptions.Data) return mnt, nil } // getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values // used for mounts. func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mountAndFD) (string, *vfs.MountOptions, bool, error) { - fsName := m.Type + fsName := m.mount.Type useOverlay := false - var data []string - var iopts interface{} + var ( + data []string + internalData interface{} + ) - verityData, verityOpts, verityRequested, remainingMOpts, err := parseVerityMountOptions(m.Options) + verityData, verityOpts, verityRequested, remainingMOpts, err := parseVerityMountOptions(m.mount.Options) if err != nil { return "", nil, false, err } - m.Options = remainingMOpts + m.mount.Options = remainingMOpts // Find filesystem name and FS specific data field. - switch m.Type { + switch m.mount.Type { case devpts.Name, devtmpfs.Name, proc.Name, sys.Name: // Nothing to do. @@ -499,7 +505,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo case tmpfs.Name: var err error - data, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...) + data, err = parseAndFilterOptions(m.mount.Options, tmpfsAllowedData...) if err != nil { return "", nil, false, err } @@ -511,35 +517,35 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo // but unlikely to be correct in this context. return "", nil, false, fmt.Errorf("9P mount requires a connection FD") } - data = p9MountData(m.fd, c.getMountAccessType(conf, m.Mount), true /* vfs2 */) - iopts = gofer.InternalFilesystemOptions{ - UniqueID: m.Destination, + data = p9MountData(m.fd, c.getMountAccessType(conf, m.mount), true /* vfs2 */) + internalData = gofer.InternalFilesystemOptions{ + UniqueID: m.mount.Destination, } // If configured, add overlay to all writable mounts. - useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly + useOverlay = conf.Overlay && !mountFlags(m.mount.Options).ReadOnly case cgroupfs.Name: var err error - data, err = parseAndFilterOptions(m.Options, cgroupfs.SupportedMountOptions...) + data, err = parseAndFilterOptions(m.mount.Options, cgroupfs.SupportedMountOptions...) if err != nil { return "", nil, false, err } default: - log.Warningf("ignoring unknown filesystem type %q", m.Type) + log.Warningf("ignoring unknown filesystem type %q", m.mount.Type) return "", nil, false, nil } opts := &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ Data: strings.Join(data, ","), - InternalData: iopts, + InternalData: internalData, }, InternalMount: true, } - for _, o := range m.Options { + for _, o := range m.mount.Options { switch o { case "rw": opts.ReadOnly = false @@ -549,13 +555,15 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo opts.Flags.NoATime = true case "noexec": opts.Flags.NoExec = true + case "bind", "rbind": + // These are the same as a mount with type="bind". default: log.Warningf("ignoring unknown mount option %q", o) } } if verityRequested { - verityData = verityData + "root_name=" + path.Base(m.Mount.Destination) + verityData = verityData + "root_name=" + path.Base(m.mount.Destination) verityOpts.LowerName = fsName verityOpts.LowerGetFSOptions = opts.GetFilesystemOptions fsName = verity.Name @@ -649,7 +657,6 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *config.Config Start: root, Path: fspath.Parse("/tmp"), } - // TODO(gvisor.dev/issue/2782): Use O_PATH when available. fd, err := c.k.VFS().OpenAt(ctx, creds, &pop, &vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_DIRECTORY}) switch err { case nil: @@ -684,7 +691,7 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *config.Config // another user. This is normally done for /tmp. Options: []string{"mode=01777"}, } - _, err := c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{Mount: tmpMount}) + _, err := c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{mount: &tmpMount}) return err case syserror.ENOTDIR: @@ -723,7 +730,7 @@ func (c *containerMounter) processHintsVFS2(conf *config.Config, creds *auth.Cre func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *config.Config, hint *mountHint, creds *auth.Credentials) (*vfs.Mount, error) { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. - mntFD := &mountAndFD{Mount: hint.mount} + mntFD := &mountAndFD{mount: &hint.mount} fsName, opts, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, mntFD) if err != nil { return nil, err @@ -733,11 +740,11 @@ func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *conf } if useOverlay { - log.Infof("Adding overlay on top of shared mount %q", mntFD.Destination) + log.Infof("Adding overlay on top of shared mount %q", mntFD.mount.Destination) var cleanup func() opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) if err != nil { - return nil, fmt.Errorf("mounting shared volume with overlay at %q: %w", mntFD.Destination, err) + return nil, fmt.Errorf("mounting shared volume with overlay at %q: %w", mntFD.mount.Destination, err) } defer cleanup() fsName = overlay.Name @@ -748,14 +755,14 @@ func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *conf // mountSharedSubmount binds mount to a previously mounted volume that is shared // among containers in the same pod. -func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount specs.Mount, source *mountHint) (*vfs.Mount, error) { +func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount *specs.Mount, source *mountHint) (*vfs.Mount, error) { if err := source.checkCompatible(mount); err != nil { return nil, err } // Ignore data and useOverlay because these were already applied to // the master mount. - _, opts, _, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount}) + _, opts, _, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{mount: mount}) if err != nil { return nil, err } @@ -808,7 +815,7 @@ func (c *containerMounter) makeMountPoint(ctx context.Context, creds *auth.Crede // configureRestore returns an updated context.Context including filesystem // state used by restore defined by conf. -func (c *containerMounter) configureRestore(ctx context.Context, conf *config.Config) (context.Context, error) { +func (c *containerMounter) configureRestore(ctx context.Context) (context.Context, error) { fdmap := make(map[string]int) fdmap["/"] = c.fds.remove() mounts, err := c.prepareMountsVFS2() @@ -818,7 +825,7 @@ func (c *containerMounter) configureRestore(ctx context.Context, conf *config.Co for i := range c.mounts { submount := &mounts[i] if submount.fd >= 0 { - fdmap[submount.Destination] = submount.fd + fdmap[submount.mount.Destination] = submount.fd } } return context.WithValue(ctx, gofer.CtxRestoreServerFDMap, fdmap), nil diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index 438b7ef3e..66a6a0f68 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -40,23 +40,24 @@ const ( cgroupRoot = "/sys/fs/cgroup" ) -var controllers = map[string]config{ - "blkio": {ctrlr: &blockIO{}}, - "cpu": {ctrlr: &cpu{}}, - "cpuset": {ctrlr: &cpuSet{}}, - "hugetlb": {ctrlr: &hugeTLB{}, optional: true}, - "memory": {ctrlr: &memory{}}, - "net_cls": {ctrlr: &networkClass{}}, - "net_prio": {ctrlr: &networkPrio{}}, - "pids": {ctrlr: &pids{}}, +var controllers = map[string]controller{ + "blkio": &blockIO{}, + "cpu": &cpu{}, + "cpuset": &cpuSet{}, + "hugetlb": &hugeTLB{}, + "memory": &memory{}, + "net_cls": &networkClass{}, + "net_prio": &networkPrio{}, + "pids": &pids{}, // These controllers either don't have anything in the OCI spec or is // irrelevant for a sandbox. - "devices": {ctrlr: &noop{}}, - "freezer": {ctrlr: &noop{}}, - "perf_event": {ctrlr: &noop{}}, - "rdma": {ctrlr: &noop{}, optional: true}, - "systemd": {ctrlr: &noop{}}, + "cpuacct": &noop{}, + "devices": &noop{}, + "freezer": &noop{}, + "perf_event": &noop{}, + "rdma": &noop{isOptional: true}, + "systemd": &noop{}, } // IsOnlyV2 checks whether cgroups V2 is enabled and V1 is not. @@ -201,31 +202,26 @@ func countCpuset(cpuset string) (int, error) { return count, nil } -// LoadPaths loads cgroup paths for given 'pid', may be set to 'self'. -func LoadPaths(pid string) (map[string]string, error) { - f, err := os.Open(filepath.Join("/proc", pid, "cgroup")) +// loadPaths loads cgroup paths for given 'pid', may be set to 'self'. +func loadPaths(pid string) (map[string]string, error) { + procCgroup, err := os.Open(filepath.Join("/proc", pid, "cgroup")) if err != nil { return nil, err } - defer f.Close() + defer procCgroup.Close() - return loadPathsHelper(f) -} - -func loadPathsHelper(cgroup io.Reader) (map[string]string, error) { - // For nested containers, in /proc/self/cgroup we see paths from host, - // which don't exist in container, so recover the container paths here by - // double-checking with /proc/pid/mountinfo - mountinfo, err := os.Open("/proc/self/mountinfo") + // Load mountinfo for the current process, because it's where cgroups is + // being accessed from. + mountinfo, err := os.Open(filepath.Join("/proc/self/mountinfo")) if err != nil { return nil, err } defer mountinfo.Close() - return loadPathsHelperWithMountinfo(cgroup, mountinfo) + return loadPathsHelper(procCgroup, mountinfo) } -func loadPathsHelperWithMountinfo(cgroup, mountinfo io.Reader) (map[string]string, error) { +func loadPathsHelper(cgroup, mountinfo io.Reader) (map[string]string, error) { paths := make(map[string]string) scanner := bufio.NewScanner(cgroup) @@ -242,34 +238,51 @@ func loadPathsHelperWithMountinfo(cgroup, mountinfo io.Reader) (map[string]strin for _, ctrlr := range strings.Split(tokens[1], ",") { // Remove prefix for cgroups with no controller, eg. systemd. ctrlr = strings.TrimPrefix(ctrlr, "name=") - paths[ctrlr] = tokens[2] + // Discard unknown controllers. + if _, ok := controllers[ctrlr]; ok { + paths[ctrlr] = tokens[2] + } } } if err := scanner.Err(); err != nil { return nil, err } - mfScanner := bufio.NewScanner(mountinfo) - for mfScanner.Scan() { - txt := mfScanner.Text() - fields := strings.Fields(txt) + // For nested containers, in /proc/[pid]/cgroup we see paths from host, + // which don't exist in container, so recover the container paths here by + // double-checking with /proc/[pid]/mountinfo + mountScanner := bufio.NewScanner(mountinfo) + for mountScanner.Scan() { + // Format: ID parent major:minor root mount-point options opt-fields - fs-type source super-options + // Example: 39 32 0:34 / /sys/fs/cgroup/devices rw,noexec shared:18 - cgroup cgroup rw,devices + fields := strings.Fields(mountScanner.Text()) if len(fields) < 9 || fields[len(fields)-3] != "cgroup" { + // Skip mounts that are not cgroup mounts. continue } - for _, opt := range strings.Split(fields[len(fields)-1], ",") { + // Cgroup controller type is in the super-options field. + superOptions := strings.Split(fields[len(fields)-1], ",") + for _, opt := range superOptions { // Remove prefix for cgroups with no controller, eg. systemd. opt = strings.TrimPrefix(opt, "name=") + + // Only considers cgroup controllers that are registered, and skip other + // irrelevant options, e.g. rw. if cgroupPath, ok := paths[opt]; ok { - root := fields[3] - relCgroupPath, err := filepath.Rel(root, cgroupPath) - if err != nil { - return nil, err + rootDir := fields[3] + if rootDir != "/" { + // When cgroup is in submount, remove repeated path components from + // cgroup path to avoid duplicating them. + relCgroupPath, err := filepath.Rel(rootDir, cgroupPath) + if err != nil { + return nil, err + } + paths[opt] = relCgroupPath } - paths[opt] = relCgroupPath } } } - if err := mfScanner.Err(); err != nil { + if err := mountScanner.Err(); err != nil { return nil, err } @@ -279,70 +292,95 @@ func loadPathsHelperWithMountinfo(cgroup, mountinfo io.Reader) (map[string]strin // Cgroup represents a group inside all controllers. For example: // Name='/foo/bar' maps to /sys/fs/cgroup/<controller>/foo/bar on // all controllers. +// +// If Name is relative, it uses the parent cgroup path to determine the +// location. For example: +// Name='foo/bar' and Parent[ctrl]="/user.slice", then it will map to +// /sys/fs/cgroup/<ctrl>/user.slice/foo/bar type Cgroup struct { Name string `json:"name"` Parents map[string]string `json:"parents"` Own map[string]bool `json:"own"` } -// New creates a new Cgroup instance if the spec includes a cgroup path. -// Returns nil otherwise. -func New(spec *specs.Spec) (*Cgroup, error) { +// NewFromSpec creates a new Cgroup instance if the spec includes a cgroup path. +// Returns nil otherwise. Cgroup paths are loaded based on the current process. +func NewFromSpec(spec *specs.Spec) (*Cgroup, error) { if spec.Linux == nil || spec.Linux.CgroupsPath == "" { return nil, nil } - return NewFromPath(spec.Linux.CgroupsPath) + return new("self", spec.Linux.CgroupsPath) } -// NewFromPath creates a new Cgroup instance. -func NewFromPath(cgroupsPath string) (*Cgroup, error) { +// NewFromPid loads cgroup for the given process. +func NewFromPid(pid int) (*Cgroup, error) { + return new(strconv.Itoa(pid), "") +} + +func new(pid, cgroupsPath string) (*Cgroup, error) { var parents map[string]string + + // If path is relative, load cgroup paths for the process to build the + // relative paths. if !filepath.IsAbs(cgroupsPath) { var err error - parents, err = LoadPaths("self") + parents, err = loadPaths(pid) if err != nil { return nil, fmt.Errorf("finding current cgroups: %w", err) } } - own := make(map[string]bool) - return &Cgroup{ + cg := &Cgroup{ Name: cgroupsPath, Parents: parents, - Own: own, - }, nil + Own: make(map[string]bool), + } + log.Debugf("New cgroup for pid: %s, %+v", pid, cg) + return cg, nil } // Install creates and configures cgroups according to 'res'. If cgroup path // already exists, it means that the caller has already provided a // pre-configured cgroups, and 'res' is ignored. func (c *Cgroup) Install(res *specs.LinuxResources) error { - log.Debugf("Creating cgroup %q", c.Name) + log.Debugf("Installing cgroup path %q", c.Name) - // The Cleanup object cleans up partially created cgroups when an error occurs. - // Errors occuring during cleanup itself are ignored. + // Clean up partially created cgroups on error. Errors during cleanup itself + // are ignored. clean := cleanup.Make(func() { _ = c.Uninstall() }) defer clean.Clean() - for key, cfg := range controllers { - path := c.makePath(key) - if _, err := os.Stat(path); err == nil { - // If cgroup has already been created; it has been setup by caller. Don't - // make any changes to configuration, just join when sandbox/gofer starts. - log.Debugf("Using pre-created cgroup %q", path) - continue + // Controllers can be symlinks to a group of controllers (e.g. cpu,cpuacct). + // So first check what directories need to be created. Otherwise, when + // the directory for one of the controllers in a group is created, it will + // make it seem like the directory already existed and it's not owned by the + // other controllers in the group. + var missing []string + for key := range controllers { + path := c.MakePath(key) + if _, err := os.Stat(path); err != nil { + missing = append(missing, key) + } else { + log.Debugf("Using pre-created cgroup %q: %q", key, path) } - - // Mark that cgroup resources are owned by me. - c.Own[key] = true - + } + for _, key := range missing { + ctrlr := controllers[key] + path := c.MakePath(key) + log.Debugf("Creating cgroup %q: %q", key, path) if err := os.MkdirAll(path, 0755); err != nil { - if cfg.optional && errors.Is(err, unix.EROFS) { + if ctrlr.optional() && errors.Is(err, unix.EROFS) { + if err := ctrlr.skip(res); err != nil { + return err + } log.Infof("Skipping cgroup %q", key) continue } return err } - if err := cfg.ctrlr.set(res, path); err != nil { + + // Only set controllers that were created by me. + c.Own[key] = true + if err := ctrlr.set(res, path); err != nil { return err } } @@ -359,7 +397,7 @@ func (c *Cgroup) Uninstall() error { // cgroup is managed by caller, don't touch it. continue } - path := c.makePath(key) + path := c.MakePath(key) log.Debugf("Removing cgroup controller for key=%q path=%q", key, path) // If we try to remove the cgroup too soon after killing the @@ -387,7 +425,7 @@ func (c *Cgroup) Uninstall() error { func (c *Cgroup) Join() (func(), error) { // First save the current state so it can be restored. undo := func() {} - paths, err := LoadPaths("self") + paths, err := loadPaths("self") if err != nil { return undo, err } @@ -414,14 +452,13 @@ func (c *Cgroup) Join() (func(), error) { } // Now join the cgroups. - for key, cfg := range controllers { - path := c.makePath(key) + for key, ctrlr := range controllers { + path := c.MakePath(key) log.Debugf("Joining cgroup %q", path) - // Writing the value 0 to a cgroup.procs file causes the - // writing process to be moved to the corresponding cgroup. - // - cgroups(7). + // Writing the value 0 to a cgroup.procs file causes the writing process to + // be moved to the corresponding cgroup - cgroups(7). if err := setValue(path, "cgroup.procs", "0"); err != nil { - if cfg.optional && os.IsNotExist(err) { + if ctrlr.optional() && os.IsNotExist(err) { continue } return undo, err @@ -432,7 +469,7 @@ func (c *Cgroup) Join() (func(), error) { // CPUQuota returns the CFS CPU quota. func (c *Cgroup) CPUQuota() (float64, error) { - path := c.makePath("cpu") + path := c.MakePath("cpu") quota, err := getInt(path, "cpu.cfs_quota_us") if err != nil { return -1, err @@ -449,7 +486,7 @@ func (c *Cgroup) CPUQuota() (float64, error) { // CPUUsage returns the total CPU usage of the cgroup. func (c *Cgroup) CPUUsage() (uint64, error) { - path := c.makePath("cpuacct") + path := c.MakePath("cpuacct") usage, err := getValue(path, "cpuacct.usage") if err != nil { return 0, err @@ -459,7 +496,7 @@ func (c *Cgroup) CPUUsage() (uint64, error) { // NumCPU returns the number of CPUs configured in 'cpuset/cpuset.cpus'. func (c *Cgroup) NumCPU() (int, error) { - path := c.makePath("cpuset") + path := c.MakePath("cpuset") cpuset, err := getValue(path, "cpuset.cpus") if err != nil { return 0, err @@ -469,7 +506,7 @@ func (c *Cgroup) NumCPU() (int, error) { // MemoryLimit returns the memory limit. func (c *Cgroup) MemoryLimit() (uint64, error) { - path := c.makePath("memory") + path := c.MakePath("memory") limStr, err := getValue(path, "memory.limit_in_bytes") if err != nil { return 0, err @@ -477,7 +514,8 @@ func (c *Cgroup) MemoryLimit() (uint64, error) { return strconv.ParseUint(strings.TrimSpace(limStr), 10, 64) } -func (c *Cgroup) makePath(controllerName string) string { +// MakePath builds a path to the given controller. +func (c *Cgroup) MakePath(controllerName string) string { path := c.Name if parent, ok := c.Parents[controllerName]; ok { path = filepath.Join(parent, c.Name) @@ -485,22 +523,48 @@ func (c *Cgroup) makePath(controllerName string) string { return filepath.Join(cgroupRoot, controllerName, path) } -type config struct { - ctrlr controller - optional bool -} - type controller interface { + // optional controllers don't fail if not found. + optional() bool + // set applies resource limits to controller. set(*specs.LinuxResources, string) error + // skip is called when controller is not found to check if it can be safely + // skipped or not based on the spec. + skip(*specs.LinuxResources) error } -type noop struct{} +type noop struct { + isOptional bool +} + +func (n *noop) optional() bool { + return n.isOptional +} func (*noop) set(*specs.LinuxResources, string) error { return nil } -type memory struct{} +func (n *noop) skip(*specs.LinuxResources) error { + if !n.isOptional { + panic("cgroup controller is not optional") + } + return nil +} + +type mandatory struct{} + +func (*mandatory) optional() bool { + return false +} + +func (*mandatory) skip(*specs.LinuxResources) error { + panic("cgroup controller is not optional") +} + +type memory struct { + mandatory +} func (*memory) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.Memory == nil { @@ -533,7 +597,9 @@ func (*memory) set(spec *specs.LinuxResources, path string) error { return nil } -type cpu struct{} +type cpu struct { + mandatory +} func (*cpu) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.CPU == nil { @@ -554,7 +620,9 @@ func (*cpu) set(spec *specs.LinuxResources, path string) error { return setOptionalValueInt(path, "cpu.rt_runtime_us", spec.CPU.RealtimeRuntime) } -type cpuSet struct{} +type cpuSet struct { + mandatory +} func (*cpuSet) set(spec *specs.LinuxResources, path string) error { // cpuset.cpus and mems are required fields, but are not set on a new cgroup. @@ -576,7 +644,9 @@ func (*cpuSet) set(spec *specs.LinuxResources, path string) error { return setValue(path, "cpuset.mems", spec.CPU.Mems) } -type blockIO struct{} +type blockIO struct { + mandatory +} func (*blockIO) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.BlockIO == nil { @@ -628,6 +698,10 @@ func setThrottle(path, name string, devs []specs.LinuxThrottleDevice) error { type networkClass struct{} +func (*networkClass) optional() bool { + return true +} + func (*networkClass) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.Network == nil { return nil @@ -635,8 +709,19 @@ func (*networkClass) set(spec *specs.LinuxResources, path string) error { return setOptionalValueUint32(path, "net_cls.classid", spec.Network.ClassID) } +func (*networkClass) skip(spec *specs.LinuxResources) error { + if spec != nil && spec.Network != nil && spec.Network.ClassID != nil { + return fmt.Errorf("Network.ClassID set but net_cls cgroup controller not found") + } + return nil +} + type networkPrio struct{} +func (*networkPrio) optional() bool { + return true +} + func (*networkPrio) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.Network == nil { return nil @@ -650,7 +735,16 @@ func (*networkPrio) set(spec *specs.LinuxResources, path string) error { return nil } -type pids struct{} +func (*networkPrio) skip(spec *specs.LinuxResources) error { + if spec != nil && spec.Network != nil && len(spec.Network.Priorities) > 0 { + return fmt.Errorf("Network.Priorities set but net_prio cgroup controller not found") + } + return nil +} + +type pids struct { + mandatory +} func (*pids) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.Pids == nil || spec.Pids.Limit <= 0 { @@ -662,6 +756,17 @@ func (*pids) set(spec *specs.LinuxResources, path string) error { type hugeTLB struct{} +func (*hugeTLB) optional() bool { + return true +} + +func (*hugeTLB) skip(spec *specs.LinuxResources) error { + if spec != nil && len(spec.HugepageLimits) > 0 { + return fmt.Errorf("HugepageLimits set but hugetlb cgroup controller not found") + } + return nil +} + func (*hugeTLB) set(spec *specs.LinuxResources, path string) error { if spec == nil { return nil diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go index 48d71cfa6..eba40621e 100644 --- a/runsc/cgroup/cgroup_test.go +++ b/runsc/cgroup/cgroup_test.go @@ -43,27 +43,27 @@ var debianMountinfo = ` ` var dindMountinfo = ` -1305 1304 0:64 / /sys/fs/cgroup rw - tmpfs tmpfs rw,mode=755 -1306 1305 0:32 /docker/136 /sys/fs/cgroup/systemd ro master:11 - cgroup cgroup rw,xattr,name=systemd -1307 1305 0:36 /docker/136 /sys/fs/cgroup/cpu,cpuacct ro master:16 - cgroup cgroup rw,cpu,cpuacct -1308 1305 0:37 /docker/136 /sys/fs/cgroup/freezer ro master:17 - cgroup cgroup rw,freezer -1309 1305 0:38 /docker/136 /sys/fs/cgroup/hugetlb ro master:18 - cgroup cgroup rw,hugetlb -1310 1305 0:39 /docker/136 /sys/fs/cgroup/cpuset ro master:19 - cgroup cgroup rw,cpuset -1311 1305 0:40 /docker/136 /sys/fs/cgroup/net_cls,net_prio ro master:20 - cgroup cgroup rw,net_cls,net_prio -1312 1305 0:41 /docker/136 /sys/fs/cgroup/pids ro master:21 - cgroup cgroup rw,pids -1313 1305 0:42 /docker/136 /sys/fs/cgroup/perf_event ro master:22 - cgroup cgroup rw,perf_event -1314 1305 0:43 /docker/136 /sys/fs/cgroup/memory ro master:23 - cgroup cgroup rw,memory -1316 1305 0:44 /docker/136 /sys/fs/cgroup/blkio ro master:24 - cgroup cgroup rw,blkio -1317 1305 0:45 /docker/136 /sys/fs/cgroup/devices ro master:25 - cgroup cgroup rw,devices -1318 1305 0:46 / /sys/fs/cgroup/rdma ro master:26 - cgroup cgroup rw,rdma +05 04 0:64 / /sys/fs/cgroup rw - tmpfs tmpfs rw,mode=755 +06 05 0:32 /docker/136 /sys/fs/cgroup/systemd ro master:11 - cgroup cgroup rw,xattr,name=systemd +07 05 0:36 /docker/136 /sys/fs/cgroup/cpu,cpuacct ro master:16 - cgroup cgroup rw,cpu,cpuacct +08 05 0:37 /docker/136 /sys/fs/cgroup/freezer ro master:17 - cgroup cgroup rw,freezer +09 05 0:38 /docker/136 /sys/fs/cgroup/hugetlb ro master:18 - cgroup cgroup rw,hugetlb +10 05 0:39 /docker/136 /sys/fs/cgroup/cpuset ro master:19 - cgroup cgroup rw,cpuset +11 05 0:40 /docker/136 /sys/fs/cgroup/net_cls,net_prio ro master:20 - cgroup cgroup rw,net_cls,net_prio +12 05 0:41 /docker/136 /sys/fs/cgroup/pids ro master:21 - cgroup cgroup rw,pids +13 05 0:42 /docker/136 /sys/fs/cgroup/perf_event ro master:22 - cgroup cgroup rw,perf_event +14 05 0:43 /docker/136 /sys/fs/cgroup/memory ro master:23 - cgroup cgroup rw,memory +16 05 0:44 /docker/136 /sys/fs/cgroup/blkio ro master:24 - cgroup cgroup rw,blkio +17 05 0:45 /docker/136 /sys/fs/cgroup/devices ro master:25 - cgroup cgroup rw,devices +18 05 0:46 / /sys/fs/cgroup/rdma ro master:26 - cgroup cgroup rw,rdma ` func TestUninstallEnoent(t *testing.T) { c := Cgroup{ - // set a non-existent name + // Use a non-existent name. Name: "runsc-test-uninstall-656e6f656e740a", + Own: make(map[string]bool), } - c.Own = make(map[string]bool) for key := range controllers { c.Own[key] = true } @@ -693,36 +693,42 @@ func TestLoadPaths(t *testing.T) { err string }{ { - name: "abs-path-unknown-controller", - cgroups: "0:ctr:/path", + name: "empty", mountinfo: debianMountinfo, - want: map[string]string{"ctr": "/path"}, + }, + { + name: "abs-path", + cgroups: "0:cpu:/path", + mountinfo: debianMountinfo, + want: map[string]string{"cpu": "/path"}, }, { name: "rel-path", - cgroups: "0:ctr:rel-path", + cgroups: "0:cpu:rel-path", mountinfo: debianMountinfo, - want: map[string]string{"ctr": "rel-path"}, + want: map[string]string{"cpu": "rel-path"}, }, { name: "non-controller", cgroups: "0:name=systemd:/path", mountinfo: debianMountinfo, - want: map[string]string{"systemd": "path"}, + want: map[string]string{"systemd": "/path"}, }, { - name: "empty", + name: "unknown-controller", + cgroups: "0:ctr:/path", mountinfo: debianMountinfo, + want: map[string]string{}, }, { name: "multiple", - cgroups: "0:ctr0:/path0\n" + - "1:ctr1:/path1\n" + + cgroups: "0:cpu:/path0\n" + + "1:memory:/path1\n" + "2::/empty\n", mountinfo: debianMountinfo, want: map[string]string{ - "ctr0": "/path0", - "ctr1": "/path1", + "cpu": "/path0", + "memory": "/path1", }, }, { @@ -747,10 +753,10 @@ func TestLoadPaths(t *testing.T) { }, { name: "nested-cgroup", - cgroups: `9:memory:/docker/136 -2:cpu,cpuacct:/docker/136 -1:name=systemd:/docker/136 -0::/system.slice/containerd.service`, + cgroups: "9:memory:/docker/136\n" + + "2:cpu,cpuacct:/docker/136\n" + + "1:name=systemd:/docker/136\n" + + "0::/system.slice/containerd.service\n", mountinfo: dindMountinfo, // we want relative path to /sys/fs/cgroup inside the nested container. // Subcroup inside the container will be created at /sys/fs/cgroup/cpu @@ -781,15 +787,15 @@ func TestLoadPaths(t *testing.T) { }, { name: "invalid-rel-path-in-proc-cgroup", - cgroups: "9:memory:./invalid", + cgroups: "9:memory:invalid", mountinfo: dindMountinfo, - err: "can't make ./invalid relative to /docker/136", + err: "can't make invalid relative to /docker/136", }, } { t.Run(tc.name, func(t *testing.T) { r := strings.NewReader(tc.cgroups) mountinfo := strings.NewReader(tc.mountinfo) - got, err := loadPathsHelperWithMountinfo(r, mountinfo) + got, err := loadPathsHelper(r, mountinfo) if len(tc.err) == 0 { if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -813,3 +819,47 @@ func TestLoadPaths(t *testing.T) { }) } } + +func TestOptional(t *testing.T) { + for _, tc := range []struct { + name string + ctrlr controller + spec *specs.LinuxResources + err string + }{ + { + name: "net-cls", + ctrlr: &networkClass{}, + spec: &specs.LinuxResources{Network: &specs.LinuxNetwork{ClassID: uint32Ptr(1)}}, + err: "Network.ClassID set but net_cls cgroup controller not found", + }, + { + name: "net-prio", + ctrlr: &networkPrio{}, + spec: &specs.LinuxResources{Network: &specs.LinuxNetwork{ + Priorities: []specs.LinuxInterfacePriority{ + {Name: "foo", Priority: 1}, + }, + }}, + err: "Network.Priorities set but net_prio cgroup controller not found", + }, + { + name: "hugetlb", + ctrlr: &hugeTLB{}, + spec: &specs.LinuxResources{HugepageLimits: []specs.LinuxHugepageLimit{ + {Pagesize: "1", Limit: 2}, + }}, + err: "HugepageLimits set but hugetlb cgroup controller not found", + }, + } { + t.Run(tc.name, func(t *testing.T) { + err := tc.ctrlr.skip(tc.spec) + if err == nil { + t.Fatalf("ctrlr.skip() didn't fail") + } + if !strings.Contains(err.Error(), tc.err) { + t.Errorf("ctrlr.skip() want: *%s*, got: %q", tc.err, err) + } + }) + } +} diff --git a/runsc/cmd/mitigate.go b/runsc/cmd/mitigate.go index d37ab80ba..f4e65adb8 100644 --- a/runsc/cmd/mitigate.go +++ b/runsc/cmd/mitigate.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "io/ioutil" + "runtime" "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" @@ -72,6 +73,11 @@ func (m *Mitigate) SetFlags(f *flag.FlagSet) { // Execute implements subcommands.Command.Execute. func (m *Mitigate) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if runtime.GOARCH == "arm64" || runtime.GOARCH == "arm" { + log.Warningf("As ARM is not affected by MDS, mitigate does not support") + return subcommands.ExitFailure + } + if f.NArg() != 0 { f.Usage() return subcommands.ExitUsageError diff --git a/runsc/cmd/mitigate_test.go b/runsc/cmd/mitigate_test.go index 5a76667e3..2d3fef7c1 100644 --- a/runsc/cmd/mitigate_test.go +++ b/runsc/cmd/mitigate_test.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build amd64 + package cmd import ( diff --git a/runsc/container/container.go b/runsc/container/container.go index e72ada311..0820edaec 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -233,7 +233,7 @@ func New(conf *config.Config, args Args) (*Container, error) { } // Create and join cgroup before processes are created to ensure they are // part of the cgroup from the start (and all their children processes). - cg, err := cgroup.New(args.Spec) + cg, err := cgroup.NewFromSpec(args.Spec) if err != nil { return nil, err } @@ -1132,7 +1132,7 @@ func (c *Container) populateStats(event *boot.EventOut) { // account for the full cgroup CPU usage. We split cgroup usage // proportionally according to the sentry-internal usage measurements, // only counting Running containers. - log.Warningf("event.ContainerUsage: %v", event.ContainerUsage) + log.Debugf("event.ContainerUsage: %v", event.ContainerUsage) var containerUsage uint64 var allContainersUsage uint64 for ID, usage := range event.ContainerUsage { @@ -1142,7 +1142,7 @@ func (c *Container) populateStats(event *boot.EventOut) { } } - cgroup, err := c.Sandbox.FindCgroup() + cgroup, err := c.Sandbox.NewCGroup() if err != nil { // No cgroup, so rely purely on the sentry's accounting. log.Warningf("events: no cgroups") @@ -1159,17 +1159,18 @@ func (c *Container) populateStats(event *boot.EventOut) { return } - // If the sentry reports no memory usage, fall back on cgroups and - // split usage equally across containers. + // If the sentry reports no CPU usage, fall back on cgroups and split usage + // equally across containers. if allContainersUsage == 0 { log.Warningf("events: no sentry CPU usage reported") allContainersUsage = cgroupsUsage containerUsage = cgroupsUsage / uint64(len(event.ContainerUsage)) } - log.Warningf("%f, %f, %f", containerUsage, cgroupsUsage, allContainersUsage) // Scaling can easily overflow a uint64 (e.g. a containerUsage and // cgroupsUsage of 16 seconds each will overflow), so use floats. - event.Event.Data.CPU.Usage.Total = uint64(float64(containerUsage) * (float64(cgroupsUsage) / float64(allContainersUsage))) + total := float64(containerUsage) * (float64(cgroupsUsage) / float64(allContainersUsage)) + log.Debugf("Usage, container: %d, cgroups: %d, all: %d, total: %.0f", containerUsage, cgroupsUsage, allContainersUsage, total) + event.Event.Data.CPU.Usage.Total = uint64(total) return } diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 5a0c468a4..0e79877b7 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -2449,6 +2449,27 @@ func TestCreateWithCorruptedStateFile(t *testing.T) { } } +func TestBindMountByOption(t *testing.T) { + for name, conf := range configs(t, all...) { + t.Run(name, func(t *testing.T) { + dir, err := ioutil.TempDir(testutil.TmpDir(), "bind-mount") + spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file")) + if err != nil { + t.Fatalf("ioutil.TempDir(): %v", err) + } + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: dir, + Source: dir, + Type: "none", + Options: []string{"rw", "bind"}, + }) + if err := run(spec, conf); err != nil { + t.Fatalf("error running sandbox: %v", err) + } + }) + } +} + func execute(cont *Container, name string, arg ...string) (unix.WaitStatus, error) { args := &control.ExecArgs{ Filename: name, diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 0f0a223ce..0dbe1e323 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cleanup" @@ -1510,7 +1511,7 @@ func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { Destination: "/mydir/test", Source: "/some/dir", Type: "tmpfs", - Options: []string{"rw", "rbind", "relatime"}, + Options: []string{"rw", "relatime"}, } podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) @@ -1917,9 +1918,9 @@ func TestMultiContainerEvent(t *testing.T) { } defer cleanup() - for _, cont := range containers { - t.Logf("Running containerd %s", cont.ID) - } + t.Logf("Running container sleep %s", containers[0].ID) + t.Logf("Running container busy %s", containers[1].ID) + t.Logf("Running container quick %s", containers[2].ID) // Wait for last container to stabilize the process count that is // checked further below. @@ -1940,50 +1941,61 @@ func TestMultiContainerEvent(t *testing.T) { } // Check events for running containers. - var prevUsage uint64 for _, cont := range containers[:2] { ret, err := cont.Event() if err != nil { - t.Errorf("Container.Events(): %v", err) + t.Errorf("Container.Event(%q): %v", cont.ID, err) } evt := ret.Event if want := "stats"; evt.Type != want { - t.Errorf("Wrong event type, want: %s, got: %s", want, evt.Type) + t.Errorf("Wrong event type, cid: %q, want: %s, got: %s", cont.ID, want, evt.Type) } if cont.ID != evt.ID { t.Errorf("Wrong container ID, want: %s, got: %s", cont.ID, evt.ID) } // One process per remaining container. if got, want := evt.Data.Pids.Current, uint64(2); got != want { - t.Errorf("Wrong number of PIDs, want: %d, got: %d", want, got) + t.Errorf("Wrong number of PIDs, cid: %q, want: %d, got: %d", cont.ID, want, got) } - // Both remaining containers should have nonzero usage, and - // 'busy' should have higher usage than 'sleep'. - usage := evt.Data.CPU.Usage.Total - if usage == 0 { - t.Errorf("Running container should report nonzero CPU usage, but got %d", usage) + // The exited container should always have a usage of zero. + if exited := ret.ContainerUsage[containers[2].ID]; exited != 0 { + t.Errorf("Exited container should report 0 CPU usage, got: %d", exited) + } + } + + // Check that CPU reported by busy container is higher than sleep. + cb := func() error { + sleepEvt, err := containers[0].Event() + if err != nil { + return &backoff.PermanentError{Err: err} } - if usage <= prevUsage { - t.Errorf("Expected container %s to use more than %d ns of CPU, but used %d", cont.ID, prevUsage, usage) + sleepUsage := sleepEvt.Event.Data.CPU.Usage.Total + + busyEvt, err := containers[1].Event() + if err != nil { + return &backoff.PermanentError{Err: err} } - t.Logf("Container %s usage: %d", cont.ID, usage) - prevUsage = usage + busyUsage := busyEvt.Event.Data.CPU.Usage.Total - // The exited container should have a usage of zero. - if exited := ret.ContainerUsage[containers[2].ID]; exited != 0 { - t.Errorf("Exited container should report 0 CPU usage, but got %d", exited) + if busyUsage <= sleepUsage { + t.Logf("Busy container usage lower than sleep (busy: %d, sleep: %d), retrying...", busyUsage, sleepUsage) + return fmt.Errorf("Busy container should have higher usage than sleep, busy: %d, sleep: %d", busyUsage, sleepUsage) } + return nil + } + // Give time for busy container to run and use more CPU than sleep. + if err := testutil.Poll(cb, 10*time.Second); err != nil { + t.Fatal(err) } - // Check that stop and destroyed containers return error. + // Check that stopped and destroyed containers return error. if err := containers[1].Destroy(); err != nil { t.Fatalf("container.Destroy: %v", err) } for _, cont := range containers[1:] { - _, err := cont.Event() - if err == nil { - t.Errorf("Container.Events() should have failed, cid:%s, state: %v", cont.ID, cont.Status) + if _, err := cont.Event(); err == nil { + t.Errorf("Container.Event() should have failed, cid: %q, state: %v", cont.ID, cont.Status) } } } diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index e04ddda47..3f362b25e 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -21,6 +21,7 @@ package fsgofer import ( + "errors" "fmt" "io" "math" @@ -58,9 +59,6 @@ var verityXattrs = map[string]struct{}{ // join is equivalent to path.Join() but skips path.Clean() which is expensive. func join(parent, child string) string { - if child == "." || child == ".." { - panic(fmt.Sprintf("invalid child path %q", child)) - } return parent + "/" + child } @@ -1226,3 +1224,60 @@ func (l *localFile) checkROMount() error { } return nil } + +func (l *localFile) MultiGetAttr(names []string) ([]p9.FullStat, error) { + stats := make([]p9.FullStat, 0, len(names)) + + if len(names) > 0 && names[0] == "" { + qid, valid, attr, err := l.GetAttr(p9.AttrMask{}) + if err != nil { + return nil, err + } + stats = append(stats, p9.FullStat{ + QID: qid, + Valid: valid, + Attr: attr, + }) + names = names[1:] + } + + parent := l.file.FD() + for _, name := range names { + child, err := unix.Openat(parent, name, openFlags|unix.O_PATH, 0) + if parent != l.file.FD() { + // Parent is no longer needed. + _ = unix.Close(parent) + parent = -1 + } + if err != nil { + if errors.Is(err, unix.ENOENT) { + // No pont in continuing any further. + break + } + return nil, err + } + + var stat unix.Stat_t + if err := unix.Fstat(child, &stat); err != nil { + _ = unix.Close(child) + return nil, err + } + valid, attr := l.fillAttr(&stat) + stats = append(stats, p9.FullStat{ + QID: l.attachPoint.makeQID(&stat), + Valid: valid, + Attr: attr, + }) + if (stat.Mode & unix.S_IFMT) != unix.S_IFDIR { + // Doesn't need to continue if entry is not a dir. Including symlinks + // that cannot be followed. + _ = unix.Close(child) + break + } + parent = child + } + if parent != -1 && parent != l.file.FD() { + _ = unix.Close(parent) + } + return stats, nil +} diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index d7e141476..77723827a 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -703,16 +703,6 @@ func TestWalkNotFound(t *testing.T) { }) } -func TestWalkPanic(t *testing.T) { - runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) { - for _, name := range []string{".", ".."} { - assertPanic(t, func() { - s.file.Walk([]string{name}) - }) - } - }) -} - func TestWalkDup(t *testing.T) { runAll(t, func(t *testing.T, s state) { _, dup, err := s.file.Walk([]string{}) diff --git a/runsc/mitigate/mitigate_test.go b/runsc/mitigate/mitigate_test.go index 3bf9ef547..890c65f05 100644 --- a/runsc/mitigate/mitigate_test.go +++ b/runsc/mitigate/mitigate_test.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build amd64 + package mitigate import ( diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index f3f60f116..29e202b7d 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -131,8 +131,9 @@ func New(conf *config.Config, args *Args) (*Sandbox, error) { // The Cleanup object cleans up partially created sandboxes when an error // occurs. Any errors occurring during cleanup itself are ignored. c := cleanup.Make(func() { - err := s.destroy() - log.Warningf("error destroying sandbox: %v", err) + if err := s.destroy(); err != nil { + log.Warningf("error destroying sandbox: %v", err) + } }) defer c.Clean() @@ -310,20 +311,9 @@ func (s *Sandbox) Processes(cid string) ([]*control.Process, error) { return pl, nil } -// FindCgroup returns the sandbox's Cgroup, or an error if it does not have one. -func (s *Sandbox) FindCgroup() (*cgroup.Cgroup, error) { - paths, err := cgroup.LoadPaths(strconv.Itoa(s.Pid)) - if err != nil { - return nil, err - } - // runsc places sandboxes in the same cgroup for each controller, so we - // pick an arbitrary controller here to get the cgroup path. - const controller = "cpuacct" - controllerPath, ok := paths[controller] - if !ok { - return nil, fmt.Errorf("no %q controller found", controller) - } - return cgroup.NewFromPath(controllerPath) +// NewCGroup returns the sandbox's Cgroup, or an error if it does not have one. +func (s *Sandbox) NewCGroup() (*cgroup.Cgroup, error) { + return cgroup.NewFromPid(s.Pid) } // Execute runs the specified command in the container. It returns the PID of diff --git a/runsc/specutils/seccomp/BUILD b/runsc/specutils/seccomp/BUILD index e9e647d82..c5f5b863e 100644 --- a/runsc/specutils/seccomp/BUILD +++ b/runsc/specutils/seccomp/BUILD @@ -28,8 +28,10 @@ go_test( srcs = ["seccomp_test.go"], library = ":seccomp", deps = [ - "//pkg/binary", + "//pkg/abi/linux", "//pkg/bpf", + "//pkg/hostarch", + "//pkg/marshal", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/runsc/specutils/seccomp/seccomp_test.go b/runsc/specutils/seccomp/seccomp_test.go index 11a6c8daa..20796bf14 100644 --- a/runsc/specutils/seccomp/seccomp_test.go +++ b/runsc/specutils/seccomp/seccomp_test.go @@ -20,20 +20,15 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" ) -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - -// asInput converts a seccompData to a bpf.Input. -func asInput(d seccompData) bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +// asInput converts a linux.SeccompData to a bpf.Input. +func asInput(d *linux.SeccompData) bpf.Input { + return bpf.InputBytes{marshal.Marshal(d), hostarch.ByteOrder} } // testInput creates an Input struct with given seccomp input values. @@ -49,13 +44,13 @@ func testInput(arch uint32, syscallName string, args *[6]uint64) bpf.Input { args = &argArray } - data := seccompData{ - nr: syscallNo, - arch: arch, - args: *args, + data := linux.SeccompData{ + Nr: int32(syscallNo), + Arch: arch, + Args: *args, } - return asInput(data) + return asInput(&data) } // testCase holds a seccomp test case. @@ -100,7 +95,7 @@ var ( }, // Syscall matches but the arch is AUDIT_ARCH_X86 so the return // value is the bad arch action. - input: asInput(seccompData{nr: 183, arch: 0x40000003}), // + input: asInput(&linux.SeccompData{Nr: 183, Arch: 0x40000003}), // expected: uint32(killThreadAction), }, { diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index e5e66546c..11b476690 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -335,9 +335,27 @@ func capsFromNames(names []string, skipSet map[linux.Capability]struct{}) (auth. // Is9PMount returns true if the given mount can be mounted as an external // gofer. func Is9PMount(m specs.Mount, vfs2Enabled bool) bool { + MaybeConvertToBindMount(&m) return m.Type == "bind" && m.Source != "" && IsSupportedDevMount(m, vfs2Enabled) } +// MaybeConvertToBindMount converts mount type to "bind" in case any of the +// mount options are either "bind" or "rbind" as required by the OCI spec. +// +// "For bind mounts (when options include either bind or rbind), the type is a +// dummy, often "none" (not listed in /proc/filesystems)." +func MaybeConvertToBindMount(m *specs.Mount) { + if m.Type == "bind" { + return + } + for _, opt := range m.Options { + if opt == "bind" || opt == "rbind" { + m.Type = "bind" + return + } + } +} + // IsSupportedDevMount returns true if m.Destination does not specify a // path that is hardcoded by VFS1's implementation of /dev. func IsSupportedDevMount(m specs.Mount, vfs2Enabled bool) bool { diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 634c15727..afe73a69a 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -252,6 +252,9 @@ ALL_TESTS = [ name = "tcp_syncookie", ), PacketimpactTestInfo( + name = "tcp_connect_icmp_error", + ), + PacketimpactTestInfo( name = "icmpv6_param_problem", ), PacketimpactTestInfo( diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index 616215dc3..d8059ab98 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -16,6 +16,8 @@ go_library( ], visibility = ["//test/packetimpact:__subpackages__"], deps = [ + "//pkg/abi/linux", + "//pkg/binary", "//pkg/hostarch", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index eabdc8cb3..269e163bb 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -22,11 +22,13 @@ import ( "testing" "time" - pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" - "golang.org/x/sys/unix" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "gvisor.dev/gvisor/pkg/abi/linux" + bin "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" + pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" ) // DUT communicates with the DUT to force it to make POSIX calls. @@ -428,6 +430,33 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, so return ret, timeval, errno } +// GetSockOptTCPInfo retreives TCPInfo for the given socket descriptor. +func (dut *DUT) GetSockOptTCPInfo(t *testing.T, sockfd int32) linux.TCPInfo { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + ret, info, err := dut.GetSockOptTCPInfoWithErrno(ctx, t, sockfd) + if ret != 0 || err != unix.Errno(0) { + t.Fatalf("failed to GetSockOptTCPInfo: %s", err) + } + return info +} + +// GetSockOptTCPInfoWithErrno retreives TCPInfo with any errno. +func (dut *DUT) GetSockOptTCPInfoWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, linux.TCPInfo, error) { + t.Helper() + + info := linux.TCPInfo{} + ret, infoBytes, errno := dut.GetSockOptWithErrno(ctx, t, sockfd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { + t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + } + bin.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + + return ret, info, errno +} + // Listen calls listen on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // ListenWithErrno. diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index a73c07e64..37d02365a 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -57,6 +57,11 @@ type DUTUname struct { OperatingSystem string } +// IsLinux returns true if we are running natively on Linux. +func (n *DUTUname) IsLinux() bool { + return Native && n.OperatingSystem == "GNU/Linux" +} + // DUTTestNet describes the test network setup on dut and how the testbench // should connect with an existing DUT. type DUTTestNet struct { diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index e015c1f0e..b1d280f98 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -104,8 +104,6 @@ packetimpact_testbench( srcs = ["tcp_retransmits_test.go"], deps = [ "//pkg/abi/linux", - "//pkg/binary", - "//pkg/hostarch", "//pkg/tcpip/header", "//test/packetimpact/testbench", "@org_golang_x_sys//unix:go_default_library", @@ -189,6 +187,7 @@ packetimpact_testbench( name = "tcp_synsent_reset", srcs = ["tcp_synsent_reset_test.go"], deps = [ + "//pkg/abi/linux", "//pkg/tcpip/header", "//test/packetimpact/testbench", "@org_golang_x_sys//unix:go_default_library", @@ -353,8 +352,6 @@ packetimpact_testbench( srcs = ["tcp_rack_test.go"], deps = [ "//pkg/abi/linux", - "//pkg/binary", - "//pkg/hostarch", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", "//test/packetimpact/testbench", @@ -367,8 +364,6 @@ packetimpact_testbench( srcs = ["tcp_info_test.go"], deps = [ "//pkg/abi/linux", - "//pkg/binary", - "//pkg/hostarch", "//pkg/tcpip/header", "//test/packetimpact/testbench", "@org_golang_x_sys//unix:go_default_library", @@ -405,6 +400,16 @@ packetimpact_testbench( ], ) +packetimpact_testbench( + name = "tcp_connect_icmp_error", + srcs = ["tcp_connect_icmp_error_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + validate_all_tests() [packetimpact_go_test( diff --git a/test/packetimpact/tests/tcp_connect_icmp_error_test.go b/test/packetimpact/tests/tcp_connect_icmp_error_test.go new file mode 100644 index 000000000..79bfe9eb7 --- /dev/null +++ b/test/packetimpact/tests/tcp_connect_icmp_error_test.go @@ -0,0 +1,104 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_connect_icmp_error_test + +import ( + "context" + "flag" + "sync" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +func sendICMPError(t *testing.T, conn *testbench.TCPIPv4, tcp *testbench.TCP) { + t.Helper() + + layers := conn.CreateFrame(t, nil) + layers = layers[:len(layers)-1] + ip, ok := tcp.Prev().(*testbench.IPv4) + if !ok { + t.Fatalf("expected %s to be IPv4", tcp.Prev()) + } + icmpErr := &testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), + Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable)} + + layers = append(layers, icmpErr, ip, tcp) + conn.SendFrameStateless(t, layers) +} + +// TestTCPConnectICMPError tests for the handshake to fail and the socket state +// cleaned up on receiving an ICMP error. +func TestTCPConnectICMPError(t *testing.T) { + dut := testbench.NewDUT(t) + + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, dut.Net.RemoteIPv4) + port := uint16(9001) + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port}) + defer conn.Close(t) + sa := unix.SockaddrInet4{Port: int(port)} + copy(sa.Addr[:], dut.Net.LocalIPv4) + // Bring the dut to SYN-SENT state with a non-blocking connect. + dut.Connect(t, clientFD, &sa) + tcp, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second) + if err != nil { + t.Fatalf("expected SYN, %s", err) + } + + done := make(chan bool) + defer close(done) + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + var block sync.WaitGroup + block.Add(1) + go func() { + defer wg.Done() + _, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + block.Done() + for { + select { + case <-done: + return + default: + if errno := dut.GetSockOptInt(t, clientFD, unix.SOL_SOCKET, unix.SO_ERROR); errno != 0 { + return + } + } + } + }() + block.Wait() + + sendICMPError(t, &conn, tcp) + + dut.PollOne(t, clientFD, unix.POLLHUP, time.Second) + + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + // The DUT should reply with RST to our ACK as the state should have + // transitioned to CLOSED because of handshake error. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected RST, %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_info_test.go b/test/packetimpact/tests/tcp_info_test.go index 93f58ec49..b7514e846 100644 --- a/test/packetimpact/tests/tcp_info_test.go +++ b/test/packetimpact/tests/tcp_info_test.go @@ -21,8 +21,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -53,13 +51,10 @@ func TestTCPInfo(t *testing.T) { } conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) - info := linux.TCPInfo{} - infoBytes := dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + info := dut.GetSockOptTCPInfo(t, acceptFD) + if got, want := uint32(info.State), linux.TCP_ESTABLISHED; got != want { + t.Fatalf("got %d want %d", got, want) } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) - rtt := time.Duration(info.RTT) * time.Microsecond rttvar := time.Duration(info.RTTVar) * time.Microsecond rto := time.Duration(info.RTO) * time.Microsecond @@ -94,12 +89,7 @@ func TestTCPInfo(t *testing.T) { t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) } - info = linux.TCPInfo{} - infoBytes = dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) - } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + info = dut.GetSockOptTCPInfo(t, acceptFD) if info.CaState != linux.TCP_CA_Loss { t.Errorf("expected the connection to be in loss recovery, got: %v want: %v", info.CaState, linux.TCP_CA_Loss) } diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go index ff1431bbf..5a60bf712 100644 --- a/test/packetimpact/tests/tcp_rack_test.go +++ b/test/packetimpact/tests/tcp_rack_test.go @@ -21,8 +21,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/test/packetimpact/testbench" @@ -69,12 +67,7 @@ func closeSACKConnection(t *testing.T, dut testbench.DUT, conn testbench.TCPIPv4 } func getRTTAndRTO(t *testing.T, dut testbench.DUT, acceptFd int32) (rtt, rto time.Duration) { - info := linux.TCPInfo{} - infoBytes := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) - } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + info := dut.GetSockOptTCPInfo(t, acceptFd) return time.Duration(info.RTT) * time.Microsecond, time.Duration(info.RTO) * time.Microsecond } @@ -402,12 +395,7 @@ func TestRACKWithLostRetransmission(t *testing.T) { } // Check the congestion control state. - info := linux.TCPInfo{} - infoBytes := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) - } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + info := dut.GetSockOptTCPInfo(t, acceptFd) if info.CaState != linux.TCP_CA_Recovery { t.Fatalf("expected connection to be in fast recovery, want: %v got: %v", linux.TCP_CA_Recovery, info.CaState) } diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go index 1eafe20c3..d3fb789f4 100644 --- a/test/packetimpact/tests/tcp_retransmits_test.go +++ b/test/packetimpact/tests/tcp_retransmits_test.go @@ -21,9 +21,6 @@ import ( "time" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -33,12 +30,7 @@ func init() { } func getRTO(t *testing.T, dut testbench.DUT, acceptFd int32) (rto time.Duration) { - info := linux.TCPInfo{} - infoBytes := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { - t.Fatalf("unexpected size for TCP_INFO, got %d bytes want %d bytes", got, want) - } - binary.Unmarshal(infoBytes, hostarch.ByteOrder, &info) + info := dut.GetSockOptTCPInfo(t, acceptFd) return time.Duration(info.RTO) * time.Microsecond } diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go index cccb0abc6..fe53e7061 100644 --- a/test/packetimpact/tests/tcp_synsent_reset_test.go +++ b/test/packetimpact/tests/tcp_synsent_reset_test.go @@ -20,6 +20,7 @@ import ( "time" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -29,7 +30,7 @@ func init() { } // dutSynSentState sets up the dut connection in SYN-SENT state. -func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, uint16, uint16) { +func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, int32, uint16, uint16) { t.Helper() dut := testbench.NewDUT(t) @@ -46,26 +47,29 @@ func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, uint16, t.Fatalf("expected SYN\n") } - return &dut, &conn, port, clientPort + return &dut, &conn, clientFD, port, clientPort } // TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition. func TestTCPSynSentReset(t *testing.T) { - _, conn, _, _ := dutSynSentState(t) + dut, conn, fd, _, _ := dutSynSentState(t) defer conn.Close(t) conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}) // Expect the connection to have closed. - // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } + info := dut.GetSockOptTCPInfo(t, fd) + if got, want := uint32(info.State), linux.TCP_CLOSE; got != want { + t.Fatalf("got %d want %d", got, want) + } } // TestTCPSynSentRcvdReset tests RFC793, p70, SYN-SENT to SYN-RCVD to CLOSED // transitions. func TestTCPSynSentRcvdReset(t *testing.T) { - dut, c, remotePort, clientPort := dutSynSentState(t) + dut, c, fd, remotePort, clientPort := dutSynSentState(t) defer c.Close(t) conn := dut.Net.NewTCPIPv4(t, testbench.TCP{SrcPort: &remotePort, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &remotePort}) @@ -79,9 +83,12 @@ func TestTCPSynSentRcvdReset(t *testing.T) { } conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. - // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } + info := dut.GetSockOptTCPInfo(t, fd) + if got, want := uint32(info.State), linux.TCP_CLOSE; got != want { + t.Fatalf("got %d want %d", got, want) + } } diff --git a/test/packetimpact/tests/tcp_zero_receive_window_test.go b/test/packetimpact/tests/tcp_zero_receive_window_test.go index d73495454..bd33a2a03 100644 --- a/test/packetimpact/tests/tcp_zero_receive_window_test.go +++ b/test/packetimpact/tests/tcp_zero_receive_window_test.go @@ -45,37 +45,114 @@ func TestZeroReceiveWindow(t *testing.T) { dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) - samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)} - // Expect the DUT to eventually advertise zero receive window. - // The test would timeout otherwise. - for readOnce := false; ; { - conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) - if err != nil { - t.Fatalf("expected packet was not received: %s", err) - } - // Read once to trigger the subsequent window update from the - // DUT to grow the right edge of the receive window from what - // was advertised in the SYN-ACK. This ensures that we test - // for the full default buffer size (1MB on gVisor at the time - // of writing this comment), thus testing for cases when the - // scaled receive window size ends up > 65535 (0xffff). - if !readOnce { - if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen { - t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen) - } - readOnce = true - } - windowSize := *gotTCP.WindowSize - t.Logf("got window size = %d", windowSize) - if windowSize == 0 { - break - } - } + fillRecvBuffer(t, &conn, &dut, acceptFd, payloadLen) }) } } +func fillRecvBuffer(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, acceptFd int32, payloadLen int) { + // Expect the DUT to eventually advertise zero receive window. + // The test would timeout otherwise. + for readOnce := false; ; { + samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)} + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + // Read once to trigger the subsequent window update from the + // DUT to grow the right edge of the receive window from what + // was advertised in the SYN-ACK. This ensures that we test + // for the full default buffer size (1MB on gVisor at the time + // of writing this comment), thus testing for cases when the + // scaled receive window size ends up > 65535 (0xffff). + if !readOnce { + if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen) + } + readOnce = true + } + windowSize := *gotTCP.WindowSize + t.Logf("got window size = %d", windowSize) + if windowSize == 0 { + break + } + if payloadLen > int(windowSize) { + payloadLen = int(windowSize) + } + } +} + +func TestZeroToNonZeroWindowUpdate(t *testing.T) { + dut := testbench.NewDUT(t) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}) + synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("didn't get synack during handshake: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + + mss := header.ParseSynOptions(synAck.Options, true).MSS + fillRecvBuffer(t, &conn, &dut, acceptFd, int(mss)) + + // Read < mss worth of data from the receive buffer and expect the DUT to + // not send a non-zero window update. + payloadLen := mss - 1 + if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != int(payloadLen) { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen) + } + // Send a zero-window-probe to force an ACK from the receiver with any + // window updates. + conn.Send(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1)), Flags: testbench.TCPFlags(header.TCPFlagAck)}) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + if windowSize := *gotTCP.WindowSize; windowSize != 0 { + t.Fatalf("got non zero window = %d", windowSize) + } + + // Now, ensure that the DUT eventually sends non-zero window update. + seqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1)) + ackNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) + recvCheckWindowUpdate := func(readLen int) uint16 { + if got := dut.Recv(t, acceptFd, int32(readLen), 0); len(got) != readLen { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, readLen, len(got), readLen) + } + conn.Send(t, testbench.TCP{SeqNum: seqNum, Flags: testbench.TCPFlags(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: make([]byte, 1)}) + gotTCP, err := conn.Expect(t, testbench.TCP{AckNum: ackNum, Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + return *gotTCP.WindowSize + } + + if !dut.Uname.IsLinux() { + if win := recvCheckWindowUpdate(1); win == 0 { + t.Fatal("expected non-zero window update") + } + } else { + // Linux stack takes additional socket reads to send out window update, + // its a function of sysctl_tcp_rmem among other things. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_input.c#L687 + for { + if win := recvCheckWindowUpdate(int(payloadLen)); win != 0 { + break + } + } + } +} + // TestNonZeroReceiveWindow tests for the DUT to never send a zero receive // window when the data is being read from the socket buffer. func TestNonZeroReceiveWindow(t *testing.T) { diff --git a/test/perf/BUILD b/test/perf/BUILD index 71982fc4d..75b5003e2 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -1,3 +1,4 @@ +load("//tools:defs.bzl", "more_shards") load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) @@ -37,6 +38,7 @@ syscall_test( syscall_test( size = "large", debug = False, + shard_count = more_shards, tags = ["nogotsan"], test = "//test/perf/linux:getdents_benchmark", ) diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index a74d6b1c1..39e838582 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -308,8 +308,8 @@ func TestCgroup(t *testing.T) { } } -// TestCgroupParent sets the "CgroupParent" option and checks that the child and parent's -// cgroups are created correctly relative to each other. +// TestCgroupParent sets the "CgroupParent" option and checks that the child and +// parent's cgroups are created correctly relative to each other. func TestCgroupParent(t *testing.T) { ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) @@ -343,15 +343,19 @@ func TestCgroupParent(t *testing.T) { // Finds cgroup for the sandbox's parent process to check that cgroup is // created in the right location relative to the parent. cmd := fmt.Sprintf("grep PPid: /proc/%d/status | sed 's/PPid:\\s//'", pid) - ppid, err := exec.Command("bash", "-c", cmd).CombinedOutput() + ppidStr, err := exec.Command("bash", "-c", cmd).CombinedOutput() if err != nil { t.Fatalf("Executing %q: %v", cmd, err) } - cgroups, err := cgroup.LoadPaths(strings.TrimSpace(string(ppid))) + ppid, err := strconv.Atoi(strings.TrimSpace(string(ppidStr))) if err != nil { - t.Fatalf("cgroup.LoadPath(%s): %v", ppid, err) + t.Fatalf("invalid PID (%s): %v", ppidStr, err) } - path := filepath.Join("/sys/fs/cgroup/memory", cgroups["memory"], parent, gid, "cgroup.procs") + cgroups, err := cgroup.NewFromPid(ppid) + if err != nil { + t.Fatalf("cgroup.NewFromPid(%d): %v", ppid, err) + } + path := filepath.Join(cgroups.MakePath("cpuacct"), parent, gid, "cgroup.procs") if err := verifyPid(pid, path); err != nil { t.Errorf("cgroup control %q processes: %v", "memory", err) } diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 2a0ef2cec..416f51935 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -88,6 +88,12 @@ def _syscall_test( tags = list(tags) tags += [full_platform, "file_" + file_access] + # Hash this target into one of 15 buckets. This can be used to + # randomly split targets between different workflows. + hash15 = hash(native.package_name() + name) % 15 + tags.append("hash15:" + str(hash15)) + tags.append("hash15") + # Disable off-host networking. tags.append("requires-net:loopback") tags.append("requires-net:ipv4") diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go index 2ad5f58ef..38e57d62f 100644 --- a/test/runner/gtest/gtest.go +++ b/test/runner/gtest/gtest.go @@ -35,39 +35,6 @@ var ( filterBenchmarkFlag = "--benchmark_filter" ) -// BuildTestArgs builds arguments to be passed to the test binary to execute -// only the test cases in `indices`. -func BuildTestArgs(indices []int, testCases []TestCase) []string { - var testFilter, benchFilter string - for _, tci := range indices { - tc := testCases[tci] - if tc.all { - // No argument will make all tests run. - return nil - } - if tc.benchmark { - if len(benchFilter) > 0 { - benchFilter += "|" - } - benchFilter += "^" + tc.Name + "$" - } else { - if len(testFilter) > 0 { - testFilter += ":" - } - testFilter += tc.FullName() - } - } - - var args []string - if len(testFilter) > 0 { - args = append(args, fmt.Sprintf("%s=%s", filterTestFlag, testFilter)) - } - if len(benchFilter) > 0 { - args = append(args, fmt.Sprintf("%s=%s", filterBenchmarkFlag, benchFilter)) - } - return args -} - // TestCase is a single gtest test case. type TestCase struct { // Suite is the suite for this test. @@ -92,6 +59,22 @@ func (tc TestCase) FullName() string { return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) } +// Args returns arguments to be passed when invoking the test. +func (tc TestCase) Args() []string { + if tc.all { + return []string{} // No arguments. + } + if tc.benchmark { + return []string{ + fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name), + fmt.Sprintf("%s=", filterTestFlag), + } + } + return []string{ + fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()), + } +} + // ParseTestCases calls a gtest test binary to list its test and returns a // slice with the name and suite of each test. // @@ -107,7 +90,6 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes // We failed to list tests with the given flags. Just // return something that will run the binary with no // flags, which should execute all tests. - fmt.Printf("failed to get test list: %v\n", err) return []TestCase{ { Suite: "Default", diff --git a/test/runner/runner.go b/test/runner/runner.go index d314a5036..7e8e88ba2 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -26,6 +26,7 @@ import ( "path/filepath" "strings" "syscall" + "testing" "time" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -56,82 +57,13 @@ var ( leakCheck = flag.Bool("leak-check", false, "check for reference leaks") ) -func main() { - flag.Parse() - if flag.NArg() != 1 { - fatalf("test must be provided") - } - - log.SetLevel(log.Info) - if *debug { - log.SetLevel(log.Debug) - } - - if *platform != "native" && *runscPath == "" { - if err := testutil.ConfigureExePath(); err != nil { - panic(err.Error()) - } - *runscPath = specutils.ExePath - } - - // Make sure stdout and stderr are opened with O_APPEND, otherwise logs - // from outside the sandbox can (and will) stomp on logs from inside - // the sandbox. - for _, f := range []*os.File{os.Stdout, os.Stderr} { - flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0) - if err != nil { - fatalf("error getting file flags for %v: %v", f, err) - } - if flags&unix.O_APPEND == 0 { - flags |= unix.O_APPEND - if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil { - fatalf("error setting file flags for %v: %v", f, err) - } - } - } - - // Resolve the absolute path for the binary. - testBin, err := filepath.Abs(flag.Args()[0]) - if err != nil { - fatalf("Abs(%q) failed: %v", flag.Args()[0], err) - } - - // Get all test cases in each binary. - testCases, err := gtest.ParseTestCases(testBin, true) - if err != nil { - fatalf("ParseTestCases(%q) failed: %v", testBin, err) - } - - // Get subset of tests corresponding to shard. - indices, err := testutil.TestIndicesForShard(len(testCases)) - if err != nil { - fatalf("TestsForShard() failed: %v", err) - } - if len(indices) == 0 { - log.Warningf("No tests to run in this shard") - return - } - args := gtest.BuildTestArgs(indices, testCases) - - switch *platform { - case "native": - if err := runTestCaseNative(testBin, args); err != nil { - fatalf(err.Error()) - } - default: - if err := runTestCaseRunsc(testBin, args); err != nil { - fatalf(err.Error()) - } - } -} - // runTestCaseNative runs the test case directly on the host machine. -func runTestCaseNative(testBin string, args []string) error { +func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { // These tests might be running in parallel, so make sure they have a // unique test temp dir. tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") if err != nil { - return fmt.Errorf("could not create temp dir: %v", err) + t.Fatalf("could not create temp dir: %v", err) } defer os.RemoveAll(tmpDir) @@ -152,12 +84,12 @@ func runTestCaseNative(testBin string, args []string) error { } // Remove shard env variables so that the gunit binary does not try to // interpret them. - env = filterEnv(env, "TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS") + env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) if *addUDSTree { socketDir, cleanup, err := uds.CreateSocketTree("/tmp") if err != nil { - return fmt.Errorf("failed to create socket tree: %v", err) + t.Fatalf("failed to create socket tree: %v", err) } defer cleanup() @@ -167,7 +99,7 @@ func runTestCaseNative(testBin string, args []string) error { env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) } - cmd := exec.Command(testBin, args...) + cmd := exec.Command(testBin, tc.Args()...) cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -183,9 +115,8 @@ func runTestCaseNative(testBin string, args []string) error { if err := cmd.Run(); err != nil { ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) - return fmt.Errorf("test exited with status %d, want 0", ws.ExitStatus()) + t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) } - return nil } // runRunsc runs spec in runsc in a standard test configuration. @@ -193,7 +124,7 @@ func runTestCaseNative(testBin string, args []string) error { // runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR. // // Returns an error if the sandboxed application exits non-zero. -func runRunsc(spec *specs.Spec) error { +func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { bundleDir, cleanup, err := testutil.SetupBundleDir(spec) if err != nil { return fmt.Errorf("SetupBundleDir failed: %v", err) @@ -206,8 +137,9 @@ func runRunsc(spec *specs.Spec) error { } defer cleanup() + name := tc.FullName() id := testutil.RandomContainerID() - log.Infof("Running test in container %q", id) + log.Infof("Running test %q in container %q", name, id) specutils.LogSpec(spec) args := []string{ @@ -243,8 +175,13 @@ func runRunsc(spec *specs.Spec) error { args = append(args, "-ref-leak-mode=log-names") } - testLogDir := os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR") - if len(testLogDir) > 0 { + testLogDir := "" + if undeclaredOutputsDir, ok := unix.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { + // Create log directory dedicated for this test. + testLogDir = filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) + if err := os.MkdirAll(testLogDir, 0755); err != nil { + return fmt.Errorf("could not create test dir: %v", err) + } debugLogDir, err := ioutil.TempDir(testLogDir, "runsc") if err != nil { return fmt.Errorf("could not create temp dir: %v", err) @@ -290,7 +227,7 @@ func runRunsc(spec *specs.Spec) error { if !ok { return } - log.Warningf("Got signal: %v", s) + log.Warningf("%s: Got signal: %v", name, s) done := make(chan bool, 1) dArgs := append([]string{}, args...) dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) @@ -323,7 +260,7 @@ func runRunsc(spec *specs.Spec) error { if err == nil && len(testLogDir) > 0 { // If the test passed, then we erase the log directory. This speeds up // uploading logs in continuous integration & saves on disk space. - _ = os.RemoveAll(testLogDir) + os.RemoveAll(testLogDir) } return err @@ -378,10 +315,10 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { } // runsTestCaseRunsc runs the test case in runsc. -func runTestCaseRunsc(testBin string, args []string) error { +func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Run a new container with the test executable and filter for the // given test suite and name. - spec := testutil.NewSpecWithArgs(append([]string{testBin}, args...)...) + spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...) // Mark the root as writeable, as some tests attempt to // write to the rootfs, and expect EACCES, not EROFS. @@ -407,12 +344,12 @@ func runTestCaseRunsc(testBin string, args []string) error { // users, so make sure it is world-accessible. tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") if err != nil { - return fmt.Errorf("could not create temp dir: %v", err) + t.Fatalf("could not create temp dir: %v", err) } defer os.RemoveAll(tmpDir) if err := os.Chmod(tmpDir, 0777); err != nil { - return fmt.Errorf("could not chmod temp dir: %v", err) + t.Fatalf("could not chmod temp dir: %v", err) } // "/tmp" is not replaced with a tmpfs mount inside the sandbox @@ -432,12 +369,13 @@ func runTestCaseRunsc(testBin string, args []string) error { // Set environment variables that indicate we are running in gVisor with // the given platform, network, and filesystem stack. - env := []string{"TEST_ON_GVISOR=" + *platform, "GVISOR_NETWORK=" + *network} - env = append(env, os.Environ()...) - const vfsVar = "GVISOR_VFS" + platformVar := "TEST_ON_GVISOR" + networkVar := "GVISOR_NETWORK" + env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network) + vfsVar := "GVISOR_VFS" if *vfs2 { env = append(env, vfsVar+"=VFS2") - const fuseVar = "FUSE_ENABLED" + fuseVar := "FUSE_ENABLED" if *fuse { env = append(env, fuseVar+"=TRUE") } else { @@ -449,11 +387,11 @@ func runTestCaseRunsc(testBin string, args []string) error { // Remove shard env variables so that the gunit binary does not try to // interpret them. - env = filterEnv(env, "TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS") + env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to // be backed by tmpfs. - env = filterEnv(env, "TEST_TMPDIR") + env = filterEnv(env, []string{"TEST_TMPDIR"}) env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir)) spec.Process.Env = env @@ -461,19 +399,18 @@ func runTestCaseRunsc(testBin string, args []string) error { if *addUDSTree { cleanup, err := setupUDSTree(spec) if err != nil { - return fmt.Errorf("error creating UDS tree: %v", err) + t.Fatalf("error creating UDS tree: %v", err) } defer cleanup() } - if err := runRunsc(spec); err != nil { - return fmt.Errorf("test failed with error %v, want nil", err) + if err := runRunsc(tc, spec); err != nil { + t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err) } - return nil } // filterEnv returns an environment with the excluded variables removed. -func filterEnv(env []string, exclude ...string) []string { +func filterEnv(env, exclude []string) []string { var out []string for _, kv := range env { ok := true @@ -494,3 +431,82 @@ func fatalf(s string, args ...interface{}) { fmt.Fprintf(os.Stderr, s+"\n", args...) os.Exit(1) } + +func matchString(a, b string) (bool, error) { + return a == b, nil +} + +func main() { + flag.Parse() + if flag.NArg() != 1 { + fatalf("test must be provided") + } + testBin := flag.Args()[0] // Only argument. + + log.SetLevel(log.Info) + if *debug { + log.SetLevel(log.Debug) + } + + if *platform != "native" && *runscPath == "" { + if err := testutil.ConfigureExePath(); err != nil { + panic(err.Error()) + } + *runscPath = specutils.ExePath + } + + // Make sure stdout and stderr are opened with O_APPEND, otherwise logs + // from outside the sandbox can (and will) stomp on logs from inside + // the sandbox. + for _, f := range []*os.File{os.Stdout, os.Stderr} { + flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0) + if err != nil { + fatalf("error getting file flags for %v: %v", f, err) + } + if flags&unix.O_APPEND == 0 { + flags |= unix.O_APPEND + if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil { + fatalf("error setting file flags for %v: %v", f, err) + } + } + } + + // Get all test cases in each binary. + testCases, err := gtest.ParseTestCases(testBin, true) + if err != nil { + fatalf("ParseTestCases(%q) failed: %v", testBin, err) + } + + // Get subset of tests corresponding to shard. + indices, err := testutil.TestIndicesForShard(len(testCases)) + if err != nil { + fatalf("TestsForShard() failed: %v", err) + } + + // Resolve the absolute path for the binary. + testBin, err = filepath.Abs(testBin) + if err != nil { + fatalf("Abs() failed: %v", err) + } + + // Run the tests. + var tests []testing.InternalTest + for _, tci := range indices { + // Capture tc. + tc := testCases[tci] + tests = append(tests, testing.InternalTest{ + Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name), + F: func(t *testing.T) { + if *platform == "native" { + // Run the test case on host. + runTestCaseNative(testBin, tc, t) + } else { + // Run the test case in runsc. + runTestCaseRunsc(testBin, tc, t) + } + }, + }) + } + + testing.Main(matchString, tests, nil, nil) +} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 0435f61a2..85412f54b 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -313,6 +313,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:verity_mmap_test", +) + +syscall_test( add_overlay = True, test = "//test/syscalls/linux:mount_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 94a582256..0582e16ce 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "gtest", "select_arch", "select_system") +load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "gbenchmark", "gtest", "select_arch", "select_system") package( default_visibility = ["//:sandbox"], @@ -520,13 +520,14 @@ cc_binary( srcs = ["concurrency.cc"], linkstatic = 1, deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", + gbenchmark, gtest, "//test/util:platform_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", ], ) @@ -1024,6 +1025,7 @@ cc_binary( "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "//test/util:verity_util", ], ) @@ -1294,6 +1296,23 @@ cc_binary( ) cc_binary( + name = "verity_mmap_test", + testonly = 1, + srcs = ["verity_mmap.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + gtest, + "//test/util:fs_util", + "//test/util:memory_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:verity_util", + ], +) + +cc_binary( name = "mount_test", testonly = 1, srcs = ["mount.cc"], @@ -1471,6 +1490,7 @@ cc_binary( "//test/util:cleanup", "//test/util:posix_error", "//test/util:pty_util", + "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", @@ -1497,7 +1517,8 @@ cc_binary( cc_binary( name = "partial_bad_buffer_test", testonly = 1, - srcs = ["partial_bad_buffer.cc"], + # Android does not support preadv or pwritev in r22. + srcs = select_system(linux = ["partial_bad_buffer.cc"]), linkstatic = 1, deps = [ ":socket_test_util", @@ -1556,6 +1577,7 @@ cc_binary( "@com_google_absl//absl/time", gtest, "//test/util:posix_error", + "//test/util:signal_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -3653,7 +3675,8 @@ cc_binary( cc_binary( name = "sync_test", testonly = 1, - srcs = ["sync.cc"], + # Android does not support syncfs in r22. + srcs = select_system(linux = ["sync.cc"]), linkstatic = 1, deps = [ gtest, @@ -3766,10 +3789,9 @@ cc_binary( srcs = ["timers.cc"], linkstatic = 1, deps = [ - "//test/util:cleanup", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/time", + gbenchmark, gtest, + "//test/util:cleanup", "//test/util:logging", "//test/util:multiprocess_util", "//test/util:posix_error", @@ -3777,6 +3799,8 @@ cc_binary( "//test/util:test_util", "//test/util:thread_util", "//test/util:timer_util", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/time", ], ) @@ -3948,7 +3972,8 @@ cc_binary( cc_binary( name = "utimes_test", testonly = 1, - srcs = ["utimes.cc"], + # Android does not support futimesat in r22. + srcs = select_system(linux = ["utimes.cc"]), linkstatic = 1, deps = [ "//test/util:file_descriptor", @@ -4064,7 +4089,8 @@ cc_binary( cc_binary( name = "semaphore_test", testonly = 1, - srcs = ["semaphore.cc"], + # Android does not support XSI semaphores in r22. + srcs = select_system(linux = ["semaphore.cc"]), linkstatic = 1, deps = [ "//test/util:capability_util", @@ -4247,10 +4273,12 @@ cc_binary( "//test/util:mount_util", "@com_google_absl//absl/strings", gtest, + "//test/util:cleanup", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "//test/util:thread_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ], diff --git a/test/syscalls/linux/cgroup.cc b/test/syscalls/linux/cgroup.cc index 70ad5868f..f29891571 100644 --- a/test/syscalls/linux/cgroup.cc +++ b/test/syscalls/linux/cgroup.cc @@ -25,9 +25,11 @@ #include "absl/strings/str_split.h" #include "test/util/capability_util.h" #include "test/util/cgroup_util.h" +#include "test/util/cleanup.h" #include "test/util/mount_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -192,6 +194,91 @@ TEST(Cgroup, MoptAllMustBeExclusive) { SyscallFailsWithErrno(EINVAL)); } +TEST(Cgroup, MountRace) { + SKIP_IF(!CgroupsAvailable()); + + TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + const DisableSave ds; // Too many syscalls. + + auto mount_thread = [&mountpoint]() { + for (int i = 0; i < 100; ++i) { + mount("none", mountpoint.path().c_str(), "cgroup", 0, 0); + } + }; + std::list<ScopedThread> threads; + for (int i = 0; i < 10; ++i) { + threads.emplace_back(mount_thread); + } + for (auto& t : threads) { + t.Join(); + } + + auto cleanup = Cleanup([&mountpoint] { + // We need 1 umount call per successful mount. If some of the mount calls + // were unsuccessful, their corresponding umount will silently fail. + for (int i = 0; i < (10 * 100) + 1; ++i) { + umount(mountpoint.path().c_str()); + } + }); + + Cgroup c = Cgroup(mountpoint.path()); + // c should be a valid cgroup. + EXPECT_NO_ERRNO(c.ContainsCallingProcess()); +} + +TEST(Cgroup, MountUnmountRace) { + SKIP_IF(!CgroupsAvailable()); + + TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + const DisableSave ds; // Too many syscalls. + + auto mount_thread = [&mountpoint]() { + for (int i = 0; i < 100; ++i) { + mount("none", mountpoint.path().c_str(), "cgroup", 0, 0); + } + }; + auto unmount_thread = [&mountpoint]() { + for (int i = 0; i < 100; ++i) { + umount(mountpoint.path().c_str()); + } + }; + std::list<ScopedThread> threads; + for (int i = 0; i < 10; ++i) { + threads.emplace_back(mount_thread); + } + for (int i = 0; i < 10; ++i) { + threads.emplace_back(unmount_thread); + } + for (auto& t : threads) { + t.Join(); + } + + // We don't know how many mount refs are remaining, since the count depends on + // the ordering of mount and umount calls. Keep calling unmount until it + // returns an error. + while (umount(mountpoint.path().c_str()) == 0) { + } +} + +TEST(Cgroup, UnmountRepeated) { + SKIP_IF(!CgroupsAvailable()); + + const DisableSave ds; // Too many syscalls. + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + + // First unmount should succeed. + EXPECT_THAT(umount(c.Path().c_str()), SyscallSucceeds()); + + // We just manually unmounted, so release managed resources. + m.release(c); + + EXPECT_THAT(umount(c.Path().c_str()), SyscallFailsWithErrno(EINVAL)); +} + TEST(MemoryCgroup, MemoryUsageInBytes) { SKIP_IF(!CgroupsAvailable()); diff --git a/test/syscalls/linux/chdir.cc b/test/syscalls/linux/chdir.cc index 3182c228b..3c64b9eab 100644 --- a/test/syscalls/linux/chdir.cc +++ b/test/syscalls/linux/chdir.cc @@ -41,8 +41,8 @@ TEST(ChdirTest, Success) { TEST(ChdirTest, PermissionDenied) { // Drop capabilities that allow us to override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */)); diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc index 4a5ea84d4..dd82c5fb1 100644 --- a/test/syscalls/linux/chmod.cc +++ b/test/syscalls/linux/chmod.cc @@ -33,7 +33,7 @@ namespace { TEST(ChmodTest, ChmodFileSucceeds) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -43,8 +43,8 @@ TEST(ChmodTest, ChmodFileSucceeds) { TEST(ChmodTest, ChmodDirSucceeds) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const std::string fileInDir = NewTempAbsPathInDir(dir.path()); @@ -55,7 +55,7 @@ TEST(ChmodTest, ChmodDirSucceeds) { TEST(ChmodTest, FchmodFileSucceeds) { // Drop capabilities that allow us to file directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); int fd; @@ -72,8 +72,8 @@ TEST(ChmodTest, FchmodFileSucceeds) { TEST(ChmodTest, FchmodDirSucceeds) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); int fd; @@ -118,7 +118,7 @@ TEST(ChmodTest, FchmodDirWithOpath) { TEST(ChmodTest, FchmodatWithOpath) { SKIP_IF(IsRunningWithVFS1()); // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -140,7 +140,7 @@ TEST(ChmodTest, FchmodatNotDir) { TEST(ChmodTest, FchmodatFileAbsolutePath) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -150,8 +150,8 @@ TEST(ChmodTest, FchmodatFileAbsolutePath) { TEST(ChmodTest, FchmodatDirAbsolutePath) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -167,7 +167,7 @@ TEST(ChmodTest, FchmodatDirAbsolutePath) { TEST(ChmodTest, FchmodatFile) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -188,8 +188,8 @@ TEST(ChmodTest, FchmodatFile) { TEST(ChmodTest, FchmodatDir) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -227,8 +227,8 @@ TEST(ChmodTest, ChmodDowngradeWritability) { TEST(ChmodTest, ChmodFileToNoPermissionsSucceeds) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); @@ -254,8 +254,8 @@ TEST(ChmodTest, FchmodDowngradeWritability) { TEST(ChmodTest, FchmodFileToNoPermissionsSucceeds) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc index ff0d39343..b0c1b6f4a 100644 --- a/test/syscalls/linux/chown.cc +++ b/test/syscalls/linux/chown.cc @@ -91,9 +91,7 @@ using Chown = class ChownParamTest : public ::testing::TestWithParam<Chown> {}; TEST_P(ChownParamTest, ChownFileSucceeds) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_CHOWN))) { - ASSERT_NO_ERRNO(SetCapability(CAP_CHOWN, false)); - } + AutoCapability cap(CAP_CHOWN, false); const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -135,9 +133,7 @@ TEST_P(ChownParamTest, ChownFilePermissionDenied) { // thread won't be able to open some log files after the test ends. ScopedThread([&] { // Drop privileges. - if (HaveCapability(CAP_CHOWN).ValueOrDie()) { - EXPECT_NO_ERRNO(SetCapability(CAP_CHOWN, false)); - } + AutoCapability cap(CAP_CHOWN, false); // Change EUID and EGID. // diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc index 7cd6a75bd..f2daf49ee 100644 --- a/test/syscalls/linux/concurrency.cc +++ b/test/syscalls/linux/concurrency.cc @@ -20,6 +20,7 @@ #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "benchmark/benchmark.h" #include "test/util/platform_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -106,6 +107,8 @@ TEST(ConcurrencyTest, MultiProcessConcurrency) { pid_t child_pid = fork(); if (child_pid == 0) { while (true) { + int x = 0; + benchmark::DoNotOptimize(x); // Don't optimize this loop away. } } ASSERT_THAT(child_pid, SyscallSucceeds()); diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc index b180f633c..3ef8b0327 100644 --- a/test/syscalls/linux/epoll.cc +++ b/test/syscalls/linux/epoll.cc @@ -39,6 +39,15 @@ namespace { constexpr int kFDsPerEpoll = 3; constexpr uint64_t kMagicConstant = 0x0102030405060708; +#ifndef SYS_epoll_pwait2 +#define SYS_epoll_pwait2 441 +#endif + +int epoll_pwait2(int fd, struct epoll_event* events, int maxevents, + const struct timespec* timeout, const sigset_t* sigset) { + return syscall(SYS_epoll_pwait2, fd, events, maxevents, timeout, sigset); +} + TEST(EpollTest, AllWritable) { auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); std::vector<FileDescriptor> eventfds; @@ -144,6 +153,50 @@ TEST(EpollTest, Timeout) { EXPECT_GT(ms_elapsed(begin, end), kTimeoutMs - 1); } +TEST(EpollTest, EpollPwait2Timeout) { + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + // 200 milliseconds. + constexpr int kTimeoutNs = 200000000; + struct timespec timeout; + timeout.tv_sec = 0; + timeout.tv_nsec = 0; + struct timespec begin; + struct timespec end; + struct epoll_event result[kFDsPerEpoll]; + + std::vector<FileDescriptor> eventfds; + for (int i = 0; i < kFDsPerEpoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, + kMagicConstant + i)); + } + + // Pass valid arguments so that the syscall won't be blocked indefinitely + // nor return errno EINVAL. + // + // The syscall returns immediately when timeout is zero, + // even if no events are available. + SKIP_IF(!IsRunningOnGvisor() && + epoll_pwait2(epollfd.get(), result, kFDsPerEpoll, &timeout, nullptr) < + 0 && + errno == ENOSYS); + + { + const DisableSave ds; // Timing-related. + EXPECT_THAT(clock_gettime(CLOCK_MONOTONIC, &begin), SyscallSucceeds()); + + timeout.tv_nsec = kTimeoutNs; + ASSERT_THAT(RetryEINTR(epoll_pwait2)(epollfd.get(), result, kFDsPerEpoll, + &timeout, nullptr), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(clock_gettime(CLOCK_MONOTONIC, &end), SyscallSucceeds()); + } + + // Check the lower bound on the timeout. Checking for an upper bound is + // fragile because Linux can overrun the timeout due to scheduling delays. + EXPECT_GT(ns_elapsed(begin, end), kTimeoutNs - 1); +} + void* writer(void* arg) { int fd = *reinterpret_cast<int*>(arg); uint64_t tmp = 1; @@ -177,6 +230,8 @@ TEST(EpollTest, WaitThenUnblock) { EXPECT_THAT(pthread_detach(thread), SyscallSucceeds()); } +#ifndef ANDROID // Android does not support pthread_cancel + void sighandler(int s) {} void* signaler(void* arg) { @@ -219,6 +274,8 @@ TEST(EpollTest, UnblockWithSignal) { EXPECT_THAT(pthread_detach(thread), SyscallSucceeds()); } +#endif // ANDROID + TEST(EpollTest, TimeoutNoFds) { auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); struct epoll_event result[kFDsPerEpoll]; diff --git a/test/syscalls/linux/exec_state_workload.cc b/test/syscalls/linux/exec_state_workload.cc index 028902b14..eafdc2bfa 100644 --- a/test/syscalls/linux/exec_state_workload.cc +++ b/test/syscalls/linux/exec_state_workload.cc @@ -26,6 +26,8 @@ #include "absl/strings/numbers.h" +#ifndef ANDROID // Conflicts with existing operator<< on Android. + // Pretty-print a sigset_t. std::ostream& operator<<(std::ostream& out, const sigset_t& s) { out << "{ "; @@ -40,6 +42,8 @@ std::ostream& operator<<(std::ostream& out, const sigset_t& s) { return out; } +#endif + // Verify that the signo handler is handler. int CheckSigHandler(uint32_t signo, uintptr_t handler) { struct sigaction sa; diff --git a/test/syscalls/linux/fchdir.cc b/test/syscalls/linux/fchdir.cc index c6675802d..0383f3f85 100644 --- a/test/syscalls/linux/fchdir.cc +++ b/test/syscalls/linux/fchdir.cc @@ -46,8 +46,8 @@ TEST(FchdirTest, InvalidFD) { TEST(FchdirTest, PermissionDenied) { // Drop capabilities that allow us to override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */)); diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 4fa6751ff..91526572b 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -390,9 +390,7 @@ TEST_F(FcntlLockTest, SetLockDir) { } TEST_F(FcntlLockTest, SetLockSymlink) { - // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH - // is supported. - SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(IsRunningWithVFS1()); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); auto symlink = ASSERT_NO_ERRNO_AND_VALUE( diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc index fd387aa45..10dad042f 100644 --- a/test/syscalls/linux/flock.cc +++ b/test/syscalls/linux/flock.cc @@ -662,9 +662,7 @@ TEST(FlockTestNoFixture, FlockDir) { } TEST(FlockTestNoFixture, FlockSymlink) { - // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH - // is supported. - SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(IsRunningWithVFS1()); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); auto symlink = ASSERT_NO_ERRNO_AND_VALUE( diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index 98d07ae85..95082a0f2 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -174,13 +174,21 @@ SocketKind IPv6TCPUnboundSocket(int type) { PosixError IfAddrHelper::Load() { Release(); +#ifndef ANDROID RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_)); +#else + // Android does not support getifaddrs in r22. + return PosixError(ENOSYS, "getifaddrs"); +#endif return NoError(); } void IfAddrHelper::Release() { if (ifaddr_) { +#ifndef ANDROID + // Android does not support freeifaddrs in r22. freeifaddrs(ifaddr_); +#endif ifaddr_ = nullptr; } } diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc index 6ce1e6cc3..d4f89527c 100644 --- a/test/syscalls/linux/lseek.cc +++ b/test/syscalls/linux/lseek.cc @@ -150,7 +150,7 @@ TEST(LseekTest, SeekCurrentDir) { // From include/linux/fs.h. constexpr loff_t MAX_LFS_FILESIZE = 0x7fffffffffffffff; - char* dir = get_current_dir_name(); + char* dir = getcwd(NULL, 0); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir, O_RDONLY)); ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceeds()); diff --git a/test/syscalls/linux/memory_accounting.cc b/test/syscalls/linux/memory_accounting.cc index 94aea4077..867a4513b 100644 --- a/test/syscalls/linux/memory_accounting.cc +++ b/test/syscalls/linux/memory_accounting.cc @@ -83,7 +83,7 @@ TEST(MemoryAccounting, AnonAccountingPreservedOnSaveRestore) { uint64_t anon_after_alloc = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo()); EXPECT_THAT(anon_after_alloc, - EquivalentWithin(anon_initial + map_bytes, 0.03)); + EquivalentWithin(anon_initial + map_bytes, 0.04)); // We have many implicit S/R cycles from scraping /proc/meminfo throughout the // test, but throw an explicit S/R in here as well. @@ -91,7 +91,7 @@ TEST(MemoryAccounting, AnonAccountingPreservedOnSaveRestore) { // Usage should remain the same across S/R. uint64_t anon_after_sr = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo()); - EXPECT_THAT(anon_after_sr, EquivalentWithin(anon_after_alloc, 0.03)); + EXPECT_THAT(anon_after_sr, EquivalentWithin(anon_after_alloc, 0.04)); } } // namespace diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index 11fbfa5c5..36504fe6d 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -72,8 +72,8 @@ TEST_F(MkdirTest, HonorsUmask2) { TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); ASSERT_THAT(mkdir(dirname_.c_str(), 0555), SyscallSucceeds()); auto dir = JoinPath(dirname_.c_str(), "foo"); @@ -84,8 +84,8 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { TEST_F(MkdirTest, DirAlreadyExists) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); auto dir = JoinPath(dirname_.c_str(), "foo"); diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc index 78ac96bed..dfa5b7133 100644 --- a/test/syscalls/linux/mlock.cc +++ b/test/syscalls/linux/mlock.cc @@ -114,9 +114,7 @@ TEST(MlockTest, Fork) { } TEST(MlockTest, RlimitMemlockZero) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0)); auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( @@ -127,9 +125,7 @@ TEST(MlockTest, RlimitMemlockZero) { } TEST(MlockTest, RlimitMemlockInsufficient) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, kPageSize)); auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( @@ -255,9 +251,7 @@ TEST(MapLockedTest, Basic) { } TEST(MapLockedTest, RlimitMemlockZero) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0)); EXPECT_THAT( @@ -266,9 +260,7 @@ TEST(MapLockedTest, RlimitMemlockZero) { } TEST(MapLockedTest, RlimitMemlockInsufficient) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, kPageSize)); EXPECT_THAT( @@ -298,9 +290,7 @@ TEST(MremapLockedTest, RlimitMemlockZero) { MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED)); EXPECT_TRUE(IsPageMlocked(mapping.addr())); - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0)); void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(), @@ -315,9 +305,7 @@ TEST(MremapLockedTest, RlimitMemlockInsufficient) { MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED)); EXPECT_TRUE(IsPageMlocked(mapping.addr())); - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE( ScopedSetSoftRlimit(RLIMIT_MEMLOCK, mapping.len())); void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(), diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index 4697c404c..ab9d19fef 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -433,7 +433,7 @@ TEST_F(OpenTest, CanTruncateReadOnly) { // O_TRUNC should fail. TEST_F(OpenTest, CanTruncateReadOnlyNoWritePermission) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); const DisableSave ds; // Permissions are dropped. ASSERT_THAT(chmod(test_file_name_.c_str(), S_IRUSR | S_IRGRP), @@ -473,8 +473,8 @@ TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission) { } TEST_F(OpenTest, CanTruncateWithStrangePermissions) { - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const DisableSave ds; // Permissions are dropped. std::string path = NewTempAbsPath(); // Create a file without user permissions. @@ -510,8 +510,8 @@ TEST_F(OpenTest, OpenWithStrangeFlags) { TEST_F(OpenTest, OpenWithOpath) { SKIP_IF(IsRunningWithVFS1()); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const DisableSave ds; // Permissions are dropped. std::string path = NewTempAbsPath(); diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 43d446926..177bda54d 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -93,7 +93,8 @@ TEST(CreateTest, CreatFileWithOTruncAndReadOnly) { TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // always override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); + auto parent = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555)); auto file = JoinPath(parent.path(), "foo"); @@ -123,8 +124,8 @@ TEST(CreateTest, ChmodReadToWriteBetweenOpens) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be // cleared for the same reason. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400)); @@ -152,7 +153,7 @@ TEST(CreateTest, ChmodReadToWriteBetweenOpens) { TEST(CreateTest, ChmodWriteToReadBetweenOpens) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // override file read/write permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0200)); @@ -186,8 +187,8 @@ TEST(CreateTest, CreateWithReadFlagNotAllowedByMode) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be // cleared for the same reason. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); // Create and open a file with read flag but without read permissions. const std::string path = NewTempAbsPath(); @@ -212,7 +213,7 @@ TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // override file read/write permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); // Create and open a file with write flag but without write permissions. const std::string path = NewTempAbsPath(); diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 96c454485..294a72468 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -14,6 +14,7 @@ #include <fcntl.h> /* Obtain O_* constant definitions */ #include <linux/magic.h> +#include <signal.h> #include <sys/ioctl.h> #include <sys/statfs.h> #include <sys/uio.h> @@ -29,6 +30,7 @@ #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/posix_error.h" +#include "test/util/signal_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -44,6 +46,28 @@ constexpr int kTestValue = 0x12345678; // Used for synchronization in race tests. const absl::Duration syncDelay = absl::Seconds(2); +std::atomic<int> global_num_signals_received = 0; +void SigRecordingHandler(int signum, siginfo_t* siginfo, + void* unused_ucontext) { + global_num_signals_received++; +} + +PosixErrorOr<Cleanup> RegisterSignalHandler(int signum) { + struct sigaction handler; + handler.sa_sigaction = SigRecordingHandler; + sigemptyset(&handler.sa_mask); + handler.sa_flags = SA_SIGINFO; + return ScopedSigaction(signum, handler); +} + +void WaitForSignalDelivery(absl::Duration timeout, int max_expected) { + absl::Time wait_start = absl::Now(); + while (global_num_signals_received < max_expected && + absl::Now() - wait_start < timeout) { + absl::SleepFor(absl::Milliseconds(10)); + } +} + struct PipeCreator { std::string name_; @@ -267,6 +291,9 @@ TEST_P(PipeTest, Seek) { } } +#ifndef ANDROID +// Android does not support preadv or pwritev in r22. + TEST_P(PipeTest, OffsetCalls) { SKIP_IF(!CreateBlocking()); @@ -283,6 +310,8 @@ TEST_P(PipeTest, OffsetCalls) { EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE)); } +#endif // ANDROID + TEST_P(PipeTest, WriterSideCloses) { SKIP_IF(!CreateBlocking()); @@ -333,10 +362,16 @@ TEST_P(PipeTest, WriterSideClosesReadDataFirst) { TEST_P(PipeTest, ReaderSideCloses) { SKIP_IF(!CreateBlocking()); + const auto signal_cleanup = + ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGPIPE)); + ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); int buf = kTestValue; EXPECT_THAT(write(wfd_.get(), &buf, sizeof(buf)), SyscallFailsWithErrno(EPIPE)); + + WaitForSignalDelivery(absl::Seconds(1), 1); + ASSERT_EQ(global_num_signals_received, 1); } TEST_P(PipeTest, CloseTwice) { @@ -355,6 +390,9 @@ TEST_P(PipeTest, CloseTwice) { TEST_P(PipeTest, BlockWriteClosed) { SKIP_IF(!CreateBlocking()); + const auto signal_cleanup = + ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGPIPE)); + absl::Notification notify; ScopedThread t([this, ¬ify]() { std::vector<char> buf(Size()); @@ -371,6 +409,10 @@ TEST_P(PipeTest, BlockWriteClosed) { notify.WaitForNotification(); ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); + + WaitForSignalDelivery(absl::Seconds(1), 1); + ASSERT_EQ(global_num_signals_received, 1); + t.Join(); } @@ -379,6 +421,9 @@ TEST_P(PipeTest, BlockWriteClosed) { TEST_P(PipeTest, BlockPartialWriteClosed) { SKIP_IF(!CreateBlocking()); + const auto signal_cleanup = + ASSERT_NO_ERRNO_AND_VALUE(RegisterSignalHandler(SIGPIPE)); + ScopedThread t([this]() { const int pipe_size = Size(); std::vector<char> buf(2 * pipe_size); @@ -396,6 +441,10 @@ TEST_P(PipeTest, BlockPartialWriteClosed) { // Unblock the above. ASSERT_THAT(close(rfd_.release()), SyscallSucceeds()); + + WaitForSignalDelivery(absl::Seconds(1), 2); + ASSERT_EQ(global_num_signals_received, 2); + t.Join(); } diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index f675dc430..19a57d353 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -184,10 +184,8 @@ TEST(PrctlTest, PDeathSig) { // This test is to validate that calling prctl with PR_SET_MM without the // CAP_SYS_RESOURCE returns EPERM. TEST(PrctlTest, InvalidPrSetMM) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))) { - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_RESOURCE, - false)); // Drop capability to test below. - } + // Drop capability to test below. + AutoCapability cap(CAP_SYS_RESOURCE, false); ASSERT_THAT(prctl(PR_SET_MM, 0, 0, 0, 0), SyscallFailsWithErrno(EPERM)); } diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 9e48fbca5..24928d876 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1201,6 +1201,15 @@ TEST(ProcSelfCwd, Absolute) { EXPECT_EQ(exe[0], '/'); } +// Sanity check that /proc/cmdline is present. +TEST(ProcCmdline, IsPresent) { + SKIP_IF(IsRunningWithVFS1()); + + std::string proc_cmdline = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/cmdline")); + ASSERT_FALSE(proc_cmdline.empty()); +} + // Sanity check for /proc/cpuinfo fields that must be present. TEST(ProcCpuinfo, RequiredFieldsArePresent) { std::string proc_cpuinfo = @@ -1849,8 +1858,8 @@ TEST(ProcPidSymlink, SubprocessRunning) { } TEST(ProcPidSymlink, SubprocessZombied) { - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); char buf[1]; @@ -2252,7 +2261,7 @@ TEST(ProcTask, VerifyTaskDir) { TEST(ProcTask, TaskDirCannotBeDeleted) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); EXPECT_THAT(rmdir("/proc/self/task"), SyscallFails()); EXPECT_THAT(rmdir(absl::StrCat("/proc/self/task/", getpid()).c_str()), @@ -2698,6 +2707,14 @@ TEST(Proc, Statfs) { EXPECT_EQ(st.f_namelen, NAME_MAX); } +// Tests that /proc/[pid]/fd/[num] can resolve to a path inside /proc. +TEST(Proc, ResolveSymlinkToProc) { + const auto proc = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/cmdline", 0)); + const auto path = JoinPath("/proc/self/fd/", absl::StrCat(proc.get())); + const auto target = ASSERT_NO_ERRNO_AND_VALUE(ReadLink(path)); + EXPECT_EQ(target, JoinPath("/proc/", absl::StrCat(getpid()), "/cmdline")); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index 2d9fec371..d519b65e6 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -175,7 +175,7 @@ TEST(PtraceTest, AttachSameThreadGroup) { TEST(PtraceTest, TraceParentNotAllowed) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) < 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); pid_t const child_pid = fork(); if (child_pid == 0) { @@ -193,7 +193,7 @@ TEST(PtraceTest, TraceParentNotAllowed) { TEST(PtraceTest, TraceNonDescendantNotAllowed) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) < 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); pid_t const tracee_pid = fork(); if (tracee_pid == 0) { @@ -259,7 +259,7 @@ TEST(PtraceTest, TraceNonDescendantWithCapabilityAllowed) { TEST(PtraceTest, TraceDescendantsAllowed) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) > 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use socket pair to communicate tids to this process from its grandchild. int sockets[2]; @@ -346,7 +346,7 @@ TEST(PtraceTest, PrctlSetPtracerInvalidPID) { TEST(PtraceTest, PrctlSetPtracerPID) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -410,7 +410,7 @@ TEST(PtraceTest, PrctlSetPtracerPID) { TEST(PtraceTest, PrctlSetPtracerAny) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -475,7 +475,7 @@ TEST(PtraceTest, PrctlSetPtracerAny) { TEST(PtraceTest, PrctlClearPtracer) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -543,7 +543,7 @@ TEST(PtraceTest, PrctlClearPtracer) { TEST(PtraceTest, PrctlReplacePtracer) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); pid_t const unused_pid = fork(); if (unused_pid == 0) { @@ -633,7 +633,7 @@ TEST(PtraceTest, PrctlReplacePtracer) { // thread group leader is still around. TEST(PtraceTest, PrctlSetPtracerPersistsPastTraceeThreadExit) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -703,7 +703,7 @@ TEST(PtraceTest, PrctlSetPtracerPersistsPastTraceeThreadExit) { // even if the tracee thread is terminated. TEST(PtraceTest, PrctlSetPtracerPersistsPastLeaderExec) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -770,7 +770,7 @@ TEST(PtraceTest, PrctlSetPtracerPersistsPastLeaderExec) { // exec. TEST(PtraceTest, PrctlSetPtracerDoesNotPersistPastNonLeaderExec) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -904,7 +904,7 @@ TEST(PtraceTest, PrctlSetPtracerDoesNotPersistPastTracerThreadExit) { [[noreturn]] void RunPrctlSetPtracerDoesNotPersistPastTracerThreadExit( int tracee_tid, int fd) { - TEST_PCHECK(SetCapability(CAP_SYS_PTRACE, false).ok()); + AutoCapability cap(CAP_SYS_PTRACE, false); ScopedThread t([fd] { pid_t const tracer_tid = gettid(); @@ -1033,7 +1033,7 @@ TEST(PtraceTest, PrctlSetPtracerRespectsTracerThreadID) { // attached. TEST(PtraceTest, PrctlClearPtracerDoesNotAffectCurrentTracer) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -1118,7 +1118,7 @@ TEST(PtraceTest, PrctlClearPtracerDoesNotAffectCurrentTracer) { TEST(PtraceTest, PrctlNotInherited) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Allow any ptracer. This should not affect the child processes. ASSERT_THAT(prctl(PR_SET_PTRACER, PR_SET_PTRACER_ANY), SyscallSucceeds()); @@ -2302,7 +2302,7 @@ TEST(PtraceTest, SetYAMAPtraceScope) { EXPECT_STREQ(buf.data(), "0\n"); // Test that a child can attach to its parent when ptrace_scope is 0. - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); pid_t const child_pid = fork(); if (child_pid == 0) { TEST_PCHECK(CheckPtraceAttach(getppid()) == 0); diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index 8d15c491e..5ff1f12a0 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -40,6 +40,7 @@ #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" #include "test/util/pty_util.h" +#include "test/util/signal_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -387,6 +388,22 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, } TEST(PtyTrunc, Truncate) { + SKIP_IF(IsRunningWithVFS1()); + + // setsid either puts us in a new session or fails because we're already the + // session leader. Either way, this ensures we're the session leader and have + // no controlling terminal. + ASSERT_THAT(setsid(), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EPERM))); + + // Make sure we're ignoring SIGHUP, which will be sent to this process once we + // disconnect the TTY. + struct sigaction sa = {}; + sa.sa_handler = SIG_IGN; + sa.sa_flags = 0; + sigemptyset(&sa.sa_mask); + const Cleanup cleanup = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGHUP, sa)); + // Opening PTYs with O_TRUNC shouldn't cause an error, but calls to // (f)truncate should. FileDescriptor master = @@ -395,6 +412,7 @@ TEST(PtyTrunc, Truncate) { std::string spath = absl::StrCat("/dev/pts/", n); FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); + ASSERT_THAT(ioctl(replica.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); @@ -464,10 +482,10 @@ TEST(BasicPtyTest, OpenSetsControllingTTY) { SKIP_IF(IsRunningWithVFS1()); // setsid either puts us in a new session or fails because we're already the // session leader. Either way, this ensures we're the session leader. - setsid(); + ASSERT_THAT(setsid(), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EPERM))); // Make sure we're ignoring SIGHUP, which will be sent to this process once we - // disconnect they TTY. + // disconnect the TTY. struct sigaction sa = {}; sa.sa_handler = SIG_IGN; sa.sa_flags = 0; @@ -491,7 +509,7 @@ TEST(BasicPtyTest, OpenMasterDoesNotSetsControllingTTY) { SKIP_IF(IsRunningWithVFS1()); // setsid either puts us in a new session or fails because we're already the // session leader. Either way, this ensures we're the session leader. - setsid(); + ASSERT_THAT(setsid(), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EPERM))); FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); // Opening master does not set the controlling TTY, and therefore we are @@ -503,7 +521,7 @@ TEST(BasicPtyTest, OpenNOCTTY) { SKIP_IF(IsRunningWithVFS1()); // setsid either puts us in a new session or fails because we're already the // session leader. Either way, this ensures we're the session leader. - setsid(); + ASSERT_THAT(setsid(), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EPERM))); FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE( OpenReplica(master, O_NOCTTY | O_NONBLOCK | O_RDWR)); @@ -1405,7 +1423,7 @@ TEST_F(JobControlTest, ReleaseTTY) { ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Make sure we're ignoring SIGHUP, which will be sent to this process once we - // disconnect they TTY. + // disconnect the TTY. struct sigaction sa = {}; sa.sa_handler = SIG_IGN; sa.sa_flags = 0; @@ -1526,7 +1544,7 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { EXPECT_THAT(setpgid(diff_pgrp_child, diff_pgrp_child), SyscallSucceeds()); // Make sure we're ignoring SIGHUP, which will be sent to this process once we - // disconnect they TTY. + // disconnect the TTY. struct sigaction sighup_sa = {}; sighup_sa.sa_handler = SIG_IGN; sighup_sa.sa_flags = 0; diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc index 2f25aceb2..8b3d02d97 100644 --- a/test/syscalls/linux/raw_socket_hdrincl.cc +++ b/test/syscalls/linux/raw_socket_hdrincl.cc @@ -177,10 +177,8 @@ TEST_F(RawHDRINCL, ConnectToLoopback) { SyscallSucceeds()); } -TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { - // FIXME(gvisor.dev/issue/3159): Test currently flaky. - SKIP_IF(true); - +// FIXME(gvisor.dev/issue/3159): Test currently flaky. +TEST_F(RawHDRINCL, DISABLED_SendWithoutConnectSucceeds) { struct iphdr hdr = LoopbackHeader(); ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0), SyscallSucceedsWithValue(sizeof(hdr))); diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc index 7056342d7..7756af24d 100644 --- a/test/syscalls/linux/read.cc +++ b/test/syscalls/linux/read.cc @@ -157,7 +157,8 @@ TEST_F(ReadTest, PartialReadSIGSEGV) { .iov_len = size, }, }; - EXPECT_THAT(preadv(fd.get(), iov, ABSL_ARRAYSIZE(iov), 0), + EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds()); + EXPECT_THAT(readv(fd.get(), iov, ABSL_ARRAYSIZE(iov)), SyscallSucceedsWithValue(size)); } diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index b1a813de0..76a8da65f 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -259,8 +259,8 @@ TEST(RenameTest, DirectoryDoesNotOverwriteNonemptyDirectory) { TEST(RenameTest, FailsWhenOldParentNotWritable) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); @@ -275,8 +275,8 @@ TEST(RenameTest, FailsWhenOldParentNotWritable) { TEST(RenameTest, FailsWhenNewParentNotWritable) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); @@ -293,8 +293,8 @@ TEST(RenameTest, FailsWhenNewParentNotWritable) { // to overwrite. TEST(RenameTest, OverwriteFailsWhenNewParentNotWritable) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); @@ -312,8 +312,8 @@ TEST(RenameTest, OverwriteFailsWhenNewParentNotWritable) { // because the user cannot determine if source exists. TEST(RenameTest, FileDoesNotExistWhenNewParentNotExecutable) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); // No execute permission. auto dir = ASSERT_NO_ERRNO_AND_VALUE( diff --git a/test/syscalls/linux/rlimits.cc b/test/syscalls/linux/rlimits.cc index 860f0f688..d31a2a880 100644 --- a/test/syscalls/linux/rlimits.cc +++ b/test/syscalls/linux/rlimits.cc @@ -41,9 +41,7 @@ TEST(RlimitTest, SetRlimitHigher) { TEST(RlimitTest, UnprivilegedSetRlimit) { // Drop privileges if necessary. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_RESOURCE, false)); - } + AutoCapability cap(CAP_SYS_RESOURCE, false); struct rlimit rl = {}; rl.rlim_cur = 1000; diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index 207377efb..2ce8f836c 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -535,7 +535,7 @@ TEST(SemaphoreTest, SemCtlGetPidFork) { TEST(SemaphoreTest, SemIpcSet) { // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); @@ -560,7 +560,7 @@ TEST(SemaphoreTest, SemIpcSet) { TEST(SemaphoreTest, SemCtlIpcStat) { // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); const uid_t kUid = getuid(); const gid_t kGid = getgid(); time_t start_time = time(nullptr); @@ -635,7 +635,7 @@ PosixErrorOr<int> WaitSemctl(int semid, int target, int cmd) { TEST(SemaphoreTest, SemopGetzcnt) { // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); // Create a write only semaphore set. AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); @@ -743,7 +743,7 @@ TEST(SemaphoreTest, SemopGetzcntOnSignal) { TEST(SemaphoreTest, SemopGetncnt) { // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); // Create a write only semaphore set. AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); @@ -853,7 +853,7 @@ TEST(SemaphoreTest, IpcInfo) { std::set<int> sem_ids; struct seminfo info; // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); for (int i = 0; i < kLoops; i++) { AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); @@ -923,7 +923,7 @@ TEST(SemaphoreTest, SemInfo) { std::set<int> sem_ids; struct seminfo info; // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); for (int i = 0; i < kLoops; i++) { AutoSem sem(semget(IPC_PRIVATE, kSemSetSize, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc index f4ee775bd..ce5f63938 100644 --- a/test/syscalls/linux/socket_bind_to_device_util.cc +++ b/test/syscalls/linux/socket_bind_to_device_util.cc @@ -58,8 +58,10 @@ PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) { } std::unordered_set<string> GetInterfaceNames() { - struct if_nameindex* interfaces = if_nameindex(); std::unordered_set<string> names; +#ifndef ANDROID + // Android does not support if_nameindex in r22. + struct if_nameindex* interfaces = if_nameindex(); if (interfaces == nullptr) { return names; } @@ -68,6 +70,7 @@ std::unordered_set<string> GetInterfaceNames() { names.insert(interface->if_name); } if_freenameindex(interfaces); +#endif return names; } diff --git a/test/syscalls/linux/socket_capability.cc b/test/syscalls/linux/socket_capability.cc index 84b5b2b21..f75482aba 100644 --- a/test/syscalls/linux/socket_capability.cc +++ b/test/syscalls/linux/socket_capability.cc @@ -40,7 +40,7 @@ TEST(SocketTest, UnixConnectNeedsWritePerm) { // Drop capabilites that allow us to override permision checks. Otherwise if // the test is run as root, the connect below will bypass permission checks // and succeed unexpectedly. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); // Connect should fail without write perms. ASSERT_THAT(chmod(addr.sun_path, 0500), SyscallSucceeds()); diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 9a6b089f6..f99d6f1c7 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -472,6 +472,77 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } } +// Test the protocol state information returned by TCPINFO. +TEST_P(SocketInetLoopbackTest, TCPInfoState) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + FileDescriptor const listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + + auto state = [](int fd) -> int { + struct tcp_info opt = {}; + socklen_t optLen = sizeof(opt); + EXPECT_THAT(getsockopt(fd, SOL_TCP, TCP_INFO, &opt, &optLen), + SyscallSucceeds()); + return opt.tcpi_state; + }; + ASSERT_EQ(state(listen_fd.get()), TCP_CLOSE); + + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_EQ(state(listen_fd.get()), TCP_CLOSE); + + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + ASSERT_EQ(state(listen_fd.get()), TCP_LISTEN); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_EQ(state(conn_fd.get()), TCP_CLOSE); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + ASSERT_EQ(state(conn_fd.get()), TCP_ESTABLISHED); + + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + ASSERT_EQ(state(accepted.get()), TCP_ESTABLISHED); + + ASSERT_THAT(close(accepted.release()), SyscallSucceeds()); + + struct pollfd pfd = { + .fd = conn_fd.get(), + .events = POLLIN | POLLRDHUP, + }; + constexpr int kTimeout = 10000; + int n = poll(&pfd, 1, kTimeout); + ASSERT_GE(n, 0) << strerror(errno); + ASSERT_EQ(n, 1); + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/6015): Notify POLLRDHUP on incoming FIN. + ASSERT_EQ(pfd.revents, POLLIN); + } else { + ASSERT_EQ(pfd.revents, POLLIN | POLLRDHUP); + } + + ASSERT_THAT(state(conn_fd.get()), TCP_CLOSE_WAIT); + ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds()); +} + void TestHangupDuringConnect(const TestParam& param, void (*hangup)(FileDescriptor&)) { TestAddress const& listener = param.listener; diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 59b56dc1a..2f5743cda 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -1155,7 +1155,7 @@ TEST_P(TCPSocketPairTest, IpMulticastLoopDefault) { TEST_P(TCPSocketPairTest, TCPResetDuringClose) { DisableSave ds; // Too many syscalls. - constexpr int kThreadCount = 1000; + constexpr int kThreadCount = 100; std::unique_ptr<ScopedThread> instances[kThreadCount]; for (int i = 0; i < kThreadCount; i++) { instances[i] = absl::make_unique<ScopedThread>([&]() { diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc index 8390f7c3b..09f070797 100644 --- a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc @@ -38,8 +38,8 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) { ipv6_mreq group_req = { .ipv6mr_multiaddr = reinterpret_cast<sockaddr_in6*>(&multicast_addr.addr)->sin6_addr, - .ipv6mr_interface = - (unsigned int)ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")), + .ipv6mr_interface = static_cast<decltype(ipv6_mreq::ipv6mr_interface)>( + ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"))), }; ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, &group_req, sizeof(group_req)), diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc index 4afed6d08..5a2841899 100644 --- a/test/syscalls/linux/sticky.cc +++ b/test/syscalls/linux/sticky.cc @@ -56,9 +56,7 @@ TEST(StickyTest, StickyBitPermDenied) { // thread won't be able to open some log files after the test ends. ScopedThread([&] { // Drop privileges. - if (HaveCapability(CAP_FOWNER).ValueOrDie()) { - EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, false)); - } + AutoCapability cap(CAP_FOWNER, false); // Change EUID and EGID. EXPECT_THAT( @@ -98,9 +96,7 @@ TEST(StickyTest, StickyBitSameUID) { // thread won't be able to open some log files after the test ends. ScopedThread([&] { // Drop privileges. - if (HaveCapability(CAP_FOWNER).ValueOrDie()) { - EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, false)); - } + AutoCapability cap(CAP_FOWNER, false); // Change EGID. EXPECT_THAT( diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index 9f6c59446..fa6849f11 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -100,8 +100,8 @@ TEST(SymlinkTest, CanCreateSymlinkDir) { TEST(SymlinkTest, CannotCreateSymlinkInReadOnlyDir) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const std::string olddir = NewTempAbsPath(); ASSERT_THAT(mkdir(olddir.c_str(), 0444), SyscallSucceeds()); @@ -250,8 +250,8 @@ TEST(SymlinkTest, PwriteToSymlink) { TEST(SymlinkTest, SymlinkAtDegradedPermissions) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); @@ -301,8 +301,8 @@ TEST(SymlinkTest, ReadlinkAtDirWithOpath) { TEST(SymlinkTest, ReadlinkAtDegradedPermissions) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const std::string oldpath = NewTempAbsPathInDir(dir.path()); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 011b60f0e..5bfdecc79 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1164,7 +1164,13 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSend) { ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); - std::vector<char> writebuf(512 << 10); // 512 KiB. + // Ensure the write buffer is large enough not to block on a single write. + size_t write_size = 512 << 10; // 512 KiB. + EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_SNDBUF, &write_size, + sizeof(write_size)), + SyscallSucceedsWithValue(0)); + + std::vector<char> writebuf(write_size); // Try to send the whole thing. int n; diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index 93a98adb1..bc12dd4af 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -26,6 +26,7 @@ #include "absl/flags/flag.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "benchmark/benchmark.h" #include "test/util/cleanup.h" #include "test/util/logging.h" #include "test/util/multiprocess_util.h" @@ -92,6 +93,8 @@ TEST(TimerTest, ProcessKilledOnCPUSoftLimit) { TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0); MaybeSave(); for (;;) { + int x = 0; + benchmark::DoNotOptimize(x); // Don't optimize this loop away. } } ASSERT_THAT(pid, SyscallSucceeds()); @@ -151,6 +154,8 @@ TEST(TimerTest, ProcessPingedRepeatedlyAfterCPUSoftLimit) { TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0); MaybeSave(); for (;;) { + int x = 0; + benchmark::DoNotOptimize(x); // Don't optimize this loop away. } } ASSERT_THAT(pid, SyscallSucceeds()); @@ -197,6 +202,8 @@ TEST(TimerTest, ProcessKilledOnCPUHardLimit) { TEST_PCHECK(setrlimit(RLIMIT_CPU, &cpu_limits) == 0); MaybeSave(); for (;;) { + int x = 0; + benchmark::DoNotOptimize(x); // Don't optimize this loop away. } } ASSERT_THAT(pid, SyscallSucceeds()); diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc index 5db0b8276..0f08d9996 100644 --- a/test/syscalls/linux/truncate.cc +++ b/test/syscalls/linux/truncate.cc @@ -181,7 +181,7 @@ TEST(TruncateTest, FtruncateDir) { TEST(TruncateTest, TruncateNonWriteable) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // always override write permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::string_view(), 0555 /* mode */)); EXPECT_THAT(truncate(temp_file.path().c_str(), 0), @@ -210,7 +210,7 @@ TEST(TruncateTest, FtruncateWithOpath) { // regardless of whether the file permissions allow writing. TEST(TruncateTest, FtruncateWithoutWritePermission) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); // The only time we can open a file with flags forbidden by its permissions // is when we are creating the file. We cannot re-open with the same flags, diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc index 6e3a00d2c..279fe342c 100644 --- a/test/syscalls/linux/tuntap.cc +++ b/test/syscalls/linux/tuntap.cc @@ -170,10 +170,10 @@ TEST(TuntapStaticTest, NetTunExists) { class TuntapTest : public ::testing::Test { protected: void SetUp() override { - have_net_admin_cap_ = + const bool have_net_admin_cap = ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)); - if (have_net_admin_cap_ && !IsRunningOnGvisor()) { + if (have_net_admin_cap && !IsRunningOnGvisor()) { // gVisor always creates enabled/up'd interfaces, while Linux does not (as // observed in b/110961832). Some of the tests require the Linux stack to // notify the socket of any link-address-resolution failures. Those @@ -183,21 +183,12 @@ class TuntapTest : public ::testing::Test { ASSERT_NO_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP)); } } - - void TearDown() override { - if (have_net_admin_cap_) { - // Bring back capability if we had dropped it in test case. - ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true)); - } - } - - bool have_net_admin_cap_; }; TEST_F(TuntapTest, CreateInterfaceNoCap) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false)); + AutoCapability cap(CAP_NET_ADMIN, false); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); diff --git a/test/syscalls/linux/uname.cc b/test/syscalls/linux/uname.cc index d8824b171..759ea4f53 100644 --- a/test/syscalls/linux/uname.cc +++ b/test/syscalls/linux/uname.cc @@ -76,9 +76,7 @@ TEST(UnameTest, SetNames) { } TEST(UnameTest, UnprivilegedSetNames) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); - } + AutoCapability cap(CAP_SYS_ADMIN, false); EXPECT_THAT(sethostname("", 0), SyscallFailsWithErrno(EPERM)); EXPECT_THAT(setdomainname("", 0), SyscallFailsWithErrno(EPERM)); diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc index 7c301c305..75dcf4465 100644 --- a/test/syscalls/linux/unlink.cc +++ b/test/syscalls/linux/unlink.cc @@ -66,8 +66,8 @@ TEST(UnlinkTest, AtDir) { TEST(UnlinkTest, AtDirDegradedPermissions) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -86,8 +86,8 @@ TEST(UnlinkTest, AtDirDegradedPermissions) { // Files cannot be unlinked if the parent is not writable and executable. TEST(UnlinkTest, ParentDegradedPermissions) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc index e647d2896..e711d6657 100644 --- a/test/syscalls/linux/utimes.cc +++ b/test/syscalls/linux/utimes.cc @@ -225,7 +225,8 @@ void TestUtimensat(int dirFd, std::string const& path) { EXPECT_GE(mtime3, before); EXPECT_LE(mtime3, after); - EXPECT_EQ(atime3, mtime3); + // TODO(b/187074006): atime/mtime may differ with local_gofer_uncached. + // EXPECT_EQ(atime3, mtime3); } TEST(UtimensatTest, OnAbsPath) { diff --git a/test/syscalls/linux/verity_ioctl.cc b/test/syscalls/linux/verity_ioctl.cc index 822e16f3c..be91b23d0 100644 --- a/test/syscalls/linux/verity_ioctl.cc +++ b/test/syscalls/linux/verity_ioctl.cc @@ -28,40 +28,13 @@ #include "test/util/mount_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" +#include "test/util/verity_util.h" namespace gvisor { namespace testing { namespace { -#ifndef FS_IOC_ENABLE_VERITY -#define FS_IOC_ENABLE_VERITY 1082156677 -#endif - -#ifndef FS_IOC_MEASURE_VERITY -#define FS_IOC_MEASURE_VERITY 3221513862 -#endif - -#ifndef FS_VERITY_FL -#define FS_VERITY_FL 1048576 -#endif - -#ifndef FS_IOC_GETFLAGS -#define FS_IOC_GETFLAGS 2148034049 -#endif - -struct fsverity_digest { - __u16 digest_algorithm; - __u16 digest_size; /* input/output */ - __u8 digest[]; -}; - -constexpr int kMaxDigestSize = 64; -constexpr int kDefaultDigestSize = 32; -constexpr char kContents[] = "foobarbaz"; -constexpr char kMerklePrefix[] = ".merkle.verity."; -constexpr char kMerkleRootPrefix[] = ".merkleroot.verity."; - class IoctlTest : public ::testing::Test { protected: void SetUp() override { @@ -85,80 +58,6 @@ class IoctlTest : public ::testing::Test { std::string filename_; }; -// Provide a function to convert bytes to hex string, since -// absl::BytesToHexString does not seem to be compatible with golang -// hex.DecodeString used in verity due to zero-padding. -std::string BytesToHexString(uint8_t bytes[], int size) { - std::stringstream ss; - ss << std::hex; - for (int i = 0; i < size; ++i) { - ss << std::setw(2) << std::setfill('0') << static_cast<int>(bytes[i]); - } - return ss.str(); -} - -std::string MerklePath(absl::string_view path) { - return JoinPath(Dirname(path), - std::string(kMerklePrefix) + std::string(Basename(path))); -} - -std::string MerkleRootPath(absl::string_view path) { - return JoinPath(Dirname(path), - std::string(kMerkleRootPrefix) + std::string(Basename(path))); -} - -// Flip a random bit in the file represented by fd. -PosixError FlipRandomBit(int fd, int size) { - // Generate a random offset in the file. - srand(time(nullptr)); - unsigned int seed = 0; - int random_offset = rand_r(&seed) % size; - - // Read a random byte and flip a bit in it. - char buf[1]; - RETURN_ERROR_IF_SYSCALL_FAIL(PreadFd(fd, buf, 1, random_offset)); - buf[0] ^= 1; - RETURN_ERROR_IF_SYSCALL_FAIL(PwriteFd(fd, buf, 1, random_offset)); - return NoError(); -} - -// Mount a verity on the tmpfs and enable both the file and the direcotry. Then -// mount a new verity with measured root hash. -PosixErrorOr<std::string> MountVerity(std::string tmpfs_dir, - std::string filename) { - // Mount a verity fs on the existing tmpfs mount. - std::string mount_opts = "lower_path=" + tmpfs_dir; - ASSIGN_OR_RETURN_ERRNO(TempPath verity_dir, TempPath::CreateDir()); - RETURN_ERROR_IF_SYSCALL_FAIL( - mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str())); - - // Enable both the file and the directory. - ASSIGN_OR_RETURN_ERRNO( - auto fd, Open(JoinPath(verity_dir.path(), filename), O_RDONLY, 0777)); - RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd.get(), FS_IOC_ENABLE_VERITY)); - ASSIGN_OR_RETURN_ERRNO(auto dir_fd, Open(verity_dir.path(), O_RDONLY, 0777)); - RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(dir_fd.get(), FS_IOC_ENABLE_VERITY)); - - // Measure the root hash. - uint8_t digest_array[sizeof(struct fsverity_digest) + kMaxDigestSize] = {0}; - struct fsverity_digest* digest = - reinterpret_cast<struct fsverity_digest*>(digest_array); - digest->digest_size = kMaxDigestSize; - RETURN_ERROR_IF_SYSCALL_FAIL( - ioctl(dir_fd.get(), FS_IOC_MEASURE_VERITY, digest)); - - // Mount a verity fs with specified root hash. - mount_opts += - ",root_hash=" + BytesToHexString(digest->digest, digest->digest_size); - ASSIGN_OR_RETURN_ERRNO(TempPath verity_with_hash_dir, TempPath::CreateDir()); - RETURN_ERROR_IF_SYSCALL_FAIL(mount("", verity_with_hash_dir.path().c_str(), - "verity", 0, mount_opts.c_str())); - // Verity directories should not be deleted. Release the TempPath objects to - // prevent those directories from being deleted by the destructor. - verity_dir.release(); - return verity_with_hash_dir.release(); -} - TEST_F(IoctlTest, Enable) { // Mount a verity fs on the existing tmpfs mount. std::string mount_opts = "lower_path=" + tmpfs_dir_.path(); diff --git a/test/syscalls/linux/verity_mmap.cc b/test/syscalls/linux/verity_mmap.cc new file mode 100644 index 000000000..dde74cc91 --- /dev/null +++ b/test/syscalls/linux/verity_mmap.cc @@ -0,0 +1,158 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdint.h> +#include <stdlib.h> +#include <sys/mman.h> +#include <sys/mount.h> + +#include <string> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/fs_util.h" +#include "test/util/memory_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" +#include "test/util/verity_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class MmapTest : public ::testing::Test { + protected: + void SetUp() override { + // Verity is implemented in VFS2. + SKIP_IF(IsRunningWithVFS1()); + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + // Mount a tmpfs file system, to be wrapped by a verity fs. + tmpfs_dir_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(mount("", tmpfs_dir_.path().c_str(), "tmpfs", 0, ""), + SyscallSucceeds()); + + // Create a new file in the tmpfs mount. + file_ = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(tmpfs_dir_.path(), kContents, 0777)); + filename_ = Basename(file_.path()); + } + + TempPath tmpfs_dir_; + TempPath file_; + std::string filename_; +}; + +TEST_F(MmapTest, MmapRead) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + // Make sure the file can be open and mmapped in the mounted verity fs. + auto const verity_fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir, filename_), O_RDONLY, 0777)); + + Mapping const m = + ASSERT_NO_ERRNO_AND_VALUE(Mmap(nullptr, sizeof(kContents) - 1, PROT_READ, + MAP_SHARED, verity_fd.get(), 0)); + EXPECT_THAT(std::string(m.view()), ::testing::StrEq(kContents)); +} + +TEST_F(MmapTest, ModifiedBeforeMmap) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + // Modify the file and check verification failure upon mmapping. + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(tmpfs_dir_.path(), filename_), O_RDWR, 0777)); + ASSERT_NO_ERRNO(FlipRandomBit(fd.get(), sizeof(kContents) - 1)); + + auto const verity_fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir, filename_), O_RDONLY, 0777)); + Mapping const m = + ASSERT_NO_ERRNO_AND_VALUE(Mmap(nullptr, sizeof(kContents) - 1, PROT_READ, + MAP_SHARED, verity_fd.get(), 0)); + + // Memory fault is expected when Translate fails. + EXPECT_EXIT(std::string(m.view()), ::testing::KilledBySignal(SIGSEGV), ""); +} + +TEST_F(MmapTest, ModifiedAfterMmap) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + auto const verity_fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir, filename_), O_RDONLY, 0777)); + Mapping const m = + ASSERT_NO_ERRNO_AND_VALUE(Mmap(nullptr, sizeof(kContents) - 1, PROT_READ, + MAP_SHARED, verity_fd.get(), 0)); + + // Modify the file after mapping and check verification failure upon mmapping. + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(tmpfs_dir_.path(), filename_), O_RDWR, 0777)); + ASSERT_NO_ERRNO(FlipRandomBit(fd.get(), sizeof(kContents) - 1)); + + // Memory fault is expected when Translate fails. + EXPECT_EXIT(std::string(m.view()), ::testing::KilledBySignal(SIGSEGV), ""); +} + +class MmapParamTest + : public MmapTest, + public ::testing::WithParamInterface<std::tuple<int, int>> { + protected: + int prot() const { return std::get<0>(GetParam()); } + int flags() const { return std::get<1>(GetParam()); } +}; + +INSTANTIATE_TEST_SUITE_P( + WriteExecNoneSharedPrivate, MmapParamTest, + ::testing::Combine(::testing::ValuesIn({ + PROT_WRITE, + PROT_EXEC, + PROT_NONE, + }), + ::testing::ValuesIn({MAP_SHARED, MAP_PRIVATE}))); + +TEST_P(MmapParamTest, Mmap) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + // Make sure the file can be open and mmapped in the mounted verity fs. + auto const verity_fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir, filename_), O_RDONLY, 0777)); + + if (prot() == PROT_WRITE && flags() == MAP_SHARED) { + // Verity file system is read-only. + EXPECT_THAT( + reinterpret_cast<intptr_t>(mmap(nullptr, sizeof(kContents) - 1, prot(), + flags(), verity_fd.get(), 0)), + SyscallFailsWithErrno(EACCES)); + } else { + Mapping const m = ASSERT_NO_ERRNO_AND_VALUE(Mmap( + nullptr, sizeof(kContents) - 1, prot(), flags(), verity_fd.get(), 0)); + if (prot() == PROT_NONE) { + // Memory mapped by MAP_NONE cannot be accessed. + EXPECT_EXIT(std::string(m.view()), ::testing::KilledBySignal(SIGSEGV), + ""); + } else { + EXPECT_THAT(std::string(m.view()), ::testing::StrEq(kContents)); + } + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index dd8067807..c8a97df6b 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -109,8 +109,8 @@ TEST_F(XattrTest, XattrInvalidPrefix) { // the restore will fail to open it with r/w permissions. TEST_F(XattrTest, XattrReadOnly) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const char* path = test_file_name_.c_str(); const char name[] = "user.test"; @@ -140,8 +140,8 @@ TEST_F(XattrTest, XattrReadOnly) { // the restore will fail to open it with r/w permissions. TEST_F(XattrTest, XattrWriteOnly) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); DisableSave ds; ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IWUSR)); @@ -632,7 +632,7 @@ TEST_F(XattrTest, TrustedNamespaceWithCapSysAdmin) { // Trusted namespace not supported in VFS1. SKIP_IF(IsRunningWithVFS1()); - // TODO(b/66162845): Only gVisor tmpfs currently supports trusted namespace. + // TODO(b/166162845): Only gVisor tmpfs currently supports trusted namespace. SKIP_IF(IsRunningOnGvisor() && !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); @@ -680,9 +680,7 @@ TEST_F(XattrTest, TrustedNamespaceWithoutCapSysAdmin) { !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); // Drop CAP_SYS_ADMIN if we have it. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); - } + AutoCapability cap(CAP_SYS_ADMIN, false); const char* path = test_file_name_.c_str(); const char name[] = "trusted.test"; diff --git a/test/util/BUILD b/test/util/BUILD index 8985b54af..cc83221ea 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -401,3 +401,16 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "verity_util", + testonly = 1, + srcs = ["verity_util.cc"], + hdrs = ["verity_util.h"], + deps = [ + ":fs_util", + ":mount_util", + ":posix_error", + ":temp_path", + ], +) diff --git a/test/util/cgroup_util.cc b/test/util/cgroup_util.cc index 04d4f8de0..977993f41 100644 --- a/test/util/cgroup_util.cc +++ b/test/util/cgroup_util.cc @@ -142,6 +142,20 @@ PosixError Mounter::Unmount(const Cgroup& c) { return NoError(); } +void Mounter::release(const Cgroup& c) { + auto mp = mountpoints_.find(c.id()); + if (mp != mountpoints_.end()) { + mp->second.release(); + mountpoints_.erase(mp); + } + + auto m = mounts_.find(c.id()); + if (m != mounts_.end()) { + m->second.Release(); + mounts_.erase(m); + } +} + constexpr char kProcCgroupsHeader[] = "#subsys_name\thierarchy\tnum_cgroups\tenabled"; diff --git a/test/util/cgroup_util.h b/test/util/cgroup_util.h index b797a8b24..e3f696a89 100644 --- a/test/util/cgroup_util.h +++ b/test/util/cgroup_util.h @@ -83,6 +83,8 @@ class Mounter { PosixError Unmount(const Cgroup& c); + void release(const Cgroup& c); + private: // The destruction order of these members avoids errors during cleanup. We // first unmount (by executing the mounts_ cleanups), then delete the diff --git a/test/util/test_util.h b/test/util/test_util.h index 876ff58db..bcbb388ed 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -272,10 +272,15 @@ PosixErrorOr<std::vector<OpenFd>> GetOpenFDs(); // Returns the number of hard links to a path. PosixErrorOr<uint64_t> Links(const std::string& path); +inline uint64_t ns_elapsed(const struct timespec& begin, + const struct timespec& end) { + return (end.tv_sec - begin.tv_sec) * 1000000000 + + (end.tv_nsec - begin.tv_nsec); +} + inline uint64_t ms_elapsed(const struct timespec& begin, const struct timespec& end) { - return (end.tv_sec - begin.tv_sec) * 1000 + - (end.tv_nsec - begin.tv_nsec) / 1000000; + return ns_elapsed(begin, end) / 1000000; } namespace internal { diff --git a/test/util/verity_util.cc b/test/util/verity_util.cc new file mode 100644 index 000000000..f1b4c251b --- /dev/null +++ b/test/util/verity_util.cc @@ -0,0 +1,93 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/verity_util.h" + +#include "test/util/fs_util.h" +#include "test/util/mount_util.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +std::string BytesToHexString(uint8_t bytes[], int size) { + std::stringstream ss; + ss << std::hex; + for (int i = 0; i < size; ++i) { + ss << std::setw(2) << std::setfill('0') << static_cast<int>(bytes[i]); + } + return ss.str(); +} + +std::string MerklePath(absl::string_view path) { + return JoinPath(Dirname(path), + std::string(kMerklePrefix) + std::string(Basename(path))); +} + +std::string MerkleRootPath(absl::string_view path) { + return JoinPath(Dirname(path), + std::string(kMerkleRootPrefix) + std::string(Basename(path))); +} + +PosixError FlipRandomBit(int fd, int size) { + // Generate a random offset in the file. + srand(time(nullptr)); + unsigned int seed = 0; + int random_offset = rand_r(&seed) % size; + + // Read a random byte and flip a bit in it. + char buf[1]; + RETURN_ERROR_IF_SYSCALL_FAIL(PreadFd(fd, buf, 1, random_offset)); + buf[0] ^= 1; + RETURN_ERROR_IF_SYSCALL_FAIL(PwriteFd(fd, buf, 1, random_offset)); + return NoError(); +} + +PosixErrorOr<std::string> MountVerity(std::string tmpfs_dir, + std::string filename) { + // Mount a verity fs on the existing tmpfs mount. + std::string mount_opts = "lower_path=" + tmpfs_dir; + ASSIGN_OR_RETURN_ERRNO(TempPath verity_dir, TempPath::CreateDir()); + RETURN_ERROR_IF_SYSCALL_FAIL( + mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str())); + + // Enable both the file and the directory. + ASSIGN_OR_RETURN_ERRNO( + auto fd, Open(JoinPath(verity_dir.path(), filename), O_RDONLY, 0777)); + RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd.get(), FS_IOC_ENABLE_VERITY)); + ASSIGN_OR_RETURN_ERRNO(auto dir_fd, Open(verity_dir.path(), O_RDONLY, 0777)); + RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(dir_fd.get(), FS_IOC_ENABLE_VERITY)); + + // Measure the root hash. + uint8_t digest_array[sizeof(struct fsverity_digest) + kMaxDigestSize] = {0}; + struct fsverity_digest* digest = + reinterpret_cast<struct fsverity_digest*>(digest_array); + digest->digest_size = kMaxDigestSize; + RETURN_ERROR_IF_SYSCALL_FAIL( + ioctl(dir_fd.get(), FS_IOC_MEASURE_VERITY, digest)); + + // Mount a verity fs with specified root hash. + mount_opts += + ",root_hash=" + BytesToHexString(digest->digest, digest->digest_size); + ASSIGN_OR_RETURN_ERRNO(TempPath verity_with_hash_dir, TempPath::CreateDir()); + RETURN_ERROR_IF_SYSCALL_FAIL(mount("", verity_with_hash_dir.path().c_str(), + "verity", 0, mount_opts.c_str())); + // Verity directories should not be deleted. Release the TempPath objects to + // prevent those directories from being deleted by the destructor. + verity_dir.release(); + return verity_with_hash_dir.release(); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/verity_util.h b/test/util/verity_util.h new file mode 100644 index 000000000..18743ecd6 --- /dev/null +++ b/test/util/verity_util.h @@ -0,0 +1,75 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_UTIL_VERITY_UTIL_H_ +#define GVISOR_TEST_UTIL_VERITY_UTIL_H_ + +#include <stdint.h> + +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +#ifndef FS_IOC_ENABLE_VERITY +#define FS_IOC_ENABLE_VERITY 1082156677 +#endif + +#ifndef FS_IOC_MEASURE_VERITY +#define FS_IOC_MEASURE_VERITY 3221513862 +#endif + +#ifndef FS_VERITY_FL +#define FS_VERITY_FL 1048576 +#endif + +#ifndef FS_IOC_GETFLAGS +#define FS_IOC_GETFLAGS 2148034049 +#endif + +struct fsverity_digest { + unsigned short digest_algorithm; + unsigned short digest_size; /* input/output */ + unsigned char digest[]; +}; + +constexpr int kMaxDigestSize = 64; +constexpr int kDefaultDigestSize = 32; +constexpr char kContents[] = "foobarbaz"; +constexpr char kMerklePrefix[] = ".merkle.verity."; +constexpr char kMerkleRootPrefix[] = ".merkleroot.verity."; + +// Get the Merkle tree file path for |path|. +std::string MerklePath(absl::string_view path); + +// Get the root Merkle tree file path for |path|. +std::string MerkleRootPath(absl::string_view path); + +// Provide a function to convert bytes to hex string, since +// absl::BytesToHexString does not seem to be compatible with golang +// hex.DecodeString used in verity due to zero-padding. +std::string BytesToHexString(uint8_t bytes[], int size); + +// Flip a random bit in the file represented by fd. +PosixError FlipRandomBit(int fd, int size); + +// Mount a verity on the tmpfs and enable both the file and the direcotry. Then +// mount a new verity with measured root hash. +PosixErrorOr<std::string> MountVerity(std::string tmpfs_dir, + std::string filename); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_UTIL_VERITY_UTIL_H_ diff --git a/tools/bazeldefs/tags.bzl b/tools/bazeldefs/tags.bzl index f5d7a7b21..6564c3b25 100644 --- a/tools/bazeldefs/tags.bzl +++ b/tools/bazeldefs/tags.bzl @@ -33,6 +33,10 @@ archs = [ "_s390x", "_sparc64", "_x86", + + # Pseudo-architectures to group by word side. + "_32bit", + "_64bit", ] oses = [ diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD index 81994f954..2b116fe0d 100644 --- a/tools/bigquery/BUILD +++ b/tools/bigquery/BUILD @@ -6,6 +6,7 @@ go_library( name = "bigquery", testonly = 1, srcs = ["bigquery.go"], + nogo = False, # FIXME(b/184974218): Analysis failing for cloud libraries. visibility = [ "//:sandbox", ], diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go index 8eeabbc3d..c788654a8 100644 --- a/tools/checkescape/checkescape.go +++ b/tools/checkescape/checkescape.go @@ -709,8 +709,13 @@ func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) { return } - // Recursively collect information from - // the other analyzers. + // If this package is the atomic package, the implementation + // may be replaced by instrinsics that don't have analysis. + if x.Pkg.Pkg.Path() == "sync/atomic" { + return + } + + // Recursively collect information. var imp packageEscapeFacts if !pass.ImportPackageFact(x.Pkg.Pkg, &imp) { // Unable to import the dependency; we must diff --git a/tools/deps.bzl b/tools/deps.bzl index ed1135a9e..91442617c 100644 --- a/tools/deps.bzl +++ b/tools/deps.bzl @@ -94,8 +94,13 @@ def _deps_test_impl(ctx): ) return [] -# Checks that library and its deps only depends on gVisor and an allowlist of -# other dependencies. +# Checks that targets only depend on an allowlist of other targets. Targets can +# be specified directly, or prefixes can be used to allow entire packages or +# directory trees. +# +# This recursively checks the "deps" attribute of each target, dependencies +# expressed other ways are not checked. For example, protobuf targets pull in +# protobuf code, but aren't analyzed by deps_test. deps_test = rule( implementation = _deps_test_impl, attrs = { diff --git a/tools/github/reviver/github.go b/tools/github/reviver/github.go index c4b624f2a..b360f0544 100644 --- a/tools/github/reviver/github.go +++ b/tools/github/reviver/github.go @@ -92,7 +92,7 @@ func (b *GitHubBugger) Activate(todo *Todo) (bool, error) { fmt.Fprintln(&comment, "There are TODOs still referencing this issue:") for _, l := range todo.Locations { fmt.Fprintf(&comment, - "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#%d): %s\n", + "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#L%d): %s\n", l.File, l.Line, b.owner, b.repo, l.File, l.Line, l.Comment) } fmt.Fprintf(&comment, @@ -110,6 +110,11 @@ func (b *GitHubBugger) Activate(todo *Todo) (bool, error) { return true, fmt.Errorf("failed to reactivate issue %d: %v", id, err) } + _, _, err = b.client.Issues.AddLabelsToIssue(ctx, b.owner, b.repo, id, []string{"revived"}) + if err != nil { + return true, fmt.Errorf("failed to set label on issue %d: %v", id, err) + } + cmt := &github.IssueComment{ Body: github.String(comment.String()), Reactions: &github.Reactions{Confused: github.Int(1)}, diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index f79defea7..85a1adf66 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -5,9 +5,7 @@ licenses(["notice"]) go_binary( name = "go_marshal", srcs = ["main.go"], - visibility = [ - "//:sandbox", - ], + visibility = ["//:sandbox"], deps = [ "//tools/go_marshal/gomarshal", ], @@ -16,6 +14,7 @@ go_binary( config_setting( name = "marshal_config_verbose", values = {"define": "gomarshal=verbose"}, + visibility = ["//:sandbox"], ) bzl_library( diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index eddba0c21..bbd4c9f48 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -140,3 +140,6 @@ options, depending on how go-marshal is being invoked: - Set `debug = True` on the `go_marshal` BUILD rule. - Pass `-debug` to the go-marshal tool invocation. + +If bazel complains about stdout output being too large, set a larger value +through `--experimental_ui_max_stdouterr_bytes`, or `-1` for unlimited output. diff --git a/website/BUILD b/website/BUILD index 6f52e9208..1a38967e5 100644 --- a/website/BUILD +++ b/website/BUILD @@ -165,6 +165,7 @@ docs( "//g3doc/user_guide/tutorials:cni", "//g3doc/user_guide/tutorials:docker", "//g3doc/user_guide/tutorials:docker_compose", + "//g3doc/user_guide/tutorials:knative", "//g3doc/user_guide/tutorials:kubernetes", ], ) |