diff options
499 files changed, 23935 insertions, 4296 deletions
@@ -21,6 +21,9 @@ go_path( mode = "link", deps = [ "//runsc", + + # Packages that are not dependencies of //runsc. + "//pkg/tcpip/link/channel", ], ) @@ -133,11 +133,9 @@ The [gvisor-users mailing list][gvisor-users-list] and [gvisor-dev mailing list][gvisor-dev-list] are good starting points for questions and discussion. -## Security +## Security Policy -Sensitive security-related questions, comments and disclosures can be sent to -the [gvisor-security mailing list][gvisor-security-list]. The full security -disclosure policy is defined in the [community][community] repository. +See [SECURITY.md](SECURITY.md). ## Contributing @@ -147,7 +145,6 @@ See [Contributing.md](CONTRIBUTING.md). [community]: https://gvisor.googlesource.com/community [docker]: https://www.docker.com [git]: https://git-scm.com -[gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security [gvisor-users-list]: https://groups.google.com/forum/#!forum/gvisor-users [gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev [oci]: https://www.opencontainers.org diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..154d68cb3 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,11 @@ +# Security and Vulnerability Reporting + +Sensitive security-related questions, comments, and reports should be sent to +the [gvisor-security mailing list][gvisor-security-list]. You should receive a +prompt response, typically within 48 hours. + +Policies for security list access, vulnerability embargo, and vulnerability +disclosure are outlined in the [community][community] repository. + +[community]: https://gvisor.googlesource.com/community +[gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security @@ -3,19 +3,19 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") http_archive( name = "io_bazel_rules_go", - sha256 = "ae8c36ff6e565f674c7a3692d6a9ea1096e4c1ade497272c2108a810fb39acd2", + sha256 = "b9aa86ec08a292b97ec4591cf578e020b35f98e12173bbd4a921f84f583aebd9", urls = [ - "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz", + "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.20.2/rules_go-v0.20.2.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/v0.20.2/rules_go-v0.20.2.tar.gz", ], ) http_archive( name = "bazel_gazelle", - sha256 = "7fc87f4170011201b1690326e8c16c5d802836e3a0d617d8f75c3af2b23180c4", + sha256 = "86c6d481b3f7aedc1d60c1c211c6f76da282ae197c3b3160f54bd3a8f847896f", urls = [ - "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/0.18.2/bazel-gazelle-0.18.2.tar.gz", - "https://github.com/bazelbuild/bazel-gazelle/releases/download/0.18.2/bazel-gazelle-0.18.2.tar.gz", + "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.1/bazel-gazelle-v0.19.1.tar.gz", + "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.1/bazel-gazelle-v0.19.1.tar.gz", ], ) @@ -24,7 +24,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to go_rules_dependencies() go_register_toolchains( - go_version = "1.13", + go_version = "1.13.4", nogo = "@//:nogo", ) @@ -75,6 +75,16 @@ load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig") rbe_autoconfig(name = "rbe_default") +http_archive( + name = "rules_pkg", + sha256 = "5bdc04987af79bd27bc5b00fe30f59a858f77ffa0bd2d8143d5b31ad8b1bd71c", + url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.0/rules_pkg-0.2.0.tar.gz", +) + +load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") + +rules_pkg_dependencies() + # External repositories, in sorted order. go_repository( name = "com_github_cenkalti_backoff", diff --git a/kokoro/build.cfg b/kokoro/build.cfg index 084347dde..c9ceda947 100644 --- a/kokoro/build.cfg +++ b/kokoro/build.cfg @@ -16,7 +16,9 @@ env_vars { action { define_artifacts { + regex: "**/runsc" regex: "**/runsc.*" regex: "**/dists/**" + regex: "**/pool/**" } } diff --git a/kokoro/build_tests.cfg b/kokoro/build_tests.cfg new file mode 100644 index 000000000..c64b7e679 --- /dev/null +++ b/kokoro/build_tests.cfg @@ -0,0 +1 @@ +build_file: "repo/scripts/build.sh" diff --git a/kokoro/docker_tests.cfg b/kokoro/docker_tests.cfg index 717d71dd3..0a0ef87ed 100644 --- a/kokoro/docker_tests.cfg +++ b/kokoro/docker_tests.cfg @@ -5,5 +5,6 @@ action { regex: "**/sponge_log.xml" regex: "**/sponge_log.log" regex: "**/outputs.zip" + regex: "**/runsc_logs_*.tar.gz" } } diff --git a/kokoro/hostnet_tests.cfg b/kokoro/hostnet_tests.cfg index 532755f4a..520dc55a3 100644 --- a/kokoro/hostnet_tests.cfg +++ b/kokoro/hostnet_tests.cfg @@ -5,5 +5,6 @@ action { regex: "**/sponge_log.xml" regex: "**/sponge_log.log" regex: "**/outputs.zip" + regex: "**/runsc_logs_*.tar.gz" } } diff --git a/kokoro/kvm_tests.cfg b/kokoro/kvm_tests.cfg index 54365c2b2..1feb60c8a 100644 --- a/kokoro/kvm_tests.cfg +++ b/kokoro/kvm_tests.cfg @@ -5,5 +5,6 @@ action { regex: "**/sponge_log.xml" regex: "**/sponge_log.log" regex: "**/outputs.zip" + regex: "**/runsc_logs_*.tar.gz" } } diff --git a/kokoro/overlay_tests.cfg b/kokoro/overlay_tests.cfg index abd96f60c..6a2ddbd03 100644 --- a/kokoro/overlay_tests.cfg +++ b/kokoro/overlay_tests.cfg @@ -5,5 +5,6 @@ action { regex: "**/sponge_log.xml" regex: "**/sponge_log.log" regex: "**/outputs.zip" + regex: "**/runsc_logs_*.tar.gz" } } diff --git a/kokoro/root_tests.cfg b/kokoro/root_tests.cfg index 20b97766a..28351695c 100644 --- a/kokoro/root_tests.cfg +++ b/kokoro/root_tests.cfg @@ -5,5 +5,6 @@ action { regex: "**/sponge_log.xml" regex: "**/sponge_log.log" regex: "**/outputs.zip" + regex: "**/runsc_logs_*.tar.gz" } } diff --git a/kokoro/runtime_tests.cfg b/kokoro/runtime_tests.cfg new file mode 100644 index 000000000..7d56d5aca --- /dev/null +++ b/kokoro/runtime_tests.cfg @@ -0,0 +1 @@ +build_file: "repo/scripts/runtime_tests.sh" diff --git a/kokoro/swgso_tests.cfg b/kokoro/swgso_tests.cfg new file mode 100644 index 000000000..101a9c607 --- /dev/null +++ b/kokoro/swgso_tests.cfg @@ -0,0 +1,9 @@ +build_file: "repo/scripts/swgso_tests.sh" + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + } +} diff --git a/kokoro/syscall_kvm_tests.cfg b/kokoro/syscall_kvm_tests.cfg new file mode 100644 index 000000000..3b99e9c13 --- /dev/null +++ b/kokoro/syscall_kvm_tests.cfg @@ -0,0 +1,9 @@ +build_file: "repo/scripts/syscall_kvm_tests.sh" + +action { + define_artifacts { + regex: "**/sponge_log.xml" + regex: "**/sponge_log.log" + regex: "**/outputs.zip" + } +} diff --git a/kokoro/ubuntu1604/40_kokoro.sh b/kokoro/ubuntu1604/40_kokoro.sh index 64772d74d..3f50929d5 100755 --- a/kokoro/ubuntu1604/40_kokoro.sh +++ b/kokoro/ubuntu1604/40_kokoro.sh @@ -23,7 +23,10 @@ declare -r ssh_public_keys=( ) # Install dependencies. -apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm +apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm python-pip zip + +# junitparser is used to merge junit xml files. +pip install junitparser # We need a kbuilder user. if useradd -c "kbuilder user" -m -s /bin/bash kbuilder; then diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD index 32c601a03..f5c08ea06 100644 --- a/pkg/abi/BUILD +++ b/pkg/abi/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "abi", srcs = [ diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index f45934466..51774c6b6 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -1,13 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + # Package linux contains the constants and types needed to interface with a # Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix # when the host OS may not be Linux. -load("@io_bazel_rules_go//go:def.bzl", "go_test") - package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "linux", srcs = [ @@ -24,6 +23,8 @@ go_library( "exec.go", "fcntl.go", "file.go", + "file_amd64.go", + "file_arm64.go", "fs.go", "futex.go", "inotify.go", diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 7d742871a..c9ee098f4 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -186,25 +186,6 @@ const ( RWF_VALID = RWF_HIPRI | RWF_DSYNC | RWF_SYNC ) -// Stat represents struct stat. -type Stat struct { - Dev uint64 - Ino uint64 - Nlink uint64 - Mode uint32 - UID uint32 - GID uint32 - _ int32 - Rdev uint64 - Size int64 - Blksize int64 - Blocks int64 - ATime Timespec - MTime Timespec - CTime Timespec - _ [3]int64 -} - // SizeOfStat is the size of a Stat struct. var SizeOfStat = binary.Size(Stat{}) @@ -271,7 +252,7 @@ type Statx struct { } // FileMode represents a mode_t. -type FileMode uint +type FileMode uint16 // Permissions returns just the permission bits. func (m FileMode) Permissions() FileMode { diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go new file mode 100644 index 000000000..74c554be6 --- /dev/null +++ b/pkg/abi/linux/file_amd64.go @@ -0,0 +1,34 @@ +// Copyright 2018 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Nlink uint64 + Mode uint32 + UID uint32 + GID uint32 + _ int32 + Rdev uint64 + Size int64 + Blksize int64 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [3]int64 +} diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go new file mode 100644 index 000000000..f16c07589 --- /dev/null +++ b/pkg/abi/linux/file_arm64.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// Stat represents struct stat. +type Stat struct { + Dev uint64 + Ino uint64 + Mode uint32 + Nlink uint32 + UID uint32 + GID uint32 + Rdev uint64 + _ uint64 + Size int64 + Blksize int32 + _ int32 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [2]int32 +} diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 152f6b144..3898d2314 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -325,3 +325,9 @@ const ( RTA_SPORT = 28 RTA_DPORT = 29 ) + +// Route flags, from include/uapi/linux/route.h. +const ( + RTF_GATEWAY = 0x2 + RTF_UP = 0x1 +) diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index d5b731390..2e2cc6be7 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -256,6 +256,17 @@ type SockAddrInet6 struct { Scope_id uint32 } +// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h. +type SockAddrLink struct { + Family uint16 + Protocol uint16 + InterfaceIndex int32 + ARPHardwareType uint16 + PacketType byte + HardwareAddrLen byte + HardwareAddr [8]byte +} + // UnixPathMax is the maximum length of the path in an AF_UNIX socket. // // From uapi/linux/un.h. @@ -278,6 +289,7 @@ type SockAddr interface { func (s *SockAddrInet) implementsSockAddr() {} func (s *SockAddrInet6) implementsSockAddr() {} +func (s *SockAddrLink) implementsSockAddr() {} func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 5f59866fa..36beaade9 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -8,6 +8,7 @@ go_library( srcs = [ "atomic_bitops.go", "atomic_bitops_amd64.s", + "atomic_bitops_arm64.s", "atomic_bitops_common.go", ], importpath = "gvisor.dev/gvisor/pkg/atomicbitops", diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomic_bitops.go index 63aa2b7f1..fcc41a9ea 100644 --- a/pkg/atomicbitops/atomic_bitops.go +++ b/pkg/atomicbitops/atomic_bitops.go @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 // Package atomicbitops provides basic bitwise operations in an atomic way. // The implementation on amd64 leverages the LOCK prefix directly instead of -// relying on the generic cas primitives. +// relying on the generic cas primitives, and the arm64 leverages the LDAXR +// and STLXR pair primitives. // // WARNING: the bitwise ops provided in this package doesn't imply any memory // ordering. Using them to construct locks must employ proper memory barriers. diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomic_bitops_arm64.s new file mode 100644 index 000000000..97f8808c1 --- /dev/null +++ b/pkg/atomicbitops/atomic_bitops_arm64.s @@ -0,0 +1,139 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +#include "textflag.h" + +TEXT ·AndUint32(SB),$0-12 + MOVD ptr+0(FP), R0 + MOVW val+8(FP), R1 +again: + LDAXRW (R0), R2 + ANDW R1, R2 + STLXRW R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·OrUint32(SB),$0-12 + MOVD ptr+0(FP), R0 + MOVW val+8(FP), R1 +again: + LDAXRW (R0), R2 + ORRW R1, R2 + STLXRW R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·XorUint32(SB),$0-12 + MOVD ptr+0(FP), R0 + MOVW val+8(FP), R1 +again: + LDAXRW (R0), R2 + EORW R1, R2 + STLXRW R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·CompareAndSwapUint32(SB),$0-20 + MOVD addr+0(FP), R0 + MOVW old+8(FP), R1 + MOVW new+12(FP), R2 + +again: + LDAXRW (R0), R3 + CMPW R1, R3 + BNE done + STLXRW R2, (R0), R4 + CBNZ R4, again +done: + MOVW R3, prev+16(FP) + RET + +TEXT ·AndUint64(SB),$0-16 + MOVD ptr+0(FP), R0 + MOVD val+8(FP), R1 +again: + LDAXR (R0), R2 + AND R1, R2 + STLXR R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·OrUint64(SB),$0-16 + MOVD ptr+0(FP), R0 + MOVD val+8(FP), R1 +again: + LDAXR (R0), R2 + ORR R1, R2 + STLXR R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·XorUint64(SB),$0-16 + MOVD ptr+0(FP), R0 + MOVD val+8(FP), R1 +again: + LDAXR (R0), R2 + EOR R1, R2 + STLXR R2, (R0), R3 + CBNZ R3, again + RET + +TEXT ·CompareAndSwapUint64(SB),$0-32 + MOVD addr+0(FP), R0 + MOVD old+8(FP), R1 + MOVD new+16(FP), R2 + +again: + LDAXR (R0), R3 + CMP R1, R3 + BNE done + STLXR R2, (R0), R4 + CBNZ R4, again +done: + MOVD R3, prev+24(FP) + RET + +TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9 + MOVD addr+0(FP), R0 + +again: + LDAXRW (R0), R1 + CBZ R1, fail + ADDW $1, R1 + STLXRW R1, (R0), R2 + CBNZ R2, again + MOVW $1, R2 + MOVB R2, ret+8(FP) + RET +fail: + MOVB ZR, ret+8(FP) + RET + +TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9 + MOVD addr+0(FP), R0 + +again: + LDAXRW (R0), R1 + SUBSW $1, R1, R1 + BEQ fail + STLXRW R1, (R0), R2 + CBNZ R2, again + MOVW $1, R2 + MOVB R2, ret+8(FP) + RET +fail: + MOVB ZR, ret+8(FP) + RET diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomic_bitops_common.go index b2a943dcb..85163ad62 100644 --- a/pkg/atomicbitops/atomic_bitops_common.go +++ b/pkg/atomicbitops/atomic_bitops_common.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !amd64 +// +build !amd64,!arm64 package atomicbitops diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 51967b811..93b88a29a 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -1,18 +1,18 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - go_library( name = "bits", srcs = [ "bits.go", "bits32.go", "bits64.go", - "uint64_arch_amd64.go", + "uint64_arch.go", "uint64_arch_amd64_asm.s", + "uint64_arch_arm64_asm.s", "uint64_arch_generic.go", ], importpath = "gvisor.dev/gvisor/pkg/bits", diff --git a/pkg/bits/uint64_arch_amd64.go b/pkg/bits/uint64_arch.go index faccaa61a..9f23eff77 100644 --- a/pkg/bits/uint64_arch_amd64.go +++ b/pkg/bits/uint64_arch.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 package bits diff --git a/pkg/bits/uint64_arch_arm64_asm.s b/pkg/bits/uint64_arch_arm64_asm.s new file mode 100644 index 000000000..814ba562d --- /dev/null +++ b/pkg/bits/uint64_arch_arm64_asm.s @@ -0,0 +1,33 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +TEXT ·TrailingZeros64(SB),$0-16 + MOVD x+0(FP), R0 + RBIT R0, R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD R0, ret+8(FP) + RET + +TEXT ·MostSignificantOne64(SB),$0-16 + MOVD x+0(FP), R0 + CLZ R0, R0 // return 64 if x == 0 + MOVD $63, R1 + SUBS R0, R1, R0 // ret = 63 - CLZ + BPL end + MOVD $64, R0 // x == 0 +end: + MOVD R0, ret+8(FP) + RET diff --git a/pkg/bits/uint64_arch_generic.go b/pkg/bits/uint64_arch_generic.go index 7dd2d1480..9dd2098d1 100644 --- a/pkg/bits/uint64_arch_generic.go +++ b/pkg/bits/uint64_arch_generic.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !amd64 +// +build !amd64,!arm64 package bits diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index 8d31e068c..fba5643e8 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "bpf", srcs = [ diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD index 32422f9e2..ed111fd2a 100644 --- a/pkg/cpuid/BUILD +++ b/pkg/cpuid/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "cpuid", srcs = [ diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index 5d61dc2ff..d37047368 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -183,6 +183,33 @@ const ( X86FeatureAVX512VBMI X86FeatureUMIP X86FeaturePKU + X86FeatureOSPKE + X86FeatureWAITPKG + X86FeatureAVX512_VBMI2 + _ // ecx bit 7 is reserved + X86FeatureGFNI + X86FeatureVAES + X86FeatureVPCLMULQDQ + X86FeatureAVX512_VNNI + X86FeatureAVX512_BITALG + X86FeatureTME + X86FeatureAVX512_VPOPCNTDQ + _ // ecx bit 15 is reserved + X86FeatureLA57 + // ecx bits 17-21 are reserved + _ + _ + _ + _ + _ + X86FeatureRDPID + // ecx bits 23-24 are reserved + _ + _ + X86FeatureCLDEMOTE + _ // ecx bit 26 is reserved + X86FeatureMOVDIRI + X86FeatureMOVDIR64B ) // Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX. @@ -353,9 +380,24 @@ var x86FeatureStrings = map[Feature]string{ X86FeatureAVX512VL: "avx512vl", // Block 3. - X86FeatureAVX512VBMI: "avx512vbmi", - X86FeatureUMIP: "umip", - X86FeaturePKU: "pku", + X86FeatureAVX512VBMI: "avx512vbmi", + X86FeatureUMIP: "umip", + X86FeaturePKU: "pku", + X86FeatureOSPKE: "ospke", + X86FeatureWAITPKG: "waitpkg", + X86FeatureAVX512_VBMI2: "avx512_vbmi2", + X86FeatureGFNI: "gfni", + X86FeatureVAES: "vaes", + X86FeatureVPCLMULQDQ: "vpclmulqdq", + X86FeatureAVX512_VNNI: "avx512_vnni", + X86FeatureAVX512_BITALG: "avx512_bitalg", + X86FeatureTME: "tme", + X86FeatureAVX512_VPOPCNTDQ: "avx512_vpopcntdq", + X86FeatureLA57: "la57", + X86FeatureRDPID: "rdpid", + X86FeatureCLDEMOTE: "cldemote", + X86FeatureMOVDIRI: "movdiri", + X86FeatureMOVDIR64B: "movdir64b", // Block 4. X86FeatureXSAVEOPT: "xsaveopt", diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index c7f549428..afa8f7659 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -8,9 +8,6 @@ go_library( srcs = ["fd.go"], importpath = "gvisor.dev/gvisor/pkg/fd", visibility = ["//visibility:public"], - deps = [ - "//pkg/unet", - ], ) go_test( diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go index 7691b477b..83bcfe220 100644 --- a/pkg/fd/fd.go +++ b/pkg/fd/fd.go @@ -22,8 +22,6 @@ import ( "runtime" "sync/atomic" "syscall" - - "gvisor.dev/gvisor/pkg/unet" ) // ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It @@ -187,12 +185,6 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) { return New(f), nil } -// DialUnix connects to a Unix Domain Socket and return the file descriptor. -func DialUnix(path string) (*FD, error) { - socket, err := unet.Connect(path, false) - return New(socket.FD()), err -} - // Close closes the file descriptor contained in the FD. // // Close is safe to call multiple times, but will return an error after the diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index 8390915a2..e7c3a3a0b 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -113,7 +113,7 @@ func (ep *Endpoint) enterFutexWait() error { return nil case epsBlocked | epsShutdown: atomic.AddInt32(&ep.ctrl.state, -epsBlocked) - return shutdownError{} + return ShutdownError{} default: // Most likely due to ep.enterFutexWait() being called concurrently // from multiple goroutines. diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 386cee42c..3cdb576e1 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -136,8 +136,8 @@ func (ep *Endpoint) unmapPacket() { // Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), // ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer -// Endpoint, to unblock and return errors. It does not wait for concurrent -// calls to return. Successive calls to Shutdown have no effect. +// Endpoint, to unblock and return ShutdownErrors. It does not wait for +// concurrent calls to return. Successive calls to Shutdown have no effect. // // Shutdown is the only Endpoint method that may be called concurrently with // other methods on the same Endpoint. @@ -154,10 +154,12 @@ func (ep *Endpoint) isShutdownLocally() bool { return atomic.LoadUint32(&ep.shutdown) != 0 } -type shutdownError struct{} +// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown() +// has been called. +type ShutdownError struct{} // Error implements error.Error. -func (shutdownError) Error() string { +func (ShutdownError) Error() string { return "flipcall connection shutdown" } diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index b127a2bbb..168c1ccff 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -61,7 +61,7 @@ func (ep *Endpoint) futexSwitchToPeer() error { if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { switch cs := atomic.LoadUint32(ep.connState()); cs { case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) } @@ -81,14 +81,14 @@ func (ep *Endpoint) futexSwitchFromPeer() error { return nil case ep.inactiveState: if ep.isShutdownLocally() { - return shutdownError{} + return ShutdownError{} } if err := ep.futexWaitConnState(ep.inactiveState); err != nil { return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) } continue case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) } diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 2412aa5e1..221516c6c 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -505,12 +505,27 @@ func (c *Client) sendRecvChannel(t message, r message) error { ch.active = false c.channelsMu.Unlock() c.channelsWg.Done() - return err + // Map all transport errors to EIO, but ensure that the real error + // is logged. + log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err) + return syscall.EIO } } - // Send the message. - err := ch.sendRecv(c, t, r) + // Send the request and receive the server's response. + rsz, err := ch.send(t) + if err != nil { + // See above. + c.channelsMu.Lock() + ch.active = false + c.channelsMu.Unlock() + c.channelsWg.Done() + log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err) + return syscall.EIO + } + + // Parse the server's response. + _, retErr := ch.recv(r, rsz) // Release the channel. c.channelsMu.Lock() @@ -523,7 +538,7 @@ func (c *Client) sendRecvChannel(t message, r message) error { c.channelsMu.Unlock() c.channelsWg.Done() - return err + return retErr } // Version returns the negotiated 9P2000.L.Google version number. diff --git a/pkg/p9/file.go b/pkg/p9/file.go index 907445e15..96d1f2a8e 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -116,7 +116,7 @@ type File interface { // N.B. The server must resolve any lazy paths when open is called. // After this point, read and write may be called on files with no // deletion check, so resolving in the data path is not viable. - Open(mode OpenFlags) (*fd.FD, QID, uint32, error) + Open(flags OpenFlags) (*fd.FD, QID, uint32, error) // Read reads from this file. Open must be called first. // diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 25530adca..415200d60 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -32,18 +32,22 @@ import ( type OpenFlags uint32 const ( - // ReadOnly is a Topen and Tcreate flag indicating read-only mode. + // ReadOnly is a Tlopen and Tlcreate flag indicating read-only mode. ReadOnly OpenFlags = 0 - // WriteOnly is a Topen and Tcreate flag indicating write-only mode. + // WriteOnly is a Tlopen and Tlcreate flag indicating write-only mode. WriteOnly OpenFlags = 1 - // ReadWrite is a Topen flag indicates read-write mode. + // ReadWrite is a Tlopen flag indicates read-write mode. ReadWrite OpenFlags = 2 // OpenFlagsModeMask is a mask of valid OpenFlags mode bits. OpenFlagsModeMask OpenFlags = 3 + // OpenTruncate is a Tlopen flag indicating that the opened file should be + // truncated. + OpenTruncate OpenFlags = 01000 + // OpenFlagsIgnoreMask is a list of OpenFlags mode bits that are ignored for Tlopen. // Note that syscall.O_LARGEFILE is set to zero, use value from Linux fcntl.h. OpenFlagsIgnoreMask OpenFlags = syscall.O_DIRECTORY | syscall.O_NOATIME | 0100000 @@ -71,25 +75,32 @@ const ( // OSFlags converts a p9.OpenFlags to an int compatible with open(2). func (o OpenFlags) OSFlags() int { - return int(o & OpenFlagsModeMask) + // "flags contains Linux open(2) flags bits" - 9P2000.L + return int(o) } // String implements fmt.Stringer. func (o OpenFlags) String() string { - switch o { + var buf strings.Builder + switch mode := o & OpenFlagsModeMask; mode { case ReadOnly: - return "ReadOnly" + buf.WriteString("ReadOnly") case WriteOnly: - return "WriteOnly" + buf.WriteString("WriteOnly") case ReadWrite: - return "ReadWrite" - case OpenFlagsModeMask: - return "OpenFlagsModeMask" - case OpenFlagsIgnoreMask: - return "OpenFlagsIgnoreMask" + buf.WriteString("ReadWrite") default: - return "UNDEFINED" + fmt.Fprintf(&buf, "%#o", mode) + } + otherFlags := o &^ OpenFlagsModeMask + if otherFlags&OpenTruncate != 0 { + buf.WriteString("|OpenTruncate") + otherFlags &^= OpenTruncate + } + if otherFlags != 0 { + fmt.Fprintf(&buf, "|%#o", otherFlags) } + return buf.String() } // Tag is a message tag. @@ -814,7 +825,7 @@ func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) { attr.Mode = FileMode(s.Mode) } if req.NLink { - attr.NLink = s.Nlink + attr.NLink = uint64(s.Nlink) } if req.UID { attr.UID = UID(s.Uid) diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 69c886a5d..40b8fa023 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -452,7 +452,13 @@ func (cs *connState) initializeChannels() (err error) { cs.channelWg.Add(1) go func() { // S/R-SAFE: Server side. defer cs.channelWg.Done() - res.service(cs) + if err := res.service(cs); err != nil { + // Don't log flipcall.ShutdownErrors, which we expect to be + // returned during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("p9.channel.service: %v", err) + } + } }() } diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go index 7cdf4ecc3..233f825e3 100644 --- a/pkg/p9/transport_flipcall.go +++ b/pkg/p9/transport_flipcall.go @@ -132,7 +132,7 @@ func (ch *channel) send(m message) (uint32, error) { if filer, ok := m.(filer); ok { if f := filer.FilePayload(); f != nil { if err := ch.fds.SendFD(f.FD()); err != nil { - return 0, syscall.EIO // Map everything to EIO. + return 0, err } f.Close() // Per sendRecvLegacy. sentFD = true // To mark below. @@ -162,15 +162,7 @@ func (ch *channel) send(m message) (uint32, error) { } // Perform the one-shot communication. - n, err := ch.data.SendRecv(ssz) - if err != nil { - if n > 0 { - return n, nil - } - return 0, syscall.EIO // See above. - } - - return n, nil + return ch.data.SendRecv(ssz) } // recv decodes a message that exists on the channel. @@ -249,15 +241,3 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) { return r, nil } - -// sendRecv sends the given message over the channel. -// -// This is used by the client. -func (ch *channel) sendRecv(c *Client, m, r message) error { - rsz, err := ch.send(m) - if err != nil { - return err - } - _, err = ch.recv(r, rsz) - return err -} diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index 30ec8e6e2..38cea9be3 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -14,7 +14,7 @@ // +build amd64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index e340d9f98..4f4b70fef 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -14,7 +14,7 @@ // +build arm64 // +build go1.8 -// +build !go1.14 +// +build !go1.15 #include "textflag.h" diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 827385139..7ad59dfd7 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "weak_ref_list", out = "weak_ref_list.go", diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD index 700385907..1b487b887 100644 --- a/pkg/segment/BUILD +++ b/pkg/segment/BUILD @@ -1,10 +1,10 @@ +load("//tools/go_generics:defs.bzl", "go_template") + package( default_visibility = ["//:sandbox"], licenses = ["notice"], ) -load("//tools/go_generics:defs.bzl", "go_template") - go_template( name = "generic_range", srcs = ["range.go"], diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index 12d7c77d2..a27c35e21 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -1,13 +1,12 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package( default_visibility = ["//visibility:private"], licenses = ["notice"], ) -load("//tools/go_generics:defs.bzl", "go_template_instance") - go_template_instance( name = "int_range", out = "int_range.go", diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index c71cff9f3..18c73cc24 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "arch", srcs = [ diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go index dfd62cbdb..23e009ef3 100644 --- a/pkg/sentry/context/context.go +++ b/pkg/sentry/context/context.go @@ -12,10 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package context defines the sentry's Context type. +// Package context defines an internal context type. +// +// The given Context conforms to the standard Go context, but mandates +// additional methods that are specific to the kernel internals. Note however, +// that the Context described by this package carries additional constraints +// regarding concurrent access and retaining beyond the scope of a call. +// +// See the Context type for complete details. package context import ( + "context" + "time" + "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/log" ) @@ -59,6 +69,7 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { type Context interface { log.Logger amutex.Sleeper + context.Context // UninterruptibleSleepStart indicates the beginning of an uninterruptible // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate @@ -72,19 +83,36 @@ type Context interface { // AddressSpace is activated. Normally activate is the same value as the // deactivate parameter passed to UninterruptibleSleepStart. UninterruptibleSleepFinish(activate bool) +} + +// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep +// methods for anonymous embedding in other types that do not implement sleeps. +type NoopSleeper struct { + amutex.NoopSleeper +} + +// UninterruptibleSleepStart does nothing. +func (NoopSleeper) UninterruptibleSleepStart(bool) {} + +// UninterruptibleSleepFinish does nothing. +func (NoopSleeper) UninterruptibleSleepFinish(bool) {} + +// Deadline returns zero values, meaning no deadline. +func (NoopSleeper) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done returns nil. +func (NoopSleeper) Done() <-chan struct{} { + return nil +} - // Value returns the value associated with this Context for key, or nil if - // no value is associated with key. Successive calls to Value with the same - // key returns the same result. - // - // A key identifies a specific value in a Context. Functions that wish to - // retrieve values from Context typically allocate a key in a global - // variable then use that key as the argument to Context.Value. A key can - // be any type that supports equality; packages should define keys as an - // unexported type to avoid collisions. - Value(key interface{}) interface{} +// Err returns nil. +func (NoopSleeper) Err() error { + return nil } +// logContext implements basic logging. type logContext struct { log.Logger NoopSleeper @@ -95,19 +123,6 @@ func (logContext) Value(key interface{}) interface{} { return nil } -// NoopSleeper is a noop implementation of amutex.Sleeper and -// Context.UninterruptibleSleep* methods for anonymous embedding in other types -// that do not want to notify kernel.Task about sleeps. -type NoopSleeper struct { - amutex.NoopSleeper -} - -// UninterruptibleSleepStart does nothing. -func (NoopSleeper) UninterruptibleSleepStart(bool) {} - -// UninterruptibleSleepFinish does nothing. -func (NoopSleeper) UninterruptibleSleepFinish(bool) {} - // bgContext is the context returned by context.Background. var bgContext = &logContext{Logger: log.Log()} diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD index 3b6841b7e..581e7aa96 100644 --- a/pkg/sentry/context/contexttest/BUILD +++ b/pkg/sentry/context/contexttest/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "contexttest", testonly = 1, diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 0c86197f7..1098ed777 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "device", srcs = ["device.go"], diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 3119a61b6..378602cc9 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "fs", srcs = [ diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 80e106e6f..a0d9e8496 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "dev", srcs = [ diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index b9bd9ed17..277ee4c31 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "fdpipe", srcs = [ diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD index a9d6d9301..358dc2be3 100644 --- a/pkg/sentry/fs/filetest/BUILD +++ b/pkg/sentry/fs/filetest/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "filetest", testonly = 1, diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index b4ac83dc4..b2e8d9c77 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "dirty_set_impl", out = "dirty_set_impl.go", diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index e239f12a5..b06a71cc2 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -209,3 +209,29 @@ func (f *HostFileMapper) unmapAndRemoveLocked(chunkStart uint64, m mapping) { } delete(f.mappings, chunkStart) } + +// RegenerateMappings must be called when the file description mapped by f +// changes, to replace existing mappings of the previous file description. +func (f *HostFileMapper) RegenerateMappings(fd int) error { + f.mapsMu.Lock() + defer f.mapsMu.Unlock() + + for chunkStart, m := range f.mappings { + prot := syscall.PROT_READ + if m.writable { + prot |= syscall.PROT_WRITE + } + _, _, errno := syscall.Syscall6( + syscall.SYS_MMAP, + m.addr, + chunkSize, + uintptr(prot), + syscall.MAP_SHARED|syscall.MAP_FIXED, + uintptr(fd), + uintptr(chunkStart)) + if errno != 0 { + return errno + } + } + return nil +} diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index d2495cb83..30475f340 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -100,13 +100,30 @@ func (h *HostMappable) Translate(ctx context.Context, required, optional memmap. } // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. -func (h *HostMappable) InvalidateUnsavable(ctx context.Context) error { +func (h *HostMappable) InvalidateUnsavable(_ context.Context) error { h.mu.Lock() h.mappings.InvalidateAll(memmap.InvalidateOpts{}) h.mu.Unlock() return nil } +// NotifyChangeFD must be called after the file description represented by +// CachedFileObject.FD() changes. +func (h *HostMappable) NotifyChangeFD() error { + // Update existing sentry mappings to refer to the new file description. + if err := h.hostFileMapper.RegenerateMappings(h.backingFile.FD()); err != nil { + return err + } + + // Shoot down existing application mappings of the old file description; + // they will be remapped with the new file description on demand. + h.mu.Lock() + defer h.mu.Unlock() + + h.mappings.InvalidateAll(memmap.InvalidateOpts{}) + return nil +} + // MapInternal implements platform.File.MapInternal. func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write) @@ -144,7 +161,7 @@ func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error { mask := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: newSize} - if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr); err != nil { + if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil { return err } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index d404a79d4..798920d18 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -140,12 +140,16 @@ type CachedFileObject interface { // WriteFromBlocksAt may return a partial write without an error. WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) - // SetMaskedAttributes sets the attributes in attr that are true in mask - // on the backing file. + // SetMaskedAttributes sets the attributes in attr that are true in + // mask on the backing file. If the mask contains only ATime or MTime + // and the CachedFileObject has an FD to the file, then this operation + // is a noop unless forceSetTimestamps is true. This avoids an extra + // RPC to the gofer in the open-read/write-close case, when the + // timestamps on the file will be updated by the host kernel for us. // // SetMaskedAttributes may be called at any point, regardless of whether // the file was opened. - SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error + SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error // Allocate allows the caller to reserve disk space for the inode. // It's equivalent to fallocate(2) with 'mode=0'. @@ -224,7 +228,7 @@ func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.I now := ktime.NowFromContext(ctx) masked := fs.AttrMask{Perms: true} - if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}, false); err != nil { return false } c.attr.Perms = perms @@ -246,7 +250,7 @@ func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode, UID: owner.UID.Ok(), GID: owner.GID.Ok(), } - if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}, false); err != nil { return err } if owner.UID.Ok() { @@ -282,7 +286,9 @@ func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.In AccessTime: !ts.ATimeOmit, ModificationTime: !ts.MTimeOmit, } - if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}); err != nil { + // Call SetMaskedAttributes with forceSetTimestamps = true to make sure + // the timestamp is updated. + if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}, true); err != nil { return err } if !ts.ATimeOmit { @@ -305,7 +311,7 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode, now := ktime.NowFromContext(ctx) masked := fs.AttrMask{Size: true} attr := fs.UnstableAttr{Size: size} - if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil { c.dataMu.Unlock() return err } @@ -394,7 +400,7 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) c.dirtyAttr.Size = false // Write out cached attributes. - if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr); err != nil { + if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil { c.attrMu.Unlock() return err } @@ -953,6 +959,23 @@ func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error return nil } +// NotifyChangeFD must be called after the file description represented by +// CachedFileObject.FD() changes. +func (c *CachingInodeOperations) NotifyChangeFD() error { + // Update existing sentry mappings to refer to the new file description. + if err := c.hostFileMapper.RegenerateMappings(c.backingFile.FD()); err != nil { + return err + } + + // Shoot down existing application mappings of the old file description; + // they will be remapped with the new file description on demand. + c.mapsMu.Lock() + defer c.mapsMu.Unlock() + + c.mappings.InvalidateAll(memmap.InvalidateOpts{}) + return nil +} + // Evict implements pgalloc.EvictableMemoryUser.Evict. func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.EvictableRange) { c.mapsMu.Lock() @@ -1021,7 +1044,6 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { } c.refs.MergeAdjacent(fr) c.dataMu.Unlock() - } // MapInternal implements platform.File.MapInternal. This is used when we diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go index eb5730c35..129f314c8 100644 --- a/pkg/sentry/fs/fsutil/inode_cached_test.go +++ b/pkg/sentry/fs/fsutil/inode_cached_test.go @@ -39,7 +39,7 @@ func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.Block return srcs.NumBytes(), nil } -func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error { +func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { return nil } @@ -230,7 +230,7 @@ func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.B return w.WriteFromBlocks(srcs) } -func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error { +func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error { return nil } diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 2b71ca0e1..4a005c605 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "gofer", srcs = [ diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index 9e2e412cd..7960b9c7b 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -214,28 +214,64 @@ func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr return entries, nil } +// maybeSync will call FSync on the file if either the cache policy or file +// flags require it. +func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n int64) error { + if n == 0 { + // Nothing to sync. + return nil + } + + if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) { + // Call WriteOut directly, as some "writethrough" filesystems + // do not support sync. + return f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode) + } + + flags := file.Flags() + var syncType fs.SyncType + switch { + case flags.Direct || flags.Sync: + syncType = fs.SyncAll + case flags.DSync: + syncType = fs.SyncData + default: + // No need to sync. + return nil + } + + return f.Fsync(ctx, file, offset, offset+n, syncType) +} + // Write implements fs.FileOperations.Write. func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { if fs.IsDir(file.Dirent.Inode.StableAttr) { // Not all remote file systems enforce this so this client does. return 0, syserror.EISDIR } - cp := f.inodeOperations.session().cachePolicy - if cp.useCachingInodeOps(file.Dirent.Inode) { - n, err := f.inodeOperations.cachingInodeOps.Write(ctx, src, offset) - if err != nil { - return n, err - } - if cp.writeThrough(file.Dirent.Inode) { - // Write out the file. - err = f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode) - } - return n, err + + var ( + n int64 + err error + ) + // The write is handled in different ways depending on the cache policy + // and availability of a host-mappable FD. + if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) { + n, err = f.inodeOperations.cachingInodeOps.Write(ctx, src, offset) + } else if f.inodeOperations.fileState.hostMappable != nil { + n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset) + } else { + n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) } - if f.inodeOperations.fileState.hostMappable != nil { - return f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset) + + // We may need to sync the written bytes. + if syncErr := f.maybeSync(ctx, file, offset, n); syncErr != nil { + // Sync failed. Report 0 bytes written, since none of them are + // guaranteed to have been synced. + return 0, syncErr } - return src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset)) + + return n, err } // incrementReadCounters increments the read counters for the read starting at the given time. We @@ -273,7 +309,7 @@ func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO } // Fsync implements fs.FileOperations.Fsync. -func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start int64, end int64, syncType fs.SyncType) error { +func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start, end int64, syncType fs.SyncType) error { switch syncType { case fs.SyncAll, fs.SyncData: if err := file.Dirent.Inode.WriteOut(ctx); err != nil { diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go index 9aa68a70e..c2fbb4be9 100644 --- a/pkg/sentry/fs/gofer/file_state.go +++ b/pkg/sentry/fs/gofer/file_state.go @@ -29,7 +29,7 @@ func (f *fileOperations) afterLoad() { // Manually load the open handles. var err error // TODO(b/38173783): Context is not plumbed to save/restore. - f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags) + f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags, f.inodeOperations.cachingInodeOps) if err != nil { return fmt.Errorf("failed to re-open handle: %v", err) } diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go index 8f8ab5d29..cf96dd9fa 100644 --- a/pkg/sentry/fs/gofer/fs.go +++ b/pkg/sentry/fs/gofer/fs.go @@ -58,6 +58,11 @@ const ( // If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to // true. limitHostFDTranslationKey = "limit_host_fd_translation" + + // overlayfsStaleRead if present closes cached readonly file after the first + // write. This is done to workaround a limitation of overlayfs in kernels + // before 4.19 where open FDs are not updated after the file is copied up. + overlayfsStaleRead = "overlayfs_stale_read" ) // defaultAname is the default attach name. @@ -145,6 +150,7 @@ type opts struct { version string privateunixsocket bool limitHostFDTranslation bool + overlayfsStaleRead bool } // options parses mount(2) data into structured options. @@ -247,6 +253,11 @@ func options(data string) (opts, error) { delete(options, limitHostFDTranslationKey) } + if _, ok := options[overlayfsStaleRead]; ok { + o.overlayfsStaleRead = true + delete(options, overlayfsStaleRead) + } + // Fail to attach if the caller wanted us to do something that we // don't support. if len(options) > 0 { diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go index 27eeae3d9..39c8ec33d 100644 --- a/pkg/sentry/fs/gofer/handles.go +++ b/pkg/sentry/fs/gofer/handles.go @@ -39,14 +39,22 @@ type handles struct { // Host is an *fd.FD handle. May be nil. Host *fd.FD + + // isHostBorrowed tells whether 'Host' is owned or borrowed. If owned, it's + // closed on destruction, otherwise it's released. + isHostBorrowed bool } // DecRef drops a reference on handles. func (h *handles) DecRef() { h.DecRefWithDestructor(func() { if h.Host != nil { - if err := h.Host.Close(); err != nil { - log.Warningf("error closing host file: %v", err) + if h.isHostBorrowed { + h.Host.Release() + } else { + if err := h.Host.Close(); err != nil { + log.Warningf("error closing host file: %v", err) + } } } // FIXME(b/38173783): Context is not plumbed here. diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 95b064aea..99910388f 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -100,7 +100,7 @@ type inodeFileState struct { // true. // // Once readHandles becomes non-nil, it can't be changed until - // inodeFileState.Release(), because of a defect in the + // inodeFileState.Release()*, because of a defect in the // fsutil.CachedFileObject interface: there's no way for the caller of // fsutil.CachedFileObject.FD() to keep the returned FD open, so if we // racily replace readHandles after inodeFileState.FD() has returned @@ -108,6 +108,9 @@ type inodeFileState struct { // FD. writeHandles can be changed if writeHandlesRW is false, since // inodeFileState.FD() can't return a write-only FD, but can't be changed // if writeHandlesRW is true for the same reason. + // + // * There is one notable exception in recreateReadHandles(), where it dup's + // the FD and invalidates the page cache. readHandles *handles `state:"nosave"` writeHandles *handles `state:"nosave"` writeHandlesRW bool `state:"nosave"` @@ -175,48 +178,129 @@ func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles) // getHandles returns a set of handles for a new file using i opened with the // given flags. -func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags) (*handles, error) { +func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cache *fsutil.CachingInodeOperations) (*handles, error) { if !i.canShareHandles() { return newHandles(ctx, i.file, flags) } + i.handlesMu.Lock() - defer i.handlesMu.Unlock() + h, invalidate, err := i.getHandlesLocked(ctx, flags) + i.handlesMu.Unlock() + + if invalidate { + cache.NotifyChangeFD() + if i.hostMappable != nil { + i.hostMappable.NotifyChangeFD() + } + } + + return h, err +} + +// getHandlesLocked returns a pointer to cached handles and a boolean indicating +// whether previously open read handle was recreated. Host mappings must be +// invalidated if so. +func (i *inodeFileState) getHandlesLocked(ctx context.Context, flags fs.FileFlags) (*handles, bool, error) { // Do we already have usable shared handles? if flags.Write { if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) { i.writeHandles.IncRef() - return i.writeHandles, nil + return i.writeHandles, false, nil } } else if i.readHandles != nil { i.readHandles.IncRef() - return i.readHandles, nil + return i.readHandles, false, nil } + // No; get new handles and cache them for future sharing. h, err := newHandles(ctx, i.file, flags) if err != nil { - return nil, err + return nil, false, err + } + + // Read handles invalidation is needed if: + // - Mount option 'overlayfs_stale_read' is set + // - Read handle is open: nothing to invalidate otherwise + // - Write handle is not open: file was not open for write and is being open + // for write now (will trigger copy up in overlayfs). + invalidate := false + if i.s.overlayfsStaleRead && i.readHandles != nil && i.writeHandles == nil && flags.Write { + if err := i.recreateReadHandles(ctx, h, flags); err != nil { + return nil, false, err + } + invalidate = true } i.setSharedHandlesLocked(flags, h) - return h, nil + return h, invalidate, nil +} + +func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handles, flags fs.FileFlags) error { + h := writer + if !flags.Read { + // Writer can't be used for read, must create a new handle. + var err error + h, err = newHandles(ctx, i.file, fs.FileFlags{Read: true}) + if err != nil { + return err + } + defer h.DecRef() + } + + if i.readHandles.Host == nil { + // If current readHandles doesn't have a host FD, it can simply be replaced. + i.readHandles.DecRef() + + h.IncRef() + i.readHandles = h + return nil + } + + if h.Host == nil { + // Current read handle has a host FD and can't be replaced with one that + // doesn't, because it breaks fsutil.CachedFileObject.FD() contract. + log.Warningf("Read handle can't be invalidated, reads may return stale data") + return nil + } + + // Due to a defect in the fsutil.CachedFileObject interface, + // readHandles.Host.FD() may be used outside locks, making it impossible to + // reliably close it. To workaround it, we dup the new FD into the old one, so + // operations on the old will see the new data. Then, make the new handle take + // ownereship of the old FD and mark the old readHandle to not close the FD + // when done. + if err := syscall.Dup2(h.Host.FD(), i.readHandles.Host.FD()); err != nil { + return err + } + + h.Host.Close() + h.Host = fd.New(i.readHandles.Host.FD()) + i.readHandles.isHostBorrowed = true + i.readHandles.DecRef() + + h.IncRef() + i.readHandles = h + return nil } // ReadToBlocksAt implements fsutil.CachedFileObject.ReadToBlocksAt. func (i *inodeFileState) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) { i.handlesMu.RLock() - defer i.handlesMu.RUnlock() - return i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts) + n, err := i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts) + i.handlesMu.RUnlock() + return n, err } // WriteFromBlocksAt implements fsutil.CachedFileObject.WriteFromBlocksAt. func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) { i.handlesMu.RLock() - defer i.handlesMu.RUnlock() - return i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs) + n, err := i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs) + i.handlesMu.RUnlock() + return n, err } // SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes. -func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error { - if i.skipSetAttr(mask) { +func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error { + if i.skipSetAttr(mask, forceSetTimestamps) { return nil } as, ans := attr.AccessTime.Unix() @@ -251,13 +335,14 @@ func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMa // when: // - Mask is empty // - Mask contains only attributes that cannot be set in the gofer -// - Mask contains only atime and/or mtime, and host FD exists +// - forceSetTimestamps is false and mask contains only atime and/or mtime +// and host FD exists // // Updates to atime and mtime can be skipped because cached value will be // "close enough" to host value, given that operation went directly to host FD. // Skipping atime updates is particularly important to reduce the number of // operations sent to the Gofer for readonly files. -func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool { +func (i *inodeFileState) skipSetAttr(mask fs.AttrMask, forceSetTimestamps bool) bool { // First remove attributes that cannot be updated. cpy := mask cpy.Type = false @@ -277,6 +362,12 @@ func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool { return false } + // If forceSetTimestamps was passed, then we cannot skip. + if forceSetTimestamps { + return false + } + + // Skip if we have a host FD. i.handlesMu.RLock() defer i.handlesMu.RUnlock() return (i.readHandles != nil && i.readHandles.Host != nil) || @@ -442,7 +533,7 @@ func (i *inodeOperations) NonBlockingOpen(ctx context.Context, p fs.PermMask) (* } func (i *inodeOperations) getFileDefault(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - h, err := i.fileState.getHandles(ctx, flags) + h, err := i.fileState.getHandles(ctx, flags, i.cachingInodeOps) if err != nil { return nil, err } diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 50da865c1..0da608548 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -122,6 +122,10 @@ type session struct { // CachingInodeOperations created by the session. limitHostFDTranslation bool + // overlayfsStaleRead when set causes the readonly handle to be invalidated + // after file is open for write. + overlayfsStaleRead bool + // connID is a unique identifier for the session connection. connID string `state:"wait"` @@ -257,6 +261,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF aname: o.aname, superBlockFlags: superBlockFlags, limitHostFDTranslation: o.limitHostFDTranslation, + overlayfsStaleRead: o.overlayfsStaleRead, mounter: mounter, } s.EnableLeakCheck("gofer.session") diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 3e532332e..1cbed07ae 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "host", srcs = [ diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index 894ab01f0..a6e4a09e3 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -114,7 +114,7 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo } // SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes. -func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error { +func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, _ bool) error { if mask.Empty() { return nil } @@ -163,7 +163,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err return unstableAttr(i.mops, &s), nil } -// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes. +// Allocate implements fsutil.CachedFileObject.Allocate. func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error { return syscall.Fallocate(i.FD(), 0, offset, length) } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 2392787cb..107336a3e 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -385,3 +385,6 @@ func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 { func (c *ConnectedEndpoint) Release() { c.ref.DecRefWithDestructor(c.close) } + +// CloseUnread implements transport.ConnectedEndpoint.CloseUnread. +func (c *ConnectedEndpoint) CloseUnread() {} diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 246b97161..5a388dad1 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -15,6 +15,7 @@ package fs import ( + "fmt" "strings" "gvisor.dev/gvisor/pkg/abi/linux" @@ -207,6 +208,11 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name } func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name string, flags FileFlags, perm FilePermissions) (*File, error) { + // Sanity check. + if parent.Inode.overlay == nil { + panic(fmt.Sprintf("overlayCreate called with non-overlay parent inode (parent InodeOperations type is %T)", parent.Inode.InodeOperations)) + } + // Dirent.Create takes renameMu if the Inode is an overlay Inode. if err := copyUpLockedForRename(ctx, parent); err != nil { return nil, err diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 5a7a5b8cd..8d62642e7 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "lock_range", out = "lock_range.go", diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 1c93e8886..75cbb0622 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "proc", srcs = [ @@ -53,6 +52,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/syserror", + "//pkg/tcpip/header", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 5e28982c5..402919924 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -18,6 +18,7 @@ import ( "bytes" "fmt" "io" + "reflect" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -33,6 +34,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // newNet creates a new proc net entry. @@ -40,7 +43,8 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo var contents map[string]*fs.Inode if s := p.k.NetworkStack(); s != nil { contents = map[string]*fs.Inode{ - "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), // The following files are simple stubs until they are // implemented in netstack, if the file contains a @@ -57,7 +61,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo // (ClockGetres returns 1ns resolution). "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")), - "route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")), + "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc), "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc), "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc), "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), @@ -66,7 +70,7 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo if s.SupportsIPv6() { contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc) contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte("")) - contents["tcp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) + contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc) contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode")) } } @@ -195,6 +199,193 @@ func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } +// netSnmp implements seqfile.SeqSource for /proc/net/snmp. +// +// +stateify savable +type netSnmp struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netSnmp) NeedsUpdate(generation int64) bool { + return true +} + +type snmpLine struct { + prefix string + header string +} + +var snmp = []snmpLine{ + { + prefix: "Ip", + header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates", + }, + { + prefix: "Icmp", + header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps", + }, + { + prefix: "IcmpMsg", + }, + { + prefix: "Tcp", + header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors", + }, + { + prefix: "Udp", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, + { + prefix: "UdpLite", + header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti", + }, +} + +func toSlice(a interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(a)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + +func sprintSlice(s []uint64) string { + if len(s) == 0 { + return "" + } + r := fmt.Sprint(s) + return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice. +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. See Linux's +// net/core/net-procfs.c:dev_seq_show. +func (n *netSnmp) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + contents := make([]string, 0, len(snmp)*2) + types := []interface{}{ + &inet.StatSNMPIP{}, + &inet.StatSNMPICMP{}, + nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats. + &inet.StatSNMPTCP{}, + &inet.StatSNMPUDP{}, + &inet.StatSNMPUDPLite{}, + } + for i, stat := range types { + line := snmp[i] + if stat == nil { + contents = append( + contents, + fmt.Sprintf("%s:\n", line.prefix), + fmt.Sprintf("%s:\n", line.prefix), + ) + continue + } + if err := n.s.Statistics(stat, line.prefix); err != nil { + if err == syserror.EOPNOTSUPP { + log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } else { + log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err) + } + } + var values string + if line.prefix == "Tcp" { + tcp := stat.(*inet.StatSNMPTCP) + // "Tcp" needs special processing because MaxConn is signed. RFC 2012. + values = fmt.Sprintf("%s %d %s", sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:])) + } else { + values = sprintSlice(toSlice(stat)) + } + contents = append( + contents, + fmt.Sprintf("%s: %s\n", line.prefix, line.header), + fmt.Sprintf("%s: %s\n", line.prefix, values), + ) + } + + data := make([]seqfile.SeqData, 0, len(snmp)*2) + for _, l := range contents { + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netSnmp)(nil)}) + } + + return data, 0 +} + +// netRoute implements seqfile.SeqSource for /proc/net/route. +// +// +stateify savable +type netRoute struct { + s inet.Stack +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (n *netRoute) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. +func (n *netRoute) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + if h != nil { + return nil, 0 + } + + interfaces := n.s.Interfaces() + contents := []string{"Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT"} + for _, rt := range n.s.RouteTable() { + // /proc/net/route only includes ipv4 routes. + if rt.Family != linux.AF_INET { + continue + } + + // /proc/net/route does not include broadcast or multicast routes. + if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST { + continue + } + + iface, ok := interfaces[rt.OutputInterface] + if !ok || iface.Name == "lo" { + continue + } + + var ( + gw uint32 + prefix uint32 + flags = linux.RTF_UP + ) + if len(rt.GatewayAddr) == header.IPv4AddressSize { + flags |= linux.RTF_GATEWAY + gw = usermem.ByteOrder.Uint32(rt.GatewayAddr) + } + if len(rt.DstAddr) == header.IPv4AddressSize { + prefix = usermem.ByteOrder.Uint32(rt.DstAddr) + } + l := fmt.Sprintf( + "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d", + iface.Name, + prefix, + gw, + flags, + 0, // RefCnt. + 0, // Use. + 0, // Metric. + (uint32(1)<<rt.DstLen)-1, + 0, // MTU. + 0, // Window. + 0, // RTT. + ) + contents = append(contents, l) + } + + var data []seqfile.SeqData + for _, l := range contents { + l = fmt.Sprintf("%-127s\n", l) + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netRoute)(nil)}) + } + + return data, 0 +} + // netUnix implements seqfile.SeqSource for /proc/net/unix. // // +stateify savable @@ -310,44 +501,51 @@ func networkToHost16(n uint16) uint16 { return usermem.ByteOrder.Uint16(buf[:]) } -func writeInetAddr(w io.Writer, a linux.SockAddrInet) { - // linux.SockAddrInet.Port is stored in the network byte order and is - // printed like a number in host byte order. Note that all numbers in host - // byte order are printed with the most-significant byte first when - // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. - port := networkToHost16(a.Port) - - // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order - // (i.e. most-significant byte in index 0). Linux represents this as a - // __be32 which is a typedef for an unsigned int, and is printed with - // %X. This means that for a little-endian machine, Linux prints the - // least-significant byte of the address first. To emulate this, we first - // invert the byte order for the address using usermem.ByteOrder.Uint32, - // which makes it have the equivalent encoding to a __be32 on a little - // endian machine. Note that this operation is a no-op on a big endian - // machine. Then similar to Linux, we format it with %X, which will print - // the most-significant byte of the __be32 address first, which is now - // actually the least-significant byte of the original address in - // linux.SockAddrInet.Addr on little endian machines, due to the conversion. - addr := usermem.ByteOrder.Uint32(a.Addr[:]) - - fmt.Fprintf(w, "%08X:%04X ", addr, port) -} +func writeInetAddr(w io.Writer, family int, i linux.SockAddr) { + switch family { + case linux.AF_INET: + var a linux.SockAddrInet + if i != nil { + a = *i.(*linux.SockAddrInet) + } -// netTCP implements seqfile.SeqSource for /proc/net/tcp. -// -// +stateify savable -type netTCP struct { - k *kernel.Kernel -} + // linux.SockAddrInet.Port is stored in the network byte order and is + // printed like a number in host byte order. Note that all numbers in host + // byte order are printed with the most-significant byte first when + // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux. + port := networkToHost16(a.Port) + + // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order + // (i.e. most-significant byte in index 0). Linux represents this as a + // __be32 which is a typedef for an unsigned int, and is printed with + // %X. This means that for a little-endian machine, Linux prints the + // least-significant byte of the address first. To emulate this, we first + // invert the byte order for the address using usermem.ByteOrder.Uint32, + // which makes it have the equivalent encoding to a __be32 on a little + // endian machine. Note that this operation is a no-op on a big endian + // machine. Then similar to Linux, we format it with %X, which will print + // the most-significant byte of the __be32 address first, which is now + // actually the least-significant byte of the original address in + // linux.SockAddrInet.Addr on little endian machines, due to the conversion. + addr := usermem.ByteOrder.Uint32(a.Addr[:]) + + fmt.Fprintf(w, "%08X:%04X ", addr, port) + case linux.AF_INET6: + var a linux.SockAddrInet6 + if i != nil { + a = *i.(*linux.SockAddrInet6) + } -// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. -func (*netTCP) NeedsUpdate(generation int64) bool { - return true + port := networkToHost16(a.Port) + addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4]) + addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8]) + addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12]) + addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16]) + fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port) + } } -// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. -func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { +func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kernel.Kernel, h seqfile.SeqHandle, fa int, header []byte) ([]seqfile.SeqData, int64) { // t may be nil here if our caller is not part of a task goroutine. This can // happen for example if we're here for "sentryctl cat". When t is nil, // degrade gracefully and retrieve what we can. @@ -358,7 +556,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se } var buf bytes.Buffer - for _, se := range n.k.ListSockets() { + for _, se := range k.ListSockets() { s := se.Sock.Get() if s == nil { log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID) @@ -369,7 +567,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se if !ok { panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile)) } - if family, stype, _ := sops.Type(); !(family == linux.AF_INET && stype == linux.SOCK_STREAM) { + if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) { s.DecRef() // Not tcp4 sockets. continue @@ -384,22 +582,22 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se fmt.Fprintf(&buf, "%4d: ", se.ID) // Field: local_adddress. - var localAddr linux.SockAddrInet + var localAddr linux.SockAddr if t != nil { if local, _, err := sops.GetSockName(t); err == nil { - localAddr = *local.(*linux.SockAddrInet) + localAddr = local } } - writeInetAddr(&buf, localAddr) + writeInetAddr(&buf, fa, localAddr) // Field: rem_address. - var remoteAddr linux.SockAddrInet + var remoteAddr linux.SockAddr if t != nil { if remote, _, err := sops.GetPeerName(t); err == nil { - remoteAddr = *remote.(*linux.SockAddrInet) + remoteAddr = remote } } - writeInetAddr(&buf, remoteAddr) + writeInetAddr(&buf, fa, remoteAddr) // Field: state; socket state. fmt.Fprintf(&buf, "%02X ", sops.State()) @@ -465,7 +663,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se data := []seqfile.SeqData{ { - Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n"), + Buf: header, Handle: n, }, { @@ -476,6 +674,42 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se return data, 0 } +// netTCP implements seqfile.SeqSource for /proc/net/tcp. +// +// +stateify savable +type netTCP struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET, header) +} + +// netTCP6 implements seqfile.SeqSource for /proc/net/tcp6. +// +// +stateify savable +type netTCP6 struct { + k *kernel.Kernel +} + +// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. +func (*netTCP6) NeedsUpdate(generation int64) bool { + return true +} + +// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData. +func (n *netTCP6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) { + header := []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n") + return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET6, header) +} + // netUDP implements seqfile.SeqSource for /proc/net/udp. // // +stateify savable @@ -529,7 +763,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se localAddr = *local.(*linux.SockAddrInet) } } - writeInetAddr(&buf, localAddr) + writeInetAddr(&buf, linux.AF_INET, &localAddr) // Field: rem_address. var remoteAddr linux.SockAddrInet @@ -538,7 +772,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se remoteAddr = *remote.(*linux.SockAddrInet) } } - writeInetAddr(&buf, remoteAddr) + writeInetAddr(&buf, linux.AF_INET, &remoteAddr) // Field: state; socket state. fmt.Fprintf(&buf, "%02X ", sops.State()) diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 0ef13f2f5..56e92721e 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -230,7 +230,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent // But for whatever crazy reason, you can still walk to the given node. for _, tg := range rpf.iops.pidns.ThreadGroups() { if leader := tg.Leader(); leader != nil { - name := strconv.FormatUint(uint64(tg.ID()), 10) + name := strconv.FormatUint(uint64(rpf.iops.pidns.IDOfThreadGroup(tg)), 10) m[name] = fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice) names = append(names, name) } diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 76433c7d0..fe7067be1 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "seqfile", srcs = ["seqfile.go"], diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index d0f351e5a..012cb3e44 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "ramfs", srcs = [ diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go index b03b7f836..311798811 100644 --- a/pkg/sentry/fs/splice.go +++ b/pkg/sentry/fs/splice.go @@ -139,7 +139,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, // Attempt to do a WriteTo; this is likely the most efficient. n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup) - if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup { + if n == 0 && err == syserror.ENOSYS && !opts.Dup { // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be // more efficient than a copy if buffers are cached or readily // available. (It's unlikely that they can actually be donated). @@ -151,7 +151,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64, // if we block at some point, we could lose data. If the source is // not a pipe then reading is not destructive; if the destination // is a regular file, then it is guaranteed not to block writing. - if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup && (!dstPipe || !srcPipe) { + if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) { // Fallback to an in-kernel copy. n, err = io.Copy(w, &io.LimitedReader{ R: r, diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD index 70fa3af89..25f0f124e 100644 --- a/pkg/sentry/fs/sys/BUILD +++ b/pkg/sentry/fs/sys/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "sys", srcs = [ diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD index 1d80daeaf..a215c1b95 100644 --- a/pkg/sentry/fs/timerfd/BUILD +++ b/pkg/sentry/fs/timerfd/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "timerfd", srcs = ["timerfd.go"], diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go index 59403d9db..f8bf663bb 100644 --- a/pkg/sentry/fs/timerfd/timerfd.go +++ b/pkg/sentry/fs/timerfd/timerfd.go @@ -141,9 +141,10 @@ func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, i } // Notify implements ktime.TimerListener.Notify. -func (t *TimerOperations) Notify(exp uint64) { +func (t *TimerOperations) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { atomic.AddUint64(&t.val, exp) t.events.Notify(waiter.EventIn) + return ktime.Setting{}, false } // Destroy implements ktime.TimerListener.Destroy. diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 11b680929..59ce400c2 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "tmpfs", srcs = [ diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 159fb7c08..69089c8a8 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -324,7 +324,7 @@ type Fifo struct { // NewFifo creates a new named pipe. func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { // First create a pipe. - p := pipe.NewPipe(ctx, true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) // Build pipe InodeOperations. iops := pipe.NewInodeOperations(ctx, perms, p) diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 25811f668..95ad98cb0 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "tty", srcs = [ diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index b0c286b7a..7ccff8b0d 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") +package(licenses = ["notice"]) + go_template_instance( name = "dirent_list", out = "dirent_list.go", diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 0b471d121..91802dc1e 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -301,8 +301,8 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in return offset, nil } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { // mmap(2) specifies that EACCESS should be returned for non-regular file fds. return syserror.EACCES } diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 2d50e30aa..fcfaf5c3e 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "disklayout", srcs = [ diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go index a0065343b..4d18b28cb 100644 --- a/pkg/sentry/fsimpl/ext/file_description.go +++ b/pkg/sentry/fsimpl/ext/file_description.go @@ -43,9 +43,6 @@ func (fd *fileDescription) inode() *inode { return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode } -// OnClose implements vfs.FileDescriptionImpl.OnClose. -func (fd *fileDescription) OnClose() error { return nil } - // StatusFlags implements vfs.FileDescriptionImpl.StatusFlags. func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) { return fd.flags, nil diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index ffc76ba5b..aec33e00a 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -152,8 +152,8 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( return offset, nil } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { // TODO(b/134676337): Implement mmap(2). return syserror.ENODEV } diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index e06548a98..bdf8705c1 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -105,7 +105,7 @@ func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int6 return 0, syserror.EBADF } -// IterDirents implements vfs.FileDescriptionImpl.IterDirents. -func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { return syserror.EBADF } diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD index 7e364c5fd..04d667273 100644 --- a/pkg/sentry/fsimpl/memfs/BUILD +++ b/pkg/sentry/fsimpl/memfs/BUILD @@ -24,14 +24,18 @@ go_library( "directory.go", "filesystem.go", "memfs.go", + "named_pipe.go", "regular_file.go", "symlink.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs", deps = [ "//pkg/abi/linux", + "//pkg/amutex", + "//pkg/sentry/arch", "//pkg/sentry/context", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/pipe", "//pkg/sentry/usermem", "//pkg/sentry/vfs", "//pkg/syserror", @@ -54,3 +58,19 @@ go_test( "//pkg/syserror", ], ) + +go_test( + name = "memfs_test", + size = "small", + srcs = ["pipe_test.go"], + embed = [":memfs"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/context/contexttest", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/syserror", + ], +) diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go index c620227c9..0bd82e480 100644 --- a/pkg/sentry/fsimpl/memfs/directory.go +++ b/pkg/sentry/fsimpl/memfs/directory.go @@ -32,7 +32,7 @@ type directory struct { childList dentryList } -func (fs *filesystem) newDirectory(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode { dir := &directory{} dir.inode.init(dir, fs, creds, mode) dir.inode.nlink = 2 // from "." and parent directory or ".." for root diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go index f79e2d9c8..f006c40cd 100644 --- a/pkg/sentry/fsimpl/memfs/filesystem.go +++ b/pkg/sentry/fsimpl/memfs/filesystem.go @@ -233,7 +233,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if err != nil { return err } - _, err = checkCreateLocked(rp, parentVFSD, parentInode) + pc, err := checkCreateLocked(rp, parentVFSD, parentInode) if err != nil { return err } @@ -241,8 +241,40 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v return err } defer rp.Mount().EndWrite() - // TODO: actually implement mknod - return syserror.EPERM + + switch opts.Mode.FileType() { + case 0: + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + fallthrough + case linux.ModeRegular: + // TODO(b/138862511): Implement. + return syserror.EINVAL + + case linux.ModeNamedPipe: + child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode)) + parentVFSD.InsertChild(&child.vfsd, pc) + parentInode.impl.(*directory).childList.PushBack(child) + return nil + + case linux.ModeSocket: + // TODO(b/138862511): Implement. + return syserror.EINVAL + + case linux.ModeCharacterDevice: + fallthrough + case linux.ModeBlockDevice: + // TODO(b/72101894): We don't support creating block or character + // devices at the moment. + // + // When we start supporting block and character devices, we'll + // need to check for CAP_MKNOD here. + return syserror.EPERM + + default: + // "EINVAL - mode requested creation of something other than a + // regular file, device special file, FIFO or socket." - mknod(2) + return syserror.EINVAL + } } // OpenAt implements vfs.FilesystemImpl.OpenAt. @@ -250,8 +282,9 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // Filter out flags that are not supported by memfs. O_DIRECTORY and // O_NOFOLLOW have no effect here (they're handled by VFS by setting // appropriate bits in rp), but are returned by - // FileDescriptionImpl.StatusFlags(). - opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW + // FileDescriptionImpl.StatusFlags(). O_NONBLOCK is supported only by + // pipes. + opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() @@ -260,7 +293,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return nil, err } - return inode.open(rp, vfsd, opts.Flags, false) + return inode.open(ctx, rp, vfsd, opts.Flags, false) } mustCreate := opts.Flags&linux.O_EXCL != 0 @@ -275,7 +308,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } - return inode.open(rp, vfsd, opts.Flags, false) + return inode.open(ctx, rp, vfsd, opts.Flags, false) } afterTrailingSymlink: // Walk to the parent directory of the last path component. @@ -320,7 +353,7 @@ afterTrailingSymlink: child := fs.newDentry(childInode) vfsd.InsertChild(&child.vfsd, pc) inode.impl.(*directory).childList.PushBack(child) - return childInode.open(rp, &child.vfsd, opts.Flags, true) + return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true) } // Open existing file or follow symlink. if mustCreate { @@ -336,10 +369,10 @@ afterTrailingSymlink: // symlink target. goto afterTrailingSymlink } - return childInode.open(rp, childVFSD, opts.Flags, false) + return childInode.open(ctx, rp, childVFSD, opts.Flags, false) } -func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { +func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(flags) if !afterCreate { if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil { @@ -378,6 +411,8 @@ func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte case *symlink: // Can't open symlinks without O_PATH (which is unimplemented). return nil, syserror.ELOOP + case *namedPipe: + return newNamedPipeFD(ctx, impl, rp, vfsd, flags) default: panic(fmt.Sprintf("unknown inode type: %T", i.impl)) } diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go index 45cd42b3e..64c851c1a 100644 --- a/pkg/sentry/fsimpl/memfs/memfs.go +++ b/pkg/sentry/fsimpl/memfs/memfs.go @@ -137,7 +137,7 @@ type inode struct { impl interface{} // immutable } -func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode uint16) { +func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) { i.refs = 1 i.mode = uint32(mode) i.uid = uint32(creds.EffectiveKUID) @@ -227,6 +227,8 @@ func (i *inode) statTo(stat *linux.Statx) { stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS stat.Size = uint64(len(impl.target)) stat.Blocks = allocatedBlocksForSize(stat.Size) + case *namedPipe: + stat.Mode |= linux.S_IFIFO default: panic(fmt.Sprintf("unknown inode type: %T", i.impl)) } diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/memfs/named_pipe.go new file mode 100644 index 000000000..732ed7c58 --- /dev/null +++ b/pkg/sentry/fsimpl/memfs/named_pipe.go @@ -0,0 +1,59 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +type namedPipe struct { + inode inode + + pipe *pipe.VFSPipe +} + +// Preconditions: +// * fs.mu must be locked. +// * rp.Mount().CheckBeginWrite() has been called successfully. +func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode { + file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)} + file.inode.init(file, fs, creds, mode) + file.inode.nlink = 1 // Only the parent has a link. + return &file.inode +} + +// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented +// entirely via struct embedding. +type namedPipeFD struct { + fileDescription + + *pipe.VFSPipeFD +} + +func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) { + var err error + var fd namedPipeFD + fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, rp, vfsd, &fd.vfsfd, flags) + if err != nil { + return nil, err + } + fd.vfsfd.Init(&fd, rp.Mount(), vfsd) + return &fd.vfsfd, nil +} diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/memfs/pipe_test.go new file mode 100644 index 000000000..0674b81a3 --- /dev/null +++ b/pkg/sentry/fsimpl/memfs/pipe_test.go @@ -0,0 +1,233 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memfs + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/context/contexttest" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +const fileName = "mypipe" + +func TestSeparateFDs(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the read side. This is done in a concurrently because opening + // One end the pipe blocks until the other end is opened. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + rfdchan := make(chan *vfs.FileDescription) + go func() { + openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY} + rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + rfdchan <- rfd + }() + + // Open the write side. + openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY} + wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer wfd.DecRef() + + rfd, ok := <-rfdchan + if !ok { + t.Fatalf("failed to open pipe for reading %q", fileName) + } + defer rfd.DecRef() + + const msg = "vamos azul" + checkEmpty(ctx, t, rfd) + checkWrite(ctx, t, wfd, msg) + checkRead(ctx, t, rfd, msg) +} + +func TestNonblockingRead(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the read side as nonblocking. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK} + rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for reading %q: %v", fileName, err) + } + defer rfd.DecRef() + + // Open the write side. + openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY} + wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer wfd.DecRef() + + const msg = "geh blau" + checkEmpty(ctx, t, rfd) + checkWrite(ctx, t, wfd, msg) + checkRead(ctx, t, rfd, msg) +} + +func TestNonblockingWriteError(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the write side as nonblocking, which should return ENXIO. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK} + _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != syserror.ENXIO { + t.Fatalf("expected ENXIO, but got error: %v", err) + } +} + +func TestSingleFD(t *testing.T) { + ctx, creds, vfsObj, root := setup(t) + defer root.DecRef() + + // Open the pipe as readable and writable. + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + openOpts := vfs.OpenOptions{Flags: linux.O_RDWR} + fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts) + if err != nil { + t.Fatalf("failed to open pipe for writing %q: %v", fileName, err) + } + defer fd.DecRef() + + const msg = "forza blu" + checkEmpty(ctx, t, fd) + checkWrite(ctx, t, fd, msg) + checkRead(ctx, t, fd, msg) +} + +// setup creates a VFS with a pipe in the root directory at path fileName. The +// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal +// upon failure. +func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) { + ctx := contexttest.Context(t) + creds := auth.CredentialsFromContext(ctx) + + // Create VFS. + vfsObj := vfs.New() + vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{}) + if err != nil { + t.Fatalf("failed to create tmpfs root mount: %v", err) + } + + // Create the pipe. + root := mntns.Root() + pop := vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + } + mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644} + if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil { + t.Fatalf("failed to create file %q: %v", fileName, err) + } + + // Sanity check: the file pipe exists and has the correct mode. + stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Pathname: fileName, + FollowFinalSymlink: true, + }, &vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat(%q) failed: %v", fileName, err) + } + if stat.Mode&^linux.S_IFMT != 0644 { + t.Errorf("got wrong permissions (%0o)", stat.Mode) + } + if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe { + t.Errorf("got wrong file type (%0o)", stat.Mode) + } + + return ctx, creds, vfsObj, root +} + +// checkEmpty calls t.Fatal if the pipe in fd is not empty. +func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { + readData := make([]byte, 1) + dst := usermem.BytesIOSequence(readData) + bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{}) + if err != syserror.ErrWouldBlock { + t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err) + } + if bytesRead != 0 { + t.Fatalf("expected to read 0 bytes, but got %d", bytesRead) + } +} + +// checkWrite calls t.Fatal if it fails to write all of msg to fd. +func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { + writeData := []byte(msg) + src := usermem.BytesIOSequence(writeData) + bytesWritten, err := fd.Impl().Write(ctx, src, vfs.WriteOptions{}) + if err != nil { + t.Fatalf("error writing to pipe %q: %v", fileName, err) + } + if bytesWritten != int64(len(writeData)) { + t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten) + } +} + +// checkRead calls t.Fatal if it fails to read msg from fd. +func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) { + readData := make([]byte, len(msg)) + dst := usermem.BytesIOSequence(readData) + bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{}) + if err != nil { + t.Fatalf("error reading from pipe %q: %v", fileName, err) + } + if bytesRead != int64(len(msg)) { + t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead) + } + if !bytes.Equal(readData, []byte(msg)) { + t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData)) + } +} diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go index 55f869798..b7f4853b3 100644 --- a/pkg/sentry/fsimpl/memfs/regular_file.go +++ b/pkg/sentry/fsimpl/memfs/regular_file.go @@ -37,7 +37,7 @@ type regularFile struct { dataLen int64 } -func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *inode { +func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode { file := ®ularFile{} file.inode.init(file, fs, creds, mode) file.inode.nlink = 1 // from parent directory diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 184b566d9..8d60ad4ad 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -1,10 +1,10 @@ +load("//tools/go_stateify:defs.bzl", "go_library") + package( default_visibility = ["//:sandbox"], licenses = ["notice"], ) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "inet", srcs = [ @@ -13,5 +13,8 @@ go_library( "test_stack.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/inet", - deps = ["//pkg/sentry/context"], + deps = [ + "//pkg/sentry/context", + "//pkg/tcpip/stack", + ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 80f227dbe..a7dfb78a7 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,6 +15,8 @@ // Package inet defines semantics for IP stacks. package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // Stack represents a TCP/IP stack. type Stack interface { // Interfaces returns all network interfaces as a mapping from interface @@ -58,6 +60,16 @@ type Stack interface { // Resume restarts the network stack after restore. Resume() + + // RegisteredEndpoints returns all endpoints which are currently registered. + RegisteredEndpoints() []stack.TransportEndpoint + + // CleanupEndpoints returns endpoints currently in the cleanup state. + CleanupEndpoints() []stack.TransportEndpoint + + // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful + // for restoring a stack after a save. + RestoreCleanupEndpoints([]stack.TransportEndpoint) } // Interface contains information about a network interface. @@ -153,3 +165,23 @@ type Route struct { // GatewayAddr is the route gateway address (RTA_GATEWAY). GatewayAddr []byte } + +// Below SNMP metrics are from Linux/usr/include/linux/snmp.h. + +// StatSNMPIP describes Ip line of /proc/net/snmp. +type StatSNMPIP [19]uint64 + +// StatSNMPICMP describes Icmp line of /proc/net/snmp. +type StatSNMPICMP [27]uint64 + +// StatSNMPICMPMSG describes IcmpMsg line of /proc/net/snmp. +type StatSNMPICMPMSG [512]uint64 + +// StatSNMPTCP describes Tcp line of /proc/net/snmp. +type StatSNMPTCP [15]uint64 + +// StatSNMPUDP describes Udp line of /proc/net/snmp. +type StatSNMPUDP [8]uint64 + +// StatSNMPUDPLite describes UdpLite line of /proc/net/snmp. +type StatSNMPUDPLite [8]uint64 diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index b9eed7c3a..dcfcbd97e 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,6 +14,8 @@ package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // TestStack is a dummy implementation of Stack for tests. type TestStack struct { InterfacesMap map[int32]Interface @@ -94,5 +96,17 @@ func (s *TestStack) RouteTable() []Route { } // Resume implements Stack.Resume. -func (s *TestStack) Resume() { +func (s *TestStack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint { + return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { + return nil +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index aba2414d4..e041c51b3 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -1,12 +1,11 @@ load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") load("@rules_cc//cc:defs.bzl", "cc_proto_library") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "pending_signals_list", out = "pending_signals_list.go", diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 1d00a6310..51de4568a 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "atomicptr_credentials", out = "atomicptr_credentials_unsafe.go", diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index e3f5b0d83..3c9dceaba 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -15,6 +15,8 @@ package kernel import ( + "time" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" ) @@ -97,6 +99,21 @@ func TaskFromContext(ctx context.Context) *Task { return nil } +// Deadline implements context.Context.Deadline. +func (*Task) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done implements context.Context.Done. +func (*Task) Done() <-chan struct{} { + return nil +} + +// Err implements context.Context.Err. +func (*Task) Err() error { + return nil +} + // AsyncContext returns a context.Context that may be used by goroutines that // do work on behalf of t and therefore share its contextual values, but are // not t's task goroutine (e.g. asynchronous I/O). @@ -129,6 +146,21 @@ func (ctx taskAsyncContext) IsLogging(level log.Level) bool { return ctx.t.IsLogging(level) } +// Deadline implements context.Context.Deadline. +func (ctx taskAsyncContext) Deadline() (time.Time, bool) { + return ctx.t.Deadline() +} + +// Done implements context.Context.Done. +func (ctx taskAsyncContext) Done() <-chan struct{} { + return ctx.t.Done() +} + +// Err implements context.Context.Err. +func (ctx taskAsyncContext) Err() error { + return ctx.t.Err() +} + // Value implements context.Context.Value. func (ctx taskAsyncContext) Value(key interface{}) interface{} { return ctx.t.Value(key) diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD index bec13a3d9..3a88a585c 100644 --- a/pkg/sentry/kernel/contexttest/BUILD +++ b/pkg/sentry/kernel/contexttest/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "contexttest", testonly = 1, diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index 65427b112..3361e8b7d 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "epoll_list", out = "epoll_list.go", diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index 983ca67ed..e65b961e8 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "eventfd", srcs = ["eventfd.go"], diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 5eddca115..49d81b712 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "fasync", srcs = ["fasync.go"], diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index cc3f43a45..11f613a11 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -81,6 +81,9 @@ type FDTable struct { // mu protects below. mu sync.Mutex `state:"nosave"` + // next is start position to find fd. + next int32 + // used contains the number of non-nil entries. It must be accessed // atomically. It may be read atomically without holding mu (but not // written). @@ -226,6 +229,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags f.mu.Lock() defer f.mu.Unlock() + // From f.next to find available fd. + if fd < f.next { + fd = f.next + } + // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.get(i); d == nil { @@ -242,6 +250,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return nil, syscall.EMFILE } + if fd == f.next { + // Update next search start position. + f.next = fds[len(fds)-1] + 1 + } + return fds, nil } @@ -361,6 +374,11 @@ func (f *FDTable) Remove(fd int32) *fs.File { f.mu.Lock() defer f.mu.Unlock() + // Update current available position. + if fd < f.next { + f.next = fd + } + orig, _, _ := f.get(fd) if orig != nil { orig.IncRef() // Reference for caller. @@ -377,6 +395,10 @@ func (f *FDTable) RemoveIf(cond func(*fs.File, FDFlags) bool) { f.forEach(func(fd int32, file *fs.File, flags FDFlags) { if cond(file, flags) { f.set(fd, nil, FDFlags{}) // Clear from table. + // Update current available position. + if fd < f.next { + f.next = fd + } } }) } diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 2413788e7..2bcb6216a 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -70,6 +70,42 @@ func TestFDTableMany(t *testing.T) { if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil { t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) } + + i := int32(2) + fdTable.Remove(i) + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { + t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) + } + }) +} + +func TestFDTableOverLimit(t *testing.T) { + runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) { + if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error") + } + + if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil { + t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error") + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil { + t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) + } else { + for _, fd := range fds { + fdTable.Remove(fd) + } + } + + if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 { + t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err) + } + + if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil { + t.Fatalf("Adding an FD to a resized map: got %v, want nil", err) + } else if len(fds) != 1 || fds[0] != 0 { + t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds) + } }) } diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 41f44999c..34286c7a8 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "atomicptr_bucket", out = "atomicptr_bucket_unsafe.go", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 8c1f79ab5..28ba950bd 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -24,6 +24,7 @@ // TaskSet.mu // SignalHandlers.mu // Task.mu +// runningTasksMu // // Locking SignalHandlers.mu in multiple SignalHandlers requires locking // TaskSet.mu exclusively first. Locking Task.mu in multiple Tasks at the same @@ -135,6 +136,22 @@ type Kernel struct { // syslog is the kernel log. syslog syslog + // runningTasksMu synchronizes disable/enable of cpuClockTicker when + // the kernel is idle (runningTasks == 0). + // + // runningTasksMu is used to exclude critical sections when the timer + // disables itself and when the first active task enables the timer, + // ensuring that tasks always see a valid cpuClock value. + runningTasksMu sync.Mutex `state:"nosave"` + + // runningTasks is the total count of tasks currently in + // TaskGoroutineRunningSys or TaskGoroutineRunningApp. i.e., they are + // not blocked or stopped. + // + // runningTasks must be accessed atomically. Increments from 0 to 1 are + // further protected by runningTasksMu (see incRunningTasks). + runningTasks int64 + // cpuClock is incremented every linux.ClockTick. cpuClock is used to // measure task CPU usage, since sampling monotonicClock twice on every // syscall turns out to be unreasonably expensive. This is similar to how @@ -150,6 +167,22 @@ type Kernel struct { // cpuClockTicker increments cpuClock. cpuClockTicker *ktime.Timer `state:"nosave"` + // cpuClockTickerDisabled indicates that cpuClockTicker has been + // disabled because no tasks are running. + // + // cpuClockTickerDisabled is protected by runningTasksMu. + cpuClockTickerDisabled bool + + // cpuClockTickerSetting is the ktime.Setting of cpuClockTicker at the + // point it was disabled. It is cached here to avoid a lock ordering + // violation with cpuClockTicker.mu when runningTaskMu is held. + // + // cpuClockTickerSetting is only valid when cpuClockTickerDisabled is + // true. + // + // cpuClockTickerSetting is protected by runningTasksMu. + cpuClockTickerSetting ktime.Setting + // fdMapUids is an ever-increasing counter for generating FDTable uids. // // fdMapUids is mutable, and is accessed using atomic memory operations. @@ -358,7 +391,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if err := state.Save(w, k.FeatureSet(), nil); err != nil { + if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) @@ -366,7 +399,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Save(w, k, &stats); err != nil { + if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil { return err } log.Infof("Kernel save stats: %s", &stats) @@ -374,7 +407,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(w); err != nil { + if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -509,7 +542,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if err := state.Load(r, &features, nil); err != nil { + if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -525,7 +558,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Load(r, k, &stats); err != nil { + if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil { return err } log.Infof("Kernel load stats: %s", &stats) @@ -533,7 +566,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(r); err != nil { + if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -771,8 +804,21 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, // Create a fresh task context. remainingTraversals = uint(args.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: mounts, + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: true, + Filename: args.Filename, + File: args.File, + CloseOnExec: false, + Argv: args.Argv, + Envv: args.Envv, + Features: k.featureSet, + } - tc, se := k.LoadTaskImage(ctx, mounts, root, wd, &remainingTraversals, args.Filename, args.File, args.Argv, args.Envv, k.featureSet) + tc, se := k.LoadTaskImage(ctx, loadArgs) if se != nil { return nil, 0, errors.New(se.String()) } @@ -912,6 +958,102 @@ func (k *Kernel) resumeTimeLocked() { } } +func (k *Kernel) incRunningTasks() { + for { + tasks := atomic.LoadInt64(&k.runningTasks) + if tasks != 0 { + // Standard case. Simply increment. + if !atomic.CompareAndSwapInt64(&k.runningTasks, tasks, tasks+1) { + continue + } + return + } + + // Transition from 0 -> 1. Synchronize with other transitions and timer. + k.runningTasksMu.Lock() + tasks = atomic.LoadInt64(&k.runningTasks) + if tasks != 0 { + // We're no longer the first task, no need to + // re-enable. + atomic.AddInt64(&k.runningTasks, 1) + k.runningTasksMu.Unlock() + return + } + + if !k.cpuClockTickerDisabled { + // Timer was never disabled. + atomic.StoreInt64(&k.runningTasks, 1) + k.runningTasksMu.Unlock() + return + } + + // We need to update cpuClock for all of the ticks missed while we + // slept, and then re-enable the timer. + // + // The Notify in Swap isn't sufficient. kernelCPUClockTicker.Notify + // always increments cpuClock by 1 regardless of the number of + // expirations as a heuristic to avoid over-accounting in cases of CPU + // throttling. + // + // We want to cover the normal case, when all time should be accounted, + // so we increment for all expirations. Throttling is less concerning + // here because the ticker is only disabled from Notify. This means + // that Notify must schedule and compensate for the throttled period + // before the timer is disabled. Throttling while the timer is disabled + // doesn't matter, as nothing is running or reading cpuClock anyways. + // + // S/R also adds complication, as there are two cases. Recall that + // monotonicClock will jump forward on restore. + // + // 1. If the ticker is enabled during save, then on Restore Notify is + // called with many expirations, covering the time jump, but cpuClock + // is only incremented by 1. + // + // 2. If the ticker is disabled during save, then after Restore the + // first wakeup will call this function and cpuClock will be + // incremented by the number of expirations across the S/R. + // + // These cause very different value of cpuClock. But again, since + // nothing was running while the ticker was disabled, those differences + // don't matter. + setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now()) + if exp > 0 { + atomic.AddUint64(&k.cpuClock, exp) + } + + // Now that cpuClock is updated it is safe to allow other tasks to + // transition to running. + atomic.StoreInt64(&k.runningTasks, 1) + + // N.B. we must unlock before calling Swap to maintain lock ordering. + // + // cpuClockTickerDisabled need not wait until after Swap to become + // true. It is sufficient that the timer *will* be enabled. + k.cpuClockTickerDisabled = false + k.runningTasksMu.Unlock() + + // This won't call Notify (unless it's been ClockTick since setting.At + // above). This means we skip the thread group work in Notify. However, + // since nothing was running while we were disabled, none of the timers + // could have expired. + k.cpuClockTicker.Swap(setting) + + return + } +} + +func (k *Kernel) decRunningTasks() { + tasks := atomic.AddInt64(&k.runningTasks, -1) + if tasks < 0 { + panic(fmt.Sprintf("Invalid running count %d", tasks)) + } + + // Nothing to do. The next CPU clock tick will disable the timer if + // there is still nothing running. This provides approximately one tick + // of slack in which we can switch back and forth between idle and + // active without an expensive transition. +} + // WaitExited blocks until all tasks in k have exited. func (k *Kernel) WaitExited() { k.tasks.liveGoroutines.Wait() @@ -1180,6 +1322,7 @@ func (k *Kernel) ListSockets() []*SocketEntry { return socks } +// supervisorContext is a privileged context. type supervisorContext struct { context.NoopSleeper log.Logger diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 2ce8952e2..9d34f6d4d 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "buffer_list", out = "buffer_list.go", @@ -25,8 +24,10 @@ go_library( "device.go", "node.go", "pipe.go", + "pipe_util.go", "reader.go", "reader_writer.go", + "vfs.go", "writer.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe", @@ -41,6 +42,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/safemem", "//pkg/sentry/usermem", + "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index a2dc72204..4a19ab7ce 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -91,10 +90,10 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi switch { case flags.Read && !flags.Write: // O_RDONLY. r := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.rWakeup) + newHandleLocked(&i.rWakeup) if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() { - if !i.waitFor(&i.wWakeup, ctx) { + if !waitFor(&i.mu, &i.wWakeup, ctx) { r.DecRef() return nil, syserror.ErrInterrupted } @@ -107,7 +106,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi case flags.Write && !flags.Read: // O_WRONLY. w := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.wWakeup) + newHandleLocked(&i.wWakeup) if i.p.isNamed && !i.p.HasReaders() { // On a nonblocking, write-only open, the open fails with ENXIO if the @@ -117,7 +116,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi return nil, syserror.ENXIO } - if !i.waitFor(&i.rWakeup, ctx) { + if !waitFor(&i.mu, &i.rWakeup, ctx) { w.DecRef() return nil, syserror.ErrInterrupted } @@ -127,8 +126,8 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi case flags.Read && flags.Write: // O_RDWR. // Pipes opened for read-write always succeeds without blocking. rw := i.p.Open(ctx, d, flags) - i.newHandleLocked(&i.rWakeup) - i.newHandleLocked(&i.wWakeup) + newHandleLocked(&i.rWakeup) + newHandleLocked(&i.wWakeup) return rw, nil default: @@ -136,65 +135,6 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi } } -// waitFor blocks until the underlying pipe has at least one reader/writer is -// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this -// function will block for either readers or writers, depending on where -// 'wakeupChan' points. -// -// f.mu must be held by the caller. waitFor returns with f.mu held, but it will -// drop f.mu before blocking for any reader/writers. -func (i *inodeOperations) waitFor(wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool { - // Ideally this function would simply use a condition variable. However, the - // wait needs to be interruptible via 'sleeper', so we must sychronize via a - // channel. The synchronization below relies on the fact that closing a - // channel unblocks all receives on the channel. - - // Does an appropriate wakeup channel already exist? If not, create a new - // one. This is all done under f.mu to avoid races. - if *wakeupChan == nil { - *wakeupChan = make(chan struct{}) - } - - // Grab a local reference to the wakeup channel since it may disappear as - // soon as we drop f.mu. - wakeup := *wakeupChan - - // Drop the lock and prepare to sleep. - i.mu.Unlock() - cancel := sleeper.SleepStart() - - // Wait for either a new reader/write to be signalled via 'wakeup', or - // for the sleep to be cancelled. - select { - case <-wakeup: - sleeper.SleepFinish(true) - case <-cancel: - sleeper.SleepFinish(false) - } - - // Take the lock and check if we were woken. If we were woken and - // interrupted, the former takes priority. - i.mu.Lock() - select { - case <-wakeup: - return true - default: - return false - } -} - -// newHandleLocked signals a new pipe reader or writer depending on where -// 'wakeupChan' points. This unblocks any corresponding reader or writer -// waiting for the other end of the channel to be opened, see Fifo.waitFor. -// -// i.mu must be held. -func (*inodeOperations) newHandleLocked(wakeupChan *chan struct{}) { - if *wakeupChan != nil { - close(*wakeupChan) - *wakeupChan = nil - } -} - func (*inodeOperations) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error { return syserror.EPIPE } diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index adbad7764..16fa80abe 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs. } func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(contexttest.Context(t), true, DefaultPipeSize, usermem.PageSize) + return NewPipe(true, DefaultPipeSize, usermem.PageSize) } func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(contexttest.Context(t), false, DefaultPipeSize, usermem.PageSize) + return NewPipe(false, DefaultPipeSize, usermem.PageSize) } // assertRecvBlocks ensures that a recv attempt on c blocks for at least diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 93b50669f..1a1b38f83 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -98,7 +98,7 @@ type Pipe struct { // NewPipe initializes and returns a pipe. // // N.B. The size and atomicIOBytes will be bounded. -func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { +func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } @@ -111,17 +111,33 @@ func NewPipe(ctx context.Context, isNamed bool, sizeBytes, atomicIOBytes int64) if atomicIOBytes > sizeBytes { atomicIOBytes = sizeBytes } - return &Pipe{ - isNamed: isNamed, - max: sizeBytes, - atomicIOBytes: atomicIOBytes, + var p Pipe + initPipe(&p, isNamed, sizeBytes, atomicIOBytes) + return &p +} + +func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) { + if sizeBytes < MinimumPipeSize { + sizeBytes = MinimumPipeSize + } + if sizeBytes > MaximumPipeSize { + sizeBytes = MaximumPipeSize + } + if atomicIOBytes <= 0 { + atomicIOBytes = 1 + } + if atomicIOBytes > sizeBytes { + atomicIOBytes = sizeBytes } + pipe.isNamed = isNamed + pipe.max = sizeBytes + pipe.atomicIOBytes = atomicIOBytes } // NewConnectedPipe initializes a pipe and returns a pair of objects // representing the read and write ends of the pipe. func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) { - p := NewPipe(ctx, false /* isNamed */, sizeBytes, atomicIOBytes) + p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes) // Build an fs.Dirent for the pipe which will be shared by both // returned files. diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go new file mode 100644 index 000000000..ef9641e6a --- /dev/null +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -0,0 +1,213 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "io" + "math" + "sync" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// This file contains Pipe file functionality that is tied to neither VFS nor +// the old fs architecture. + +// Release cleans up the pipe's state. +func (p *Pipe) Release() { + p.rClose() + p.wClose() + + // Wake up readers and writers. + p.Notify(waiter.EventIn | waiter.EventOut) +} + +// Read reads from the Pipe into dst. +func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) { + n, err := p.read(ctx, readOps{ + left: func() int64 { + return dst.NumBytes() + }, + limit: func(l int64) { + dst = dst.TakeFirst64(l) + }, + read: func(buf *buffer) (int64, error) { + n, err := dst.CopyOutFrom(ctx, buf) + dst = dst.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventOut) + } + return n, err +} + +// WriteTo writes to w from the Pipe. +func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) { + ops := readOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + read: func(buf *buffer) (int64, error) { + n, err := buf.ReadToWriter(w, count, dup) + count -= n + return n, err + }, + } + if dup { + // There is no notification for dup operations. + return p.dup(ctx, ops) + } + n, err := p.read(ctx, ops) + if n > 0 { + p.Notify(waiter.EventOut) + } + return n, err +} + +// Write writes to the Pipe from src. +func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) { + n, err := p.write(ctx, writeOps{ + left: func() int64 { + return src.NumBytes() + }, + limit: func(l int64) { + src = src.TakeFirst64(l) + }, + write: func(buf *buffer) (int64, error) { + n, err := src.CopyInTo(ctx, buf) + src = src.DropFirst64(n) + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventIn) + } + return n, err +} + +// ReadFrom reads from r to the Pipe. +func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) { + n, err := p.write(ctx, writeOps{ + left: func() int64 { + return count + }, + limit: func(l int64) { + count = l + }, + write: func(buf *buffer) (int64, error) { + n, err := buf.WriteFromReader(r, count) + count -= n + return n, err + }, + }) + if n > 0 { + p.Notify(waiter.EventIn) + } + return n, err +} + +// Readiness returns the ready events in the underlying pipe. +func (p *Pipe) Readiness(mask waiter.EventMask) waiter.EventMask { + return p.rwReadiness() & mask +} + +// Ioctl implements ioctls on the Pipe. +func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + // Switch on ioctl request. + switch int(args[1].Int()) { + case linux.FIONREAD: + v := p.queued() + if v > math.MaxInt32 { + v = math.MaxInt32 // Silently truncate. + } + // Copy result to user-space. + _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + default: + return 0, syscall.ENOTTY + } +} + +// waitFor blocks until the underlying pipe has at least one reader/writer is +// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this +// function will block for either readers or writers, depending on where +// 'wakeupChan' points. +// +// mu must be held by the caller. waitFor returns with mu held, but it will +// drop mu before blocking for any reader/writers. +func waitFor(mu *sync.Mutex, wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool { + // Ideally this function would simply use a condition variable. However, the + // wait needs to be interruptible via 'sleeper', so we must sychronize via a + // channel. The synchronization below relies on the fact that closing a + // channel unblocks all receives on the channel. + + // Does an appropriate wakeup channel already exist? If not, create a new + // one. This is all done under f.mu to avoid races. + if *wakeupChan == nil { + *wakeupChan = make(chan struct{}) + } + + // Grab a local reference to the wakeup channel since it may disappear as + // soon as we drop f.mu. + wakeup := *wakeupChan + + // Drop the lock and prepare to sleep. + mu.Unlock() + cancel := sleeper.SleepStart() + + // Wait for either a new reader/write to be signalled via 'wakeup', or + // for the sleep to be cancelled. + select { + case <-wakeup: + sleeper.SleepFinish(true) + case <-cancel: + sleeper.SleepFinish(false) + } + + // Take the lock and check if we were woken. If we were woken and + // interrupted, the former takes priority. + mu.Lock() + select { + case <-wakeup: + return true + default: + return false + } +} + +// newHandleLocked signals a new pipe reader or writer depending on where +// 'wakeupChan' points. This unblocks any corresponding reader or writer +// waiting for the other end of the channel to be opened, see Fifo.waitFor. +// +// Precondition: the mutex protecting wakeupChan must be held. +func newHandleLocked(wakeupChan *chan struct{}) { + if *wakeupChan != nil { + close(*wakeupChan) + *wakeupChan = nil + } +} diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go index 7c307f013..b4d29fc77 100644 --- a/pkg/sentry/kernel/pipe/reader_writer.go +++ b/pkg/sentry/kernel/pipe/reader_writer.go @@ -16,16 +16,12 @@ package pipe import ( "io" - "math" - "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/waiter" ) // ReaderWriter satisfies the FileOperations interface and services both @@ -45,124 +41,27 @@ type ReaderWriter struct { *Pipe } -// Release implements fs.FileOperations.Release. -func (rw *ReaderWriter) Release() { - rw.Pipe.rClose() - rw.Pipe.wClose() - - // Wake up readers and writers. - rw.Pipe.Notify(waiter.EventIn | waiter.EventOut) -} - // Read implements fs.FileOperations.Read. func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.read(ctx, readOps{ - left: func() int64 { - return dst.NumBytes() - }, - limit: func(l int64) { - dst = dst.TakeFirst64(l) - }, - read: func(buf *buffer) (int64, error) { - n, err := dst.CopyOutFrom(ctx, buf) - dst = dst.DropFirst64(n) - return n, err - }, - }) - if n > 0 { - rw.Pipe.Notify(waiter.EventOut) - } - return n, err + return rw.Pipe.Read(ctx, dst) } // WriteTo implements fs.FileOperations.WriteTo. func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) { - ops := readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(buf *buffer) (int64, error) { - n, err := buf.ReadToWriter(w, count, dup) - count -= n - return n, err - }, - } - if dup { - // There is no notification for dup operations. - return rw.Pipe.dup(ctx, ops) - } - n, err := rw.Pipe.read(ctx, ops) - if n > 0 { - rw.Pipe.Notify(waiter.EventOut) - } - return n, err + return rw.Pipe.WriteTo(ctx, w, count, dup) } // Write implements fs.FileOperations.Write. func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - n, err := rw.Pipe.write(ctx, writeOps{ - left: func() int64 { - return src.NumBytes() - }, - limit: func(l int64) { - src = src.TakeFirst64(l) - }, - write: func(buf *buffer) (int64, error) { - n, err := src.CopyInTo(ctx, buf) - src = src.DropFirst64(n) - return n, err - }, - }) - if n > 0 { - rw.Pipe.Notify(waiter.EventIn) - } - return n, err + return rw.Pipe.Write(ctx, src) } // ReadFrom implements fs.FileOperations.WriteTo. func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { - n, err := rw.Pipe.write(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(buf *buffer) (int64, error) { - n, err := buf.WriteFromReader(r, count) - count -= n - return n, err - }, - }) - if n > 0 { - rw.Pipe.Notify(waiter.EventIn) - } - return n, err -} - -// Readiness returns the ready events in the underlying pipe. -func (rw *ReaderWriter) Readiness(mask waiter.EventMask) waiter.EventMask { - return rw.Pipe.rwReadiness() & mask + return rw.Pipe.ReadFrom(ctx, r, count) } // Ioctl implements fs.FileOperations.Ioctl. func (rw *ReaderWriter) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - // Switch on ioctl request. - switch int(args[1].Int()) { - case linux.FIONREAD: - v := rw.queued() - if v > math.MaxInt32 { - v = math.MaxInt32 // Silently truncate. - } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) - return 0, err - default: - return 0, syscall.ENOTTY - } + return rw.Pipe.Ioctl(ctx, io, args) } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go new file mode 100644 index 000000000..6416e0dd8 --- /dev/null +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -0,0 +1,220 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipe + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" +) + +// This file contains types enabling the pipe package to be used with the vfs +// package. + +// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should +// not be copied. +type VFSPipe struct { + // mu protects the fields below. + mu sync.Mutex `state:"nosave"` + + // pipe is the underlying pipe. + pipe Pipe + + // Channels for synchronizing the creation of new readers and writers + // of this fifo. See waitFor and newHandleLocked. + // + // These are not saved/restored because all waiters are unblocked on + // save, and either automatically restart (via ERESTARTSYS) or return + // EINTR on resume. On restarts via ERESTARTSYS, the appropriate + // channel will be recreated. + rWakeup chan struct{} `state:"nosave"` + wWakeup chan struct{} `state:"nosave"` +} + +// NewVFSPipe returns an initialized VFSPipe. +func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe { + var vp VFSPipe + initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes) + return &vp +} + +// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics +// during open: +// +// "Normally, opening the FIFO blocks until the other end is opened also. A +// process can open a FIFO in nonblocking mode. In this case, opening for +// read-only will succeed even if no-one has opened on the write side yet, +// opening for write-only will fail with ENXIO (no such device or address) +// unless the other end has already been opened. Under Linux, opening a FIFO +// for read and write will succeed both in blocking and nonblocking mode. POSIX +// leaves this behavior undefined. This can be used to open a FIFO for writing +// while there are no readers available." - fifo(7) +func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { + vp.mu.Lock() + defer vp.mu.Unlock() + + readable := vfs.MayReadFileWithOpenFlags(flags) + writable := vfs.MayWriteFileWithOpenFlags(flags) + if !readable && !writable { + return nil, syserror.EINVAL + } + + vfd, err := vp.open(rp, vfsd, vfsfd, flags) + if err != nil { + return nil, err + } + + switch { + case readable && writable: + // Pipes opened for read-write always succeed without blocking. + newHandleLocked(&vp.rWakeup) + newHandleLocked(&vp.wWakeup) + + case readable: + newHandleLocked(&vp.rWakeup) + // If this pipe is being opened as nonblocking and there's no + // writer, we have to wait for a writer to open the other end. + if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) { + return nil, syserror.EINTR + } + + case writable: + newHandleLocked(&vp.wWakeup) + + if !vp.pipe.HasReaders() { + // Nonblocking, write-only opens fail with ENXIO when + // the read side isn't open yet. + if flags&linux.O_NONBLOCK != 0 { + return nil, syserror.ENXIO + } + // Wait for a reader to open the other end. + if !waitFor(&vp.mu, &vp.rWakeup, ctx) { + return nil, syserror.EINTR + } + } + + default: + panic("invalid pipe flags: must be readable, writable, or both") + } + + return vfd, nil +} + +// Preconditions: vp.mu must be held. +func (vp *VFSPipe) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) { + var fd VFSPipeFD + fd.flags = flags + fd.readable = vfs.MayReadFileWithOpenFlags(flags) + fd.writable = vfs.MayWriteFileWithOpenFlags(flags) + fd.vfsfd = vfsfd + fd.pipe = &vp.pipe + if fd.writable { + // The corresponding Mount.EndWrite() is in VFSPipe.Release(). + if err := rp.Mount().CheckBeginWrite(); err != nil { + return nil, err + } + } + + switch { + case fd.readable && fd.writable: + vp.pipe.rOpen() + vp.pipe.wOpen() + case fd.readable: + vp.pipe.rOpen() + case fd.writable: + vp.pipe.wOpen() + default: + panic("invalid pipe flags: must be readable, writable, or both") + } + + return &fd, nil +} + +// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is +// expected that filesystesm will use this in a struct implementing +// vfs.FileDescriptionImpl. +type VFSPipeFD struct { + pipe *Pipe + flags uint32 + readable bool + writable bool + vfsfd *vfs.FileDescription +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *VFSPipeFD) Release() { + var event waiter.EventMask + if fd.readable { + fd.pipe.rClose() + event |= waiter.EventIn + } + if fd.writable { + fd.pipe.wClose() + event |= waiter.EventOut + } + if event == 0 { + panic("invalid pipe flags: must be readable, writable, or both") + } + + if fd.writable { + fd.vfsfd.VirtualDentry().Mount().EndWrite() + } + + fd.pipe.Notify(event) +} + +// OnClose implements vfs.FileDescriptionImpl.OnClose. +func (fd *VFSPipeFD) OnClose(_ context.Context) error { + return nil +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { + if !fd.readable { + return 0, syserror.EINVAL + } + + return fd.pipe.Read(ctx, dst) +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { + if !fd.writable { + return 0, syserror.EINVAL + } + + return fd.pipe.Write(ctx, src) +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return fd.pipe.Ioctl(ctx, uio, args) +} diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go index c5d095af7..2e861a5a8 100644 --- a/pkg/sentry/kernel/posixtimer.go +++ b/pkg/sentry/kernel/posixtimer.go @@ -117,9 +117,9 @@ func (it *IntervalTimer) signalRejectedLocked() { } // Notify implements ktime.TimerListener.Notify. -func (it *IntervalTimer) Notify(exp uint64) { +func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { if it.target == nil { - return + return ktime.Setting{}, false } it.target.tg.pidns.owner.mu.RLock() @@ -129,7 +129,7 @@ func (it *IntervalTimer) Notify(exp uint64) { if it.sigpending { it.overrunCur += exp - return + return ktime.Setting{}, false } // sigpending must be set before sendSignalTimerLocked() so that it can be @@ -148,6 +148,8 @@ func (it *IntervalTimer) Notify(exp uint64) { if err := it.target.sendSignalTimerLocked(si, it.group, it); err != nil { it.signalRejectedLocked() } + + return ktime.Setting{}, false } // Destroy implements ktime.TimerListener.Destroy. Users of Timer should call diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 80e5e5da3..f4c00cd86 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "waiter_list", out = "waiter_list.go", diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index aa7471eb6..cd48945e6 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "shm", srcs = [ diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 06fd5ec88..4b08d7d72 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -121,7 +121,10 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // Readiness implements waiter.Waitable.Readiness. func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask { - return mask & waiter.EventIn + if mask&waiter.EventIn != 0 && s.target.PendingSignals()&s.Mask() != 0 { + return waiter.EventIn // Pending signals. + } + return 0 } // EventRegister implements waiter.Waitable.EventRegister. diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index c82ef5486..9be3dae3c 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -709,9 +709,9 @@ func (t *Task) FDTable() *FDTable { return t.fdTable } -// GetFile is a convenience wrapper t.FDTable().GetFile. +// GetFile is a convenience wrapper for t.FDTable().Get. // -// Precondition: same as FDTable. +// Precondition: same as FDTable.Get. func (t *Task) GetFile(fd int32) *fs.File { f, _ := t.fdTable.Get(fd) return f diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 8639d379f..bb5560acf 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -18,10 +18,8 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -132,30 +130,21 @@ func (t *Task) Stack() *arch.Stack { return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())} } -// LoadTaskImage loads filename into a new TaskContext. +// LoadTaskImage loads a specified file into a new TaskContext. // -// It takes several arguments: -// * mounts: MountNamespace to lookup filename in -// * root: Root to lookup filename under -// * wd: Working directory to lookup filename under -// * maxTraversals: maximum number of symlinks to follow -// * filename: path to binary to load -// * file: an open fs.File object of the binary to load. If set, -// file will be loaded and not filename. -// * argv: Binary argv -// * envv: Binary envv -// * fs: Binary FeatureSet -func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, file *fs.File, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) { - // If File is not nil, we should load that instead of resolving filename. - if file != nil { - filename = file.MappedName(ctx) +// args.MemoryManager does not need to be set by the caller. +func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) { + // If File is not nil, we should load that instead of resolving Filename. + if args.File != nil { + args.Filename = args.File.MappedName(ctx) } // Prepare a new user address space to load into. m := mm.NewMemoryManager(k, k) defer m.DecUsers(ctx) + args.MemoryManager = m - os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv, envv, k.extraAuxv, k.vdso) + os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso) if err != nil { return nil, err } diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go index 78ff14b20..ce3e6ef28 100644 --- a/pkg/sentry/kernel/task_identity.go +++ b/pkg/sentry/kernel/task_identity.go @@ -465,8 +465,8 @@ func (t *Task) SetKeepCaps(k bool) { // disables the features we don't support anyway, is always set. This // drastically simplifies this function. // -// - We don't implement AT_SECURE, because no_new_privs always being set means -// that the conditions that require AT_SECURE never arise. (Compare Linux's +// - We don't set AT_SECURE = 1, because no_new_privs always being set means +// that the conditions that require AT_SECURE = 1 never arise. (Compare Linux's // security/commoncap.c:cap_bprm_set_creds() and cap_bprm_secureexec().) // // - We don't check for CAP_SYS_ADMIN in prctl(PR_SET_SECCOMP), since diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index e76c069b0..8b148db35 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -126,12 +126,22 @@ func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) { t.gosched.Timestamp = now t.gosched.State = state t.goschedSeq.EndWrite() + + if state != TaskGoroutineRunningApp { + // Task is blocking/stopping. + t.k.decRunningTasks() + } } // Preconditions: The caller must be running on the task goroutine, and leaving // a state indicated by a previous call to // t.accountTaskGoroutineEnter(state). func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) { + if state != TaskGoroutineRunningApp { + // Task is unblocking/continuing. + t.k.incRunningTasks() + } + now := t.k.CPUClockNow() if t.gosched.State != state { panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, state, TaskGoroutineRunningSys)) @@ -330,7 +340,7 @@ func newKernelCPUClockTicker(k *Kernel) *kernelCPUClockTicker { } // Notify implements ktime.TimerListener.Notify. -func (ticker *kernelCPUClockTicker) Notify(exp uint64) { +func (ticker *kernelCPUClockTicker) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { // Only increment cpuClock by 1 regardless of the number of expirations. // This approximately compensates for cases where thread throttling or bad // Go runtime scheduling prevents the kernelCPUClockTicker goroutine, and @@ -426,6 +436,27 @@ func (ticker *kernelCPUClockTicker) Notify(exp uint64) { tgs[i] = nil } ticker.tgs = tgs[:0] + + // If nothing is running, we can disable the timer. + tasks := atomic.LoadInt64(&ticker.k.runningTasks) + if tasks == 0 { + ticker.k.runningTasksMu.Lock() + defer ticker.k.runningTasksMu.Unlock() + tasks := atomic.LoadInt64(&ticker.k.runningTasks) + if tasks != 0 { + // Raced with a 0 -> 1 transition. + return setting, false + } + + // Stop the timer. We must cache the current setting so the + // kernel can access it without violating the lock order. + ticker.k.cpuClockTickerSetting = setting + ticker.k.cpuClockTickerDisabled = true + setting.Enabled = false + return setting, true + } + + return setting, false } // Destroy implements ktime.TimerListener.Destroy. diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 0eef24bfb..72568d296 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -511,8 +511,9 @@ type itimerRealListener struct { } // Notify implements ktime.TimerListener.Notify. -func (l *itimerRealListener) Notify(exp uint64) { +func (l *itimerRealListener) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) { l.tg.SendSignal(SignalInfoPriv(linux.SIGALRM)) + return ktime.Setting{}, false } // Destroy implements ktime.TimerListener.Destroy. diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 9beae4b31..31847e1df 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "time", srcs = [ diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index aa6c75d25..107394183 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -280,13 +280,16 @@ func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask { // A TimerListener receives expirations from a Timer. type TimerListener interface { // Notify is called when its associated Timer expires. exp is the number of - // expirations. + // expirations. setting is the next timer Setting. // // Notify is called with the associated Timer's mutex locked, so Notify // must not take any locks that precede Timer.mu in lock order. // + // If Notify returns true, the timer will use the returned setting + // rather than the passed one. + // // Preconditions: exp > 0. - Notify(exp uint64) + Notify(exp uint64, setting Setting) (newSetting Setting, update bool) // Destroy is called when the timer is destroyed. Destroy() @@ -533,7 +536,9 @@ func (t *Timer) Tick() { s, exp := t.setting.At(now) t.setting = s if exp > 0 { - t.listener.Notify(exp) + if newS, ok := t.listener.Notify(exp, t.setting); ok { + t.setting = newS + } } t.resetKickerLocked(now) } @@ -588,7 +593,9 @@ func (t *Timer) Get() (Time, Setting) { s, exp := t.setting.At(now) t.setting = s if exp > 0 { - t.listener.Notify(exp) + if newS, ok := t.listener.Notify(exp, t.setting); ok { + t.setting = newS + } } t.resetKickerLocked(now) return now, s @@ -620,7 +627,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) { } oldS, oldExp := t.setting.At(now) if oldExp > 0 { - t.listener.Notify(oldExp) + t.listener.Notify(oldExp, oldS) + // N.B. The returned Setting doesn't matter because we're about + // to overwrite. } if f != nil { f() @@ -628,7 +637,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) { newS, newExp := s.At(now) t.setting = newS if newExp > 0 { - t.listener.Notify(newExp) + if newS, ok := t.listener.Notify(newExp, t.setting); ok { + t.setting = newS + } } t.resetKickerLocked(now) return now, oldS @@ -683,11 +694,13 @@ func NewChannelNotifier() (TimerListener, <-chan struct{}) { } // Notify implements ktime.TimerListener.Notify. -func (c *ChannelNotifier) Notify(uint64) { +func (c *ChannelNotifier) Notify(uint64, Setting) (Setting, bool) { select { case c.tchan <- struct{}{}: default: } + + return Setting{}, false } // Destroy implements ktime.TimerListener.Destroy and will close the channel. diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 59649c770..156e67bf8 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "limits", srcs = [ diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index 3b322f5f3..2890393bd 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_embed_data") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_embed_data( name = "vdso_bin", src = "//vdso:vdso.so", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index ba9c9ce12..c2c3ec06e 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -323,18 +323,22 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf. return syserror.ENOEXEC } + // N.B. Linux uses vm_brk_flags to map these pages, which only + // honors the X bit, always mapping at least RW. ignoring These + // pages are not included in the final brk region. + prot := usermem.ReadWrite + if phdr.Flags&elf.PF_X == elf.PF_X { + prot.Execute = true + } + if _, err := m.MMap(ctx, memmap.MMapOpts{ Length: uint64(anonSize), Addr: anonAddr, // Fixed without Unmap will fail the mmap if something is // already at addr. - Fixed: true, - Private: true, - // N.B. Linux uses vm_brk to map these pages, ignoring - // the segment protections, instead always mapping RW. - // These pages are not included in the final brk - // region. - Perms: usermem.ReadWrite, + Fixed: true, + Private: true, + Perms: prot, MaxPerms: usermem.AnyAccess, }); err != nil { ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err) @@ -620,15 +624,15 @@ func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, in return loadParsedELF(ctx, m, f, info, 0) } -// loadELF loads f into the Task address space. +// loadELF loads args.File into the Task address space. // // If loadELF returns ErrSwitchFile it should be called again with the returned // path and argv. // // Preconditions: -// * f is an ELF file -func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) { - bin, ac, err := loadInitialELF(ctx, m, fs, f) +// * args.File is an ELF file +func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error) { + bin, ac, err := loadInitialELF(ctx, args.MemoryManager, args.Features, args.File) if err != nil { ctx.Infof("Error loading binary: %v", err) return loadedELF{}, nil, err @@ -636,7 +640,14 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace var interp loadedELF if bin.interpreter != "" { - d, i, err := openPath(ctx, mounts, root, wd, maxTraversals, bin.interpreter) + // Even if we do not allow the final link of the script to be + // resolved, the interpreter should still be resolved if it is + // a symlink. + args.ResolveFinal = true + // Refresh the traversal limit. + *args.RemainingTraversals = linux.MaxSymlinkTraversals + args.Filename = bin.interpreter + d, i, err := openPath(ctx, args) if err != nil { ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err) return loadedELF{}, nil, err @@ -645,7 +656,7 @@ func loadELF(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace // We don't need the Dirent. d.DecRef() - interp, err = loadInterpreterELF(ctx, m, i, bin) + interp, err = loadInterpreterELF(ctx, args.MemoryManager, i, bin) if err != nil { ctx.Infof("Error loading interpreter: %v", err) return loadedELF{}, nil, err diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index f6f1ae762..b03eeb005 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package loader loads a binary into a MemoryManager. +// Package loader loads an executable file into a MemoryManager. package loader import ( @@ -20,6 +20,7 @@ import ( "fmt" "io" "path" + "strings" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" @@ -35,6 +36,54 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// LoadArgs holds specifications for an executable file to be loaded. +type LoadArgs struct { + // MemoryManager is the memory manager to load the executable into. + MemoryManager *mm.MemoryManager + + // Mounts is the mount namespace in which to look up Filename. + Mounts *fs.MountNamespace + + // Root is the root directory under which to look up Filename. + Root *fs.Dirent + + // WorkingDirectory is the working directory under which to look up + // Filename. + WorkingDirectory *fs.Dirent + + // RemainingTraversals is the maximum number of symlinks to follow to + // resolve Filename. This counter is passed by reference to keep it + // updated throughout the call stack. + RemainingTraversals *uint + + // ResolveFinal indicates whether the final link of Filename should be + // resolved, if it is a symlink. + ResolveFinal bool + + // Filename is the path for the executable. + Filename string + + // File is an open fs.File object of the executable. If File is not + // nil, then File will be loaded and Filename will be ignored. + File *fs.File + + // CloseOnExec indicates that the executable (or one of its parent + // directories) was opened with O_CLOEXEC. If the executable is an + // interpreter script, then cause an ENOENT error to occur, since the + // script would otherwise be inaccessible to the interpreter. + CloseOnExec bool + + // Argv is the vector of arguments to pass to the executable. + Argv []string + + // Envv is the vector of environment variables to pass to the + // executable. + Envv []string + + // Features specifies the CPU feature set for the executable. + Features *cpuid.FeatureSet +} + // readFull behaves like io.ReadFull for an *fs.File. func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { var total int64 @@ -51,80 +100,82 @@ func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset in return total, nil } -// openPath opens name for loading. +// openPath opens args.Filename and checks that it is valid for loading. // -// openPath returns the fs.Dirent and an *fs.File for name, which is not +// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not // installed in the Task FDTable. The caller takes ownership of both. // -// name must be a readable, executable, regular file. -func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, name string) (*fs.Dirent, *fs.File, error) { - if name == "" { +// args.Filename must be a readable, executable, regular file. +func openPath(ctx context.Context, args LoadArgs) (*fs.Dirent, *fs.File, error) { + if args.Filename == "" { ctx.Infof("cannot open empty name") return nil, nil, syserror.ENOENT } - d, err := mm.FindInode(ctx, root, wd, name, maxTraversals) + var d *fs.Dirent + var err error + if args.ResolveFinal { + d, err = args.Mounts.FindInode(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } else { + d, err = args.Mounts.FindLink(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals) + } if err != nil { return nil, nil, err } - - // Open file will take a reference to Dirent, so destroy this one. + // Defer a DecRef for the sake of failure cases. defer d.DecRef() - return openFile(ctx, nil, d, name) -} + if !args.ResolveFinal && fs.IsSymlink(d.Inode.StableAttr) { + return nil, nil, syserror.ELOOP + } -// openFile performs checks on a file to be executed. If provided a *fs.File, -// openFile takes that file's Dirent and performs checks on it. If provided a -// *fs.Dirent and not a *fs.File, it creates a *fs.File object from the Dirent's -// Inode and performs checks on that. -// -// openFile returns an *fs.File and *fs.Dirent, and the caller takes ownership -// of both. -// -// "dirent" and "file" must not both be nil and point to a readable, executable, regular file. -func openFile(ctx context.Context, file *fs.File, dirent *fs.Dirent, name string) (*fs.Dirent, *fs.File, error) { - // file and dirent must not be nil. - if dirent == nil && file == nil { - ctx.Infof("dirent and file cannot both be nil.") - return nil, nil, syserror.ENOENT + if err := checkPermission(ctx, d); err != nil { + return nil, nil, err } - if file != nil { - dirent = file.Dirent + // If they claim it's a directory, then make sure. + // + // N.B. we reject directories below, but we must first reject + // non-directories passed as directories. + if strings.HasSuffix(args.Filename, "/") && !fs.IsDir(d.Inode.StableAttr) { + return nil, nil, syserror.ENOTDIR } - // Perform permissions checks on the file. - if err := checkFile(ctx, dirent, name); err != nil { + if err := checkIsRegularFile(ctx, d, args.Filename); err != nil { return nil, nil, err } - if file == nil { - var ferr error - if file, ferr = dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}); ferr != nil { - return nil, nil, ferr - } - } else { - // GetFile takes a reference to the created file, so make one in the case - // that the file reference already existed. - file.IncRef() + f, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) + if err != nil { + return nil, nil, err + } + // Defer a DecRef for the sake of failure cases. + defer f.DecRef() + + if err := checkPread(ctx, f, args.Filename); err != nil { + return nil, nil, err + } + + d.IncRef() + f.IncRef() + return d, f, err +} + +// checkFile performs checks on a file to be executed. +func checkFile(ctx context.Context, f *fs.File, filename string) error { + if err := checkPermission(ctx, f.Dirent); err != nil { + return err } - // We must be able to read at arbitrary offsets. - if !file.Flags().Pread { - file.DecRef() - ctx.Infof("%s cannot be read at an offset: %+v", file.MappedName(ctx), file.Flags()) - return nil, nil, syserror.EACCES + if err := checkIsRegularFile(ctx, f.Dirent, filename); err != nil { + return err } - // Grab reference for caller. - dirent.IncRef() - return dirent, file, nil + return checkPread(ctx, f, filename) } -// checkFile performs file permissions checks for binaries called in openPath -// and openFile -func checkFile(ctx context.Context, d *fs.Dirent, name string) error { +// checkPermission checks whether the file is readable and executable. +func checkPermission(ctx context.Context, d *fs.Dirent) error { perms := fs.PermMask{ // TODO(gvisor.dev/issue/160): Linux requires only execute // permission, not read. However, our backing filesystems may @@ -135,26 +186,26 @@ func checkFile(ctx context.Context, d *fs.Dirent, name string) error { Read: true, Execute: true, } - if err := d.Inode.CheckPermission(ctx, perms); err != nil { - return err - } + return d.Inode.CheckPermission(ctx, perms) +} - // If they claim it's a directory, then make sure. - // - // N.B. we reject directories below, but we must first reject - // non-directories passed as directories. - if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) { - return syserror.ENOTDIR +// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc. +func checkIsRegularFile(ctx context.Context, d *fs.Dirent, filename string) error { + attr := d.Inode.StableAttr + if !fs.IsRegular(attr) { + ctx.Infof("%s is not regular: %v", filename, attr) + return syserror.EACCES } + return nil +} - // No exec-ing directories, pipes, etc! - if !fs.IsRegular(d.Inode.StableAttr) { - ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr) +// checkPread checks whether we can read the file at arbitrary offsets. +func checkPread(ctx context.Context, f *fs.File, filename string) error { + if !f.Flags().Pread { + ctx.Infof("%s cannot be read at an offset: %+v", filename, f.Flags()) return syserror.EACCES } - return nil - } // allocStack allocates and maps a stack in to any available part of the address space. @@ -173,45 +224,49 @@ const ( maxLoaderAttempts = 6 ) -// loadBinary loads a binary that is pointed to by "file". If nil, the path -// "filename" is resolved and loaded. +// loadExecutable loads an executable that is pointed to by args.File. If nil, +// the path args.Filename is resolved and loaded. If the executable is an +// interpreter script rather than an ELF, the binary of the corresponding +// interpreter will be loaded. // // It returns: // * loadedELF, description of the loaded binary // * arch.Context matching the binary arch // * fs.Dirent of the binary file -// * Possibly updated argv -func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, features *cpuid.FeatureSet, filename string, passedFile *fs.File, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) { +// * Possibly updated args.Argv +func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, *fs.Dirent, []string, error) { for i := 0; i < maxLoaderAttempts; i++ { var ( d *fs.Dirent - f *fs.File err error ) - if passedFile == nil { - d, f, err = openPath(ctx, mounts, root, wd, remainingTraversals, filename) - + if args.File == nil { + d, args.File, err = openPath(ctx, args) + // We will return d in the successful case, but defer a DecRef for the + // sake of intermediate loops and failure cases. + if d != nil { + defer d.DecRef() + } + if args.File != nil { + defer args.File.DecRef() + } } else { - d, f, err = openFile(ctx, passedFile, nil, "") - // Set to nil in case we loop on a Interpreter Script. - passedFile = nil + d = args.File.Dirent + d.IncRef() + defer d.DecRef() + err = checkFile(ctx, args.File, args.Filename) } - if err != nil { - ctx.Infof("Error opening %s: %v", filename, err) + ctx.Infof("Error opening %s: %v", args.Filename, err) return loadedELF{}, nil, nil, nil, err } - defer f.DecRef() - // We will return d in the successful case, but defer a DecRef - // for intermediate loops and failure cases. - defer d.DecRef() // Check the header. Is this an ELF or interpreter script? var hdr [4]uint8 // N.B. We assume that reading from a regular file cannot block. - _, err = readFull(ctx, f, usermem.BytesIOSequence(hdr[:]), 0) - // Allow unexpected EOF, as a valid executable could be only three - // bytes (e.g., #!a). + _, err = readFull(ctx, args.File, usermem.BytesIOSequence(hdr[:]), 0) + // Allow unexpected EOF, as a valid executable could be only three bytes + // (e.g., #!a). if err != nil && err != io.ErrUnexpectedEOF { if err == io.EOF { err = syserror.ENOEXEC @@ -221,33 +276,38 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp switch { case bytes.Equal(hdr[:], []byte(elfMagic)): - loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, features, f) + loaded, ac, err := loadELF(ctx, args) if err != nil { ctx.Infof("Error loading ELF: %v", err) return loadedELF{}, nil, nil, nil, err } // An ELF is always terminal. Hold on to d. d.IncRef() - return loaded, ac, d, argv, err + return loaded, ac, d, args.Argv, err case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)): - newpath, newargv, err := parseInterpreterScript(ctx, filename, f, argv) + if args.CloseOnExec { + return loadedELF{}, nil, nil, nil, syserror.ENOENT + } + args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv) if err != nil { ctx.Infof("Error loading interpreter script: %v", err) return loadedELF{}, nil, nil, nil, err } - filename = newpath - argv = newargv + // Refresh the traversal limit for the interpreter. + *args.RemainingTraversals = linux.MaxSymlinkTraversals default: ctx.Infof("Unknown magic: %v", hdr) return loadedELF{}, nil, nil, nil, syserror.ENOEXEC } + // Set to nil in case we loop on a Interpreter Script. + args.File = nil } return loadedELF{}, nil, nil, nil, syserror.ELOOP } -// Load loads "file" into a MemoryManager. If file is nil, the path "filename" -// is resolved and loaded instead. +// Load loads args.File into a MemoryManager. If args.File is nil, the path +// args.Filename is resolved and loaded instead. // // If Load returns ErrSwitchFile it should be called again with the returned // path and argv. @@ -255,37 +315,37 @@ func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamesp // Preconditions: // * The Task MemoryManager is empty. // * Load is called on the Task goroutine. -func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, file *fs.File, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { - // Load the binary itself. - loaded, ac, d, argv, err := loadBinary(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv) +func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { + // Load the executable itself. + loaded, ac, d, newArgv, err := loadExecutable(ctx, args) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", filename, err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } defer d.DecRef() // Load the VDSO. - vdsoAddr, err := loadVDSO(ctx, m, vdso, loaded) + vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) } // Setup the heap. brk starts at the next page after the end of the - // binary. Userspace can assume that the remainer of the page after + // executable. Userspace can assume that the remainer of the page after // loaded.end is available for its use. e, ok := loaded.end.RoundUp() if !ok { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("brk overflows: %#x", loaded.end), linux.ENOEXEC) } - m.BrkSetup(ctx, e) + args.MemoryManager.BrkSetup(ctx, e) // Allocate our stack. - stack, err := allocStack(ctx, m, ac) + stack, err := allocStack(ctx, args.MemoryManager, ac) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to allocate stack: %v", err), syserr.FromError(err).ToLinux()) } // Push the original filename to the stack, for AT_EXECFN. - execfn, err := stack.Push(filename) + execfn, err := stack.Push(args.Filename) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux()) } @@ -308,6 +368,9 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())}, arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())}, arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())}, + // The conditions that require AT_SECURE = 1 never arise. See + // kernel.Task.updateCredsForExecLocked. + arch.AuxEntry{linux.AT_SECURE, 0}, arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC}, arch.AuxEntry{linux.AT_EXECFN, execfn}, arch.AuxEntry{linux.AT_RANDOM, random}, @@ -316,11 +379,12 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r }...) auxv = append(auxv, extraAuxv...) - sl, err := stack.Load(argv, envv, auxv) + sl, err := stack.Load(newArgv, args.Envv, auxv) if err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load stack: %v", err), syserr.FromError(err).ToLinux()) } + m := args.MemoryManager m.SetArgvStart(sl.ArgvStart) m.SetArgvEnd(sl.ArgvEnd) m.SetEnvvStart(sl.EnvvStart) @@ -331,7 +395,7 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r ac.SetIP(uintptr(loaded.entry)) ac.SetStack(uintptr(stack.Bottom)) - name := path.Base(filename) + name := path.Base(args.Filename) if len(name) > linux.TASK_COMM_LEN-1 { name = name[:linux.TASK_COMM_LEN-1] } diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 9687e7e76..3ef84245b 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "mappable_range", out = "mappable_range.go", diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index b35c8c673..a804b8b5c 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "file_refcount_set", out = "file_refcount_set.go", diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 3fd904c67..f404107af 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "evictable_range", out = "evictable_range.go", diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index 1effc7735..aafce1d00 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -16,6 +16,7 @@ package pgalloc import ( "bytes" + "context" "fmt" "io" "runtime" @@ -29,7 +30,7 @@ import ( ) // SaveTo writes f's state to the given stream. -func (f *MemoryFile) SaveTo(w io.Writer) error { +func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { // Wait for reclaim. f.mu.Lock() defer f.mu.Unlock() @@ -78,10 +79,10 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // Save metadata. - if err := state.Save(w, &f.fileSize, nil); err != nil { + if err := state.Save(ctx, w, &f.fileSize, nil); err != nil { return err } - if err := state.Save(w, &f.usage, nil); err != nil { + if err := state.Save(ctx, w, &f.usage, nil); err != nil { return err } @@ -114,9 +115,9 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // LoadFrom loads MemoryFile state from the given stream. -func (f *MemoryFile) LoadFrom(r io.Reader) error { +func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { // Load metadata. - if err := state.Load(r, &f.fileSize, nil); err != nil { + if err := state.Load(ctx, r, &f.fileSize, nil); err != nil { return err } if err := f.file.Truncate(f.fileSize); err != nil { @@ -124,7 +125,7 @@ func (f *MemoryFile) LoadFrom(r io.Reader) error { } newMappings := make([]uintptr, f.fileSize>>chunkShift) f.mappings.Store(newMappings) - if err := state.Load(r, &f.usage, nil); err != nil { + if err := state.Load(ctx, r, &f.usage, nil); err != nil { return err } diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 9aa6ec507..157bffa81 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "file_range", out = "file_range.go", diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 31fa48ec5..6803d488c 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -23,6 +23,7 @@ go_library( "machine.go", "machine_amd64.go", "machine_amd64_unsafe.go", + "machine_arm64.go", "machine_unsafe.go", "physical_map.go", "virtual_map.go", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index acd41f73d..ea8b9632e 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -127,7 +127,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac // not have physical mappings, the KVM module may inject // spurious exceptions when emulation fails (i.e. it tries to // emulate because the RIP is pointed at those pages). - as.machine.mapPhysical(physical, length) + as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) // Install the page table mappings. Note that the ordering is // important; if the pagetable mappings were installed before diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/allocator.go index 80942e9c9..3f35414bb 100644 --- a/pkg/sentry/platform/kvm/allocator.go +++ b/pkg/sentry/platform/kvm/allocator.go @@ -54,7 +54,7 @@ func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { // //go:nosplit func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { - virtualStart, physicalStart, _, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions) if !ok { panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) } diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index b97476053..f6459cda9 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -46,9 +46,9 @@ func yield() { // calculateBluepillFault calculates the fault address range. // //go:nosplit -func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) { +func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) { alignedPhysical := physical &^ uintptr(usermem.PageSize-1) - for _, pr := range physicalRegions { + for _, pr := range phyRegions { end := pr.physical + pr.length if physical < pr.physical || physical >= end { continue @@ -77,12 +77,12 @@ func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, leng // The corresponding virtual address is returned. This may throw on error. // //go:nosplit -func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { +func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) { // Paging fault: we need to map the underlying physical pages for this // fault. This all has to be done in this function because we're in a // signal handler context. (We can't call any functions that might // split the stack.) - virtualStart, physicalStart, length, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { return 0, false } @@ -96,7 +96,7 @@ func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { yield() // Race with another call. slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0)) } - errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart) + errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags) if errno == 0 { // Successfully added region; we can increment nextSlot and // allow another set to proceed here. diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 7e8e9f42a..ca011ef78 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -80,13 +80,17 @@ func bluepillHandler(context unsafe.Pointer) { // interrupted KVM. Since we're in a signal handler // currently, all signals are masked and the signal // must have been delivered directly to this thread. + timeout := syscall.Timespec{} sig, _, errno := syscall.RawSyscall6( syscall.SYS_RT_SIGTIMEDWAIT, uintptr(unsafe.Pointer(&bounceSignalMask)), - 0, // siginfo. - 0, // timeout. - 8, // sigset size. + 0, // siginfo. + uintptr(unsafe.Pointer(&timeout)), // timeout. + 8, // sigset size. 0, 0) + if errno == syscall.EAGAIN { + continue + } if errno != 0 { throw("error waiting for pending signal") } @@ -162,7 +166,7 @@ func bluepillHandler(context unsafe.Pointer) { // For MMIO, the physical address is the first data item. physical := uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical) + virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) if !ok { c.die(bluepillArchContext(context), "invalid physical address") return diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index d05f05c29..766131d60 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -62,3 +62,10 @@ const ( _KVM_NR_INTERRUPTS = 0x100 _KVM_NR_CPUID_ENTRIES = 0x100 ) + +// KVM kvm_memory_region::flags. +const ( + _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0 + _KVM_MEM_READONLY = uint32(1) << 1 + _KVM_MEM_FLAGS_NONE = 0 +) diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index cc6c138b2..7d02ebf19 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -215,6 +215,17 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) + var physicalRegionsReadOnly []physicalRegion + var physicalRegionsAvailable []physicalRegion + + physicalRegionsReadOnly = rdonlyRegionsForSetMem() + physicalRegionsAvailable = availableRegionsForSetMem() + + // Map all read-only regions. + for _, r := range physicalRegionsReadOnly { + m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) + } + // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to @@ -223,6 +234,13 @@ func newMachine(vm int) (*machine, error) { if excludeVirtualRegion(vr) { return // skip region. } + + for _, r := range physicalRegionsReadOnly { + if vr.virtual == r.virtual { + return + } + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -236,7 +254,7 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length) + m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) virtual += length } }) @@ -256,9 +274,9 @@ func newMachine(vm int) (*machine, error) { // not available. This attempts to be efficient for calls in the hot path. // // This panics on error. -func (m *machine) mapPhysical(physical, length uintptr) { +func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) { for end := physical + length; physical < end; { - _, physicalStart, length, ok := calculateBluepillFault(physical) + _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { // Should never happen. panic("mapPhysical on unknown physical address") @@ -266,7 +284,7 @@ func (m *machine) mapPhysical(physical, length uintptr) { if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok { // Not present in the cache; requires setting the slot. - if _, ok := handleBluepillFault(m, physical); !ok { + if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok { panic("handleBluepillFault failed") } } diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c1cbe33be..b99fe425e 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -355,3 +355,13 @@ func (m *machine) retryInGuest(fn func()) { } } } + +// On x86 platform, the flags for "setMemoryRegion" can always be set as 0. +// There is no need to return read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + return nil +} + +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 506ec9af1..61227cafb 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -26,30 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" ) -// setMemoryRegion initializes a region. -// -// This may be called from bluepillHandler, and therefore returns an errno -// directly (instead of wrapping in an error) to avoid allocations. -// -//go:nosplit -func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { - userRegion := userMemoryRegion{ - slot: uint32(slot), - flags: 0, - guestPhysAddr: uint64(physical), - memorySize: uint64(length), - userspaceAddr: uint64(virtual), - } - - // Set the region. - _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(m.fd), - _KVM_SET_USER_MEMORY_REGION, - uintptr(unsafe.Pointer(&userRegion))) - return errno -} - // loadSegments copies the current segments. // // This may be called from within the signal context and throws on error. diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go new file mode 100644 index 000000000..b7e2cfb9d --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -0,0 +1,61 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +// Get all read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + var rdonlyRegions []region + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return + } + + if !vr.accessType.Write && vr.accessType.Read { + rdonlyRegions = append(rdonlyRegions, vr.region) + } + }) + + for _, r := range rdonlyRegions { + physical, _, ok := translateToPhysical(r.virtual) + if !ok { + continue + } + + phyRegions = append(phyRegions, physicalRegion{ + region: region{ + virtual: r.virtual, + length: r.length, + }, + physical: physical, + }) + } + + return phyRegions +} + +// Get all available physicalRegions. +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + var excludeRegions []region + applyVirtualRegions(func(vr virtualRegion) { + if !vr.accessType.Write { + excludeRegions = append(excludeRegions, vr.region) + } + }) + + phyRegions = computePhysicalRegions(excludeRegions) + + return phyRegions +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 405e00292..ed9433311 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. @@ -35,6 +35,30 @@ func entersyscall() //go:linkname exitsyscall runtime.exitsyscall func exitsyscall() +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: uint32(flags), + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + // mapRunData maps the vCPU run data. func mapRunData(fd int) (*runData, error) { r, _, errno := syscall.RawSyscall6( diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD index 77a449a8b..b0e45f159 100644 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ b/pkg/sentry/platform/kvm/testutil/BUILD @@ -9,6 +9,8 @@ go_library( "testutil.go", "testutil_amd64.go", "testutil_amd64.s", + "testutil_arm64.go", + "testutil_arm64.s", ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil", visibility = ["//pkg/sentry/platform/kvm:__pkg__"], diff --git a/pkg/sentry/platform/kvm/testutil/testutil.go b/pkg/sentry/platform/kvm/testutil/testutil.go index 6cf2359a3..5c1efa0fd 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil.go +++ b/pkg/sentry/platform/kvm/testutil/testutil.go @@ -41,9 +41,6 @@ func TwiddleRegsFault() // TwiddleRegsSyscall twiddles registers then executes a syscall. func TwiddleRegsSyscall() -// TwiddleSegments reads segments into known registers. -func TwiddleSegments() - // FloatingPointWorks is a floating point test. // // It returns true or false. diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go index 203d71528..4c108abbf 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go +++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go @@ -21,6 +21,9 @@ import ( "syscall" ) +// TwiddleSegments reads segments into known registers. +func TwiddleSegments() + // SetTestTarget sets the rip appropriately. func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { regs.Rip = uint64(reflect.ValueOf(fn).Pointer()) diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go new file mode 100644 index 000000000..40b2e4acc --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go @@ -0,0 +1,59 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package testutil + +import ( + "fmt" + "reflect" + "syscall" +) + +// SetTestTarget sets the rip appropriately. +func SetTestTarget(regs *syscall.PtraceRegs, fn func()) { + regs.Pc = uint64(reflect.ValueOf(fn).Pointer()) +} + +// SetTouchTarget sets rax appropriately. +func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) { + if target != nil { + regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer()) + } else { + regs.Regs[8] = 0 + } +} + +// RewindSyscall rewinds a syscall RIP. +func RewindSyscall(regs *syscall.PtraceRegs) { + regs.Pc -= 4 +} + +// SetTestRegs initializes registers to known values. +func SetTestRegs(regs *syscall.PtraceRegs) { + for i := 0; i <= 30; i++ { + regs.Regs[i] = uint64(i) + 1 + } +} + +// CheckTestRegs checks that registers were twiddled per TwiddleRegs. +func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) { + for i := 0; i <= 30; i++ { + if need := ^uint64(i + 1); regs.Regs[i] != need { + err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need) + } + } + return +} diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s new file mode 100644 index 000000000..2cd28b2d2 --- /dev/null +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -0,0 +1,91 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +// test_util_arm64.s provides ARM64 test functions. + +#include "funcdata.h" +#include "textflag.h" + +#define SYS_GETPID 172 + +// This function simulates the getpid syscall. +TEXT ·Getpid(SB),NOSPLIT,$0 + NO_LOCAL_POINTERS + MOVD $SYS_GETPID, R8 + SVC + RET + +TEXT ·Touch(SB),NOSPLIT,$0 +start: + MOVD 0(R8), R1 + MOVD $SYS_GETPID, R8 // getpid + SVC + B start + +TEXT ·HaltLoop(SB),NOSPLIT,$0 +start: + HLT + B start + +// This function simulates a loop of syscall. +TEXT ·SyscallLoop(SB),NOSPLIT,$0 +start: + SVC + B start + +TEXT ·SpinLoop(SB),NOSPLIT,$0 +start: + B start + +// MVN: bitwise logical NOT +// This case simulates an application that modified R0-R30. +#define TWIDDLE_REGS() \ + MVN R0, R0; \ + MVN R1, R1; \ + MVN R2, R2; \ + MVN R3, R3; \ + MVN R4, R4; \ + MVN R5, R5; \ + MVN R6, R6; \ + MVN R7, R7; \ + MVN R8, R8; \ + MVN R9, R9; \ + MVN R10, R10; \ + MVN R11, R11; \ + MVN R12, R12; \ + MVN R13, R13; \ + MVN R14, R14; \ + MVN R15, R15; \ + MVN R16, R16; \ + MVN R17, R17; \ + MVN R18_PLATFORM, R18_PLATFORM; \ + MVN R19, R19; \ + MVN R20, R20; \ + MVN R21, R21; \ + MVN R22, R22; \ + MVN R23, R23; \ + MVN R24, R24; \ + MVN R25, R25; \ + MVN R26, R26; \ + MVN R27, R27; \ + MVN g, g; \ + MVN R29, R29; \ + MVN R30, R30; + +TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 + TWIDDLE_REGS() + SVC + RET // never reached diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 9f0ecfbe4..ddb1f41e3 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -327,6 +327,20 @@ func (t *thread) dumpAndPanic(message string) { panic(message) } +func (t *thread) unexpectedStubExit() { + msg, err := t.getEventMessage() + status := syscall.WaitStatus(msg) + if status.Signaled() && status.Signal() == syscall.SIGKILL { + // SIGKILL can be only sent by an user or OOM-killer. In both + // these cases, we don't need to panic. There is no reasons to + // think that something wrong in gVisor. + log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid) + pid := os.Getpid() + syscall.Tgkill(pid, pid, syscall.Signal(syscall.SIGKILL)) + } + t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) +} + // wait waits for a stop event. // // Precondition: outcome is a valid waitOutcome. @@ -355,8 +369,7 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal { } if stopSig == syscall.SIGTRAP { if status.TrapCause() == syscall.PTRACE_EVENT_EXIT { - msg, err := t.getEventMessage() - t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err)) + t.unexpectedStubExit() } // Re-encode the trap cause the way it's expected. return stopSig | syscall.Signal(status.TrapCause()<<8) diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index c075b5f91..3782d4332 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -129,6 +129,9 @@ func createStub() (*thread, error) { // transitively) will be killed as well. It's simply not possible to // safely handle a single stub getting killed: the exact state of // execution is unknown and not recoverable. + // + // In addition, we set the PTRACE_O_TRACEEXIT option to log more + // information about a stub process when it receives a fatal signal. return attachedThread(uintptr(syscall.SIGKILL)|syscall.CLONE_FILES, defaultAction) } diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index b80a3604d..2ae6b9f9d 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 939a0033a..f1af18265 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -4,43 +4,66 @@ load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) go_template( - name = "defs", - srcs = select( - { - "@bazel_tools//src/conditions:linux_aarch64": ["defs.go", "defs_arm64.go", "offsets_arm64.go", "aarch64.go",], - "//conditions:default": ["defs.go", "defs_amd64.go", "offsets_amd64.go", "x86.go",], - }, - ), + name = "defs_amd64", + srcs = [ + "defs.go", + "defs_amd64.go", + "offsets_amd64.go", + "x86.go", + ], + visibility = [":__subpackages__"], +) + +go_template( + name = "defs_arm64", + srcs = [ + "aarch64.go", + "defs.go", + "defs_arm64.go", + "offsets_arm64.go", + ], visibility = [":__subpackages__"], ) go_template_instance( - name = "defs_impl", - out = "defs_impl.go", + name = "defs_impl_amd64", + out = "defs_impl_amd64.go", package = "ring0", - template = ":defs", + template = ":defs_amd64", +) + +go_template_instance( + name = "defs_impl_arm64", + out = "defs_impl_arm64.go", + package = "ring0", + template = ":defs_arm64", +) + +genrule( + name = "entry_impl_amd64", + srcs = ["entry_amd64.s"], + outs = ["entry_impl_amd64.s"], + cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) genrule( - name = "entry_impl", - srcs = ["entry_amd64.s", "entry_arm64.s"], - outs = ["entry_impl.s"], - cmd = select( - { - "@bazel_tools//src/conditions:linux_aarch64": "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", - "//conditions:default": "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", - }, - ), + name = "entry_impl_arm64", + srcs = ["entry_arm64.s"], + outs = ["entry_impl_arm64.s"], + cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) go_library( name = "ring0", srcs = [ - "defs_impl.go", + "defs_impl_amd64.go", + "defs_impl_arm64.go", "entry_amd64.go", "entry_arm64.go", - "entry_impl.s", + "entry_impl_amd64.s", + "entry_impl_arm64.s", "kernel.go", "kernel_amd64.go", "kernel_arm64.go", diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index d7029d5a9..42076fb04 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -1,20 +1,27 @@ load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") +go_template_instance( + name = "defs_impl_arm64", + out = "defs_impl_arm64.go", + package = "main", + template = "//pkg/sentry/platform/ring0:defs_arm64", +) go_template_instance( - name = "defs_impl", - out = "defs_impl.go", + name = "defs_impl_amd64", + out = "defs_impl_amd64.go", package = "main", - template = "//pkg/sentry/platform/ring0:defs", + template = "//pkg/sentry/platform/ring0:defs_amd64", ) go_binary( name = "gen_offsets", srcs = [ - "defs_impl.go", + "defs_impl_amd64.go", + "defs_impl_arm64.go", "main.go", ], visibility = ["//pkg/sentry/platform/ring0:__pkg__"], diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index ea090b686..934a90378 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,10 +1,9 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - go_template( name = "generic_walker", srcs = [ diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go index eace3766d..c303435d5 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sentry/sighandling/sighandling_unsafe.go @@ -23,7 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -// TODO(b/34161764): Move to pkg/abi/linux along with definitions in +// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in // pkg/sentry/arch. type sigaction struct { handler uintptr diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 3300f9a6b..26176b10d 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "socket", srcs = ["socket.go"], diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 81dbd7309..4a6e83a8b 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "control", srcs = ["control.go"], diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a951f1bb0..8b66a719d 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "hostinet", srcs = [ @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip/stack", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 3a4fdec47..e67b46c9e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -16,8 +16,11 @@ package hostinet import ( "fmt" + "io" "io/ioutil" "os" + "reflect" + "strconv" "strings" "syscall" @@ -26,7 +29,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) var defaultRecvBufSize = inet.TCPBufferSize{ @@ -51,6 +56,8 @@ type Stack struct { tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool + netDevFile *os.File + netSNMPFile *os.File } // NewStack returns an empty Stack containing no configuration. @@ -98,6 +105,18 @@ func (s *Stack) Configure() error { log.Warningf("Failed to read if TCP SACK if enabled, setting to true") } + if f, err := os.Open("/proc/net/dev"); err != nil { + log.Warningf("Failed to open /proc/net/dev: %v", err) + } else { + s.netDevFile = f + } + + if f, err := os.Open("/proc/net/snmp"); err != nil { + log.Warningf("Failed to open /proc/net/snmp: %v", err) + } else { + s.netSNMPFile = f + } + return nil } @@ -326,9 +345,95 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// getLine reads one line from proc file, with specified prefix. +// The last argument, withHeader, specifies if it contains line header. +func getLine(f *os.File, prefix string, withHeader bool) string { + data := make([]byte, 4096) + + if _, err := f.Seek(0, 0); err != nil { + return "" + } + + if _, err := io.ReadFull(f, data); err != io.ErrUnexpectedEOF { + return "" + } + + prefix = prefix + ":" + lines := strings.Split(string(data), "\n") + for _, l := range lines { + l = strings.TrimSpace(l) + if strings.HasPrefix(l, prefix) { + if withHeader { + withHeader = false + continue + } + return l + } + } + return "" +} + +func toSlice(i interface{}) []uint64 { + v := reflect.Indirect(reflect.ValueOf(i)) + return v.Slice(0, v.Len()).Interface().([]uint64) +} + // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserror.EOPNOTSUPP + var ( + snmpTCP bool + rawLine string + sliceStat []uint64 + ) + + switch stat.(type) { + case *inet.StatDev: + if s.netDevFile == nil { + return fmt.Errorf("/proc/net/dev is not opened for hostinet") + } + rawLine = getLine(s.netDevFile, arg, false /* with no header */) + case *inet.StatSNMPIP, *inet.StatSNMPICMP, *inet.StatSNMPICMPMSG, *inet.StatSNMPTCP, *inet.StatSNMPUDP, *inet.StatSNMPUDPLite: + if s.netSNMPFile == nil { + return fmt.Errorf("/proc/net/snmp is not opened for hostinet") + } + rawLine = getLine(s.netSNMPFile, arg, true) + default: + return syserr.ErrEndpointOperation.ToError() + } + + if rawLine == "" { + return fmt.Errorf("Failed to get raw line") + } + + parts := strings.SplitN(rawLine, ":", 2) + if len(parts) != 2 { + return fmt.Errorf("Failed to get prefix from: %q", rawLine) + } + + sliceStat = toSlice(stat) + fields := strings.Fields(strings.TrimSpace(parts[1])) + if len(fields) != len(sliceStat) { + return fmt.Errorf("Failed to parse fields: %q", rawLine) + } + if _, ok := stat.(*inet.StatSNMPTCP); ok { + snmpTCP = true + } + for i := 0; i < len(sliceStat); i++ { + var err error + if snmpTCP && i == 3 { + var tmp int64 + // MaxConn field is signed, RFC 2012. + tmp, err = strconv.ParseInt(fields[i], 10, 64) + sliceStat[i] = uint64(tmp) // Convert back to int before use. + } else { + sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) + } + if err != nil { + return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + } + } + + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -338,3 +443,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 354a0d6ee..5eb06bbf4 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "netfilter", srcs = [ diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 45ebb2a0e..79589e3c8 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "netlink", srcs = [ @@ -20,7 +20,9 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", + "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 445080aa4..463544c1a 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "port", srcs = ["port.go"], diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 689cad997..be005df24 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -30,6 +30,13 @@ type Protocol interface { // Protocol returns the Linux netlink protocol value. Protocol() int + // CanSend returns true if this protocol may ever send messages. + // + // TODO(gvisor.dev/issue/1119): This is a workaround to allow + // advertising support for otherwise unimplemented features on sockets + // that will never send messages, thus making those features no-ops. + CanSend() bool + // ProcessMessage processes a single message from userspace. // // If err == nil, any messages added to ms will be sent back to the diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 5dc8533ec..1d4912753 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "route", srcs = ["protocol.go"], diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index cc70ac237..6b4a0ecf4 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -61,6 +61,11 @@ func (p *Protocol) Protocol() int { return linux.NETLINK_ROUTE } +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return true +} + // dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests. func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d0aab293d..4a1b87a9a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -27,7 +27,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -52,6 +54,8 @@ const ( maxSendBufferSize = 4 << 20 // 4MB ) +var errNoFilter = syserr.New("no filter attached", linux.ENOENT) + // netlinkSocketDevice is the netlink socket virtual device. var netlinkSocketDevice = device.NewAnonDevice() @@ -60,7 +64,7 @@ var netlinkSocketDevice = device.NewAnonDevice() // This implementation only supports userspace sending and receiving messages // to/from the kernel. // -// Socket implements socket.Socket. +// Socket implements socket.Socket and transport.Credentialer. // // +stateify savable type Socket struct { @@ -103,9 +107,19 @@ type Socket struct { // sendBufferSize is the send buffer "size". We don't actually have a // fixed buffer but only consume this many bytes. sendBufferSize uint32 + + // passcred indicates if this socket wants SCM credentials. + passcred bool + + // filter indicates that this socket has a BPF filter "installed". + // + // TODO(gvisor.dev/issue/1119): We don't actually support filtering, + // this is just bookkeeping for tracking add/remove. + filter bool } var _ socket.Socket = (*Socket)(nil) +var _ transport.Credentialer = (*Socket)(nil) // NewSocket creates a new Socket. func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { @@ -171,6 +185,22 @@ func (s *Socket) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } +// Passcred implements transport.Credentialer.Passcred. +func (s *Socket) Passcred() bool { + s.mu.Lock() + passcred := s.passcred + s.mu.Unlock() + return passcred +} + +// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. +func (s *Socket) ConnectedPasscred() bool { + // This socket is connected to the kernel, which doesn't need creds. + // + // This is arbitrary, as ConnectedPasscred on this type has no callers. + return false +} + // Ioctl implements fs.FileOperations.Ioctl. func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) { // TODO(b/68878065): no ioctls supported. @@ -308,9 +338,20 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem. // We don't have limit on receiving size. return int32(math.MaxInt32), nil + case linux.SO_PASSCRED: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + var passcred int32 + if s.Passcred() { + passcred = 1 + } + return passcred, nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } + case linux.SOL_NETLINK: switch name { case linux.NETLINK_BROADCAST_ERROR, @@ -347,6 +388,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy s.sendBufferSize = size s.mu.Unlock() return nil + case linux.SO_RCVBUF: if len(opt) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -354,6 +396,52 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy // We don't have limit on receiving size. So just accept anything as // valid for compatibility. return nil + + case linux.SO_PASSCRED: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + passcred := usermem.ByteOrder.Uint32(opt) + + s.mu.Lock() + s.passcred = passcred != 0 + s.mu.Unlock() + return nil + + case linux.SO_ATTACH_FILTER: + // TODO(gvisor.dev/issue/1119): We don't actually + // support filtering. If this socket can't ever send + // messages, then there is nothing to filter and we can + // advertise support. Otherwise, be conservative and + // return an error. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + s.filter = true + s.mu.Unlock() + return nil + + case linux.SO_DETACH_FILTER: + // TODO(gvisor.dev/issue/1119): See above. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + filter := s.filter + s.filter = false + s.mu.Unlock() + + if !filter { + return errNoFilter + } + + return nil + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -416,6 +504,24 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have Peek: flags&linux.MSG_PEEK != 0, } + // If MSG_TRUNC is set with a zero byte destination then we still need + // to read the message and discard it, or in the case where MSG_PEEK is + // set, leave it be. In both cases the full message length must be + // returned. However, the memory manager for the destination will not read + // the endpoint if the destination is zero length. + // + // In order for the endpoint to be read when the destination size is zero, + // we must cause a read of the endpoint by using a separate fake zero + // length block sequence and calling the EndpointReader directly. + if trunc && dst.Addrs.NumBytes() == 0 { + // Perform a read to a zero byte block sequence. We can ignore the + // original destination since it was zero bytes. The length returned by + // ReadToBlocks is ignored and we return the full message length to comply + // with MSG_TRUNC. + _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0)))) + return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + } + if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { var mflags int if n < int64(r.MsgSize) { @@ -464,6 +570,26 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ }) } +// kernelSCM implements control.SCMCredentials with credentials that represent +// the kernel itself rather than a Task. +// +// +stateify savable +type kernelSCM struct{} + +// Equals implements transport.CredentialsControlMessage.Equals. +func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool { + _, ok := oc.(kernelSCM) + return ok +} + +// Credentials implements control.SCMCredentials.Credentials. +func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) { + return 0, auth.RootUID, auth.RootGID +} + +// kernelCreds is the concrete version of kernelSCM used in all creds. +var kernelCreds = &kernelSCM{} + // sendResponse sends the response messages in ms back to userspace. func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error { // Linux combines multiple netlink messages into a single datagram. @@ -472,10 +598,15 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error bufs = append(bufs, m.Finalize()) } + // All messages are from the kernel. + cms := transport.ControlMessages{ + Credentials: kernelCreds, + } + if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if err != nil && err != syserr.ErrWouldBlock { @@ -499,7 +630,10 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error PortID: uint32(ms.PortID), }) - _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{}) + // Add the dump_done_errno payload. + m.Put(int64(0)) + + _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err } diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD new file mode 100644 index 000000000..0777f3baf --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/BUILD @@ -0,0 +1,17 @@ +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "uevent", + srcs = ["protocol.go"], + importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/context", + "//pkg/sentry/kernel", + "//pkg/sentry/socket/netlink", + "//pkg/syserr", + ], +) diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go new file mode 100644 index 000000000..b5d7808d7 --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/protocol.go @@ -0,0 +1,60 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol. +// +// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does +// not support any device events, so these sockets never send any messages. +package uevent + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink" + "gvisor.dev/gvisor/pkg/syserr" +) + +// Protocol implements netlink.Protocol. +// +// +stateify savable +type Protocol struct{} + +var _ netlink.Protocol = (*Protocol)(nil) + +// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol. +func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) { + return &Protocol{}, nil +} + +// Protocol implements netlink.Protocol.Protocol. +func (p *Protocol) Protocol() int { + return linux.NETLINK_KOBJECT_UEVENT +} + +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return false +} + +// ProcessMessage implements netlink.Protocol.ProcessMessage. +func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { + // Silently ignore all messages. + return nil +} + +// init registers the NETLINK_KOBJECT_UEVENT provider. +func init() { + netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol) +} diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/netstack/BUILD index e927821e1..e414d8055 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -1,17 +1,17 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( - name = "epsocket", + name = "netstack", srcs = [ "device.go", - "epsocket.go", + "netstack.go", "provider.go", "save_restore.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/epsocket", + importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack", visibility = [ "//pkg/sentry:internal", ], diff --git a/pkg/sentry/socket/epsocket/device.go b/pkg/sentry/socket/netstack/device.go index 85484d5b1..fbeb89fb8 100644 --- a/pkg/sentry/socket/epsocket/device.go +++ b/pkg/sentry/socket/netstack/device.go @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import "gvisor.dev/gvisor/pkg/sentry/device" -// epsocketDevice is the endpoint socket virtual device. -var epsocketDevice = device.NewAnonDevice() +// netstackDevice is the endpoint socket virtual device. +var netstackDevice = device.NewAnonDevice() diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/netstack/netstack.go index 3e66f9cbb..d92399efd 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package epsocket provides an implementation of the socket.Socket interface +// Package netstack provides an implementation of the socket.Socket interface // that is backed by a tcpip.Endpoint. // // It does not depend on any particular endpoint implementation, and thus can @@ -22,7 +22,7 @@ // Lock ordering: netstack => mm: ioSequencePayload copies user memory inside // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during // this operation. -package epsocket +package netstack import ( "bytes" @@ -53,6 +53,7 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -137,15 +138,19 @@ var Metrics = tcpip.Stats{ }, }, IP: tcpip.IPStats{ - PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), - InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), - PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), - PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), - OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), + InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), + PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), + PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), + OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), + MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), + CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), + EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), @@ -155,6 +160,7 @@ var Metrics = tcpip.Stats{ ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."), InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."), SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."), + SegmentSendErrors: mustCreateMetric("/netstack/tcp/segment_send_errors", "Number of TCP segments failed to be sent."), ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."), ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."), Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."), @@ -170,13 +176,18 @@ var Metrics = tcpip.Stats{ UnknownPortErrors: mustCreateMetric("/netstack/udp/unknown_port_errors", "Number of incoming UDP datagrams dropped because they did not have a known destination port."), ReceiveBufferErrors: mustCreateMetric("/netstack/udp/receive_buffer_errors", "Number of incoming UDP datagrams dropped due to the receiving buffer being in an invalid state."), MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), - PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent via sendUDP."), + PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), + PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), }, } +// DefaultTTL is linux's default TTL. All network protocols in all stacks used +// with this package must have this value set as their default TTL. +const DefaultTTL = 64 + const sizeOfInt32 int = 4 -var errStackType = syserr.New("expected but did not receive an epsocket.Stack", linux.EINVAL) +var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) // ntohs converts a 16-bit number from network byte order to host byte order. It // assumes that the host is little endian. @@ -262,20 +273,20 @@ type SocketOperations struct { // valid when timestampValid is true. It is protected by readMu. timestampNS int64 - // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket - // level, because it takes into account data from readView. + // sockOptInq corresponds to TCP_INQ. It is implemented at this level + // because it takes into account data from readView. sockOptInq bool } // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { + if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { return nil, syserr.TranslateNetstackError(err) } } - dirent := socket.NewDirent(t, epsocketDevice) + dirent := socket.NewDirent(t, netstackDevice) defer dirent.DecRef() return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ Queue: queue, @@ -288,6 +299,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{})) var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{})) +var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{})) // bytesToIPAddress converts an IPv4 or IPv6 address from the user to the // netstack representation taking any addresses into account. @@ -299,12 +311,12 @@ func bytesToIPAddress(addr []byte) tcpip.Address { } // AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and -// AF_INET6 addresses. +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. // // strict indicates whether addresses with the AF_UNSPEC family are accepted of not. // -// AddressAndFamily returns an address, its family. +// AddressAndFamily returns an address and its family. func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { @@ -363,6 +375,22 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } return out, family, nil + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(b/129292371): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + case linux.AF_UNSPEC: return tcpip.FullAddress{}, family, nil @@ -760,7 +788,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // tcpip.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is - // implemented specifically for epsocket.SocketOperations rather than + // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket // options where the implementation is not shared, as unix sockets need // their own support for SO_TIMESTAMP. @@ -833,7 +861,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, return getSockOptIPv6(t, ep, name, outLen) case linux.SOL_IP: - return getSockOptIP(t, ep, name, outLen) + return getSockOptIP(t, ep, name, outLen, family) case linux.SOL_UDP, linux.SOL_ICMPV6, @@ -942,6 +970,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_BINDTODEVICE: + var v tcpip.BindToDeviceOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if len(v) == 0 { + return []byte{}, nil + } + if outLen < linux.IFNAMSIZ { + return nil, syserr.ErrInvalidArgument + } + return append([]byte(v), 0), nil + case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1014,8 +1055,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.ErrInvalidArgument } - var v tcpip.DelayOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1132,6 +1173,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa copy(b, v) return b, nil + case linux.TCP_LINGER2: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPLingerTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1156,6 +1209,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_TCLASS: + // Length handling for parity with Linux. + if outLen == 0 { + return make([]byte, 0), nil + } + var v tcpip.IPv6TrafficClassOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + uintv := uint32(v) + // Linux truncates the output binary to outLen. + ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + // Handle cases where outLen is lesser than sizeOfInt32. + if len(ib) > outLen { + ib = ib[:outLen] + } + return ib, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1163,8 +1235,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { switch name { + case linux.IP_TTL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TTLOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + // Fill in the default value, if needed. + if v == 0 { + v = DefaultTTL + } + + return int32(v), nil + case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1206,6 +1295,20 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac } return int32(0), nil + case linux.IP_TOS: + // Length handling for parity with Linux. + if outLen == 0 { + return []byte(nil), nil + } + var v tcpip.IPv4TOSOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if outLen < sizeOfInt32 { + return uint8(v), nil + } + return int32(v), nil + default: emitUnimplementedEventIP(t, name) } @@ -1216,7 +1319,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac // tcpip.Endpoint. func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is - // implemented specifically for epsocket.SocketOperations rather than + // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket // options where the implementation is not shared, as unix sockets need // their own support for SO_TIMESTAMP. @@ -1305,6 +1408,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_BINDTODEVICE: + n := bytes.IndexByte(optVal, 0) + if n == -1 { + n = len(optVal) + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -1399,11 +1509,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o tcpip.DelayOption + var o int if v == 0 { o = 1 } - return syserr.TranslateNetstackError(ep.SetSockOpt(o)) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1458,6 +1568,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return nil + case linux.TCP_LINGER2: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1497,6 +1615,19 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_TCLASS: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < -1 || v > 255 { + return syserr.ErrInvalidArgument + } + if v == -1 { + v = 0 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + default: emitUnimplementedEventIPv6(t, name) } @@ -1628,6 +1759,30 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s t.Kernel().EmitUnimplementedEvent(t) return syserr.ErrInvalidArgument + case linux.IP_TTL: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + // -1 means default TTL. + if v == -1 { + v = 0 + } else if v < 1 || v > 255 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + + case linux.IP_TOS: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1651,9 +1806,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, - linux.IP_TOS, linux.IP_TRANSPARENT, - linux.IP_TTL, linux.IP_UNBLOCK_SOURCE, linux.IP_UNICAST_IF, linux.IP_XFRM_POLICY, @@ -1838,12 +1991,14 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(2 + l) } return &out, uint32(3 + l) + case linux.AF_INET: var out linux.SockAddrInet copy(out.Addr[:], addr.Addr) out.Family = linux.AF_INET out.Port = htons(addr.Port) - return &out, uint32(binary.Size(out)) + return &out, uint32(sockAddrInetSize) + case linux.AF_INET6: var out linux.SockAddrInet6 if len(addr.Addr) == 4 { @@ -1859,7 +2014,17 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) if isLinkLocal(addr.Addr) { out.Scope_id = uint32(addr.NIC) } - return &out, uint32(binary.Size(out)) + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(b/129292371): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + default: return nil, 0 } @@ -2215,7 +2380,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { @@ -2518,7 +2683,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Flag values and meanings are described in greater detail in netdevice(7) in // the SIOCGIFFLAGS section. func interfaceStatusFlags(stack inet.Stack, name string) (uint32, *syserr.Error) { - // epsocket should only ever be passed an epsocket.Stack. + // We should only ever be passed a netstack.Stack. epstack, ok := stack.(*Stack) if !ok { return 0, errStackType diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/netstack/provider.go index 421f93dc4..2d2c1ba2a 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "syscall" @@ -62,10 +62,14 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in } case linux.SOCK_RAW: + // TODO(b/142504697): "In order to create a raw socket, a + // process must have the CAP_NET_RAW capability in the user + // namespace that governs its network namespace." - raw(7) + // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { - return 0, true, syserr.ErrPermissionDenied + return 0, true, syserr.ErrNotPermitted } switch protocol { @@ -85,7 +89,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in return 0, true, syserr.ErrProtocolNotSupported } -// Socket creates a new socket object for the AF_INET or AF_INET6 family. +// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET +// family. func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Fail right away if we don't have a stack. stack := t.NetworkContext() @@ -99,6 +104,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return nil, nil } + // Packet sockets are handled separately, since they are neither INET + // nor INET6 specific. + if p.family == linux.AF_PACKET { + return packetSocket(t, eps, stype, protocol) + } + // Figure out the transport protocol. transProto, associated, err := getTransportProtocol(t, stype, protocol) if err != nil { @@ -121,12 +132,47 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return New(t, p.family, stype, int(transProto), wq, ep) } +func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { + // TODO(b/142504697): "In order to create a packet socket, a process + // must have the CAP_NET_RAW capability in the user namespace that + // governs its network namespace." - packet(7) + + // Packet sockets require CAP_NET_RAW. + creds := auth.CredentialsFromContext(t) + if !creds.HasCapability(linux.CAP_NET_RAW) { + return nil, syserr.ErrNotPermitted + } + + // "cooked" packets don't contain link layer information. + var cooked bool + switch stype { + case linux.SOCK_DGRAM: + cooked = true + case linux.SOCK_RAW: + cooked = false + default: + return nil, syserr.ErrProtocolNotSupported + } + + // protocol is passed in network byte order, but netstack wants it in + // host order. + netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + + wq := &waiter.Queue{} + ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return New(t, linux.AF_PACKET, stype, protocol, wq, ep) +} + // Pair just returns nil sockets (not supported). func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil } -// init registers socket providers for AF_INET and AF_INET6. +// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET. func init() { // Providers backed by netstack. p := []provider{ @@ -138,6 +184,9 @@ func init() { family: linux.AF_INET6, netProto: ipv6.ProtocolNumber, }, + { + family: linux.AF_PACKET, + }, } for i := range p { diff --git a/pkg/sentry/socket/epsocket/save_restore.go b/pkg/sentry/socket/netstack/save_restore.go index f7b8c10cc..c7aaf722a 100644 --- a/pkg/sentry/socket/epsocket/save_restore.go +++ b/pkg/sentry/socket/netstack/save_restore.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/netstack/stack.go index 7cf7ff735..a0db2d4fd 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "gvisor.dev/gvisor/pkg/abi/linux" @@ -144,7 +144,98 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserr.ErrEndpointOperation.ToError() + switch stats := stat.(type) { + case *inet.StatSNMPIP: + ip := Metrics.IP + *stats = inet.StatSNMPIP{ + 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. + 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. + ip.PacketsReceived.Value(), // InReceives. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. + ip.InvalidAddressesReceived.Value(), // InAddrErrors. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. + ip.PacketsDelivered.Value(), // InDelivers. + ip.PacketsSent.Value(), // OutRequests. + ip.OutgoingPacketErrors.Value(), // OutDiscards. + 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + } + case *inet.StatSNMPICMP: + in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + *stats = inet.StatSNMPICMP{ + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs. + Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors. + in.DstUnreachable.Value(), // InDestUnreachs. + in.TimeExceeded.Value(), // InTimeExcds. + in.ParamProblem.Value(), // InParmProbs. + in.SrcQuench.Value(), // InSrcQuenchs. + in.Redirect.Value(), // InRedirects. + in.Echo.Value(), // InEchos. + in.EchoReply.Value(), // InEchoReps. + in.Timestamp.Value(), // InTimestamps. + in.TimestampReply.Value(), // InTimestampReps. + in.InfoRequest.Value(), // InAddrMasks. + in.InfoReply.Value(), // InAddrMaskReps. + 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs. + Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. + } + case *inet.StatSNMPTCP: + tcp := Metrics.TCP + // RFC 2012 (updates 1213): SNMPv2-MIB-TCP. + *stats = inet.StatSNMPTCP{ + 1, // RtoAlgorithm. + 200, // RtoMin. + 120000, // RtoMax. + (1<<64 - 1), // MaxConn. + tcp.ActiveConnectionOpenings.Value(), // ActiveOpens. + tcp.PassiveConnectionOpenings.Value(), // PassiveOpens. + tcp.FailedConnectionAttempts.Value(), // AttemptFails. + tcp.EstablishedResets.Value(), // EstabResets. + tcp.CurrentEstablished.Value(), // CurrEstab. + tcp.ValidSegmentsReceived.Value(), // InSegs. + tcp.SegmentsSent.Value(), // OutSegs. + tcp.Retransmits.Value(), // RetransSegs. + tcp.InvalidSegmentsReceived.Value(), // InErrs. + tcp.ResetsSent.Value(), // OutRsts. + tcp.ChecksumErrors.Value(), // InCsumErrors. + } + case *inet.StatSNMPUDP: + udp := Metrics.UDP + *stats = inet.StatSNMPUDP{ + udp.PacketsReceived.Value(), // InDatagrams. + udp.UnknownPortErrors.Value(), // NoPorts. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors. + udp.PacketsSent.Value(), // OutDatagrams. + udp.ReceiveBufferErrors.Value(), // RcvbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors. + 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti. + } + default: + return syserr.ErrEndpointOperation.ToError() + } + return nil } // RouteTable implements inet.Stack.RouteTable. @@ -200,3 +291,18 @@ func (s *Stack) FillDefaultIPTables() { func (s *Stack) Resume() { s.Stack.Resume() } + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { + return s.Stack.RegisteredEndpoints() +} + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { + return s.Stack.CleanupEndpoints() +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { + s.Stack.RestoreCleanupEndpoints(es) +} diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 3a6baa308..4668b87d1 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -37,6 +37,7 @@ go_library( "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", "//pkg/unet", "//pkg/waiter", ], diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 5dcb6b455..f7878a760 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/unet" ) @@ -165,3 +166,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index da9977fde..5b6a154f6 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "unix", srcs = [ @@ -24,7 +24,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", - "//pkg/sentry/socket/epsocket", + "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 0b0240336..788ad70d2 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") +package(licenses = ["notice"]) + go_template_instance( name = "transport_message_list", out = "transport_message_list.go", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 4bd15808a..dea11e253 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -220,6 +220,11 @@ func (e *connectionedEndpoint) Close() { case e.Connected(): e.connected.CloseSend() e.receiver.CloseRecv() + // Still have unread data? If yes, we set this into the write + // end so that the peer can get ECONNRESET) when it does read. + if e.receiver.RecvQueuedSize() > 0 { + e.connected.CloseUnread() + } c = e.connected r = e.receiver e.connected = nil diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index 0415fae9a..e27b1c714 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -33,6 +33,7 @@ type queue struct { mu sync.Mutex `state:"nosave"` closed bool + unread bool used int64 limit int64 dataList messageList @@ -161,6 +162,9 @@ func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) { err := syserr.ErrWouldBlock if q.closed { err = syserr.ErrClosedForReceive + if q.unread { + err = syserr.ErrConnectionReset + } } q.mu.Unlock() @@ -188,7 +192,9 @@ func (q *queue) Peek() (*message, *syserr.Error) { if q.dataList.Front() == nil { err := syserr.ErrWouldBlock if q.closed { - err = syserr.ErrClosedForReceive + if err = syserr.ErrClosedForReceive; q.unread { + err = syserr.ErrConnectionReset + } } return nil, err } @@ -208,3 +214,11 @@ func (q *queue) QueuedSize() int64 { func (q *queue) MaxQueueSize() int64 { return q.limit } + +// CloseUnread sets flag to indicate that the peer is closed (not shutdown) +// with unread data. So if read on this queue shall return ECONNRESET error. +func (q *queue) CloseUnread() { + q.mu.Lock() + defer q.mu.Unlock() + q.unread = true +} diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 1867b3a5c..529a7a7a9 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -608,6 +608,10 @@ type ConnectedEndpoint interface { // Release releases any resources owned by the ConnectedEndpoint. It should // be called before droping all references to a ConnectedEndpoint. Release() + + // CloseUnread sets the fact that this end is closed with unread data to + // the peer socket. + CloseUnread() } // +stateify savable @@ -711,6 +715,11 @@ func (e *connectedEndpoint) Release() { e.writeQueue.DecRef() } +// CloseUnread implements ConnectedEndpoint.CloseUnread. +func (e *connectedEndpoint) CloseUnread() { + e.writeQueue.CloseUnread() +} + // baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless // unix domain socket Endpoint implementations. // diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 0d0cb68df..1aaae8487 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -31,7 +31,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" - "gvisor.dev/gvisor/pkg/sentry/socket/epsocket" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" @@ -40,8 +40,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// SocketOperations is a Unix socket. It is similar to an epsocket, except it -// is backed by a transport.Endpoint instead of a tcpip.Endpoint. +// SocketOperations is a Unix socket. It is similar to a netstack socket, +// except it is backed by a transport.Endpoint instead of a tcpip.Endpoint. // // +stateify savable type SocketOperations struct { @@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) + addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { return "", err } @@ -143,7 +143,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, return nil, 0, syserr.TranslateNetstackError(err) } - a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr) + a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -155,19 +155,19 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, return nil, 0, syserr.TranslateNetstackError(err) } - a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr) + a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - return epsocket.Ioctl(ctx, s.ep, io, args) + return netstack.Ioctl(ctx, s.ep, io, args) } // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { - return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by @@ -474,13 +474,13 @@ func (s *SocketOperations) EventUnregister(e *waiter.Entry) { // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { - return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal) + return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) } // Shutdown implements the linux syscall shutdown(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { - f, err := epsocket.ConvertShutdown(how) + f, err := netstack.ConvertShutdown(how) if err != nil { return err } @@ -546,7 +546,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -581,7 +581,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -595,7 +595,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags total += n } - if err != nil || !waitAll || isPacket || n >= dst.NumBytes() { + streamPeerClosed := s.stype == linux.SOCK_STREAM && n == 0 && err == nil + if err != nil || !waitAll || isPacket || n >= dst.NumBytes() || streamPeerClosed { if total > 0 { err = nil } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 7d7b42eba..72ebf766d 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -32,8 +32,8 @@ go_library( "//pkg/sentry/arch", "//pkg/sentry/kernel", "//pkg/sentry/socket/control", - "//pkg/sentry/socket/epsocket", "//pkg/sentry/socket/netlink", + "//pkg/sentry/socket/netstack", "//pkg/sentry/syscalls/linux", "//pkg/sentry/usermem", ], diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f779186ad..94334f6d2 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/control" - "gvisor.dev/gvisor/pkg/sentry/socket/epsocket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/usermem" ) @@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, _, err := epsocket.AddressAndFamily(int(family), b, true /* strict */) + fa, _, err := netstack.AddressAndFamily(int(family), b, true /* strict */) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index e76ee27d2..4c0bf96e4 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -1,13 +1,15 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "linux", srcs = [ "error.go", "flags.go", "linux64.go", + "linux64_amd64.go", + "linux64_arm64.go", "sigset.go", "sys_aio.go", "sys_capability.go", @@ -77,6 +79,7 @@ go_library( "//pkg/sentry/kernel/signalfd", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", + "//pkg/sentry/loader", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/safemem", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 61acd0abd..68589a377 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,377 +15,13 @@ // Package linux provides syscall tables for amd64 Linux. package linux -import ( - "gvisor.dev/gvisor/pkg/abi" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/syscalls" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserror" -) - -// AUDIT_ARCH_X86_64 identifies the Linux syscall API on AMD64, and is taken -// from <linux/audit.h>. -const _AUDIT_ARCH_X86_64 = 0xc000003e +const ( + // LinuxSysname is the OS name advertised by gVisor. + LinuxSysname = "Linux" -// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall -// numbers from Linux 4.4. -var AMD64 = &kernel.SyscallTable{ - OS: abi.Linux, - Arch: arch.AMD64, - Version: kernel.Version{ - // Version 4.4 is chosen as a stable, longterm version of Linux, which - // guides the interface provided by this syscall table. The build - // version is that for a clean build with default kernel config, at 5 - // minutes after v4.4 was tagged. - Sysname: "Linux", - Release: "4.4", - Version: "#1 SMP Sun Jan 10 15:06:54 PST 2016", - }, - AuditNumber: _AUDIT_ARCH_X86_64, - Table: map[uintptr]kernel.Syscall{ - 0: syscalls.Supported("read", Read), - 1: syscalls.Supported("write", Write), - 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil), - 3: syscalls.Supported("close", Close), - 4: syscalls.Supported("stat", Stat), - 5: syscalls.Supported("fstat", Fstat), - 6: syscalls.Supported("lstat", Lstat), - 7: syscalls.Supported("poll", Poll), - 8: syscalls.Supported("lseek", Lseek), - 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil), - 10: syscalls.Supported("mprotect", Mprotect), - 11: syscalls.Supported("munmap", Munmap), - 12: syscalls.Supported("brk", Brk), - 13: syscalls.Supported("rt_sigaction", RtSigaction), - 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask), - 15: syscalls.Supported("rt_sigreturn", RtSigreturn), - 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), - 17: syscalls.Supported("pread64", Pread64), - 18: syscalls.Supported("pwrite64", Pwrite64), - 19: syscalls.Supported("readv", Readv), - 20: syscalls.Supported("writev", Writev), - 21: syscalls.Supported("access", Access), - 22: syscalls.Supported("pipe", Pipe), - 23: syscalls.Supported("select", Select), - 24: syscalls.Supported("sched_yield", SchedYield), - 25: syscalls.Supported("mremap", Mremap), - 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil), - 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), - 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), - 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), - 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil), - 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), - 32: syscalls.Supported("dup", Dup), - 33: syscalls.Supported("dup2", Dup2), - 34: syscalls.Supported("pause", Pause), - 35: syscalls.Supported("nanosleep", Nanosleep), - 36: syscalls.Supported("getitimer", Getitimer), - 37: syscalls.Supported("alarm", Alarm), - 38: syscalls.Supported("setitimer", Setitimer), - 39: syscalls.Supported("getpid", Getpid), - 40: syscalls.Supported("sendfile", Sendfile), - 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil), - 42: syscalls.Supported("connect", Connect), - 43: syscalls.Supported("accept", Accept), - 44: syscalls.Supported("sendto", SendTo), - 45: syscalls.Supported("recvfrom", RecvFrom), - 46: syscalls.Supported("sendmsg", SendMsg), - 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil), - 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil), - 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil), - 50: syscalls.Supported("listen", Listen), - 51: syscalls.Supported("getsockname", GetSockName), - 52: syscalls.Supported("getpeername", GetPeerName), - 53: syscalls.Supported("socketpair", SocketPair), - 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil), - 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil), - 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil), - 57: syscalls.Supported("fork", Fork), - 58: syscalls.Supported("vfork", Vfork), - 59: syscalls.Supported("execve", Execve), - 60: syscalls.Supported("exit", Exit), - 61: syscalls.Supported("wait4", Wait4), - 62: syscalls.Supported("kill", Kill), - 63: syscalls.Supported("uname", Uname), - 64: syscalls.Supported("semget", Semget), - 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), - 67: syscalls.Supported("shmdt", Shmdt), - 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), - 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), - 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), - 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil), - 76: syscalls.Supported("truncate", Truncate), - 77: syscalls.Supported("ftruncate", Ftruncate), - 78: syscalls.Supported("getdents", Getdents), - 79: syscalls.Supported("getcwd", Getcwd), - 80: syscalls.Supported("chdir", Chdir), - 81: syscalls.Supported("fchdir", Fchdir), - 82: syscalls.Supported("rename", Rename), - 83: syscalls.Supported("mkdir", Mkdir), - 84: syscalls.Supported("rmdir", Rmdir), - 85: syscalls.Supported("creat", Creat), - 86: syscalls.Supported("link", Link), - 87: syscalls.Supported("unlink", Unlink), - 88: syscalls.Supported("symlink", Symlink), - 89: syscalls.Supported("readlink", Readlink), - 90: syscalls.Supported("chmod", Chmod), - 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil), - 92: syscalls.Supported("chown", Chown), - 93: syscalls.Supported("fchown", Fchown), - 94: syscalls.Supported("lchown", Lchown), - 95: syscalls.Supported("umask", Umask), - 96: syscalls.Supported("gettimeofday", Gettimeofday), - 97: syscalls.Supported("getrlimit", Getrlimit), - 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil), - 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), - 100: syscalls.Supported("times", Times), - 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil), - 102: syscalls.Supported("getuid", Getuid), - 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil), - 104: syscalls.Supported("getgid", Getgid), - 105: syscalls.Supported("setuid", Setuid), - 106: syscalls.Supported("setgid", Setgid), - 107: syscalls.Supported("geteuid", Geteuid), - 108: syscalls.Supported("getegid", Getegid), - 109: syscalls.Supported("setpgid", Setpgid), - 110: syscalls.Supported("getppid", Getppid), - 111: syscalls.Supported("getpgrp", Getpgrp), - 112: syscalls.Supported("setsid", Setsid), - 113: syscalls.Supported("setreuid", Setreuid), - 114: syscalls.Supported("setregid", Setregid), - 115: syscalls.Supported("getgroups", Getgroups), - 116: syscalls.Supported("setgroups", Setgroups), - 117: syscalls.Supported("setresuid", Setresuid), - 118: syscalls.Supported("getresuid", Getresuid), - 119: syscalls.Supported("setresgid", Setresgid), - 120: syscalls.Supported("getresgid", Getresgid), - 121: syscalls.Supported("getpgid", Getpgid), - 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) - 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) - 124: syscalls.Supported("getsid", Getsid), - 125: syscalls.Supported("capget", Capget), - 126: syscalls.Supported("capset", Capset), - 127: syscalls.Supported("rt_sigpending", RtSigpending), - 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait), - 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo), - 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend), - 131: syscalls.Supported("sigaltstack", Sigaltstack), - 132: syscalls.Supported("utime", Utime), - 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil), - 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil), - 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), - 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil), - 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil), - 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), - 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}), - 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), - 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), - 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), - 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil), - 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil), - 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil), - 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil), - 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil), - 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil), - 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), - 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil), - 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil), - 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil), - 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil), - 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil), - 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil), - 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil), - 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil), - 161: syscalls.Supported("chroot", Chroot), - 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil), - 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), - 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil), - 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), - 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), - 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil), - 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil), - 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil), - 170: syscalls.Supported("sethostname", Sethostname), - 171: syscalls.Supported("setdomainname", Setdomainname), - 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil), - 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil), - 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil), - 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), - 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), - 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), - 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), - 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations - 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), - 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), - 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), - 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil), - 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil), - 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), - 186: syscalls.Supported("gettid", Gettid), - 187: syscalls.Supported("readahead", Readahead), - 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 200: syscalls.Supported("tkill", Tkill), - 201: syscalls.Supported("time", Time), - 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), - 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), - 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), - 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), - 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), - 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), - 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), - 213: syscalls.Supported("epoll_create", EpollCreate), - 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil), - 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil), - 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), - 217: syscalls.Supported("getdents64", Getdents64), - 218: syscalls.Supported("set_tid_address", SetTidAddress), - 219: syscalls.Supported("restart_syscall", RestartSyscall), - 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) - 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), - 222: syscalls.Supported("timer_create", TimerCreate), - 223: syscalls.Supported("timer_settime", TimerSettime), - 224: syscalls.Supported("timer_gettime", TimerGettime), - 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun), - 226: syscalls.Supported("timer_delete", TimerDelete), - 227: syscalls.Supported("clock_settime", ClockSettime), - 228: syscalls.Supported("clock_gettime", ClockGettime), - 229: syscalls.Supported("clock_getres", ClockGetres), - 230: syscalls.Supported("clock_nanosleep", ClockNanosleep), - 231: syscalls.Supported("exit_group", ExitGroup), - 232: syscalls.Supported("epoll_wait", EpollWait), - 233: syscalls.Supported("epoll_ctl", EpollCtl), - 234: syscalls.Supported("tgkill", Tgkill), - 235: syscalls.Supported("utimes", Utimes), - 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil), - 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), - 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), - 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), - 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) - 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), - 247: syscalls.Supported("waitid", Waitid), - 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil), - 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil), - 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), - 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) - 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) - 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil), - 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), - 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), - 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), - 257: syscalls.Supported("openat", Openat), - 258: syscalls.Supported("mkdirat", Mkdirat), - 259: syscalls.Supported("mknodat", Mknodat), - 260: syscalls.Supported("fchownat", Fchownat), - 261: syscalls.Supported("futimesat", Futimesat), - 262: syscalls.Supported("fstatat", Fstatat), - 263: syscalls.Supported("unlinkat", Unlinkat), - 264: syscalls.Supported("renameat", Renameat), - 265: syscalls.Supported("linkat", Linkat), - 266: syscalls.Supported("symlinkat", Symlinkat), - 267: syscalls.Supported("readlinkat", Readlinkat), - 268: syscalls.Supported("fchmodat", Fchmodat), - 269: syscalls.Supported("faccessat", Faccessat), - 270: syscalls.Supported("pselect", Pselect), - 271: syscalls.Supported("ppoll", Ppoll), - 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), - 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 275: syscalls.Supported("splice", Splice), - 276: syscalls.Supported("tee", Tee), - 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), - 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) - 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) - 280: syscalls.Supported("utimensat", Utimensat), - 281: syscalls.Supported("epoll_pwait", EpollPwait), - 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), - 283: syscalls.Supported("timerfd_create", TimerfdCreate), - 284: syscalls.Supported("eventfd", Eventfd), - 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), - 286: syscalls.Supported("timerfd_settime", TimerfdSettime), - 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), - 288: syscalls.Supported("accept4", Accept4), - 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), - 290: syscalls.Supported("eventfd2", Eventfd2), - 291: syscalls.Supported("epoll_create1", EpollCreate1), - 292: syscalls.Supported("dup3", Dup3), - 293: syscalls.Supported("pipe2", Pipe2), - 294: syscalls.Supported("inotify_init1", InotifyInit1), - 295: syscalls.Supported("preadv", Preadv), - 296: syscalls.Supported("pwritev", Pwritev), - 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), - 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil), - 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), - 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), - 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), - 302: syscalls.Supported("prlimit64", Prlimit64), - 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), - 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), - 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), - 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), - 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), - 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) - 309: syscalls.Supported("getcpu", Getcpu), - 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), - 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), - 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), - 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), - 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) - 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) - 317: syscalls.Supported("seccomp", Seccomp), - 318: syscalls.Supported("getrandom", GetRandom), - 319: syscalls.Supported("memfd_create", MemfdCreate), - 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), - 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), - 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) - 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897) - 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + // LinuxRelease is the Linux release version number advertised by gVisor. + LinuxRelease = "4.4.0" - // Syscalls after 325 are "backports" from versions of Linux after 4.4. - 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), - 327: syscalls.Supported("preadv2", Preadv2), - 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), - 332: syscalls.Supported("statx", Statx), - }, - - Emulate: map[usermem.Addr]uintptr{ - 0xffffffffff600000: 96, // vsyscall gettimeofday(2) - 0xffffffffff600400: 201, // vsyscall time(2) - 0xffffffffff600800: 309, // vsyscall getcpu(2) - }, - Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { - t.Kernel().EmitUnimplementedEvent(t) - return 0, syserror.ENOSYS - }, -} + // LinuxVersion is the version info advertised by gVisor. + LinuxVersion = "#1 SMP Sun Jan 10 15:06:54 PST 2016" +) diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go new file mode 100644 index 000000000..81e4f93a6 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux64_amd64.go @@ -0,0 +1,386 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall +// numbers from Linux 4.4. +var AMD64 = &kernel.SyscallTable{ + OS: abi.Linux, + Arch: arch.AMD64, + Version: kernel.Version{ + // Version 4.4 is chosen as a stable, longterm version of Linux, which + // guides the interface provided by this syscall table. The build + // version is that for a clean build with default kernel config, at 5 + // minutes after v4.4 was tagged. + Sysname: LinuxSysname, + Release: LinuxRelease, + Version: LinuxVersion, + }, + AuditNumber: linux.AUDIT_ARCH_X86_64, + Table: map[uintptr]kernel.Syscall{ + 0: syscalls.Supported("read", Read), + 1: syscalls.Supported("write", Write), + 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil), + 3: syscalls.Supported("close", Close), + 4: syscalls.Supported("stat", Stat), + 5: syscalls.Supported("fstat", Fstat), + 6: syscalls.Supported("lstat", Lstat), + 7: syscalls.Supported("poll", Poll), + 8: syscalls.Supported("lseek", Lseek), + 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil), + 10: syscalls.Supported("mprotect", Mprotect), + 11: syscalls.Supported("munmap", Munmap), + 12: syscalls.Supported("brk", Brk), + 13: syscalls.Supported("rt_sigaction", RtSigaction), + 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask), + 15: syscalls.Supported("rt_sigreturn", RtSigreturn), + 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), + 17: syscalls.Supported("pread64", Pread64), + 18: syscalls.Supported("pwrite64", Pwrite64), + 19: syscalls.Supported("readv", Readv), + 20: syscalls.Supported("writev", Writev), + 21: syscalls.Supported("access", Access), + 22: syscalls.Supported("pipe", Pipe), + 23: syscalls.Supported("select", Select), + 24: syscalls.Supported("sched_yield", SchedYield), + 25: syscalls.Supported("mremap", Mremap), + 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil), + 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), + 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), + 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), + 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil), + 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), + 32: syscalls.Supported("dup", Dup), + 33: syscalls.Supported("dup2", Dup2), + 34: syscalls.Supported("pause", Pause), + 35: syscalls.Supported("nanosleep", Nanosleep), + 36: syscalls.Supported("getitimer", Getitimer), + 37: syscalls.Supported("alarm", Alarm), + 38: syscalls.Supported("setitimer", Setitimer), + 39: syscalls.Supported("getpid", Getpid), + 40: syscalls.Supported("sendfile", Sendfile), + 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil), + 42: syscalls.Supported("connect", Connect), + 43: syscalls.Supported("accept", Accept), + 44: syscalls.Supported("sendto", SendTo), + 45: syscalls.Supported("recvfrom", RecvFrom), + 46: syscalls.Supported("sendmsg", SendMsg), + 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil), + 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil), + 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil), + 50: syscalls.Supported("listen", Listen), + 51: syscalls.Supported("getsockname", GetSockName), + 52: syscalls.Supported("getpeername", GetPeerName), + 53: syscalls.Supported("socketpair", SocketPair), + 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil), + 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil), + 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil), + 57: syscalls.Supported("fork", Fork), + 58: syscalls.Supported("vfork", Vfork), + 59: syscalls.Supported("execve", Execve), + 60: syscalls.Supported("exit", Exit), + 61: syscalls.Supported("wait4", Wait4), + 62: syscalls.Supported("kill", Kill), + 63: syscalls.Supported("uname", Uname), + 64: syscalls.Supported("semget", Semget), + 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 67: syscalls.Supported("shmdt", Shmdt), + 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), + 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), + 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), + 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil), + 76: syscalls.Supported("truncate", Truncate), + 77: syscalls.Supported("ftruncate", Ftruncate), + 78: syscalls.Supported("getdents", Getdents), + 79: syscalls.Supported("getcwd", Getcwd), + 80: syscalls.Supported("chdir", Chdir), + 81: syscalls.Supported("fchdir", Fchdir), + 82: syscalls.Supported("rename", Rename), + 83: syscalls.Supported("mkdir", Mkdir), + 84: syscalls.Supported("rmdir", Rmdir), + 85: syscalls.Supported("creat", Creat), + 86: syscalls.Supported("link", Link), + 87: syscalls.Supported("unlink", Unlink), + 88: syscalls.Supported("symlink", Symlink), + 89: syscalls.Supported("readlink", Readlink), + 90: syscalls.Supported("chmod", Chmod), + 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil), + 92: syscalls.Supported("chown", Chown), + 93: syscalls.Supported("fchown", Fchown), + 94: syscalls.Supported("lchown", Lchown), + 95: syscalls.Supported("umask", Umask), + 96: syscalls.Supported("gettimeofday", Gettimeofday), + 97: syscalls.Supported("getrlimit", Getrlimit), + 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil), + 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), + 100: syscalls.Supported("times", Times), + 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil), + 102: syscalls.Supported("getuid", Getuid), + 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil), + 104: syscalls.Supported("getgid", Getgid), + 105: syscalls.Supported("setuid", Setuid), + 106: syscalls.Supported("setgid", Setgid), + 107: syscalls.Supported("geteuid", Geteuid), + 108: syscalls.Supported("getegid", Getegid), + 109: syscalls.Supported("setpgid", Setpgid), + 110: syscalls.Supported("getppid", Getppid), + 111: syscalls.Supported("getpgrp", Getpgrp), + 112: syscalls.Supported("setsid", Setsid), + 113: syscalls.Supported("setreuid", Setreuid), + 114: syscalls.Supported("setregid", Setregid), + 115: syscalls.Supported("getgroups", Getgroups), + 116: syscalls.Supported("setgroups", Setgroups), + 117: syscalls.Supported("setresuid", Setresuid), + 118: syscalls.Supported("getresuid", Getresuid), + 119: syscalls.Supported("setresgid", Setresgid), + 120: syscalls.Supported("getresgid", Getresgid), + 121: syscalls.Supported("getpgid", Getpgid), + 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 124: syscalls.Supported("getsid", Getsid), + 125: syscalls.Supported("capget", Capget), + 126: syscalls.Supported("capset", Capset), + 127: syscalls.Supported("rt_sigpending", RtSigpending), + 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait), + 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo), + 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend), + 131: syscalls.Supported("sigaltstack", Sigaltstack), + 132: syscalls.Supported("utime", Utime), + 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil), + 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil), + 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), + 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil), + 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil), + 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), + 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}), + 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), + 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), + 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), + 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil), + 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil), + 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil), + 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil), + 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil), + 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil), + 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil), + 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil), + 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil), + 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil), + 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil), + 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil), + 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil), + 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil), + 161: syscalls.Supported("chroot", Chroot), + 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil), + 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), + 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil), + 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), + 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), + 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil), + 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil), + 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil), + 170: syscalls.Supported("sethostname", Sethostname), + 171: syscalls.Supported("setdomainname", Setdomainname), + 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil), + 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil), + 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil), + 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), + 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), + 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), + 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil), + 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations + 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), + 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), + 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil), + 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil), + 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil), + 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), + 186: syscalls.Supported("gettid", Gettid), + 187: syscalls.Supported("readahead", Readahead), + 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 200: syscalls.Supported("tkill", Tkill), + 201: syscalls.Supported("time", Time), + 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), + 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), + 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), + 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), + 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil), + 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), + 213: syscalls.Supported("epoll_create", EpollCreate), + 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil), + 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil), + 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), + 217: syscalls.Supported("getdents64", Getdents64), + 218: syscalls.Supported("set_tid_address", SetTidAddress), + 219: syscalls.Supported("restart_syscall", RestartSyscall), + 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), + 222: syscalls.Supported("timer_create", TimerCreate), + 223: syscalls.Supported("timer_settime", TimerSettime), + 224: syscalls.Supported("timer_gettime", TimerGettime), + 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun), + 226: syscalls.Supported("timer_delete", TimerDelete), + 227: syscalls.Supported("clock_settime", ClockSettime), + 228: syscalls.Supported("clock_gettime", ClockGettime), + 229: syscalls.Supported("clock_getres", ClockGetres), + 230: syscalls.Supported("clock_nanosleep", ClockNanosleep), + 231: syscalls.Supported("exit_group", ExitGroup), + 232: syscalls.Supported("epoll_wait", EpollWait), + 233: syscalls.Supported("epoll_ctl", EpollCtl), + 234: syscalls.Supported("tgkill", Tgkill), + 235: syscalls.Supported("utimes", Utimes), + 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil), + 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), + 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), + 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), + 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), + 247: syscalls.Supported("waitid", Waitid), + 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil), + 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil), + 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), + 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil), + 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), + 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), + 257: syscalls.Supported("openat", Openat), + 258: syscalls.Supported("mkdirat", Mkdirat), + 259: syscalls.Supported("mknodat", Mknodat), + 260: syscalls.Supported("fchownat", Fchownat), + 261: syscalls.Supported("futimesat", Futimesat), + 262: syscalls.Supported("fstatat", Fstatat), + 263: syscalls.Supported("unlinkat", Unlinkat), + 264: syscalls.Supported("renameat", Renameat), + 265: syscalls.Supported("linkat", Linkat), + 266: syscalls.Supported("symlinkat", Symlinkat), + 267: syscalls.Supported("readlinkat", Readlinkat), + 268: syscalls.Supported("fchmodat", Fchmodat), + 269: syscalls.Supported("faccessat", Faccessat), + 270: syscalls.Supported("pselect", Pselect), + 271: syscalls.Supported("ppoll", Ppoll), + 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), + 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 275: syscalls.Supported("splice", Splice), + 276: syscalls.Supported("tee", Tee), + 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), + 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) + 280: syscalls.Supported("utimensat", Utimensat), + 281: syscalls.Supported("epoll_pwait", EpollPwait), + 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), + 283: syscalls.Supported("timerfd_create", TimerfdCreate), + 284: syscalls.Supported("eventfd", Eventfd), + 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), + 286: syscalls.Supported("timerfd_settime", TimerfdSettime), + 287: syscalls.Supported("timerfd_gettime", TimerfdGettime), + 288: syscalls.Supported("accept4", Accept4), + 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), + 290: syscalls.Supported("eventfd2", Eventfd2), + 291: syscalls.Supported("epoll_create1", EpollCreate1), + 292: syscalls.Supported("dup3", Dup3), + 293: syscalls.Supported("pipe2", Pipe2), + 294: syscalls.Supported("inotify_init1", InotifyInit1), + 295: syscalls.Supported("preadv", Preadv), + 296: syscalls.Supported("pwritev", Pwritev), + 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), + 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil), + 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), + 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 302: syscalls.Supported("prlimit64", Prlimit64), + 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), + 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), + 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), + 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) + 309: syscalls.Supported("getcpu", Getcpu), + 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), + 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), + 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) + 317: syscalls.Supported("seccomp", Seccomp), + 318: syscalls.Supported("getrandom", GetRandom), + 319: syscalls.Supported("memfd_create", MemfdCreate), + 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil), + 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), + 322: syscalls.Supported("execveat", Execveat), + 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) + 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) + 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + + // Syscalls after 325 are "backports" from versions of Linux after 4.4. + 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), + 327: syscalls.Supported("preadv2", Preadv2), + 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 332: syscalls.Supported("statx", Statx), + }, + + Emulate: map[usermem.Addr]uintptr{ + 0xffffffffff600000: 96, // vsyscall gettimeofday(2) + 0xffffffffff600400: 201, // vsyscall time(2) + 0xffffffffff600800: 309, // vsyscall getcpu(2) + }, + Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { + t.Kernel().EmitUnimplementedEvent(t) + return 0, syserror.ENOSYS + }, +} diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go new file mode 100644 index 000000000..a809115e0 --- /dev/null +++ b/pkg/sentry/syscalls/linux/linux64_arm64.go @@ -0,0 +1,313 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/syscalls" + "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/syserror" +) + +// ARM64 is a table of Linux arm64 syscall API with the corresponding syscall +// numbers from Linux 4.4. +var ARM64 = &kernel.SyscallTable{ + OS: abi.Linux, + Arch: arch.ARM64, + Version: kernel.Version{ + Sysname: LinuxSysname, + Release: LinuxRelease, + Version: LinuxVersion, + }, + AuditNumber: linux.AUDIT_ARCH_AARCH64, + Table: map[uintptr]kernel.Syscall{ + 0: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 1: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}), + 5: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 8: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 12: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 13: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 14: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 15: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 16: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 17: syscalls.Supported("getcwd", Getcwd), + 18: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil), + 19: syscalls.Supported("eventfd2", Eventfd2), + 20: syscalls.Supported("epoll_create1", EpollCreate1), + 21: syscalls.Supported("epoll_ctl", EpollCtl), + 22: syscalls.Supported("epoll_pwait", EpollPwait), + 23: syscalls.Supported("dup", Dup), + 24: syscalls.Supported("dup3", Dup3), + 26: syscalls.Supported("inotify_init1", InotifyInit1), + 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), + 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 29: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), + 30: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 31: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) + 32: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), + 33: syscalls.Supported("mknodat", Mknodat), + 34: syscalls.Supported("mkdirat", Mkdirat), + 35: syscalls.Supported("unlinkat", Unlinkat), + 36: syscalls.Supported("symlinkat", Symlinkat), + 37: syscalls.Supported("linkat", Linkat), + 38: syscalls.Supported("renameat", Renameat), + 39: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil), + 40: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil), + 41: syscalls.Error("pivot_root", syserror.EPERM, "", nil), + 42: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil), + 44: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil), + 46: syscalls.Supported("ftruncate", Ftruncate), + 47: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil), + 48: syscalls.Supported("faccessat", Faccessat), + 49: syscalls.Supported("chdir", Chdir), + 50: syscalls.Supported("fchdir", Fchdir), + 51: syscalls.Supported("chroot", Chroot), + 52: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil), + 53: syscalls.Supported("fchmodat", Fchmodat), + 54: syscalls.Supported("fchownat", Fchownat), + 55: syscalls.Supported("fchown", Fchown), + 56: syscalls.Supported("openat", Openat), + 57: syscalls.Supported("close", Close), + 58: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil), + 59: syscalls.Supported("pipe2", Pipe2), + 60: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations + 61: syscalls.Supported("getdents64", Getdents64), + 62: syscalls.Supported("lseek", Lseek), + 63: syscalls.Supported("read", Read), + 64: syscalls.Supported("write", Write), + 65: syscalls.Supported("readv", Readv), + 66: syscalls.Supported("writev", Writev), + 67: syscalls.Supported("pread64", Pread64), + 68: syscalls.Supported("pwrite64", Pwrite64), + 69: syscalls.Supported("preadv", Preadv), + 70: syscalls.Supported("pwritev", Pwritev), + 71: syscalls.Supported("sendfile", Sendfile), + 72: syscalls.Supported("pselect", Pselect), + 73: syscalls.Supported("ppoll", Ppoll), + 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}), + 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098) + 77: syscalls.Supported("tee", Tee), + 78: syscalls.Supported("readlinkat", Readlinkat), + 80: syscalls.Supported("fstat", Fstat), + 81: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil), + 82: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil), + 83: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil), + 84: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), + 85: syscalls.Supported("timerfd_create", TimerfdCreate), + 86: syscalls.Supported("timerfd_settime", TimerfdSettime), + 87: syscalls.Supported("timerfd_gettime", TimerfdGettime), + 88: syscalls.Supported("utimensat", Utimensat), + 89: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil), + 90: syscalls.Supported("capget", Capget), + 91: syscalls.Supported("capset", Capset), + 92: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil), + 93: syscalls.Supported("exit", Exit), + 94: syscalls.Supported("exit_group", ExitGroup), + 95: syscalls.Supported("waitid", Waitid), + 96: syscalls.Supported("set_tid_address", SetTidAddress), + 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), + 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), + 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 101: syscalls.Supported("nanosleep", Nanosleep), + 102: syscalls.Supported("getitimer", Getitimer), + 103: syscalls.Supported("setitimer", Setitimer), + 104: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil), + 105: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil), + 106: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil), + 107: syscalls.Supported("timer_create", TimerCreate), + 108: syscalls.Supported("timer_gettime", TimerGettime), + 109: syscalls.Supported("timer_getoverrun", TimerGetoverrun), + 110: syscalls.Supported("timer_settime", TimerSettime), + 111: syscalls.Supported("timer_delete", TimerDelete), + 112: syscalls.Supported("clock_settime", ClockSettime), + 113: syscalls.Supported("clock_gettime", ClockGettime), + 114: syscalls.Supported("clock_getres", ClockGetres), + 115: syscalls.Supported("clock_nanosleep", ClockNanosleep), + 116: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil), + 117: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil), + 118: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil), + 119: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil), + 120: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil), + 121: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil), + 122: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil), + 123: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil), + 124: syscalls.Supported("sched_yield", SchedYield), + 125: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil), + 126: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil), + 127: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil), + 128: syscalls.Supported("restart_syscall", RestartSyscall), + 129: syscalls.Supported("kill", Kill), + 130: syscalls.Supported("tkill", Tkill), + 131: syscalls.Supported("tgkill", Tgkill), + 132: syscalls.Supported("sigaltstack", Sigaltstack), + 133: syscalls.Supported("rt_sigsuspend", RtSigsuspend), + 134: syscalls.Supported("rt_sigaction", RtSigaction), + 135: syscalls.Supported("rt_sigprocmask", RtSigprocmask), + 136: syscalls.Supported("rt_sigpending", RtSigpending), + 137: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait), + 138: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo), + 139: syscalls.Supported("rt_sigreturn", RtSigreturn), + 140: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil), + 141: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil), + 142: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil), + 143: syscalls.Supported("setregid", Setregid), + 144: syscalls.Supported("setgid", Setgid), + 145: syscalls.Supported("setreuid", Setreuid), + 146: syscalls.Supported("setuid", Setuid), + 147: syscalls.Supported("setresuid", Setresuid), + 148: syscalls.Supported("getresuid", Getresuid), + 149: syscalls.Supported("setresgid", Setresgid), + 150: syscalls.Supported("getresgid", Getresgid), + 151: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 152: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702) + 153: syscalls.Supported("times", Times), + 154: syscalls.Supported("setpgid", Setpgid), + 155: syscalls.Supported("getpgid", Getpgid), + 156: syscalls.Supported("getsid", Getsid), + 157: syscalls.Supported("setsid", Setsid), + 158: syscalls.Supported("getgroups", Getgroups), + 159: syscalls.Supported("setgroups", Setgroups), + 160: syscalls.Supported("uname", Uname), + 161: syscalls.Supported("sethostname", Sethostname), + 162: syscalls.Supported("setdomainname", Setdomainname), + 163: syscalls.Supported("getrlimit", Getrlimit), + 164: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil), + 165: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil), + 166: syscalls.Supported("umask", Umask), + 167: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil), + 168: syscalls.Supported("getcpu", Getcpu), + 169: syscalls.Supported("gettimeofday", Gettimeofday), + 170: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil), + 171: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil), + 172: syscalls.Supported("getpid", Getpid), + 173: syscalls.Supported("getppid", Getppid), + 174: syscalls.Supported("getuid", Getuid), + 175: syscalls.Supported("geteuid", Geteuid), + 176: syscalls.Supported("getgid", Getgid), + 177: syscalls.Supported("getegid", Getegid), + 178: syscalls.Supported("gettid", Gettid), + 179: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil), + 180: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 181: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 182: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) + 186: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 187: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 190: syscalls.Supported("semget", Semget), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920) + 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), + 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), + 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), + 196: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil), + 197: syscalls.Supported("shmdt", Shmdt), + 198: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil), + 199: syscalls.Supported("socketpair", SocketPair), + 200: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil), + 201: syscalls.Supported("listen", Listen), + 202: syscalls.Supported("accept", Accept), + 203: syscalls.Supported("connect", Connect), + 204: syscalls.Supported("getsockname", GetSockName), + 205: syscalls.Supported("getpeername", GetPeerName), + 206: syscalls.Supported("sendto", SendTo), + 207: syscalls.Supported("recvfrom", RecvFrom), + 208: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil), + 209: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil), + 210: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil), + 211: syscalls.Supported("sendmsg", SendMsg), + 212: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil), + 213: syscalls.Supported("readahead", Readahead), + 214: syscalls.Supported("brk", Brk), + 215: syscalls.Supported("munmap", Munmap), + 216: syscalls.Supported("mremap", Mremap), + 217: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil), + 218: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil), + 219: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), + 220: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil), + 221: syscalls.Supported("execve", Execve), + 224: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil), + 225: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil), + 226: syscalls.Supported("mprotect", Mprotect), + 227: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil), + 228: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 229: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 230: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 231: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 232: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil), + 233: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil), + 234: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil), + 235: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}), + 236: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil), + 237: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil), + 238: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), + 239: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly) + 240: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), + 241: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil), + 242: syscalls.Supported("accept4", Accept4), + 243: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil), + 260: syscalls.Supported("wait4", Wait4), + 261: syscalls.Supported("prlimit64", Prlimit64), + 262: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 263: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil), + 264: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 265: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil), + 266: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil), + 267: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil), + 268: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995) + 269: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil), + 270: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 271: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}), + 272: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil), + 273: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil), + 274: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 275: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272) + 276: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772) + 277: syscalls.Supported("seccomp", Seccomp), + 278: syscalls.Supported("getrandom", GetRandom), + 279: syscalls.Supported("memfd_create", MemfdCreate), + 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), + 281: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836) + 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) + 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) + 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), + 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil), + 286: syscalls.Supported("preadv2", Preadv2), + 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil), + 291: syscalls.Supported("statx", Statx), + }, + Emulate: map[usermem.Addr]uintptr{}, + + Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { + t.Kernel().EmitUnimplementedEvent(t) + return 0, syserror.ENOSYS + }, +} diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 3bac4d90d..b5a72ce63 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - if optLen <= 0 { + if optLen < 0 { return 0, nil, syserror.EINVAL } if optLen > maxOptLen { diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index f0a292f2f..dd3a5807f 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -159,9 +159,14 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc }, outFile.Flags().NonBlocking) } + // Sendfile can't lose any data because inFD is always a regual file. + if n != 0 { + err = nil + } + // We can only pass a single file to handleIOError, so pick inFile // arbitrarily. This is used only for debugging purposes. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "sendfile", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "sendfile", inFile) } // Splice implements splice(2). @@ -245,12 +250,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if inOffset != 0 || outOffset != 0 { return 0, nil, syserror.ESPIPE } - default: - return 0, nil, syserror.EINVAL - } - // We may not refer to the same pipe; otherwise it's a continuous loop. - if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID { + // We may not refer to the same pipe; otherwise it's a continuous loop. + if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID { + return 0, nil, syserror.EINVAL + } + default: return 0, nil, syserror.EINVAL } @@ -305,6 +310,11 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo Dup: true, }, nonBlock) + // Tee doesn't change a state of inFD, so it can't lose any data. + if n != 0 { + err = nil + } + // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile) + return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "tee", inFile) } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 8ab7ffa25..4115116ff 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -15,12 +15,15 @@ package linux import ( + "path" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" + "gvisor.dev/gvisor/pkg/sentry/loader" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" ) @@ -67,8 +70,22 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal argvAddr := args[1].Pointer() envvAddr := args[2].Pointer() - // Extract our arguments. - filename, err := t.CopyInString(filenameAddr, linux.PATH_MAX) + return execveat(t, linux.AT_FDCWD, filenameAddr, argvAddr, envvAddr, 0) +} + +// Execveat implements linux syscall execveat(2). +func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + dirFD := args[0].Int() + pathnameAddr := args[1].Pointer() + argvAddr := args[2].Pointer() + envvAddr := args[3].Pointer() + flags := args[4].Int() + + return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags) +} + +func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) { + pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX) if err != nil { return 0, nil, err } @@ -89,14 +106,66 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } } + if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 { + return 0, nil, syserror.EINVAL + } + atEmptyPath := flags&linux.AT_EMPTY_PATH != 0 + if !atEmptyPath && len(pathname) == 0 { + return 0, nil, syserror.ENOENT + } + resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0 + root := t.FSContext().RootDirectory() defer root.DecRef() - wd := t.FSContext().WorkingDirectory() - defer wd.DecRef() + + var wd *fs.Dirent + var executable *fs.File + var closeOnExec bool + if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) { + // Even if the pathname is absolute, we may still need the wd + // for interpreter scripts if the path of the interpreter is + // relative. + wd = t.FSContext().WorkingDirectory() + } else { + // Need to extract the given FD. + f, fdFlags := t.FDTable().Get(dirFD) + if f == nil { + return 0, nil, syserror.EBADF + } + defer f.DecRef() + closeOnExec = fdFlags.CloseOnExec + + if atEmptyPath && len(pathname) == 0 { + executable = f + } else { + wd = f.Dirent + wd.IncRef() + if !fs.IsDir(wd.Inode.StableAttr) { + return 0, nil, syserror.ENOTDIR + } + } + } + if wd != nil { + defer wd.DecRef() + } // Load the new TaskContext. - maxTraversals := uint(linux.MaxSymlinkTraversals) - tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, nil, argv, envv, t.Arch().FeatureSet()) + remainingTraversals := uint(linux.MaxSymlinkTraversals) + loadArgs := loader.LoadArgs{ + Mounts: t.MountNamespace(), + Root: root, + WorkingDirectory: wd, + RemainingTraversals: &remainingTraversals, + ResolveFinal: resolveFinal, + Filename: pathname, + File: executable, + CloseOnExec: closeOnExec, + Argv: argv, + Envv: envv, + Features: t.Arch().FeatureSet(), + } + + tc, se := t.Kernel().LoadTaskImage(t, loadArgs) if se != nil { return 0, nil, se.ToError() } diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index 4b3f043a2..b887fa9d7 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -15,6 +15,7 @@ package linux import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -228,41 +229,35 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use timer.Destroy() - var remaining time.Duration - // Did we just block for the entire duration? - if err == syserror.ETIMEDOUT { - remaining = 0 - } else { - remaining = dur - after.Sub(start) + switch err { + case syserror.ETIMEDOUT: + // Slept for entire timeout. + return nil + case syserror.ErrInterrupted: + // Interrupted. + remaining := dur - after.Sub(start) if remaining < 0 { remaining = time.Duration(0) } - } - // Copy out remaining time. - if err != nil && rem != usermem.Addr(0) { - timeleft := linux.NsecToTimespec(remaining.Nanoseconds()) - if err := copyTimespecOut(t, rem, &timeleft); err != nil { - return err + // Copy out remaining time. + if rem != 0 { + timeleft := linux.NsecToTimespec(remaining.Nanoseconds()) + if err := copyTimespecOut(t, rem, &timeleft); err != nil { + return err + } } - } - - // Did we just block for the entire duration? - if err == syserror.ETIMEDOUT { - return nil - } - // If interrupted, arrange for a restart with the remaining duration. - if err == syserror.ErrInterrupted { + // Arrange for a restart with the remaining duration. t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{ c: c, duration: remaining, rem: rem, }) return kernel.ERESTART_RESTARTBLOCK + default: + panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err)) } - - return err } // Nanosleep implements linux syscall Nanosleep(2). diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go index 271ace08e..748e8dd8d 100644 --- a/pkg/sentry/syscalls/linux/sys_utsname.go +++ b/pkg/sentry/syscalls/linux/sys_utsname.go @@ -79,11 +79,11 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, syserror.EINVAL } - name, err := t.CopyInString(nameAddr, int(size)) - if err != nil { + name := make([]byte, size) + if _, err := t.CopyInBytes(nameAddr, name); err != nil { return 0, nil, err } - utsns.SetHostName(name) + utsns.SetHostName(string(name)) return 0, nil, nil } diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 27cd2c336..ad4b67806 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -191,7 +191,6 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Pwritev2 implements linux syscall pwritev2(2). -// TODO(b/120161091): Implement O_SYNC and D_SYNC functionality. func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { // While the syscall is // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags) diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index beb43ba13..d3a4cd943 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -1,10 +1,9 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") - go_template_instance( name = "seqatomic_parameters", out = "seqatomic_parameters_unsafe.go", diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index a34c39540..c32fe3241 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "usage", srcs = [ diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD index cc5d25762..684f59a6b 100644 --- a/pkg/sentry/usermem/BUILD +++ b/pkg/sentry/usermem/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "addr_range", out = "addr_range.go", diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go index 6eced660a..7b1f312b1 100644 --- a/pkg/sentry/usermem/usermem.go +++ b/pkg/sentry/usermem/usermem.go @@ -16,6 +16,7 @@ package usermem import ( + "bytes" "errors" "io" "strconv" @@ -270,11 +271,10 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts) // Look for the terminating zero byte, which may have occurred before // hitting err. - for i, c := range buf[done : done+n] { - if c == 0 { - return stringFromImmutableBytes(buf[:done+i]), nil - } + if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 { + return stringFromImmutableBytes(buf[:done+i]), nil } + done += n if err != nil { return stringFromImmutableBytes(buf[:done]), err diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 7eb2b2821..3a9665800 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -102,7 +102,7 @@ type FileDescriptionImpl interface { // OnClose is called when a file descriptor representing the // FileDescription is closed. Note that returning a non-nil error does not // prevent the file descriptor from being closed. - OnClose() error + OnClose(ctx context.Context) error // StatusFlags returns file description status flags, as for // fcntl(F_GETFL). @@ -180,7 +180,7 @@ type FileDescriptionImpl interface { // ConfigureMMap mutates opts to implement mmap(2) for the file. Most // implementations that support memory mapping can call // GenericConfigureMMap with the appropriate memmap.Mappable. - ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error + ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error // Ioctl implements the ioctl(2) syscall. Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index ba230da72..4fbad7840 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -45,7 +45,7 @@ type FileDescriptionDefaultImpl struct{} // OnClose implements FileDescriptionImpl.OnClose analogously to // file_operations::flush == NULL in Linux. -func (FileDescriptionDefaultImpl) OnClose() error { +func (FileDescriptionDefaultImpl) OnClose(ctx context.Context) error { return nil } @@ -117,7 +117,7 @@ func (FileDescriptionDefaultImpl) Sync(ctx context.Context) error { // ConfigureMMap implements FileDescriptionImpl.ConfigureMMap analogously to // file_operations::mmap == NULL in Linux. -func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error { +func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { return syserror.ENODEV } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b0511aa40..75e6c7dfa 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index 187e5410c..3aa73d911 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -31,14 +31,14 @@ type GetDentryOptions struct { // FilesystemImpl.MkdirAt(). type MkdirOptions struct { // Mode is the file mode bits for the created directory. - Mode uint16 + Mode linux.FileMode } // MknodOptions contains options to VirtualFilesystem.MknodAt() and // FilesystemImpl.MknodAt(). type MknodOptions struct { // Mode is the file type and mode bits for the created file. - Mode uint16 + Mode linux.FileMode // If Mode specifies a character or block device special file, DevMajor and // DevMinor are the major and minor device numbers for the created device. @@ -61,7 +61,7 @@ type OpenOptions struct { // If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the // created file. - Mode uint16 + Mode linux.FileMode } // ReadOptions contains options to FileDescription.PRead(), diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go index 23f2b9e08..abde0feaa 100644 --- a/pkg/sentry/vfs/syscalls.go +++ b/pkg/sentry/vfs/syscalls.go @@ -96,6 +96,26 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia } } +// MknodAt creates a file of the given mode at the given path. It returns an +// error from the syserror package. +func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { + rp, err := vfs.getResolvingPath(creds, pop) + if err != nil { + return nil + } + for { + if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil { + vfs.putResolvingPath(rp) + return nil + } + // Handle mount traversals. + if !rp.handleError(err) { + vfs.putResolvingPath(rp) + return err + } + } +} + // OpenAt returns a FileDescription providing access to the file at the given // path. A reference is taken on the returned FileDescription. func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { @@ -198,8 +218,6 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) err // // - VFS.LinkAt() // -// - VFS.MknodAt() -// // - VFS.ReadlinkAt() // // - VFS.RenameAt() diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 145102c0d..ecce6c69f 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -42,8 +42,35 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) -// DefaultTimeout is a resonable timeout value for most applications. -const DefaultTimeout = 3 * time.Minute +// Opts configures the watchdog. +type Opts struct { + // TaskTimeout is the amount of time to allow a task to execute the + // same syscall without blocking before it's declared stuck. + TaskTimeout time.Duration + + // TaskTimeoutAction indicates what action to take when a stuck tasks + // is detected. + TaskTimeoutAction Action + + // StartupTimeout is the amount of time to allow between watchdog + // creation and calling watchdog.Start. + StartupTimeout time.Duration + + // StartupTimeoutAction indicates what action to take when + // watchdog.Start is not called within the timeout. + StartupTimeoutAction Action +} + +// DefaultOpts is a default set of options for the watchdog. +var DefaultOpts = Opts{ + // Task timeout. + TaskTimeout: 3 * time.Minute, + TaskTimeoutAction: LogWarning, + + // Startup timeout. + StartupTimeout: 30 * time.Second, + StartupTimeoutAction: LogWarning, +} // descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period // is discounted from task's last update time. It's set high enough that small scheduling delays won't @@ -61,6 +88,7 @@ type Action int const ( // LogWarning logs warning message followed by stack trace. LogWarning Action = iota + // Panic will do the same logging as LogWarning and panic(). Panic ) @@ -80,17 +108,13 @@ func (a Action) String() string { // Watchdog is the main watchdog class. It controls a goroutine that periodically // analyses all tasks and reports if any of them appear to be stuck. type Watchdog struct { + // Configuration options are embedded. + Opts + // period indicates how often to check all tasks. It's calculated based on - // 'taskTimeout'. + // opts.TaskTimeout. period time.Duration - // taskTimeout is the amount of time to allow a task to execute the same syscall - // without blocking before it's declared stuck. - taskTimeout time.Duration - - // timeoutAction indicates what action to take when a stuck tasks is detected. - timeoutAction Action - // k is where the tasks come from. k *kernel.Kernel @@ -113,8 +137,12 @@ type Watchdog struct { // mu protects the fields below. mu sync.Mutex - // started is true if the watchdog has been started before. - started bool + // running is true if the watchdog is running. + running bool + + // startCalled is true if Start has ever been called. It remains true + // even if Stop is called. + startCalled bool } type offender struct { @@ -122,58 +150,81 @@ type offender struct { } // New creates a new watchdog. -func New(k *kernel.Kernel, taskTimeout time.Duration, a Action) *Watchdog { - // 4 is arbitrary, just don't want to prolong 'taskTimeout' too much. - period := taskTimeout / 4 - return &Watchdog{ - k: k, - period: period, - taskTimeout: taskTimeout, - timeoutAction: a, - offenders: make(map[*kernel.Task]*offender), - stop: make(chan struct{}), - done: make(chan struct{}), +func New(k *kernel.Kernel, opts Opts) *Watchdog { + // 4 is arbitrary, just don't want to prolong 'TaskTimeout' too much. + period := opts.TaskTimeout / 4 + w := &Watchdog{ + Opts: opts, + k: k, + period: period, + offenders: make(map[*kernel.Task]*offender), + stop: make(chan struct{}), + done: make(chan struct{}), + } + + // Handle StartupTimeout if it exists. + if w.StartupTimeout > 0 { + log.Infof("Watchdog waiting %v for startup", w.StartupTimeout) + go w.waitForStart() // S/R-SAFE: watchdog is stopped buring save and restarted after restore. } + + return w } // Start starts the watchdog. func (w *Watchdog) Start() { - if w.taskTimeout == 0 { - log.Infof("Watchdog disabled") - return - } - w.mu.Lock() defer w.mu.Unlock() - if w.started { + w.startCalled = true + + if w.running { return } + if w.TaskTimeout == 0 { + log.Infof("Watchdog task timeout disabled") + return + } w.lastRun = w.k.MonotonicClock().Now() - log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.taskTimeout, w.timeoutAction) + log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.TaskTimeout, w.TaskTimeoutAction) go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore. - w.started = true + w.running = true } // Stop requests the watchdog to stop and wait for it. func (w *Watchdog) Stop() { - if w.taskTimeout == 0 { + if w.TaskTimeout == 0 { return } w.mu.Lock() defer w.mu.Unlock() - if !w.started { + if !w.running { return } log.Infof("Stopping watchdog") w.stop <- struct{}{} <-w.done - w.started = false + w.running = false log.Infof("Watchdog stopped") } +// waitForStart waits for Start to be called and takes action if it does not +// happen within the startup timeout. +func (w *Watchdog) waitForStart() { + <-time.After(w.StartupTimeout) + w.mu.Lock() + defer w.mu.Unlock() + if w.startCalled { + // We are fine. + return + } + var buf bytes.Buffer + buf.WriteString("Watchdog.Start() not called within %s:\n") + w.doAction(w.StartupTimeoutAction, false, &buf) +} + // loop is the main watchdog routine. It only returns when 'Stop()' is called. func (w *Watchdog) loop() { // Loop until someone stops it. @@ -202,7 +253,7 @@ func (w *Watchdog) runTurn() { select { case <-done: - case <-time.After(w.taskTimeout): + case <-time.After(w.TaskTimeout): // Report if the watchdog is not making progress. // No one is wathching the watchdog watcher though. w.reportStuckWatchdog() @@ -231,7 +282,7 @@ func (w *Watchdog) runTurn() { if tsched.State == kernel.TaskGoroutineRunningSys { lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick))) elapsed := now.Sub(lastUpdateTime) - discount - if elapsed > w.taskTimeout { + if elapsed > w.TaskTimeout { tc, ok := w.offenders[t] if !ok { // New stuck task detected. @@ -261,28 +312,34 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo tid := w.k.TaskSet().Root.IDOfTask(t) buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime))) } + buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine") - w.onStuckTask(newTaskFound, &buf) + + // Dump stack only if a new task is detected or if it sometime has + // passed since the last time a stack dump was generated. + skipStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod + w.doAction(w.TaskTimeoutAction, skipStack, &buf) } func (w *Watchdog) reportStuckWatchdog() { var buf bytes.Buffer buf.WriteString("Watchdog goroutine is stuck:\n") - w.onStuckTask(true, &buf) + w.doAction(w.TaskTimeoutAction, false, &buf) } -func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { - switch w.timeoutAction { +// doAction will take the given action. If the action is LogWarnind and +// skipStack is true, then the stack printing will be skipped. +func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) { + switch action { case LogWarning: - // Dump stack only if a new task is detected or if it sometime has passed since - // the last time a stack dump was generated. - if !newTaskFound && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod { + if skipStack { msg.WriteString("\n...[stack dump skipped]...") log.Warningf(msg.String()) - } else { - log.TracebackAll(msg.String()) - w.lastStackDump = time.Now() + return + } + log.TracebackAll(msg.String()) + w.lastStackDump = time.Now() case Panic: // Panic will skip over running tasks, which is likely the culprit here. So manually @@ -301,5 +358,8 @@ func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { case <-time.After(1 * time.Second): } panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String())) + default: + panic(fmt.Sprintf("Unknown watchdog action %v", action)) + } } diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index bdca80d37..a23c86fb1 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -7,6 +7,7 @@ go_library( name = "sleep", srcs = [ "commit_amd64.s", + "commit_arm64.s", "commit_asm.go", "commit_noasm.go", "sleep_unsafe.go", diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s new file mode 100644 index 000000000..d0ef15b20 --- /dev/null +++ b/pkg/sleep/commit_arm64.s @@ -0,0 +1,38 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +#define preparingG 1 + +// See commit_noasm.go for a description of commitSleep. +// +// func commitSleep(g uintptr, waitingG *uintptr) bool +TEXT ·commitSleep(SB),NOSPLIT,$0-24 + MOVD waitingG+8(FP), R0 + MOVD $preparingG, R1 + MOVD G+0(FP), R2 + + // Store the G in waitingG if it's still preparingG. If it's anything + // else it means a waker has aborted the sleep. +again: + LDAXR (R0), R3 + CMP R1, R3 + BNE ok + STLXR R2, (R0), R3 + CBNZ R3, again +ok: + CSET EQ, R0 + MOVB R0, ret+16(FP) + RET diff --git a/pkg/sleep/commit_asm.go b/pkg/sleep/commit_asm.go index 35e2cc337..75728a97d 100644 --- a/pkg/sleep/commit_asm.go +++ b/pkg/sleep/commit_asm.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 +// +build amd64 arm64 package sleep diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go index 686b1da3d..3af447fb9 100644 --- a/pkg/sleep/commit_noasm.go +++ b/pkg/sleep/commit_noasm.go @@ -13,7 +13,7 @@ // limitations under the License. // +build !race -// +build !amd64 +// +build !amd64,!arm64 package sleep diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index 8f5e60a25..acbf0229b 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 329904457..be93750bf 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,11 +1,10 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") - go_template_instance( name = "addr_range", out = "addr_range.go", diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 47e6b878a..590c241a3 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -133,6 +134,9 @@ func (os *objectState) findCycle() []*objectState { // to ensure that all callbacks are executed, otherwise the callback graph was // not acyclic. type decodeState struct { + // ctx is the decode context. + ctx context.Context + // objectByID is the set of objects in progress. objectsByID map[uint64]*objectState diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 5d9409a45..c5118d3a9 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -16,6 +16,7 @@ package state import ( "container/list" + "context" "encoding/binary" "fmt" "io" @@ -38,6 +39,9 @@ type queuedObject struct { // The encoding process is a breadth-first traversal of the object graph. The // inherent races and dependencies are much simpler than the decode case. type encodeState struct { + // ctx is the encode context. + ctx context.Context + // lastID is the last object ID. // // See idsByObject for context. Because of the special zero encoding diff --git a/pkg/state/map.go b/pkg/state/map.go index 7e6fefed4..4f3ebb0da 100644 --- a/pkg/state/map.go +++ b/pkg/state/map.go @@ -15,6 +15,7 @@ package state import ( + "context" "fmt" "reflect" "sort" @@ -219,3 +220,13 @@ func (m Map) AfterLoad(fn func()) { // data dependencies have been cleared. m.os.callbacks = append(m.os.callbacks, fn) } + +// Context returns the current context object. +func (m Map) Context() context.Context { + if m.es != nil { + return m.es.ctx + } else if m.ds != nil { + return m.ds.ctx + } + return context.Background() // No context. +} diff --git a/pkg/state/state.go b/pkg/state/state.go index d408ff84a..dbe507ab4 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -50,6 +50,7 @@ package state import ( + "context" "fmt" "io" "reflect" @@ -86,9 +87,10 @@ func UnwrapErrState(err error) error { } // Save saves the given object state. -func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { +func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { // Create the encoding state. es := &encodeState{ + ctx: ctx, idsByObject: make(map[uintptr]uint64), w: w, stats: stats, @@ -101,9 +103,10 @@ func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { } // Load loads a checkpoint. -func Load(r io.Reader, rootPtr interface{}, stats *Stats) error { +func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { // Create the decoding state. ds := &decodeState{ + ctx: ctx, objectsByID: make(map[uint64]*objectState), deferred: make(map[uint64]*pb.Object), r: r, diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 7c24bbcda..d7221e9e8 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "io/ioutil" "math" "reflect" @@ -46,7 +47,7 @@ func runTest(t *testing.T, tests []TestCase) { saveBuffer := &bytes.Buffer{} saveObjectPtr := reflect.New(reflect.TypeOf(root)) saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Save failed unexpectedly: %v", err) continue } else if err != nil { @@ -56,7 +57,7 @@ func runTest(t *testing.T, tests []TestCase) { // Load a new copy of the object. loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Load failed unexpectedly: %v", err) continue } else if err != nil { @@ -624,7 +625,7 @@ func BenchmarkEncoding(b *testing.B) { bs := buildObject(b.N) var stats Stats b.StartTimer() - if err := Save(ioutil.Discard, bs, &stats); err != nil { + if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil { b.Errorf("save failed: %v", err) } b.StopTimer() @@ -638,12 +639,12 @@ func BenchmarkDecoding(b *testing.B) { bs := buildObject(b.N) var newBS benchStruct buf := &bytes.Buffer{} - if err := Save(buf, bs, nil); err != nil { + if err := Save(context.Background(), buf, bs, nil); err != nil { b.Errorf("save failed: %v", err) } var stats Stats b.StartTimer() - if err := Load(buf, &newBS, &stats); err != nil { + if err := Load(context.Background(), buf, &newBS, &stats); err != nil { b.Errorf("load failed: %v", err) } b.StopTimer() diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 3fd9e3134..65d4d0cd8 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,12 +1,13 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "tcpip", srcs = [ + "packet_buffer.go", + "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", ], diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 8ced960bb..ee077ae83 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -151,10 +151,8 @@ func TestCloseReader(t *testing.T) { buf := make([]byte, 256) n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + if n != 0 || err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) } }() sender, err := connect(s, addr) @@ -203,10 +201,8 @@ func TestCloseReaderWithForwarder(t *testing.T) { buf := make([]byte, 256) n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) } }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index b4e8d6810..d6c31bfa2 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "buffer", srcs = [ diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 4287464f3..48a2a2713 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -41,6 +41,11 @@ func NewPrependableFromView(v View) Prependable { return Prependable{buf: v, usedIdx: 0} } +// NewEmptyPrependableFromView creates a new prependable buffer from a View. +func NewEmptyPrependableFromView(v View) Prependable { + return Prependable{buf: v, usedIdx: len(v)} +} + // View returns a View of the backing buffer that contains all prepended // data so far. func (p Prependable) View() View { diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index 4cecfb989..b6fa6fc37 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", ], diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 096ad71ab..2f15bf1f1 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -22,6 +22,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -639,6 +640,8 @@ func ICMPv4Code(want byte) TransportChecker { // ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and // potentially additional ICMPv6 header fields. +// +// ICMPv6 will validate the checksum field before calling checkers. func ICMPv6(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -650,6 +653,10 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker { } icmp := header.ICMPv6(last.Payload()) + if got, want := icmp.Checksum(), header.ICMPv6Checksum(icmp, last.SourceAddress(), last.DestinationAddress(), buffer.VectorisedView{}); got != want { + t.Fatalf("Bad ICMPv6 checksum; got %d, want %d", got, want) + } + for _, f := range checkers { f(t, icmp) } @@ -686,3 +693,64 @@ func ICMPv6Code(want byte) TransportChecker { } } } + +// NDP creates a checker that checks that the packet contains a valid NDP +// message for type of ty, with potentially additional checks specified by +// checkers. +// +// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDP message as far as the size of the message (minSize) is concerned. The +// values within the message are up to checkers to validate. +func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + // Check normal ICMPv6 first. + ICMPv6( + ICMPv6Type(msgType), + ICMPv6Code(0))(t, h) + + last := h[len(h)-1] + + icmp := header.ICMPv6(last.Payload()) + if got := len(icmp.NDPPayload()); got < minSize { + t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) + } + + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// NDPNS creates a checker that checks that the packet contains a valid NDP +// Neighbor Solicitation message (as per the raw wire format), with potentially +// additional checks specified by checkers. +// +// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPNS message as far as the size of the messages concerned. The values within +// the message are up to checkers to validate. +func NDPNS(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...) +} + +// NDPNSTargetAddress creates a checker that checks the Target Address field of +// a header.NDPNeighborSolicit. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNS message as far as the size is concerned. +func NDPNSTargetAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + + if got := ns.TargetAddress(); got != want { + t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want) + } + } +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index b558350c3..a3485b35c 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library") - go_library( name = "header", srcs = [ @@ -17,6 +16,10 @@ go_library( "ipv4.go", "ipv6.go", "ipv6_fragment.go", + "ndp_neighbor_advert.go", + "ndp_neighbor_solicit.go", + "ndp_options.go", + "ndp_router_advert.go", "tcp.go", "udp.go", ], @@ -31,13 +34,26 @@ go_library( ) go_test( - name = "header_test", + name = "header_x_test", size = "small", srcs = [ + "checksum_test.go", "ipversion_test.go", "tcp_test.go", ], deps = [ ":header", + "//pkg/tcpip/buffer", + ], +) + +go_test( + name = "header_test", + size = "small", + srcs = [ + "eth_test.go", + "ndp_test.go", ], + embed = [":header"], + deps = ["//pkg/tcpip"], ) diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 39a4d69be..9749c7f4d 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -23,11 +23,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" ) -func calculateChecksum(buf []byte, initial uint32) uint16 { +func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { v := initial + if odd { + v += uint32(buf[0]) + buf = buf[1:] + } + l := len(buf) - if l&1 != 0 { + odd = l&1 != 0 + if odd { l-- v += uint32(buf[l]) << 8 } @@ -36,7 +42,7 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) } - return ChecksumCombine(uint16(v), uint16(v>>16)) + return ChecksumCombine(uint16(v), uint16(v>>16)), odd } // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the @@ -44,7 +50,8 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { // // The initial checksum must have been computed on an even number of bytes. func Checksum(buf []byte, initial uint16) uint16 { - return calculateChecksum(buf, uint32(initial)) + s, _ := calculateChecksum(buf, false, uint32(initial)) + return s } // ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in @@ -52,19 +59,40 @@ func Checksum(buf []byte, initial uint16) uint16 { // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { - var odd bool + return ChecksumVVWithOffset(vv, initial, 0, vv.Size()) +} + +// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the +// bytes in the given VectorizedView. +// +// The initial checksum must have been computed on an even number of bytes. +func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 { + odd := false sum := initial for _, v := range vv.Views() { if len(v) == 0 { continue } - s := uint32(sum) - if odd { - s += uint32(v[0]) - v = v[1:] + + if off >= len(v) { + off -= len(v) + continue + } + v = v[off:] + + l := len(v) + if l > size { + l = size + } + v = v[:l] + + sum, odd = calculateChecksum(v, odd, uint32(sum)) + + size -= len(v) + if size == 0 { + break } - odd = len(v)&1 != 0 - sum = calculateChecksum(v, s) + off = 0 } return sum } diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go new file mode 100644 index 000000000..86b466c1c --- /dev/null +++ b/pkg/tcpip/header/checksum_test.go @@ -0,0 +1,109 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package header provides the implementation of the encoding and decoding of +// network protocol headers. +package header_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestChecksumVVWithOffset(t *testing.T) { + testCases := []struct { + name string + vv buffer.VectorisedView + off, size int + initial uint16 + want uint16 + }{ + { + name: "empty", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + }), + off: 0, + size: 0, + want: 0, + }, + { + name: "OneView", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + }), + off: 0, + size: 5, + want: 1294, + }, + { + name: "TwoViews", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 0, + size: 11, + want: 33819, + }, + { + name: "TwoViewsWithOffset", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 1, + size: 11, + want: 33819, + }, + { + name: "ThreeViewsWithOffset", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), + }), + off: 7, + size: 11, + want: 33819, + }, + { + name: "ThreeViewsWithInitial", + vv: buffer.NewVectorisedView(0, []buffer.View{ + buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), + buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), + buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), + }), + initial: 77, + off: 7, + size: 11, + want: 33896, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { + t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) + } + v := tc.vv.ToView() + v.TrimFront(tc.off) + v.CapLength(tc.size) + if got, want := header.Checksum(v, tc.initial), tc.want; got != want { + t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) + } + }) + } +} diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index 4c3d3311f..f5d2c127f 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -48,8 +48,48 @@ const ( // EthernetAddressSize is the size, in bytes, of an ethernet address. EthernetAddressSize = 6 + + // unspecifiedEthernetAddress is the unspecified ethernet address + // (all bits set to 0). + unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + + // unicastMulticastFlagMask is the mask of the least significant bit in + // the first octet (in network byte order) of an ethernet address that + // determines whether the ethernet address is a unicast or multicast. If + // the masked bit is a 1, then the address is a multicast, unicast + // otherwise. + // + // See the IEEE Std 802-2001 document for more details. Specifically, + // section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf: + // "A 48-bit universal address consists of two parts. The first 24 bits + // correspond to the OUI as assigned by the IEEE, expect that the + // assignee may set the LSB of the first octet to 1 for group addresses + // or set it to 0 for individual addresses." + unicastMulticastFlagMask = 1 + + // unicastMulticastFlagByteIdx is the byte that holds the + // unicast/multicast flag. See unicastMulticastFlagMask. + unicastMulticastFlagByteIdx = 0 +) + +const ( + // EthernetProtocolAll is a catch-all for all protocols carried inside + // an ethernet frame. It is mainly used to create packet sockets that + // capture all traffic. + EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003 + + // EthernetProtocolPUP is the PARC Universial Packet protocol ethertype. + EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200 ) +// Ethertypes holds the protocol numbers describing the payload of an ethernet +// frame. These types aren't necessarily supported by netstack, but can be used +// to catch all traffic of a type via packet endpoints. +var Ethertypes = []tcpip.NetworkProtocolNumber{ + EthernetProtocolAll, + EthernetProtocolPUP, +} + // SourceAddress returns the "MAC source" field of the ethernet frame header. func (b Ethernet) SourceAddress() tcpip.LinkAddress { return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize]) @@ -72,3 +112,25 @@ func (b Ethernet) Encode(e *EthernetFields) { copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr) copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) } + +// IsValidUnicastEthernetAddress returns true if addr is a valid unicast +// ethernet address. +func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { + // Must be of the right length. + if len(addr) != EthernetAddressSize { + return false + } + + // Must not be unspecified. + if addr == unspecifiedEthernetAddress { + return false + } + + // Must not be a multicast. + if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 { + return false + } + + // addr is a valid unicast ethernet address. + return true +} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go new file mode 100644 index 000000000..6634c90f5 --- /dev/null +++ b/pkg/tcpip/header/eth_test.go @@ -0,0 +1,68 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestIsValidUnicastEthernetAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.LinkAddress + expected bool + }{ + { + "Nil", + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty", + tcpip.LinkAddress(""), + false, + }, + { + "InvalidLength", + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Unspecified", + unspecifiedEthernetAddress, + false, + }, + { + "Multicast", + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + false, + }, + { + "Valid", + tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), + true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected { + t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected) + } + }) + } +} diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 1125a7d14..b4037b6c8 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -25,6 +25,12 @@ import ( type ICMPv6 []byte const ( + // ICMPv6HeaderSize is the size of the ICMPv6 header. That is, the + // sum of the size of the ICMPv6 Type, Code and Checksum fields, as + // per RFC 4443 section 2.1. After the ICMPv6 header, the ICMPv6 + // message body begins. + ICMPv6HeaderSize = 4 + // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. ICMPv6MinimumSize = 8 @@ -37,10 +43,16 @@ const ( // ICMPv6NeighborSolicitMinimumSize is the minimum size of a // neighbor solicitation packet. - ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16 + ICMPv6NeighborSolicitMinimumSize = ICMPv6HeaderSize + NDPNSMinimumSize + + // ICMPv6NeighborAdvertMinimumSize is the minimum size of a + // neighbor advertisement packet. + ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize - // ICMPv6NeighborAdvertSize is size of a neighbor advertisement. - ICMPv6NeighborAdvertSize = 32 + // ICMPv6NeighborAdvertSize is size of a neighbor advertisement + // including the NDP Target Link Layer option for an Ethernet + // address. + ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpTargetEthernetLinkLayerAddressSize // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. ICMPv6EchoMinimumSize = 8 @@ -68,6 +80,13 @@ const ( // icmpv6SequenceOffset is the offset of the sequence field // in a ICMPv6 Echo Request/Reply message. icmpv6SequenceOffset = 6 + + // NDPHopLimit is the expected IP hop limit value of 255 for received + // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, + // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently + // drop the NDP packet. All outgoing NDP packets must use this value for + // its IP hop limit field. + NDPHopLimit = 255 ) // ICMPv6Type is the ICMP type field described in RFC 4443 and friends. @@ -113,7 +132,7 @@ func (b ICMPv6) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:]) } -// SetChecksum calculates and sets the ICMP checksum field. +// SetChecksum sets the ICMP checksum field. func (b ICMPv6) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum) } @@ -166,12 +185,19 @@ func (b ICMPv6) SetSequence(sequence uint16) { binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) } +// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6 +// packet's message body as defined by RFC 4443 section 2.1; the portion of the +// ICMPv6 buffer after the first ICMPv6HeaderSize bytes. +func (b ICMPv6) NDPPayload() []byte { + return b[ICMPv6HeaderSize:] +} + // Payload implements Transport.Payload. func (b ICMPv6) Payload() []byte { return b[ICMPv6PayloadOffset:] } -// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header, +// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { // Calculate the IPv6 pseudo-header upper-layer checksum. diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 554632a64..e5360e7c1 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -284,7 +284,7 @@ func (b IPv4) IsValid(pktSize int) bool { hlen := int(b.HeaderLength()) tlen := int(b.TotalLength()) - if hlen > tlen || tlen > pktSize { + if hlen < IPv4MinimumSize || hlen > tlen || tlen > pktSize { return false } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 9d3abc0e4..0caa51c1e 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -87,11 +87,14 @@ const ( // section 5. IPv6MinimumMTU = 1280 - // IPv6Any is the non-routable IPv6 "any" meta address. + // IPv6Any is the non-routable IPv6 "any" meta address. It is also + // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ) -// IPv6EmptySubnet is the empty IPv6 subnet. +// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the +// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to +// be contained within this subnet. var IPv6EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) if err != nil { @@ -100,6 +103,15 @@ var IPv6EmptySubnet = func() tcpip.Subnet { return subnet }() +// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined +// by RFC 4291 section 2.5.6. +// +// The prefix is fe80::/64 +var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{ + Address: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 64, +} + // PayloadLength returns the value of the "payload length" field of the ipv6 // header. func (b IPv6) PayloadLength() uint16 { diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go new file mode 100644 index 000000000..505c92668 --- /dev/null +++ b/pkg/tcpip/header/ndp_neighbor_advert.go @@ -0,0 +1,110 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import "gvisor.dev/gvisor/pkg/tcpip" + +// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will +// only contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.4 for more details. +type NDPNeighborAdvert []byte + +const ( + // NDPNAMinimumSize is the minimum size of a valid NDP Neighbor + // Advertisement message (body of an ICMPv6 packet). + NDPNAMinimumSize = 20 + + // ndpNATargetAddressOffset is the start of the Target Address + // field within an NDPNeighborAdvert. + ndpNATargetAddressOffset = 4 + + // ndpNAOptionsOffset is the start of the NDP options in an + // NDPNeighborAdvert. + ndpNAOptionsOffset = ndpNATargetAddressOffset + IPv6AddressSize + + // ndpNAFlagsOffset is the offset of the flags within an + // NDPNeighborAdvert + ndpNAFlagsOffset = 0 + + // ndpNARouterFlagMask is the mask of the Router Flag field in + // the flags byte within in an NDPNeighborAdvert. + ndpNARouterFlagMask = (1 << 7) + + // ndpNASolicitedFlagMask is the mask of the Solicited Flag field in + // the flags byte within in an NDPNeighborAdvert. + ndpNASolicitedFlagMask = (1 << 6) + + // ndpNAOverrideFlagMask is the mask of the Override Flag field in + // the flags byte within in an NDPNeighborAdvert. + ndpNAOverrideFlagMask = (1 << 5) +) + +// TargetAddress returns the value within the Target Address field. +func (b NDPNeighborAdvert) TargetAddress() tcpip.Address { + return tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]) +} + +// SetTargetAddress sets the value within the Target Address field. +func (b NDPNeighborAdvert) SetTargetAddress(addr tcpip.Address) { + copy(b[ndpNATargetAddressOffset:][:IPv6AddressSize], addr) +} + +// RouterFlag returns the value of the Router Flag field. +func (b NDPNeighborAdvert) RouterFlag() bool { + return b[ndpNAFlagsOffset]&ndpNARouterFlagMask != 0 +} + +// SetRouterFlag sets the value in the Router Flag field. +func (b NDPNeighborAdvert) SetRouterFlag(f bool) { + if f { + b[ndpNAFlagsOffset] |= ndpNARouterFlagMask + } else { + b[ndpNAFlagsOffset] &^= ndpNARouterFlagMask + } +} + +// SolicitedFlag returns the value of the Solicited Flag field. +func (b NDPNeighborAdvert) SolicitedFlag() bool { + return b[ndpNAFlagsOffset]&ndpNASolicitedFlagMask != 0 +} + +// SetSolicitedFlag sets the value in the Solicited Flag field. +func (b NDPNeighborAdvert) SetSolicitedFlag(f bool) { + if f { + b[ndpNAFlagsOffset] |= ndpNASolicitedFlagMask + } else { + b[ndpNAFlagsOffset] &^= ndpNASolicitedFlagMask + } +} + +// OverrideFlag returns the value of the Override Flag field. +func (b NDPNeighborAdvert) OverrideFlag() bool { + return b[ndpNAFlagsOffset]&ndpNAOverrideFlagMask != 0 +} + +// SetOverrideFlag sets the value in the Override Flag field. +func (b NDPNeighborAdvert) SetOverrideFlag(f bool) { + if f { + b[ndpNAFlagsOffset] |= ndpNAOverrideFlagMask + } else { + b[ndpNAFlagsOffset] &^= ndpNAOverrideFlagMask + } +} + +// Options returns an NDPOptions of the the options body. +func (b NDPNeighborAdvert) Options() NDPOptions { + return NDPOptions(b[ndpNAOptionsOffset:]) +} diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go new file mode 100644 index 000000000..3a1b8e139 --- /dev/null +++ b/pkg/tcpip/header/ndp_neighbor_solicit.go @@ -0,0 +1,52 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import "gvisor.dev/gvisor/pkg/tcpip" + +// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only +// contain the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.3 for more details. +type NDPNeighborSolicit []byte + +const ( + // NDPNSMinimumSize is the minimum size of a valid NDP Neighbor + // Solicitation message (body of an ICMPv6 packet). + NDPNSMinimumSize = 20 + + // ndpNSTargetAddessOffset is the start of the Target Address + // field within an NDPNeighborSolicit. + ndpNSTargetAddessOffset = 4 + + // ndpNSOptionsOffset is the start of the NDP options in an + // NDPNeighborSolicit. + ndpNSOptionsOffset = ndpNSTargetAddessOffset + IPv6AddressSize +) + +// TargetAddress returns the value within the Target Address field. +func (b NDPNeighborSolicit) TargetAddress() tcpip.Address { + return tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]) +} + +// SetTargetAddress sets the value within the Target Address field. +func (b NDPNeighborSolicit) SetTargetAddress(addr tcpip.Address) { + copy(b[ndpNSTargetAddessOffset:][:IPv6AddressSize], addr) +} + +// Options returns an NDPOptions of the the options body. +func (b NDPNeighborSolicit) Options() NDPOptions { + return NDPOptions(b[ndpNSOptionsOffset:]) +} diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go new file mode 100644 index 000000000..a2b9d7435 --- /dev/null +++ b/pkg/tcpip/header/ndp_options.go @@ -0,0 +1,463 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "errors" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // NDPTargetLinkLayerAddressOptionType is the type of the Target + // Link-Layer Address option, as per RFC 4861 section 4.6.1. + NDPTargetLinkLayerAddressOptionType = 2 + + // ndpTargetEthernetLinkLayerAddressSize is the size of a Target + // Link Layer Option for an Ethernet address. + ndpTargetEthernetLinkLayerAddressSize = 8 + + // NDPPrefixInformationType is the type of the Prefix Information + // option, as per RFC 4861 section 4.6.2. + NDPPrefixInformationType = 3 + + // ndpPrefixInformationLength is the expected length, in bytes, of the + // body of an NDP Prefix Information option, as per RFC 4861 section + // 4.6.2 which specifies that the Length field is 4. Given this, the + // expected length, in bytes, is 30 becuase 4 * lengthByteUnits (8) - 2 + // (Type & Length) = 30. + ndpPrefixInformationLength = 30 + + // ndpPrefixInformationPrefixLengthOffset is the offset of the Prefix + // Length field within an NDPPrefixInformation. + ndpPrefixInformationPrefixLengthOffset = 0 + + // ndpPrefixInformationFlagsOffset is the offset of the flags byte + // within an NDPPrefixInformation. + ndpPrefixInformationFlagsOffset = 1 + + // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationOnLinkFlagMask = (1 << 7) + + // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the + // Autonomous Address-Configuration flag field in the flags byte within + // an NDPPrefixInformation. + ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + + // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 + // field in the flags byte within an NDPPrefixInformation. + ndpPrefixInformationReserved1FlagsMask = 63 + + // ndpPrefixInformationValidLifetimeOffset is the start of the 4-byte + // Valid Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationValidLifetimeOffset = 2 + + // ndpPrefixInformationPreferredLifetimeOffset is the start of the + // 4-byte Preferred Lifetime field within an NDPPrefixInformation. + ndpPrefixInformationPreferredLifetimeOffset = 6 + + // ndpPrefixInformationReserved2Offset is the start of the 4-byte + // Reserved2 field within an NDPPrefixInformation. + ndpPrefixInformationReserved2Offset = 10 + + // ndpPrefixInformationReserved2Length is the length of the Reserved2 + // field. + // + // It is 4 bytes. + ndpPrefixInformationReserved2Length = 4 + + // ndpPrefixInformationPrefixOffset is the start of the Prefix field + // within an NDPPrefixInformation. + ndpPrefixInformationPrefixOffset = 14 + + // NDPPrefixInformationInfiniteLifetime is a value that represents + // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix + // Information option. Its value is (2^32 - 1)s = 4294967295s + NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295 + + // lengthByteUnits is the multiplier factor for the Length field of an + // NDP option. That is, the length field for NDP options is in units of + // 8 octets, as per RFC 4861 section 4.6. + lengthByteUnits = 8 +) + +// NDPOptionIterator is an iterator of NDPOption. +// +// Note, between when an NDPOptionIterator is obtained and last used, no changes +// to the NDPOptions may happen. Doing so may cause undefined and unexpected +// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first +// few NDPOption then modify the backing NDPOptions so long as the +// NDPOptionIterator obtained before modification is no longer used. +type NDPOptionIterator struct { + // The NDPOptions this NDPOptionIterator is iterating over. + opts NDPOptions +} + +// Potential errors when iterating over an NDPOptions. +var ( + ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") + ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") + ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") +) + +// Next returns the next element in the backing NDPOptions, or true if we are +// done, or false if an error occured. +// +// The return can be read as option, done, error. Note, option should only be +// used if done is false and error is nil. +func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { + for { + // Do we still have elements to look at? + if len(i.opts) == 0 { + return nil, true, nil + } + + // Do we have enough bytes for an NDP option that has a Length + // field of at least 1? Note, 0 in the Length field is invalid. + if len(i.opts) < lengthByteUnits { + return nil, true, ErrNDPOptBufExhausted + } + + // Get the Type field. + t := i.opts[0] + + // Get the Length field. + l := i.opts[1] + + // This would indicate an erroneous NDP option as the Length + // field should never be 0. + if l == 0 { + return nil, true, ErrNDPOptZeroLength + } + + // How many bytes are in the option body? + numBytes := int(l) * lengthByteUnits + numBodyBytes := numBytes - 2 + + potentialBody := i.opts[2:] + + // This would indicate an erroenous NDPOptions buffer as we ran + // out of the buffer in the middle of an NDP option. + if left := len(potentialBody); left < numBodyBytes { + return nil, true, ErrNDPOptBufExhausted + } + + // Get only the options body, leaving the rest of the options + // buffer alone. + body := potentialBody[:numBodyBytes] + + // Update opts with the remaining options body. + i.opts = i.opts[numBytes:] + + switch t { + case NDPTargetLinkLayerAddressOptionType: + return NDPTargetLinkLayerAddressOption(body), false, nil + + case NDPPrefixInformationType: + // Make sure the length of a Prefix Information option + // body is ndpPrefixInformationLength, as per RFC 4861 + // section 4.6.2. + if numBodyBytes != ndpPrefixInformationLength { + return nil, true, ErrNDPOptMalformedBody + } + + return NDPPrefixInformation(body), false, nil + default: + // We do not yet recognize the option, just skip for + // now. This is okay because RFC 4861 allows us to + // skip/ignore any unrecognized options. However, + // we MUST recognized all the options in RFC 4861. + // + // TODO(b/141487990): Handle all NDP options as defined + // by RFC 4861. + } + } +} + +// NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6. +type NDPOptions []byte + +// Iter returns an iterator of NDPOption. +// +// If check is true, Iter will do an integrity check on the options by iterating +// over it and returning an error if detected. +// +// See NDPOptionIterator for more information. +func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { + it := NDPOptionIterator{opts: b} + + if check { + for it2 := it; true; { + if _, done, err := it2.Next(); err != nil || done { + return it, err + } + } + } + + return it, nil +} + +// Serialize serializes the provided list of NDP options into o. +// +// Note, b must be of sufficient size to hold all the options in s. See +// NDPOptionsSerializer.Length for details on the getting the total size +// of a serialized NDPOptionsSerializer. +// +// Serialize may panic if b is not of sufficient size to hold all the options +// in s. +func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { + done := 0 + + for _, o := range s { + l := paddedLength(o) + + if l == 0 { + continue + } + + b[0] = o.Type() + + // We know this safe because paddedLength would have returned + // 0 if o had an invalid length (> 255 * lengthByteUnits). + b[1] = uint8(l / lengthByteUnits) + + // Serialize NDP option body. + used := o.serializeInto(b[2:]) + + // Zero out remaining (padding) bytes, if any exists. + for i := used + 2; i < l; i++ { + b[i] = 0 + } + + b = b[l:] + done += l + } + + return done +} + +// NDPOption is the set of functions to be implemented by all NDP option types. +type NDPOption interface { + // Type returns the type of the receiver. + Type() uint8 + + // Length returns the length of the body of the receiver, in bytes. + Length() int + + // serializeInto serializes the receiver into the provided byte + // buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // Length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto([]byte) int +} + +// paddedLength returns the length of o, in bytes, with any padding bytes, if +// required. +func paddedLength(o NDPOption) int { + l := o.Length() + + if l == 0 { + return 0 + } + + // Length excludes the 2 Type and Length bytes. + l += 2 + + // Add extra bytes if needed to make sure the option is + // lengthByteUnits-byte aligned. We do this by adding lengthByteUnits-1 + // to l and then stripping off the last few LSBits from l. This will + // make sure that l is rounded up to the nearest unit of + // lengthByteUnits. This works since lengthByteUnits is a power of 2 + // (= 8). + mask := lengthByteUnits - 1 + l += mask + l &^= mask + + if l/lengthByteUnits > 255 { + // Should never happen because an option can only have a max + // value of 255 for its Length field, so just return 0 so this + // option does not get serialized. + // + // Returning 0 here will make sure that this option does not get + // serialized when NDPOptions.Serialize is called with the + // NDPOptionsSerializer that holds this option, effectively + // skipping this option during serialization. Also note that + // a value of zero for the Length field in an NDP option is + // invalid so this is another sign to the caller that this NDP + // option is malformed, as per RFC 4861 section 4.6. + return 0 + } + + return l +} + +// NDPOptionsSerializer is a serializer for NDP options. +type NDPOptionsSerializer []NDPOption + +// Length returns the total number of bytes required to serialize. +func (b NDPOptionsSerializer) Length() int { + l := 0 + + for _, o := range b { + l += paddedLength(o) + } + + return l +} + +// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option +// as defined by RFC 4861 section 4.6.1. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPTargetLinkLayerAddressOption tcpip.LinkAddress + +// Type implements NDPOption.Type. +func (o NDPTargetLinkLayerAddressOption) Type() uint8 { + return NDPTargetLinkLayerAddressOptionType +} + +// Length implements NDPOption.Length. +func (o NDPTargetLinkLayerAddressOption) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + +// NDPPrefixInformation is the NDP Prefix Information option as defined by +// RFC 4861 section 4.6.2. +// +// The length, in bytes, of a valid NDP Prefix Information option body MUST be +// ndpPrefixInformationLength bytes. +type NDPPrefixInformation []byte + +// Type implements NDPOption.Type. +func (o NDPPrefixInformation) Type() uint8 { + return NDPPrefixInformationType +} + +// Length implements NDPOption.Length. +func (o NDPPrefixInformation) Length() int { + return ndpPrefixInformationLength +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPPrefixInformation) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the Reserved1 field. + b[ndpPrefixInformationFlagsOffset] &^= ndpPrefixInformationReserved1FlagsMask + + // Zero out the Reserved2 field. + reserved2 := b[ndpPrefixInformationReserved2Offset:][:ndpPrefixInformationReserved2Length] + for i := range reserved2 { + reserved2[i] = 0 + } + + return used +} + +// PrefixLength returns the value in the number of leading bits in the Prefix +// that are valid. +// +// Valid values are in the range [0, 128], but o may not always contain valid +// values. It is up to the caller to valdiate the Prefix Information option. +func (o NDPPrefixInformation) PrefixLength() uint8 { + return o[ndpPrefixInformationPrefixLengthOffset] +} + +// OnLinkFlag returns true of the prefix is considered on-link. On-link means +// that a forwarding node is not needed to send packets to other nodes on the +// same prefix. +// +// Note, when this function returns false, no statement is made about the +// on-link property of a prefix. That is, if OnLinkFlag returns false, the +// caller MUST NOT conclude that the prefix is off-link and MUST NOT update any +// previously stored state for this prefix about its on-link status. +func (o NDPPrefixInformation) OnLinkFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationOnLinkFlagMask != 0 +} + +// AutonomousAddressConfigurationFlag returns true if the prefix can be used for +// Stateless Address Auto-Configuration (as specified in RFC 4862). +func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { + return o[ndpPrefixInformationFlagsOffset]&ndpPrefixInformationAutoAddrConfFlagMask != 0 +} + +// ValidLifetime returns the length of time that the prefix is valid for the +// purpose of on-link determination. This value is relative to the send time of +// the packet that the Prefix Information option was present in. +// +// Note, a value of 0 implies the prefix should not be considered as on-link, +// and a value of infinity/forever is represented by +// NDPPrefixInformationInfiniteLifetime. +func (o NDPPrefixInformation) ValidLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) +} + +// PreferredLifetime returns the length of time that an address generated from +// the prefix via Stateless Address Auto-Configuration remains preferred. This +// value is relative to the send time of the packet that the Prefix Information +// option was present in. +// +// Note, a value of 0 implies that addresses generated from the prefix should +// no longer remain preferred, and a value of infinity is represented by +// NDPPrefixInformationInfiniteLifetime. +// +// Also note that the value of this field MUST NOT exceed the Valid Lifetime +// field to avoid preferring addresses that are no longer valid, for the +// purpose of Stateless Address Auto-Configuration. +func (o NDPPrefixInformation) PreferredLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.6.2. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationPreferredLifetimeOffset:])) +} + +// Prefix returns an IPv6 address or a prefix of an IPv6 address. The Prefix +// Length field (see NDPPrefixInformation.PrefixLength) contains the number +// of valid leading bits in the prefix. +// +// Hosts SHOULD ignore an NDP Prefix Information option where the Prefix field +// holds the link-local prefix (fe80::). +func (o NDPPrefixInformation) Prefix() tcpip.Address { + return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize]) +} diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go new file mode 100644 index 000000000..bf7610863 --- /dev/null +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -0,0 +1,112 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "time" +) + +// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.2 for more details. +type NDPRouterAdvert []byte + +const ( + // NDPRAMinimumSize is the minimum size of a valid NDP Router + // Advertisement message (body of an ICMPv6 packet). + NDPRAMinimumSize = 12 + + // ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field + // within an NDPRouterAdvert. + ndpRACurrHopLimitOffset = 0 + + // ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags + // within an NDPRouterAdvert. + ndpRAFlagsOffset = 1 + + // ndpRAManagedAddrConfFlagMask is the mask of the Managed Address + // Configuration flag within the bit-field/flags byte of an + // NDPRouterAdvert. + ndpRAManagedAddrConfFlagMask = (1 << 7) + + // ndpRAOtherConfFlagMask is the mask of the Other Configuration flag + // within the bit-field/flags byte of an NDPRouterAdvert. + ndpRAOtherConfFlagMask = (1 << 6) + + // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime + // field within an NDPRouterAdvert. + ndpRARouterLifetimeOffset = 2 + + // ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time + // field within an NDPRouterAdvert. + ndpRAReachableTimeOffset = 4 + + // ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer + // field within an NDPRouterAdvert. + ndpRARetransTimerOffset = 8 + + // ndpRAOptionsOffset is the start of the NDP options in an + // NDPRouterAdvert. + ndpRAOptionsOffset = 12 +) + +// CurrHopLimit returns the value of the Curr Hop Limit field. +func (b NDPRouterAdvert) CurrHopLimit() uint8 { + return b[ndpRACurrHopLimitOffset] +} + +// ManagedAddrConfFlag returns the value of the Managed Address Configuration +// flag. +func (b NDPRouterAdvert) ManagedAddrConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0 +} + +// OtherConfFlag returns the value of the Other Configuration flag. +func (b NDPRouterAdvert) OtherConfFlag() bool { + return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0 +} + +// RouterLifetime returns the lifetime associated with the default router. A +// value of 0 means the source of the Router Advertisement is not a default +// router and SHOULD NOT appear on the default router list. Note, a value of 0 +// only means that the router should not be used as a default router, it does +// not apply to other information contained in the Router Advertisement. +func (b NDPRouterAdvert) RouterLifetime() time.Duration { + // The field is the time in seconds, as per RFC 4861 section 4.2. + return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:])) +} + +// ReachableTime returns the time that a node assumes a neighbor is reachable +// after having received a reachability confirmation. A value of 0 means +// that it is unspecified by the source of the Router Advertisement message. +func (b NDPRouterAdvert) ReachableTime() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:])) +} + +// RetransTimer returns the time between retransmitted Neighbor Solicitation +// messages. A value of 0 means that it is unspecified by the source of the +// Router Advertisement message. +func (b NDPRouterAdvert) RetransTimer() time.Duration { + // The field is the time in milliseconds, as per RFC 4861 section 4.2. + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:])) +} + +// Options returns an NDPOptions of the the options body. +func (b NDPRouterAdvert) Options() NDPOptions { + return NDPOptions(b[ndpRAOptionsOffset:]) +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go new file mode 100644 index 000000000..ad6daafcd --- /dev/null +++ b/pkg/tcpip/header/ndp_test.go @@ -0,0 +1,570 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "bytes" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit. +func TestNDPNeighborSolicit(t *testing.T) { + b := []byte{ + 0, 0, 0, 0, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + } + + // Test getting the Target Address. + ns := NDPNeighborSolicit(b) + addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + if got := ns.TargetAddress(); got != addr { + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) + } + + // Test updating the Target Address. + addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + ns.SetTargetAddress(addr2) + if got := ns.TargetAddress(); got != addr2 { + t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) + } + // Make sure the address got updated in the backing buffer. + if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 { + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) + } +} + +// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert. +func TestNDPNeighborAdvert(t *testing.T) { + b := []byte{ + 160, 0, 0, 0, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + } + + // Test getting the Target Address. + na := NDPNeighborAdvert(b) + addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + if got := na.TargetAddress(); got != addr { + t.Errorf("got TargetAddress = %s, want %s", got, addr) + } + + // Test getting the Router Flag. + if got := na.RouterFlag(); !got { + t.Errorf("got RouterFlag = false, want = true") + } + + // Test getting the Solicited Flag. + if got := na.SolicitedFlag(); got { + t.Errorf("got SolicitedFlag = true, want = false") + } + + // Test getting the Override Flag. + if got := na.OverrideFlag(); !got { + t.Errorf("got OverrideFlag = false, want = true") + } + + // Test updating the Target Address. + addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + na.SetTargetAddress(addr2) + if got := na.TargetAddress(); got != addr2 { + t.Errorf("got TargetAddress = %s, want %s", got, addr2) + } + // Make sure the address got updated in the backing buffer. + if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 { + t.Errorf("got targetaddress buffer = %s, want %s", got, addr2) + } + + // Test updating the Router Flag. + na.SetRouterFlag(false) + if got := na.RouterFlag(); got { + t.Errorf("got RouterFlag = true, want = false") + } + + // Test updating the Solicited Flag. + na.SetSolicitedFlag(true) + if got := na.SolicitedFlag(); !got { + t.Errorf("got SolicitedFlag = false, want = true") + } + + // Test updating the Override Flag. + na.SetOverrideFlag(false) + if got := na.OverrideFlag(); got { + t.Errorf("got OverrideFlag = true, want = false") + } + + // Make sure flags got updated in the backing buffer. + if got := b[ndpNAFlagsOffset]; got != 64 { + t.Errorf("got flags byte = %d, want = 64") + } +} + +func TestNDPRouterAdvert(t *testing.T) { + b := []byte{ + 64, 128, 1, 2, + 3, 4, 5, 6, + 7, 8, 9, 10, + } + + ra := NDPRouterAdvert(b) + + if got := ra.CurrHopLimit(); got != 64 { + t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) + } + + if got := ra.ManagedAddrConfFlag(); !got { + t.Errorf("got ManagedAddrConfFlag = false, want = true") + } + + if got := ra.OtherConfFlag(); got { + t.Errorf("got OtherConfFlag = true, want = false") + } + + if got, want := ra.RouterLifetime(), time.Second*258; got != want { + t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) + } + + if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { + t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) + } + + if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { + t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) + } +} + +// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the +// Ethernet address from an NDPTargetLinkLayerAddressOption. +func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { + tests := []struct { + name string + buf []byte + expected tcpip.LinkAddress + }{ + { + "ValidMAC", + []byte{1, 2, 3, 4, 5, 6}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + { + "TLLBodyTooShort", + []byte{1, 2, 3, 4, 5}, + tcpip.LinkAddress([]byte(nil)), + }, + { + "TLLBodyLargerThanNeeded", + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tll := NDPTargetLinkLayerAddressOption(test.buf) + if got := tll.EthernetAddress(); got != test.expected { + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected) + } + }) + } + +} + +// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a +// NDPTargetLinkLayerAddressOption. +func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { + tests := []struct { + name string + buf []byte + expectedBuf []byte + addr tcpip.LinkAddress + }{ + { + "Ethernet", + make([]byte, 8), + []byte{2, 1, 1, 2, 3, 4, 5, 6}, + "\x01\x02\x03\x04\x05\x06", + }, + { + "Padding", + []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + "\x01\x02\x03\x04\x05\x06\x07\x08", + }, + { + "Empty", + []byte{}, + []byte{}, + "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + serializer := NDPOptionsSerializer{ + NDPTargetLinkLayerAddressOption(test.addr), + } + if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { + t.Fatalf("got Length = %d, want = %d", got, want) + } + opts.Serialize(serializer) + if !bytes.Equal(test.buf, test.expectedBuf) { + t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + if len(test.expectedBuf) > 0 { + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + tll := next.(NDPTargetLinkLayerAddressOption) + if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { + t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + + if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { + t.Errorf("got tll.MACAddress = %s, want = %s", got, want) + } + } + + // Iterator should not return anything else. + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + +// TestNDPPrefixInformationOption tests the field getters and serialization of a +// NDPPrefixInformation. +func TestNDPPrefixInformationOption(t *testing.T) { + b := []byte{ + 43, 127, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 5, 5, 5, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPPrefixInformation(b), + } + opts.Serialize(serializer) + expectedBuf := []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + if !bytes.Equal(targetBuf, expectedBuf) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + pi := next.(NDPPrefixInformation) + + if got := pi.Type(); got != 3 { + t.Errorf("got Type = %d, want = 3", got) + } + + if got := pi.Length(); got != 30 { + t.Errorf("got Length = %d, want = 30", got) + } + + if got := pi.PrefixLength(); got != 43 { + t.Errorf("got PrefixLength = %d, want = 43", got) + } + + if pi.OnLinkFlag() { + t.Error("got OnLinkFlag = true, want = false") + } + + if !pi.AutonomousAddressConfigurationFlag() { + t.Error("got AutonomousAddressConfigurationFlag = false, want = true") + } + + if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want { + t.Errorf("got ValidLifetime = %d, want = %d", got, want) + } + + if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want { + t.Errorf("got PreferredLifetime = %d, want = %d", got, want) + } + + if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want { + t.Errorf("got Prefix = %s, want = %s", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions +// the iterator was returned for is malformed. +func TestNDPOptionsIterCheck(t *testing.T) { + tests := []struct { + name string + buf []byte + expected error + }{ + { + "ZeroLengthField", + []byte{0, 0, 0, 0, 0, 0, 0, 0}, + ErrNDPOptZeroLength, + }, + { + "ValidTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5, 6}, + nil, + }, + { + "TooSmallTargetLinkLayerAddressOption", + []byte{2, 1, 1, 2, 3, 4, 5}, + ErrNDPOptBufExhausted, + }, + { + "ValidPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "TooSmallPrefixInformation", + []byte{ + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, + }, + ErrNDPOptBufExhausted, + }, + { + "InvalidPrefixInformationLength", + []byte{ + 3, 3, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + }, + ErrNDPOptMalformedBody, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformation", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + { + "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", + []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up + // being the type for some recognized type, + // update 255 to some other unrecognized value. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + }, + nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + + if _, err := opts.Iter(true); err != test.expected { + t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected) + } + + // test.buf may be malformed but we chose not to check + // the iterator so it must return true. + if _, err := opts.Iter(false); err != nil { + t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) + } + }) + } +} + +// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note, +// this test does not actually check any of the option's getters, it simply +// checks the option Type and Body. We have other tests that tests the option +// field gettings given an option body and don't need to duplicate those tests +// here. +func TestNDPOptionsIter(t *testing.T) { + buf := []byte{ + // Target Link-Layer Address. + 2, 1, 1, 2, 3, 4, 5, 6, + + // 255 is an unrecognized type. If 255 ends up being the type + // for some recognized type, update 255 to some other + // unrecognized value. Note, this option should be skipped when + // iterating. + 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8, + + // Prefix information. + 3, 4, 43, 64, + 1, 2, 3, 4, + 5, 6, 7, 8, + 0, 0, 0, 0, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 21, 22, 23, 24, + } + + opts := NDPOptions(buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Test the first (Taret Link-Layer) option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + } + + // Test the next (Prefix Information) option. + // Note, the unrecognized option should be skipped. + next, done, err = it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPPrefixInformationType { + t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 3fc14bacd..cc5f531e2 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "iptables", srcs = [ diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 18adb2085..22eefb564 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -65,14 +65,14 @@ func (e *Endpoint) Drain() int { } } -// Inject injects an inbound packet. -func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.InjectLinkAddr(protocol, "", vv) +// InjectInbound injects an inbound packet. +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil)) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets @@ -96,7 +96,7 @@ func (e *Endpoint) MTU() uint32 { func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { caps := stack.LinkEndpointCapabilities(0) if e.GSO { - caps |= stack.CapabilityGSO + caps |= stack.CapabilityHardwareGSO } return caps } @@ -134,5 +134,49 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen return nil } +// WritePackets stores outbound packets into the channel. +func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + payloadView := payload.ToView() + n := 0 +packetLoop: + for i := range hdrs { + hdr := &hdrs[i].Hdr + off := hdrs[i].Off + size := hdrs[i].Size + p := PacketInfo{ + Header: hdr.View(), + Proto: protocol, + Payload: buffer.NewViewFromBytes(payloadView[off : off+size]), + GSO: gso, + } + + select { + case e.C <- p: + n++ + default: + break packetLoop + } + } + + return n, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { + p := PacketInfo{ + Header: packet.ToView(), + Proto: 0, + Payload: buffer.View{}, + GSO: nil, + } + + select { + case e.C <- p: + default: + } + + return nil +} + // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 584db710e..edef7db26 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -82,7 +82,19 @@ const ( PacketMMap ) -// An endpoint implements the link-layer using a message-oriented file descriptor. +func (p PacketDispatchMode) String() string { + switch p { + case Readv: + return "Readv" + case RecvMMsg: + return "RecvMMsg" + case PacketMMap: + return "PacketMMap" + default: + return fmt.Sprintf("unknown packet dispatch mode %v", p) + } +} + type endpoint struct { // fds is the set of file descriptors each identifying one inbound/outbound // channel. The endpoint will dispatch from all inbound channels as well as @@ -153,6 +165,9 @@ type Options struct { // disabled. GSOMaxSize uint32 + // SoftwareGSOEnabled indicates whether software GSO is enabled or not. + SoftwareGSOEnabled bool + // PacketDispatchMode specifies the type of inbound dispatcher to be // used for this endpoint. PacketDispatchMode PacketDispatchMode @@ -166,6 +181,14 @@ type Options struct { RXChecksumOffload bool } +// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT +// support in the host kernel. This allows us to use multiple FD's to receive +// from the same underlying NIC. The fanoutID needs to be the same for a given +// set of FD's that point to the same NIC. Trying to set the PACKET_FANOUT +// option for an FD with a fanoutID already in use by another FD for a different +// NIC will return an EINVAL. +var fanoutID = 1 + // New creates a new fd-based endpoint. // // Makes fd non-blocking, but does not take ownership of fd, which must remain @@ -222,7 +245,11 @@ func New(opts *Options) (stack.LinkEndpoint, error) { } if isSocket { if opts.GSOMaxSize != 0 { - e.caps |= stack.CapabilityGSO + if opts.SoftwareGSOEnabled { + e.caps |= stack.CapabilitySoftwareGSO + } else { + e.caps |= stack.CapabilityHardwareGSO + } e.gsoMaxSize = opts.GSOMaxSize } } @@ -233,6 +260,10 @@ func New(opts *Options) (stack.LinkEndpoint, error) { e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) } + // Increment fanoutID to ensure that we don't re-use the same fanoutID for + // the next endpoint. + fanoutID++ + return e, nil } @@ -253,7 +284,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher case *unix.SockaddrLinklayer: // enable PACKET_FANOUT mode is the underlying socket is // of type AF_PACKET. - const fanoutID = 1 const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG fanoutArg := fanoutID | fanoutType<<16 if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { @@ -374,7 +404,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen eth.Encode(ethHdr) } - if e.Capabilities()&stack.CapabilityGSO != 0 { + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) if gso != nil { @@ -407,8 +437,130 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil) } -// WriteRawPacket writes a raw packet directly to the file descriptor. -func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { +// WritePackets writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + var ethHdrBuf []byte + // hdr + data + iovLen := 2 + if e.hdrSize > 0 { + // Add ethernet header if needed. + ethHdrBuf = make([]byte, header.EthernetMinimumSize) + eth := header.Ethernet(ethHdrBuf) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } + + // Preserve the src address if it's set in the route. + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + iovLen++ + } + + n := len(hdrs) + + views := payload.Views() + /* + * Each bondary in views can add one more iovec. + * + * payload | | | | + * ----------------------------- + * packets | | | | | | | + * ----------------------------- + * iovecs | | | | | | | | | + */ + iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) + mmsgHdrs := make([]rawfile.MMsgHdr, n) + + iovecIdx := 0 + viewIdx := 0 + viewOff := 0 + off := 0 + nextOff := 0 + for i := range hdrs { + prevIovecIdx := iovecIdx + mmsgHdr := &mmsgHdrs[i] + mmsgHdr.Msg.Iov = &iovec[iovecIdx] + packetSize := hdrs[i].Size + hdr := &hdrs[i].Hdr + + off = hdrs[i].Off + if off != nextOff { + // We stop in a different point last time. + size := packetSize + viewIdx = 0 + viewOff = 0 + for size > 0 { + if size >= len(views[viewIdx]) { + viewIdx++ + viewOff = 0 + size -= len(views[viewIdx]) + } else { + viewOff = size + size = 0 + } + } + } + nextOff = off + packetSize + + if ethHdrBuf != nil { + v := &iovec[iovecIdx] + v.Base = ðHdrBuf[0] + v.Len = uint64(len(ethHdrBuf)) + iovecIdx++ + } + + v := &iovec[iovecIdx] + hdrView := hdr.View() + v.Base = &hdrView[0] + v.Len = uint64(len(hdrView)) + iovecIdx++ + + for packetSize > 0 { + vec := &iovec[iovecIdx] + iovecIdx++ + + v := views[viewIdx] + vec.Base = &v[viewOff] + s := len(v) - viewOff + if s <= packetSize { + viewIdx++ + viewOff = 0 + } else { + s = packetSize + viewOff += s + } + vec.Len = uint64(s) + packetSize -= s + } + + mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) + } + + packets := 0 + for packets < n { + sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs) + if err != nil { + return packets, err + } + packets += sent + mmsgHdrs = mmsgHdrs[sent:] + } + return packets, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { + return rawfile.NonBlockingWrite(e.fds[0], packet.ToView()) +} + +// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. +func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) } @@ -445,9 +597,9 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher } -// Inject injects an inbound packet. -func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv) +// InjectInbound injects an inbound packet. +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 04406bc9a..7e08e033b 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents buffer.View + contents tcpip.PacketBuffer } type context struct { @@ -92,8 +92,8 @@ func (c *context) cleanup() { syscall.Close(c.fds[1]) } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - c.ch <- packetInfo{remote, protocol, vv.ToView()} +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + c.ch <- packetInfo{remote, protocol, pkt} } func TestNoEthernetProperties(t *testing.T) { @@ -293,11 +293,12 @@ func TestDeliverPacket(t *testing.T) { b[i] = uint8(rand.Intn(256)) } + var hdr header.Ethernet if !eth { // So that it looks like an IPv4 packet. b[0] = 0x40 } else { - hdr := make(header.Ethernet, header.EthernetMinimumSize) + hdr = make(header.Ethernet, header.EthernetMinimumSize) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, @@ -315,14 +316,21 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: b, + raddr: raddr, + proto: proto, + contents: tcpip.PacketBuffer{ + Data: buffer.View(b).ToVectorisedView(), + LinkHeader: buffer.View(hdr), + }, } if !eth { want.proto = header.IPv4ProtocolNumber want.raddr = "" } + // want.contents.Data will be a single + // view, so make pi do the same for the + // DeepEqual check. + pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() if !reflect.DeepEqual(want, pi) { t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 8bfeb97e4..62ed1e569 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -169,9 +169,10 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(pkt) + eth = header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -189,6 +190,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)})) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{ + Data: buffer.View(pkt).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 7ca217e5b..c67d684ce 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -53,7 +53,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &readVDispatcher{fd: fd, e: e} d.views = make([]buffer.View, len(BufConfig)) iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { iovLen++ } d.iovecs = make([]syscall.Iovec, iovLen) @@ -63,7 +63,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { func (d *readVDispatcher) allocateViews(bufConfig []int) { var vnetHdr [virtioNetHdrSize]byte vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. @@ -106,7 +106,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { if err != nil { return false, err } - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // Skip virtioNetHdr which is added before each packet, it // isn't used and it isn't in a view. n -= virtioNetHdrSize @@ -118,9 +118,10 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(d.views[0]) + eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,10 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, d.views[:used]) - vv.TrimFront(d.e.hdrSize) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -194,7 +198,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { } d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv) iovLen := len(BufConfig) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // virtioNetHdr is prepended before each packet. iovLen++ } @@ -225,7 +229,7 @@ func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) { for k := 0; k < len(d.views); k++ { var vnetHdr [virtioNetHdrSize]byte vnetHdrOff := 0 - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { // The kernel adds virtioNetHdr before each packet, but // we don't use it, so so we allocate a buffer for it, // add it in iovecs but don't add it in a view. @@ -261,7 +265,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { // Process each of received packets. for k := 0; k < nMsgs; k++ { n := int(d.msgHdrs[k].Len) - if d.e.Capabilities()&stack.CapabilityGSO != 0 { + if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } if n <= d.e.hdrSize { @@ -271,9 +275,10 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress + eth header.Ethernet ) if d.e.hdrSize > 0 { - eth := header.Ethernet(d.views[k][0]) + eth = header.Ethernet(d.views[k][0]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -291,9 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), d.views[k][:used]) - vv.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index 47a54845c..23e4d1418 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -10,6 +10,7 @@ go_library( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index b36629d2c..bc5d8a2f3 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -23,6 +23,7 @@ package loopback import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,21 +71,45 @@ func (*endpoint) LinkAddress() tcpip.LinkAddress { return "" } +// Wait implements stack.LinkEndpoint.Wait. +func (*endpoint) Wait() {} + // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) // Because we're immediately turning around and writing the packet back to the // rx path, we intentionally don't preserve the remote and local link // addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv) + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + }) return nil } -// Wait implements stack.LinkEndpoint.Wait. -func (*endpoint) Wait() {} +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { + // Reject the packet if it's shorter than an ethernet header. + if packet.Size() < header.EthernetMinimumSize { + return tcpip.ErrBadAddress + } + + // There should be an ethernet header at the beginning of packet. + linkHeader := header.Ethernet(packet.First()[:header.EthernetMinimumSize]) + packet.TrimFront(len(linkHeader)) + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ + Data: packet, + LinkHeader: buffer.View(linkHeader), + }) + + return nil +} diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 7c946101d..9a8e8ebfe 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -79,29 +79,47 @@ func (m *InjectableEndpoint) IsAttached() bool { return m.dispatcher != nil } -// Inject implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv) +// InjectInbound implements stack.InjectableLinkEndpoint. +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) +} + +// WritePackets writes outbound packets to the appropriate +// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if +// r.RemoteAddress has a route registered in this endpoint. +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + endpoint, ok := m.routes[r.RemoteAddress] + if !ok { + return 0, tcpip.ErrNoRoute + } + return endpoint.WritePackets(r, gso, hdrs, payload, protocol) } // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, nil /* gso */, hdr, payload, protocol) + return endpoint.WritePacket(r, gso, hdr, payload, protocol) } return tcpip.ErrNoRoute } -// WriteRawPacket writes outbound packets to the appropriate +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (m *InjectableEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { + // WriteRawPacket doesn't get a route or network address, so there's + // nowhere to write this. + return tcpip.ErrNoRoute +} + +// InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error { +func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { endpoint, ok := m.routes[dest] if !ok { return tcpip.ErrNoRoute } - return endpoint.WriteRawPacket(dest, packet) + return endpoint.InjectOutbound(dest, packet) } // Wait implements stack.LinkEndpoint.Wait. diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 3086fec00..9cd300af8 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -31,7 +31,7 @@ import ( func TestInjectableEndpointRawDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - endpoint.WriteRawPacket(dstIP, []byte{0xFA}) + endpoint.InjectOutbound(dstIP, []byte{0xFA}) buf := make([]byte, ipv4.MaxTotalSize) bytesRead, err := sock.Read(buf) diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 2e8bc772a..05c7b8024 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -16,5 +16,8 @@ go_library( visibility = [ "//visibility:public", ], - deps = ["//pkg/tcpip"], + deps = [ + "//pkg/tcpip", + "@org_golang_x_sys//unix:go_default_library", + ], ) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index dda3b10a6..0b5a6cf49 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 7e286a3a6..44e25d475 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -22,6 +22,7 @@ import ( "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -101,6 +102,16 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { return nil } +// NonBlockingSendMMsg sends multiple messages on a socket. +func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { + n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) + if e != 0 { + return 0, TranslateErrno(e) + } + + return int(n), nil +} + // PollEvent represents the pollfd structure passed to a poll() system call. type PollEvent struct { FD int32 diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 9e71d4edf..2bace5298 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -212,6 +212,26 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa return nil } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { + v := packet.ToView() + // Transmit the packet. + e.mu.Lock() + ok := e.tx.transmit(v, buffer.View{}) + e.mu.Unlock() + + if !ok { + return tcpip.ErrWouldBlock + } + + return nil +} + // dispatchLoop reads packets from the rx queue in a loop and dispatches them // to the network stack. func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { @@ -253,8 +273,11 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { } // Send packet up the stack. - eth := header.Ethernet(b) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView()) + eth := header.Ethernet(b[:header.EthernetMinimumSize]) + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ + Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 0e9ba0846..199406886 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -78,9 +78,10 @@ func (q *queueBuffers) cleanup() { } type packetInfo struct { - addr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - vv buffer.VectorisedView + addr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + vv buffer.VectorisedView + linkHeader buffer.View } type testContext struct { @@ -130,12 +131,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, proto: proto, - vv: vv.Clone(nil), + vv: pkt.Data.Clone(nil), }) c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index e401dce44..d71a03cd2 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -116,19 +116,19 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, vv.First(), nil) + logPacket("recv", protocol, pkt.Data.First(), nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - vs := vv.Views() - length := vv.Size() + vs := pkt.Data.Views() + length := pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil { panic(err) } for _, v := range vs { @@ -147,7 +147,7 @@ func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local panic(err) } } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) } // Attach implements the stack.LinkEndpoint interface. It saves the dispatcher @@ -193,10 +193,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -// WritePacket implements the stack.LinkEndpoint interface. It is called by -// higher-level protocols to write packets; it just logs the packet and forwards -// the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", protocol, hdr.View(), gso) } @@ -218,28 +215,74 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen panic(err) } length -= len(hdrBuf) - if length > 0 { - for _, v := range payload.Views() { - if len(v) > length { - v = v[:length] - } - n, err := buf.Write(v) - if err != nil { - panic(err) - } - length -= n - if length == 0 { - break - } - } - } + logVectorisedView(payload, length, buf) if _, err := e.file.Write(buf.Bytes()); err != nil { panic(err) } } +} + +// WritePacket implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and +// forwards the request to the lower endpoint. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + e.dumpPacket(gso, hdr, payload, protocol) return e.lower.WritePacket(r, gso, hdr, payload, protocol) } +// WritePackets implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and +// forwards the request to the lower endpoint. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + view := payload.ToView() + for _, d := range hdrs { + e.dumpPacket(gso, d.Hdr, buffer.NewVectorisedView(d.Size, []buffer.View{view[d.Off:][:d.Size]}), protocol) + } + return e.lower.WritePackets(r, gso, hdrs, payload, protocol) +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { + if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { + logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */) + } + if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { + length := packet.Size() + if length > int(e.maxPCAPLen) { + length = int(e.maxPCAPLen) + } + + buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(packet.Size()))); err != nil { + panic(err) + } + logVectorisedView(packet, length, buf) + if _, err := e.file.Write(buf.Bytes()); err != nil { + panic(err) + } + } + return e.lower.WriteRawPacket(packet) +} + +func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) { + if length <= 0 { + return + } + for _, v := range vv.Views() { + if len(v) > length { + v = v[:length] + } + n, err := buf.Write(v) + if err != nil { + panic(err) + } + length -= n + if length == 0 { + return + } + } +} + // Wait implements stack.LinkEndpoint.Wait. func (*endpoint) Wait() {} diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 5a1791cb5..b440970e0 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) e.dispatchGate.Leave() } @@ -109,6 +109,30 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return err } +// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by +// higher-level protocols to write packets. It only forwards packets to the +// lower endpoint if Wait or WaitWrite haven't been called. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if !e.writeGate.Enter() { + return len(hdrs), nil + } + + n, err := e.lower.WritePackets(r, gso, hdrs, payload, protocol) + e.writeGate.Leave() + return n, err +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { + if !e.writeGate.Enter() { + return nil + } + + err := e.lower.WriteRawPacket(packet) + e.writeGate.Leave() + return err +} + // WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, // and waits for inflight ones to finish before returning. func (e *Endpoint) WaitWrite() { diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index ae23c96b7..df2e70e54 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.dispatchCount++ } @@ -70,6 +70,17 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P return nil } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += len(hdrs) + return len(hdrs), nil +} + +func (e *countedEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { + e.writeCount++ + return nil +} + // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} @@ -109,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 26cf1c528..0ee509ebe 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -42,7 +42,7 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache } @@ -58,7 +58,7 @@ func (e *endpoint) MTU() uint32 { } func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { @@ -79,16 +79,21 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error { return tcpip.ErrNotSupported } +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketDescriptor, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - v := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return @@ -97,7 +102,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) @@ -109,7 +114,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + fallthrough // also fill the cache from requests case header.ARPReply: + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) } } @@ -126,12 +135,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { if addrWithPrefix.Address != ProtocolAddress { return nil, tcpip.ErrBadLocalAddress } return &endpoint{ - nicid: nicid, + nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, }, nil diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 88b57ec03..47098bfdc 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -102,7 +102,9 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) + c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ + Data: v.ToVectorisedView(), + }) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index c5c7aad86..2cad0a0b6 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "reassembler_list", out = "reassembler_list.go", @@ -29,6 +28,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/tcpip", "//pkg/tcpip/buffer", ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 1628a82be..6da5238ec 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -17,6 +17,7 @@ package fragmentation import ( + "fmt" "log" "sync" "time" @@ -82,7 +83,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t // Process processes an incoming fragment belonging to an ID // and returns a complete packet when all the packets belonging to that ID have been received. -func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) { +func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { f.mu.Lock() r, ok := f.reassemblers[id] if ok && r.tooOld(f.timeout) { @@ -97,8 +98,15 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf } f.mu.Unlock() - res, done, consumed := r.process(first, last, more, vv) - + res, done, consumed, err := r.process(first, last, more, vv) + if err != nil { + // We probably got an invalid sequence of fragments. Just + // discard the reassembler and move on. + f.mu.Lock() + f.release(r) + f.mu.Unlock() + return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err) + } f.mu.Lock() f.size += consumed if done { @@ -114,7 +122,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf } } f.mu.Unlock() - return res, done + return res, done, nil } func (f *Fragmentation) release(r *reassembler) { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 799798544..72c0f53be 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -83,7 +83,10 @@ func TestFragmentationProcess(t *testing.T) { t.Run(c.comment, func(t *testing.T) { f := NewFragmentation(1024, 512, DefaultReassembleTimeout) for i, in := range c.in { - vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv) + vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) + if err != nil { + t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) + } if !reflect.DeepEqual(vv, c.out[i].vv) { t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) } @@ -114,7 +117,10 @@ func TestReassemblingTimeout(t *testing.T) { time.Sleep(2 * timeout) // Send another fragment that completes a packet. // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done := f.Process(0, 1, 1, false, vv(1, "1")) + _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) + if err != nil { + t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) + } if done { t.Errorf("Fragmentation does not respect the reassembling timeout.") } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 8037f734b..9e002e396 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) { +func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -86,7 +86,7 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, false, consumed + return buffer.VectorisedView{}, false, consumed, nil } if r.updateHoles(first, last, more) { // We store the incoming packet only if it filled some holes. @@ -96,13 +96,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise } // Check if all the holes have been deleted and we are ready to reassamble. if r.deleted < len(r.holes) { - return buffer.VectorisedView{}, false, consumed + return buffer.VectorisedView{}, false, consumed, nil } res, err := r.heap.reassemble() if err != nil { - panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err)) + return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err) } - return res, true, consumed + return res, true, consumed, nil } func (r *reassembler) tooOld(timeout time.Duration) bool { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a9741622e..fe499d47e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,16 +96,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { - t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { + t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - t.checkValues(trans, vv, remote, local) +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) } @@ -171,6 +171,15 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen return nil } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, hdr []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + return tcpip.ErrNotSupported +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, @@ -230,7 +239,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -270,7 +279,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -358,7 +369,9 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: vv, + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -421,13 +434,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, frag1.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag1.ToVectorisedView(), + }) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, frag2.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag2.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -460,7 +477,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -500,7 +517,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -510,6 +529,7 @@ func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } const mtu = 0xffff + const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" cases := []struct { name string expectedCount int @@ -561,7 +581,7 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, - SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + SrcAddr: outerSrcAddr, DstAddr: localIpv6Addr, }) @@ -608,8 +628,12 @@ func TestIPv6ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + // Set ICMPv6 checksum. + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view[:len(view)-c.trunc].ToVectorisedView(), + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index a25756443..ce771631c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -24,8 +25,8 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -39,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } hlen := int(h.HeaderLength()) - if vv.Size() < hlen || h.FragmentOffset() != 0 { + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -47,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } // Skip the ip header, then deliver control message. - vv.TrimFront(hlen) + pkt.Data.TrimFront(hlen) p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return @@ -73,20 +74,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // checksum. We'll have to reset this before we hand the packet // off. h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) if gotChecksum != wantChecksum { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{ + Data: pkt.Data.Clone(nil), + NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + }) - vv := vv.Clone(nil) + vv := pkt.Data.Clone(nil) vv.TrimFront(header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -95,7 +99,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil { + if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } @@ -104,19 +108,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - vv.TrimFront(header.ICMPv4MinimumSize) + pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) } case header.ICMPv4SrcQuench: diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b7b07a6c1..ac16c8add 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -39,12 +39,15 @@ const ( // TotalLength field of the ipv4 header. MaxTotalSize = 0xffff + // DefaultTTL is the default time-to-live value for this endpoint. + DefaultTTL = 64 + // buckets is the number of identifier buckets. buckets = 2048 ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint @@ -54,9 +57,9 @@ type endpoint struct { } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, @@ -70,7 +73,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi // DefaultTTL is the default time-to-live value for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -86,7 +89,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv4 endpoint ID. @@ -195,34 +198,44 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff return nil } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payload.Size()) + length := uint16(hdr.UsedLength() + payloadSize) id := uint32(0) if length > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) } ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, ID: uint16(id), - TTL: ttl, - Protocol: uint8(protocol), + TTL: params.TTL, + TOS: params.TOS, + Protocol: uint8(params.Protocol), SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { + ip := e.addIPHeader(r, &hdr, payload.Size(), params) if loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + NetworkHeader: buffer.View(ip), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -238,6 +251,23 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return nil } +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("multiple packets in local loop") + } + if loop&stack.PacketOut == 0 { + return len(hdrs), nil + } + + for i := range hdrs { + e.addIPHeader(r, &hdrs[i].Hdr, hdrs[i].Size, params) + } + n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err +} + // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { @@ -276,7 +306,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(^ip.CalculateChecksum()) if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, payload) + e.HandlePacket(r, tcpip.PacketBuffer{ + Data: payload, + NetworkHeader: buffer.View(ip), + }) } if loop&stack.PacketOut == 0 { return nil @@ -289,24 +322,47 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv4(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { + r.Stats().IP.MalformedPacketsReceived.Increment() return } + pkt.NetworkHeader = headerView[:h.HeaderLength()] hlen := int(h.HeaderLength()) tlen := int(h.TotalLength()) - vv.TrimFront(hlen) - vv.CapLength(tlen - hlen) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { + if pkt.Data.Size() == 0 { + // Drop the packet as it's marked as a fragment but has + // no payload. + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(vv.Size()) - 1 + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 + // Drop the packet if the fragmentOffset is incorrect. i.e the + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. + if last < h.FragmentOffset() { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } var ready bool - vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + var err error + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } if !ready { return } @@ -314,11 +370,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { headerView.CapLength(hlen) - e.handleICMP(r, headerView, vv) + e.handleICMP(r, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. @@ -327,6 +383,11 @@ func (e *endpoint) Close() {} type protocol struct { ids []uint32 hashIV uint32 + + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 } // Number returns the ipv4 protocol number. @@ -352,12 +413,34 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -391,5 +474,5 @@ func NewProtocol() stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ids: ids, hashIV: hashIV} + return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index b6641ccc3..01dfb5f20 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -47,13 +47,6 @@ func TestExcludeBroadcast(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: 1, @@ -305,7 +298,7 @@ func TestFragmentation(t *testing.T) { Payload: payload.Clone([]buffer.View{}), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) if err != nil { t.Errorf("err got %v, want %v", err, nil) } @@ -352,7 +345,7 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -369,3 +362,119 @@ func TestFragmentationErrors(t *testing.T) { }) } } + +func TestInvalidFragments(t *testing.T) { + // These packets have both IHL and TotalLength set to 0. + testCases := []struct { + name string + packets [][]byte + wantMalformedIPPackets uint64 + wantMalformedFragments uint64 + }{ + { + "ihl_totallen_zero_valid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + "ihl_totallen_zero_invalid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + // Total Length of 37(20 bytes IP header + 17 bytes of + // payload) + // Frag Offset of 0x1ffe = 8190*8 = 65520 + // Leading to the fragment end to be past 65535. + "ihl_totallen_valid_invalid_frag_offset_1", + [][]byte{ + {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + // The following 3 tests were found by running a fuzzer and were + // triggering a panic in the IPv4 reassembler code. + { + "ihl_less_than_ipv4_minimum_size_1", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_2", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_3", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "fragment_with_short_total_len_extra_payload", + [][]byte{ + {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + { + "multiple_fragments_with_more_fragments_set_to_false", + [][]byte{ + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }, + 1, + 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + const nicID tcpip.NICID = 42 + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + }, + }) + + var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) + var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) + ep := channel.New(10, 1500, linkAddr) + s.CreateNIC(nicID, sniffer.New(ep)) + + for _, pkt := range tc.packets { + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + }) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { + t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { + t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b4d0295bf..6629951c6 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -21,21 +21,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -const ( - // ndpHopLimit is the expected IP hop limit value of 255 for received - // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, - // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently - // drop the NDP packet. All outgoing NDP packets must use this value for - // its IP hop limit field. - ndpHopLimit = 255 -) - // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -49,10 +40,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip the IP header, then handle the fragmentation header if there // is one. - vv.TrimFront(header.IPv6MinimumSize) + pkt.Data.TrimFront(header.IPv6MinimumSize) p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(vv.First()) + f := header.IPv6Fragment(pkt.Data.First()) if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. @@ -61,35 +52,54 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip fragmentation header and find out the actual protocol // number. - vv.TrimFront(header.IPv6FragmentHeaderSize) + pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } h := header.ICMPv6(v) + iph := header.IPv6(netHeader) + + // Validate ICMPv6 checksum before processing the packet. + // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. + // This copy is used as extra payload during the checksum calculation. + payload := pkt.Data + payload.RemoveFirst() + if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + received.Invalid.Increment() + return + } // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field - // in the IPv6 header is not set to 255. + // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not + // set to 0. switch h.Type() { case header.ICMPv6NeighborSolicit, header.ICMPv6NeighborAdvert, header.ICMPv6RouterSolicit, header.ICMPv6RouterAdvert, header.ICMPv6RedirectMsg: - if header.IPv6(netHeader).HopLimit() != ndpHopLimit { + if iph.HopLimit() != header.NDPHopLimit { + received.Invalid.Increment() + return + } + + if h.Code() != 0 { received.Invalid.Increment() return } @@ -103,9 +113,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) mtu := h.MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -113,33 +123,80 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch h.Code() { case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) - if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { + + ns := header.NDPNeighborSolicit(h.NDPPayload()) + targetAddr := ns.TargetAddress() + s := r.Stack() + rxNICID := r.NICID() + + isTentative, err := s.IsAddrTentative(rxNICID, targetAddr) + if err != nil { + // We will only get an error if rxNICID is unrecognized, + // which should not happen. For now short-circuit this + // packet. + // + // TODO(b/141002840): Handle this better? + return + } + + if isTentative { + // If the target address is tentative and the source + // of the packet is a unicast (specified) address, then + // the source of the packet is attempting to perform + // address resolution on the target. In this case, the + // solicitation is silently ignored, as per RFC 4862 + // section 5.4.3. + // + // If the target address is tentative and the source of + // the packet is the unspecified address (::), then we + // know another node is also performing DAD for the + // same address (since targetAddr is tentative for us, + // we know we are also performing DAD on it). In this + // case we let the stack know so it can handle such a + // scenario and do nothing further with the NDP NS. + if iph.SourceAddress() == header.IPv6Any { + s.DupTentativeAddrDetected(rxNICID, targetAddr) + } + + // Do not handle neighbor solicitations targeted + // to an address that is tentative on the received + // NIC any further. + return + } + + // At this point we know that targetAddr is not tentative on + // rxNICID so the packet is processed as defined in RFC 4861, + // as per RFC 4862 section 5.4.3. + + if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag - copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr) - pkt[icmpV6OptOffset] = ndpOptDstLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:]) + optsSerializer := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(targetAddr) + opts := na.Options() + opts.Serialize(optsSerializer) // ICMPv6 Neighbor Solicit messages are always sent to // specially crafted IPv6 multicast addresses. As a result, the @@ -152,9 +209,24 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + // TODO(tamird/ghanan): there exists an explicit NDP option that is + // used to update the neighbor table with link addresses for a + // neighbor from an NS (see the Source Link Layer option RFC + // 4861 section 4.6.1 and section 7.2.3). + // + // Furthermore, the entirety of NDP handling here seems to be + // contradicted by RFC 4861. + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) + + // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) + // + // 7.1.2. Validation of Neighbor Advertisements + // + // The IP Hop Limit field has a value of 255, i.e., the packet + // could not possibly have been forwarded by a router. + if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } @@ -166,10 +238,45 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) - e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) + + na := header.NDPNeighborAdvert(h.NDPPayload()) + targetAddr := na.TargetAddress() + stack := r.Stack() + rxNICID := r.NICID() + + isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr) + if err != nil { + // We will only get an error if rxNICID is unrecognized, + // which should not happen. For now short-circuit this + // packet. + // + // TODO(b/141002840): Handle this better? + return + } + + if isTentative { + // We just got an NA from a node that owns an address we + // are performing DAD on, implying the address is not + // unique. In this case we let the stack know so it can + // handle such a scenario and do nothing furthur with + // the NDP NA. + stack.DupTentativeAddrDetected(rxNICID, targetAddr) + return + } + + // At this point we know that the targetAddress is not tentative + // on rxNICID. However, targetAddr may still be assigned to + // rxNICID but not tentative (it could be permanent). Such a + // scenario is beyond the scope of RFC 4862. As such, we simply + // ignore such a scenario for now and proceed as normal. + // + // TODO(b/143147598): Handle the scenario described above. Also + // inform the netstack integration that a duplicate address was + // detected outside of DAD. + + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) } case header.ICMPv6EchoRequest: @@ -178,14 +285,13 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - - vv.TrimFront(header.ICMPv6EchoMinimumSize) + pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(packet, h) + packet.SetType(header.ICMPv6EchoReply) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + if err := r.WritePacket(nil /* gso */, hdr, pkt.Data, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } @@ -197,7 +303,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -209,8 +315,51 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.RouterSolicit.Increment() case header.ICMPv6RouterAdvert: + routerAddr := iph.SourceAddress() + + // + // Validate the RA as per RFC 4861 section 6.1.2. + // + + // Is the IP Source Address a link-local address? + if !header.IsV6LinkLocalAddress(routerAddr) { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + p := h.NDPPayload() + + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if len(p) < header.NDPRAMinimumSize { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + ra := header.NDPRouterAdvert(p) + opts := ra.Options() + + // Are options valid as per the wire format? + if _, err := opts.Iter(true); err != nil { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // + // At this point, we have a valid Router Advertisement, as far + // as RFC 4861 section 6.1.2 is concerned. + // + received.RouterAdvert.Increment() + // Tell the NIC to handle the RA. + stack := r.Stack() + rxNICID := r.NICID() + stack.HandleNDPRA(rxNICID, routerAddr, ra) + case header.ICMPv6RedirectMsg: received.RedirectMsg.Increment() @@ -262,7 +411,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: defaultIPv6HopLimit, + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 01f5a17ec..6037a1ef8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,7 +15,6 @@ package ipv6 import ( - "fmt" "reflect" "strings" "testing" @@ -31,7 +30,7 @@ import ( ) const ( - linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") ) @@ -66,7 +65,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { } type stubLinkAddressCache struct { @@ -132,7 +131,7 @@ func TestICMPCounts(t *testing.T) { {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize}, + {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize}, {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, @@ -144,11 +143,13 @@ func TestICMPCounts(t *testing.T) { ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: r.DefaultTTL(), + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, hdr.View().ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } for _, typ := range types { @@ -179,13 +180,10 @@ func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { t := v.Type() for i := 0; i < v.NumField(); i++ { v := v.Field(i) - switch v.Kind() { - case reflect.Ptr: - f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter)) - case reflect.Struct: + if s, ok := v.Interface().(*tcpip.StatCounter); ok { + f(t.Field(i).Name, s) + } else { visitStats(v, f) - default: - panic(fmt.Sprintf("unexpected type %s", v.Type())) } } } @@ -284,7 +282,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pkt.Header, pkt.Payload} size := len(pkt.Header) + len(pkt.Payload) vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) + args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + Data: vv, + }) } if pkt.Proto != ProtocolNumber { @@ -363,3 +363,537 @@ func TestLinkResolution(t *testing.T) { routeICMPv6Packet(t, args, nil) } } + +func TestICMPChecksumValidationSimple(t *testing.T) { + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + }, + { + "RouterSolicit", + header.ICMPv6RouterSolicit, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + "RouterAdvert", + header.ICMPv6RouterAdvert, + header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + "NeighborSolicit", + header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborSolicitMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + "NeighborAdvert", + header.ICMPv6NeighborAdvert, + header.ICMPv6NeighborAdvertSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + "RedirectMsg", + header.ICMPv6RedirectMsg, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayload(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + icmpSize := size + payloadSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(typ) + payloadFn(pkt.Payload()) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + + payload := buffer.NewView(payloadSize) + payloadFn(payload) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size + payloadSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 7de6a4546..4cee848a1 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -21,6 +21,8 @@ package ipv6 import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -35,23 +37,24 @@ const ( // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff - // defaultIPv6HopLimit is the default hop limit for IPv6 Packets - // egressed by Netstack. - defaultIPv6HopLimit = 255 + // DefaultTTL is the default hop limit for IPv6 Packets egressed by + // Netstack. + DefaultTTL = 64 ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher + protocol *protocol } // DefaultTTL is the default hop limit for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -62,7 +65,7 @@ func (e *endpoint) MTU() uint32 { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv6 endpoint ID. @@ -94,25 +97,35 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { - length := uint16(hdr.UsedLength() + payload.Size()) +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { + length := uint16(hdr.UsedLength() + payloadSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, - NextHeader: uint8(protocol), - HopLimit: ttl, + NextHeader: uint8(params.Protocol), + HopLimit: params.TTL, + TrafficClass: params.TOS, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { + ip := e.addIPHeader(r, &hdr, payload.Size(), params) if loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + NetworkHeader: buffer.View(ip), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -123,6 +136,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("not implemented") + } + if loop&stack.PacketOut == 0 { + return len(hdrs), nil + } + + for i := range hdrs { + hdr := &hdrs[i].Hdr + size := hdrs[i].Size + e.addIPHeader(r, hdr, size, params) + } + + n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err +} + // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { @@ -132,30 +165,36 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vector // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv6(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { return } - vv.TrimFront(header.IPv6MinimumSize) - vv.CapLength(int(h.PayloadLength())) + pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] + pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data.CapLength(int(h.PayloadLength())) p := h.TransportProtocol() if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, vv) + e.handleICMP(r, headerView, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. func (*endpoint) Close() {} -type protocol struct{} +type protocol struct { + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 +} // Number returns the ipv6 protocol number. func (p *protocol) Number() tcpip.NetworkProtocolNumber { @@ -179,25 +218,48 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, + protocol: p, }, nil } // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -212,5 +274,5 @@ func calculateMTU(mtu uint32) uint32 { // NewProtocol returns an IPv6 network protocol. func NewProtocol() stack.NetworkProtocol { - return &protocol{} + return &protocol{defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 78c674c2c..1cbfa7278 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -55,7 +55,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stats := s.Stats().ICMP.V6PacketsReceived @@ -64,7 +66,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst } } -// testReceiveICMP tests receiving a UDP packet from src to dst. want is the +// testReceiveUDP tests receiving a UDP packet from src to dst. want is the // expected UDP received count after receiving the packet. func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { t.Helper() @@ -111,7 +113,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stat := s.Stats().UDP.PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index e30791fe3..0dbce14a0 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) @@ -97,7 +98,9 @@ func TestHopLimitValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, hdr.View().ToVectorisedView()) + ep.HandlePacket(r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } types := []struct { @@ -109,7 +112,7 @@ func TestHopLimitValidation(t *testing.T) { {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }}, {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { @@ -150,7 +153,7 @@ func TestHopLimitValidation(t *testing.T) { // Receive the NDP packet with an invalid hop limit // value. - handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r) // Invalid count should have increased. if got := invalid.Value(); got != 1 { @@ -164,7 +167,7 @@ func TestHopLimitValidation(t *testing.T) { } // Receive the NDP packet with a valid hop limit value. - handleIPv6Payload(hdr, ndpHopLimit, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r) // Rx count of NDP packet of type typ.typ should have // increased. @@ -179,3 +182,191 @@ func TestHopLimitValidation(t *testing.T) { }) } } + +// TestRouterAdvertValidation tests that when the NIC is configured to handle +// NDP Router Advertisement packets, it validates the Router Advertisement +// properly before handling them. +func TestRouterAdvertValidation(t *testing.T) { + tests := []struct { + name string + src tcpip.Address + hopLimit uint8 + code uint8 + ndpPayload []byte + expectedSuccess bool + }{ + { + "OK", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + true, + }, + { + "NonLinkLocalSourceAddr", + addr1, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "HopLimitNot255", + lladdr0, + 254, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NonZeroCode", + lladdr0, + 255, + 1, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NDPPayloadTooSmall", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, + }, + false, + }, + { + "OKWithOptions", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + 2, 1, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + true, + }, + { + "OptionWithZeroLength", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + // Invalid as it has 0 length. + 2, 0, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } + + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + } + }) + } +} diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go new file mode 100644 index 000000000..10b04239d --- /dev/null +++ b/pkg/tcpip/packet_buffer.go @@ -0,0 +1,54 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. Clone() should be called in such cases so that +// modifications to the Data field do not affect other copies. +// +// +stateify savable +type PacketBuffer struct { + // Data holds the payload of the packet. For inbound packets, it also + // holds the headers, which are consumed as the packet moves up the + // stack. Headers are guaranteed not to be split across views. + // + // The bytes backing Data are immutable, but Data itself may be trimmed + // or otherwise modified. + Data buffer.VectorisedView + + // The bytes backing these views are immutable. Each field may be nil + // if either it has not been set yet or no such header exists (e.g. + // packets sent via loopback may not have a link header). + // + // These fields may be Views into other Views. SR dosen't support this, + // so deep copies are necessary in some cases. + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View +} + +// Clone makes a copy of pk. It clones the Data field, which creates a new +// VectorisedView but does not deep copy the underlying bytes. +func (pk PacketBuffer) Clone() PacketBuffer { + return PacketBuffer{ + Data: pk.Data.Clone(nil), + LinkHeader: pk.LinkHeader, + NetworkHeader: pk.NetworkHeader, + TransportHeader: pk.TransportHeader, + } +} diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go new file mode 100644 index 000000000..04c4cf136 --- /dev/null +++ b/pkg/tcpip/packet_buffer_state.go @@ -0,0 +1,26 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// beforeSave is invoked by stateify. +func (pk *PacketBuffer) beforeSave() { + // Non-Data fields may be slices of the Data field. This causes + // problems for SR, so during save we make each header independent. + pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) + pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) + pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) +} diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 315780c0c..30cea8996 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -19,6 +19,7 @@ import ( "math" "math/rand" "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -27,6 +28,10 @@ const ( // FirstEphemeral is the first ephemeral port. FirstEphemeral = 16000 + // numEphemeralPorts it the mnumber of available ephemeral ports to + // Netstack. + numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1 + anyIPAddress tcpip.Address = "" ) @@ -40,6 +45,13 @@ type portDescriptor struct { type PortManager struct { mu sync.RWMutex allocatedPorts map[portDescriptor]bindAddresses + + // hint is used to pick ports ephemeral ports in a stable order for + // a given port offset. + // + // hint must be accessed using the portHint/incPortHint helpers. + // TODO(gvisor.dev/issue/940): S/R this field. + hint uint32 } type portNode struct { @@ -47,43 +59,76 @@ type portNode struct { refs int } -// bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]portNode +// deviceNode is never empty. When it has no elements, it is removed from the +// map that references it. +type deviceNode map[tcpip.NICID]portNode -// isAvailable checks whether an IP address is available to bind to. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool { - if addr == anyIPAddress { - if len(b) == 0 { - return true - } +// isAvailable checks whether binding is possible by device. If not binding to a +// device, check against all portNodes. If binding to a specific device, check +// against the unspecified device and the provided device. +func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { + if bindToDevice == 0 { + // Trying to binding all devices. if !reuse { + // Can't bind because the (addr,port) is already bound. return false } - for _, n := range b { - if !n.reuse { + for _, p := range d { + if !p.reuse { + // Can't bind because the (addr,port) was previously bound without reuse. return false } } return true } - // If all addresses for this portDescriptor are already bound, no - // address is available. - if n, ok := b[anyIPAddress]; ok { - if !reuse { + if p, ok := d[0]; ok { + if !reuse || !p.reuse { return false } - if !n.reuse { + } + + if p, ok := d[bindToDevice]; ok { + if !reuse || !p.reuse { return false } } - if n, ok := b[addr]; ok { - if !reuse { + return true +} + +// bindAddresses is a set of IP addresses. +type bindAddresses map[tcpip.Address]deviceNode + +// isAvailable checks whether an IP address is available to bind to. If the +// address is the "any" address, check all other addresses. Otherwise, just +// check against the "any" address and the provided address. +func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { + if addr == anyIPAddress { + // If binding to the "any" address then check that there are no conflicts + // with all addresses. + for _, d := range b { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + return true + } + + // Check that there is no conflict with the "any" address. + if d, ok := b[anyIPAddress]; ok { + if !d.isAvailable(reuse, bindToDevice) { return false } - return n.reuse } + + // Check that this is no conflict with the provided address. + if d, ok := b[addr]; ok { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + return true } @@ -97,11 +142,40 @@ func NewPortManager() *PortManager { // is suitable for its needs, and stopping when a port is found or an error // occurs. func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { - count := uint16(math.MaxUint16 - FirstEphemeral + 1) - offset := uint16(rand.Int31n(int32(count))) + offset := uint32(rand.Int31n(numEphemeralPorts)) + return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) +} + +// portHint atomically reads and returns the s.hint value. +func (s *PortManager) portHint() uint32 { + return atomic.LoadUint32(&s.hint) +} + +// incPortHint atomically increments s.hint by 1. +func (s *PortManager) incPortHint() { + atomic.AddUint32(&s.hint, 1) +} - for i := uint16(0); i < count; i++ { - port = FirstEphemeral + (offset+i)%count +// PickEphemeralPortStable starts at the specified offset + s.portHint and +// iterates over all ephemeral ports, allowing the caller to decide whether a +// given port is suitable for its needs and stopping when a port is found or an +// error occurs. +func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { + p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) + if err == nil { + s.incPortHint() + } + return p, err + +} + +// pickEphemeralPort starts at the offset specified from the FirstEphemeral port +// and iterates over the number of ports specified by count and allows the +// caller to decide whether a given port is suitable for its needs, and stopping +// when a port is found or an error occurs. +func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { + for i := uint32(0); i < count; i++ { + port = uint16(FirstEphemeral + (offset+i)%count) ok, err := testPort(port) if err != nil { return 0, err @@ -116,17 +190,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse) + return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse) { + if !addrs.isAvailable(addr, reuse, bindToDevice) { return false } } @@ -138,14 +212,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse) { + if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -153,13 +227,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil + return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { return false } @@ -171,11 +245,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m = make(bindAddresses) s.allocatedPorts[desc] = m } - if n, ok := m[addr]; ok { + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + if n, ok := d[bindToDevice]; ok { n.refs++ - m[addr] = n + d[bindToDevice] = n } else { - m[addr] = portNode{reuse: reuse, refs: 1} + d[bindToDevice] = portNode{reuse: reuse, refs: 1} } } @@ -184,22 +263,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { - n, ok := m[addr] + d, ok := m[addr] + if !ok { + continue + } + n, ok := d[bindToDevice] if !ok { continue } n.refs-- + d[bindToDevice] = n if n.refs == 0 { + delete(d, bindToDevice) + } + if len(d) == 0 { delete(m, addr) - } else { - m[addr] = n } if len(m) == 0 { delete(s.allocatedPorts, desc) diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 689401661..19f4833fc 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -15,6 +15,7 @@ package ports import ( + "math/rand" "testing" "gvisor.dev/gvisor/pkg/tcpip" @@ -34,6 +35,7 @@ type portReserveTestAction struct { want *tcpip.Error reuse bool release bool + device tcpip.NICID } func TestPortReservation(t *testing.T) { @@ -100,6 +102,112 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: anyIPAddress, release: true}, {port: 24, ip: anyIPAddress, reuse: false, want: nil}, }, + }, { + tname: "bind twice with device fails", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 3, want: nil}, + {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 1, want: nil}, + {port: 24, ip: fakeIPAddress, device: 2, want: nil}, + }, + }, { + tname: "bind to device and then without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + }, + }, { + tname: "binding with reuse and device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "mixing reuse and not reuse by binding to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: nil}, + }, + }, { + tname: "can't bind to 0 after mixing reuse and not reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind and release", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + + // Release the bind to device 0 and try again. + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + }, + }, { + tname: "bind twice with reuse once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "release an unreserved device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + // The below don't exist. + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + // Release all. + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + }, }, } { t.Run(test.tname, func(t *testing.T) { @@ -108,12 +216,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) @@ -125,7 +233,6 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { - pm := NewPortManager() customErr := &tcpip.Error{} for _, test := range []struct { name string @@ -169,9 +276,63 @@ func TestPickEphemeralPort(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { + pm := NewPortManager() if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) } }) } } + +func TestPickEphemeralPortStable(t *testing.T) { + customErr := &tcpip.Error{} + for _, test := range []struct { + name string + f func(port uint16) (bool, *tcpip.Error) + wantErr *tcpip.Error + wantPort uint16 + }{ + { + name: "no-port-available", + f: func(port uint16) (bool, *tcpip.Error) { + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + { + name: "port-tester-error", + f: func(port uint16) (bool, *tcpip.Error) { + return false, customErr + }, + wantErr: customErr, + }, + { + name: "only-port-16042-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port == FirstEphemeral+42 { + return true, nil + } + return false, nil + }, + wantPort: FirstEphemeral + 42, + }, + { + name: "only-port-under-16000-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port < FirstEphemeral { + return true, nil + } + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + } { + t.Run(test.name, func(t *testing.T) { + pm := NewPortManager() + portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) + if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr { + t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + } + }) + } +} diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index 76b5f4ffa..29b7d761c 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -1,7 +1,7 @@ -package(licenses = ["notice"]) - load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_library( name = "seqnum", srcs = ["seqnum.go"], diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 28c49e8ff..460db3cf8 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "linkaddrentry_list", out = "linkaddrentry_list.go", @@ -23,6 +22,7 @@ go_library( "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", + "ndp.go", "nic.go", "registration.go", "route.go", @@ -36,6 +36,7 @@ go_library( ], deps = [ "//pkg/ilist", + "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", @@ -53,18 +54,26 @@ go_test( name = "stack_x_test", size = "small", srcs = [ + "ndp_test.go", "stack_test.go", + "transport_demuxer_test.go", "transport_test.go", ], deps = [ ":stack", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", "//pkg/waiter", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go new file mode 100644 index 000000000..8e49f7a56 --- /dev/null +++ b/pkg/tcpip/stack/ndp.go @@ -0,0 +1,534 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "log" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + // defaultDupAddrDetectTransmits is the default number of NDP Neighbor + // Solicitation messages to send when doing Duplicate Address Detection + // for a tentative address. + // + // Default = 1 (from RFC 4862 section 5.1) + defaultDupAddrDetectTransmits = 1 + + // defaultRetransmitTimer is the default amount of time to wait between + // sending NDP Neighbor solicitation messages. + // + // Default = 1s (from RFC 4861 section 10). + defaultRetransmitTimer = time.Second + + // defaultHandleRAs is the default configuration for whether or not to + // handle incoming Router Advertisements as a host. + // + // Default = true. + defaultHandleRAs = true + + // defaultDiscoverDefaultRouters is the default configuration for + // whether or not to discover default routers from incoming Router + // Advertisements as a host. + // + // Default = true. + defaultDiscoverDefaultRouters = true + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending NDP Neighbor solicitation messages. Note, RFC 4861 does + // not impose a minimum Retransmit Timer, but we do here to make sure + // the messages are not sent all at once. We also come to this value + // because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. + // Note, the unit of the RetransmitTimer field in the Router + // Advertisement is milliseconds. + // + // Min = 1ms. + minimumRetransmitTimer = time.Millisecond + + // MaxDiscoveredDefaultRouters is the maximum number of discovered + // default routers. The stack should stop discovering new routers after + // discovering MaxDiscoveredDefaultRouters routers. + // + // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and + // SHOULD be more. + // + // Max = 10. + MaxDiscoveredDefaultRouters = 10 +) + +// NDPDispatcher is the interface integrators of netstack must implement to +// receive and handle NDP related events. +type NDPDispatcher interface { + // OnDuplicateAddressDetectionStatus will be called when the DAD process + // for an address (addr) on a NIC (with ID nicID) completes. resolved + // will be set to true if DAD completed successfully (no duplicate addr + // detected); false otherwise (addr was detected to be a duplicate on + // the link the NIC is a part of, or it was stopped for some other + // reason, such as the address being removed). If an error occured + // during DAD, err will be set and resolved must be ignored. + // + // This function is permitted to block indefinitely without interfering + // with the stack's operation. + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + + // OnDefaultRouterDiscovered will be called when a new default router is + // discovered. Implementations must return true along with a new valid + // route table if the newly discovered router should be remembered. If + // an implementation returns false, the second return value will be + // ignored. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) + + // OnDefaultRouterInvalidated will be called when a discovered default + // router is invalidated. Implementers must return a new valid route + // table. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route +} + +// NDPConfigurations is the NDP configurations for the netstack. +type NDPConfigurations struct { + // The number of Neighbor Solicitation messages to send when doing + // Duplicate Address Detection for a tentative address. + // + // Note, a value of zero effectively disables DAD. + DupAddrDetectTransmits uint8 + + // The amount of time to wait between sending Neighbor solicitation + // messages. + // + // Must be greater than 0.5s. + RetransmitTimer time.Duration + + // HandleRAs determines whether or not Router Advertisements will be + // processed. + HandleRAs bool + + // DiscoverDefaultRouters determines whether or not default routers will + // be discovered from Router Advertisements. This configuration is + // ignored if HandleRAs is false. + DiscoverDefaultRouters bool +} + +// DefaultNDPConfigurations returns an NDPConfigurations populated with +// default values. +func DefaultNDPConfigurations() NDPConfigurations { + return NDPConfigurations{ + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + } +} + +// validate modifies an NDPConfigurations with valid values. If invalid values +// are present in c, the corresponding default values will be used instead. +// +// If RetransmitTimer is less than minimumRetransmitTimer, then a value of +// defaultRetransmitTimer will be used. +func (c *NDPConfigurations) validate() { + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } +} + +// ndpState is the per-interface NDP state. +type ndpState struct { + // The NIC this ndpState is for. + nic *NIC + + // configs is the per-interface NDP configurations. + configs NDPConfigurations + + // The DAD state to send the next NS message, or resolve the address. + dad map[tcpip.Address]dadState + + // The default routers discovered through Router Advertisements. + defaultRouters map[tcpip.Address]defaultRouterState +} + +// dadState holds the Duplicate Address Detection timer and channel to signal +// to the DAD goroutine that DAD should stop. +type dadState struct { + // The DAD timer to send the next NS message, or resolve the address. + timer *time.Timer + + // Used to let the DAD timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the NIC this dadState is associated with. + done *bool +} + +// defaultRouterState holds data associated with a default router discovered by +// a Router Advertisement. +type defaultRouterState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the default router (R) in + // a race condition (T1 is a goroutine that handles an RA from R and T2 + // is the goroutine that handles R's invalidation timer firing): + // T1: Receive a new RA from R + // T1: Obtain the NIC's lock before processing the RA + // T2: R's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends R's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates R immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // signal the timer using this channel to not invalidate R, so that once + // T2 obtains the lock, it will see that there is an event on this + // channel and do nothing further. + doNotInvalidateC chan struct{} +} + +// startDuplicateAddressDetection performs Duplicate Address Detection. +// +// This function must only be called by IPv6 addresses that are currently +// tentative. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { + // addr must be a valid unicast IPv6 address. + if !header.IsV6UnicastAddress(addr) { + return tcpip.ErrAddressFamilyNotSupported + } + + // Should not attempt to perform DAD on an address that is currently in + // the DAD process. + if _, ok := ndp.dad[addr]; ok { + // Should never happen because we should only ever call this + // function for newly created addresses. If we attemped to + // "add" an address that already existed, we would returned an + // error since we attempted to add a duplicate address, or its + // reference count would have been increased without doing the + // work that would have been done for an address that was brand + // new. See NIC.addPermanentAddressLocked. + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) + } + + remaining := ndp.configs.DupAddrDetectTransmits + + { + done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref) + if err != nil { + return err + } + if done { + return nil + } + } + + remaining-- + + var done bool + var timer *time.Timer + timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() { + var d bool + var err *tcpip.Error + + // doDadIteration does a single iteration of the DAD loop. + // + // Returns true if the integrator needs to be informed of DAD + // completing. + doDadIteration := func() bool { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if done { + // If we reach this point, it means that the DAD + // timer fired after another goroutine already + // obtained the NIC lock and stopped DAD before + // this function obtained the NIC lock. Simply + // return here and do nothing further. + return false + } + + ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}] + if !ok { + // This should never happen. + // We should have an endpoint for addr since we + // are still performing DAD on it. If the + // endpoint does not exist, but we are doing DAD + // on it, then we started DAD at some point, but + // forgot to stop it when the endpoint was + // deleted. + panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID())) + } + + d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref) + if err != nil || d { + delete(ndp.dad, addr) + + if err != nil { + log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) + } + + // Let the integrator know DAD has completed. + return true + } + + remaining-- + timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) + return false + } + + if doDadIteration() && ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err) + } + }) + + ndp.dad[addr] = dadState{ + timer: timer, + done: &done, + } + + return nil +} + +// doDuplicateAddressDetection is called on every iteration of the timer, and +// when DAD starts. +// +// It handles resolving the address (if there are no more NS to send), or +// sending the next NS if there are more NS to send. +// +// This function must only be called by IPv6 addresses that are currently +// tentative. +// +// The NIC that ndp belongs to (n) MUST be locked. +// +// Returns true if DAD has resolved; false if DAD is still ongoing. +func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { + if ref.getKind() != permanentTentative { + // The endpoint should still be marked as tentative + // since we are still performing DAD on it. + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) + } + + if remaining == 0 { + // DAD has resolved. + ref.setKind(permanent) + return true, nil + } + + // Send a new NS. + snmc := header.SolicitedNodeAddr(addr) + snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}] + if !ok { + // This should never happen as if we have the + // address, we should have the solicited-node + // address. + panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) + } + + // Use the unspecified address as the source address when performing + // DAD. + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) + + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(addr) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + sent := r.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}); err != nil { + sent.Dropped.Increment() + return false, err + } + sent.NeighborSolicit.Increment() + + return false, nil +} + +// stopDuplicateAddressDetection ends a running Duplicate Address Detection +// process. Note, this may leave the DAD process for a tentative address in +// such a state forever, unless some other external event resolves the DAD +// process (receiving an NA from the true owner of addr, or an NS for addr +// (implying another node is attempting to use addr)). It is up to the caller +// of this function to handle such a scenario. Normally, addr will be removed +// from n right after this function returns or the address successfully +// resolved. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { + dad, ok := ndp.dad[addr] + if !ok { + // Not currently performing DAD on addr, just return. + return + } + + if dad.timer != nil { + dad.timer.Stop() + dad.timer = nil + + *dad.done = true + dad.done = nil + } + + delete(ndp.dad, addr) + + // Let the integrator know DAD did not resolve. + if ndp.nic.stack.ndpDisp != nil { + go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + } +} + +// handleRA handles a Router Advertisement message that arrived on the NIC +// this ndp is for. Does nothing if the NIC is configured to not handle RAs. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + // Is the NIC configured to handle RAs at all? + // + // Currently, the stack does not determine router interface status on a + // per-interface basis; it is a stack-wide configuration, so we check + // stack's forwarding flag to determine if the NIC is a routing + // interface. + if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + return + } + + // Is the NIC configured to discover default routers? + if ndp.configs.DiscoverDefaultRouters { + rtr, ok := ndp.defaultRouters[ip] + rl := ra.RouterLifetime() + switch { + case !ok && rl != 0: + // This is a new default router we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredDefaultRouters routers. + if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { + ndp.rememberDefaultRouter(ip, rl) + } + + case ok && rl != 0: + // This is an already discovered default router. Update + // the invalidation timer. + timer := rtr.invalidationTimer + + // We should ALWAYS have an invalidation timer for a + // discovered router. + if timer == nil { + panic("ndphandlera: RA invalidation timer should not be nil") + } + + if !timer.Stop() { + // If we reach this point, then we know the + // timer fired after we already took the NIC + // lock. Signal the timer so that once it + // obtains the lock, it doesn't actually + // invalidate the router as we just got a new + // RA that refreshes its lifetime to a non-zero + // value. See + // defaultRouterState.doNotInvalidateC for more + // details. + rtr.doNotInvalidateC <- struct{}{} + } + + timer.Reset(rl) + + case ok && rl == 0: + // We know about the router but it is no longer to be + // used as a default router so invalidate it. + ndp.invalidateDefaultRouter(ip) + } + } + + // TODO(b/140948104): Do Prefix Discovery. + // TODO(b/141556115): Do Parameter Discovery. +} + +// invalidateDefaultRouter invalidates a discovered default router. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { + rtr, ok := ndp.defaultRouters[ip] + + // Is the router still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + rtr.invalidationTimer.Stop() + rtr.invalidationTimer = nil + close(rtr.doNotInvalidateC) + rtr.doNotInvalidateC = nil + + delete(ndp.defaultRouters, ip) + + // Let the integrator know a discovered default router is invalidated. + if ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + } +} + +// rememberDefaultRouter remembers a newly discovered default router with IPv6 +// link-local address ip with lifetime rl. +// +// The router identified by ip MUST NOT already be known by the NIC. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { + if ndp.nic.stack.ndpDisp == nil { + return + } + + // Inform the integrator when we discovered a default router. + remember, routeTable := ndp.nic.stack.ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) + if !remember { + // Informed by the integrator to not remember the router, do + // nothing further. + return + } + + // Used to signal the timer not to invalidate the default router (R) in + // a race condition. See defaultRouterState.doNotInvalidateC for more + // details. + doNotInvalidateC := make(chan struct{}, 1) + + ndp.defaultRouters[ip] = defaultRouterState{ + invalidationTimer: time.AfterFunc(rl, func() { + ndp.nic.stack.mu.Lock() + defer ndp.nic.stack.mu.Unlock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + select { + case <-doNotInvalidateC: + return + default: + } + + ndp.invalidateDefaultRouter(ip) + }), + doNotInvalidateC: doNotInvalidateC, + } + + ndp.nic.stack.routeTable = routeTable +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go new file mode 100644 index 000000000..50ce1bbfa --- /dev/null +++ b/pkg/tcpip/stack/ndp_test.go @@ -0,0 +1,1013 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "encoding/binary" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + linkAddr1 = "\x02\x02\x03\x04\x05\x06" + linkAddr2 = "\x02\x02\x03\x04\x05\x07" + linkAddr3 = "\x02\x02\x03\x04\x05\x08" + defaultTimeout = 250 * time.Millisecond +) + +var ( + llAddr1 = header.LinkLocalAddr(linkAddr1) + llAddr2 = header.LinkLocalAddr(linkAddr2) + llAddr3 = header.LinkLocalAddr(linkAddr3) +) + +// TestDADDisabled tests that an address successfully resolves immediately +// when DAD is not enabled (the default for an empty stack.Options). +func TestDADDisabled(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + } + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + + // Should get the address immediately since we should not have performed + // DAD on it. + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) + } + + // We should not have sent any NDP NS messages. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + t.Fatalf("got NeighborSolicit = %d, want = 0", got) + } +} + +// ndpDADEvent is a set of parameters that was passed to +// ndpDispatcher.OnDuplicateAddressDetectionStatus. +type ndpDADEvent struct { + nicID tcpip.NICID + addr tcpip.Address + resolved bool + err *tcpip.Error +} + +type ndpRouterEvent struct { + nicID tcpip.NICID + addr tcpip.Address + // true if router was discovered, false if invalidated. + discovered bool +} + +var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) + +// ndpDispatcher implements NDPDispatcher so tests can know when various NDP +// related events happen for test purposes. +type ndpDispatcher struct { + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + routeTable []tcpip.Route +} + +// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { + if n.dadC != nil { + n.dadC <- ndpDADEvent{ + nicID, + addr, + resolved, + err, + } + } +} + +// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) { + if n.routerC != nil { + n.routerC <- ndpRouterEvent{ + nicID, + addr, + true, + } + } + + if !n.rememberRouter { + return false, nil + } + + rt := append([]tcpip.Route(nil), n.routeTable...) + rt = append(rt, tcpip.Route{ + Destination: header.IPv6EmptySubnet, + Gateway: addr, + NIC: nicID, + }) + n.routeTable = rt + return true, rt +} + +// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route { + if n.routerC != nil { + n.routerC <- ndpRouterEvent{ + nicID, + addr, + false, + } + } + + var rt []tcpip.Route + exclude := tcpip.Route{ + Destination: header.IPv6EmptySubnet, + Gateway: addr, + NIC: nicID, + } + + for _, r := range n.routeTable { + if r != exclude { + rt = append(rt, r) + } + } + n.routeTable = rt + return rt +} + +// TestDADResolve tests that an address successfully resolves after performing +// DAD for various values of DupAddrDetectTransmits and RetransmitTimer. +// Included in the subtests is a test to make sure that an invalid +// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. +func TestDADResolve(t *testing.T) { + tests := []struct { + name string + dupAddrDetectTransmits uint8 + retransTimer time.Duration + expectedRetransmitTimer time.Duration + }{ + {"1:1s:1s", 1, time.Second, time.Second}, + {"2:1s:1s", 2, time.Second, time.Second}, + {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second}, + // 0s is an invalid RetransmitTimer timer and will be fixed to + // the default RetransmitTimer value of 1s. + {"1:0s:1s", 1, 0, time.Second}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + } + opts.NDPConfigs.RetransmitTimer = test.retransTimer + opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + + stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit + + // Should have sent an NDP NS immediately. + if got := stat.Value(); got != 1 { + t.Fatalf("got NeighborSolicit = %d, want = 1", got) + + } + + // Address should not be considered bound to the NIC yet + // (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Wait for the remaining time - some delta (500ms), to + // make sure the address is still not resolved. + const delta = 500 * time.Millisecond + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) + } + + // Should not have sent any more NS messages. + if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) { + t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) + } + + // Validate the sent Neighbor Solicitation messages. + for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { + p := <-e.C + + // Make sure its an IPv6 packet. + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + + // Check NDP packet. + checker.IPv6(t, p.Header.ToVectorisedView().First(), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(addr1))) + } + }) + } + +} + +// TestDADFail tests to make sure that the DAD process fails if another node is +// detected to be performing DAD on the same address (receive an NS message from +// a node doing DAD for the same address), or if another node is detected to own +// the address already (receive an NA message for the tentative address). +func TestDADFail(t *testing.T) { + tests := []struct { + name string + makeBuf func(tgt tcpip.Address) buffer.Prependable + getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + "RxSolicit", + func(tgt tcpip.Address) buffer.Prependable { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(tgt) + snmc := header.SolicitedNodeAddr(tgt) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, + }) + + return hdr + + }, + func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return s.NeighborSolicit + }, + }, + { + "RxAdvert", + func(tgt tcpip.Address) buffer.Prependable { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(tgt) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: 255, + SrcAddr: tgt, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + return hdr + + }, + func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return s.NeighborAdvert + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + } + opts.NDPConfigs.RetransmitTimer = time.Second * 2 + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + + // Address should not be considered bound to the NIC yet + // (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Receive a packet to simulate multiple nodes owning or + // attempting to own the same address. + hdr := test.makeBuf(addr1) + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) + if got := stat.Value(); got != 1 { + t.Fatalf("got stat = %d, want = 1", got) + } + + // Wait for DAD to fail and make sure the address did + // not get resolved. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + }) + } +} + +// TestDADStop tests to make sure that the DAD process stops when an address is +// removed. +func TestDADStop(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + } + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + + // Address should not be considered bound to the NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Remove the address. This should stop DAD. + if err := s.RemoveAddress(1, addr1); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err) + } + + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if e.resolved { + t.Fatal("got DAD event w/ resolved = true, want = false") + } + + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Should not have sent more than 1 NS message. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + t.Fatalf("got NeighborSolicit = %d, want <= 1", got) + } +} + +// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NDP configurations using an invalid NICID. +func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + + // No NIC with ID 1 yet. + if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) + } +} + +// TestSetNDPConfigurations tests that we can update and use per-interface NDP +// configurations without affecting the default NDP configurations or other +// interfaces' configurations. +func TestSetNDPConfigurations(t *testing.T) { + tests := []struct { + name string + dupAddrDetectTransmits uint8 + retransmitTimer time.Duration + expectedRetransmitTimer time.Duration + }{ + { + "OK", + 1, + time.Second, + time.Second, + }, + { + "Invalid Retransmit Timer", + 1, + 0, + time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + }) + + // This NIC(1)'s NDP configurations will be updated to + // be different from the default. + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Created before updating NIC(1)'s NDP configurations + // but updating NIC(1)'s NDP configurations should not + // affect other existing NICs. + if err := s.CreateNIC(2, e); err != nil { + t.Fatalf("CreateNIC(2) = %s", err) + } + + // Update the NDP configurations on NIC(1) to use DAD. + configs := stack.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + RetransmitTimer: test.retransmitTimer, + } + if err := s.SetNDPConfigurations(1, configs); err != nil { + t.Fatalf("got SetNDPConfigurations(1, _) = %s", err) + } + + // Created after updating NIC(1)'s NDP configurations + // but the stack's default NDP configurations should not + // have been updated. + if err := s.CreateNIC(3, e); err != nil { + t.Fatalf("CreateNIC(3) = %s", err) + } + + // Add addresses for each NIC. + if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + } + if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err) + } + if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err) + } + + // Address should not be considered bound to NIC(1) yet + // (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Should get the address on NIC(2) and NIC(3) + // immediately since we should not have performed DAD on + // it as the stack was configured to not do DAD by + // default and we only updated the NDP configurations on + // NIC(1). + addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err) + } + if addr.Address != addr2 { + t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2) + } + addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err) + } + if addr.Address != addr3 { + t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3) + } + + // Sleep until right (500ms before) before resolution to + // make sure the address didn't resolve on NIC(1) yet. + const delta = 500 * time.Millisecond + time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(2 * delta): + // We should get a resolution event after 500ms + // (delta) since we wait for 500ms less than the + // expected resolution time above to make sure + // that the address did not yet resolve. Waiting + // for 1s (2x delta) without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != addr1 { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1) + } + }) + } +} + +// raBuf returns a valid NDP Router Advertisement. +// +// Note, raBuf does not populate any of the RA fields other than the +// Router Lifetime. +func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(0) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + iph.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} +} + +// TestNoRouterDiscovery tests that router discovery will not be performed if +// configured not to. +func TestNoRouterDiscovery(t *testing.T) { + // Being configured to discover routers means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // router discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router when configured not to") + case <-time.After(defaultTimeout): + } + }) + } +} + +// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered router when the dispatcher asks it not to. +func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + routeTable := []tcpip.Route{ + { + header.IPv6EmptySubnet, + llAddr3, + 1, + }, + } + s.SetRouteTable(routeTable) + + // Rx an RA with short lifetime. + lifetime := time.Duration(1) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(lifetime))) + select { + case r := <-ndpDisp.routerC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != llAddr2 { + t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr2) + } + if !r.discovered { + t.Fatal("got r.discovered = false, want = true") + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for router discovery event") + } + + // Original route table should not have been modified. + if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + } + + // Wait for the normal invalidation time plus an extra second to + // make sure we do not actually receive any invalidation events as + // we should not have remembered the router in the first place. + select { + case <-ndpDisp.routerC: + t.Fatal("should not have received any router events") + case <-time.After(lifetime*time.Second + defaultTimeout): + } + + // Original route table should not have been modified. + if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + } +} + +func TestRouterDiscovery(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 10), + rememberRouter: true, + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + waitForEvent := func(addr tcpip.Address, discovered bool, timeout time.Duration) { + t.Helper() + + select { + case r := <-ndpDisp.routerC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.discovered != discovered { + t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) + } + case <-time.After(timeout): + t.Fatal("timeout waiting for router discovery event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + case <-time.After(defaultTimeout): + } + + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + waitForEvent(llAddr2, true, defaultTimeout) + + // Should have a default route through the discovered router. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Rx an RA from another router (lladdr3) with non-zero lifetime. + l3Lifetime := time.Duration(6) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) + waitForEvent(llAddr3, true, defaultTimeout) + + // Should have default routes through the discovered routers. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Rx an RA from lladdr2 with lesser lifetime. + l2Lifetime := time.Duration(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + case <-time.After(defaultTimeout): + } + + // Should still have a default route through the discovered routers. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Wait for lladdr2's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + waitForEvent(llAddr2, false, l2Lifetime*time.Second+defaultTimeout) + + // Should no longer have the default route through lladdr2. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + waitForEvent(llAddr2, true, defaultTimeout) + + // Should have a default route through the discovered routers. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + waitForEvent(llAddr2, false, defaultTimeout) + + // Should have deleted the default route through the router that just + // got invalidated. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Wait for lladdr3's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + waitForEvent(llAddr3, false, l3Lifetime*time.Second+defaultTimeout) + + // Should not have any routes now that all discovered routers have been + // invalidated. + if got := len(s.GetRouteTable()); got != 0 { + t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) + } +} + +// TestRouterDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +func TestRouterDiscoveryMaxRouters(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 10), + rememberRouter: true, + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectedRt := [stack.MaxDiscoveredDefaultRouters]tcpip.Route{} + + // Receive an RA from 2 more than the max number of discovered routers. + for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + linkAddr := []byte{2, 2, 3, 4, 5, 0} + linkAddr[5] = byte(i) + llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) + + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + + if i <= stack.MaxDiscoveredDefaultRouters { + expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1} + select { + case r := <-ndpDisp.routerC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != llAddr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr) + } + if !r.discovered { + t.Fatal("got r.discovered = false, want = true") + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for router discovery event") + } + + } else { + select { + case <-ndpDisp.routerC: + t.Fatal("should not have discovered a new router after we already discovered the max number of routers") + case <-time.After(defaultTimeout): + } + } + } + + // Should only have default routes for the first + // stack.MaxDiscoveredDefaultRouters discovered routers. + if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0e8a23f00..28a28ae6e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -19,7 +19,6 @@ import ( "sync" "sync/atomic" - "gvisor.dev/gvisor/pkg/ilist" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -34,17 +33,24 @@ type NIC struct { linkEP LinkEndpoint loopback bool - demux *transportDemuxer - mu sync.RWMutex spoofing bool promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List + primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint endpoints map[NetworkEndpointID]*referencedNetworkEndpoint addressRanges []tcpip.Subnet mcastJoins map[NetworkEndpointID]int32 + // packetEPs is protected by mu, but the contained PacketEndpoint + // values are not. + packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint stats NICStats + + // ndp is the NDP related state for NIC. + // + // Note, read and write operations on ndp require that the NIC is + // appropriately locked. + ndp ndpState } // NICStats includes transmitted and received stats. @@ -78,17 +84,26 @@ const ( NeverPrimaryEndpoint ) +// newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { - return &NIC{ + // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For + // example, make sure that the link address it provides is a valid + // unicast ethernet address. + + // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints + // observe an MTU of at least 1280 bytes. Ensure that this requirement + // of IPv6 is supported on this endpoint's LinkEndpoint. + + nic := &NIC{ stack: stack, id: id, name: name, linkEP: ep, loopback: loopback, - demux: newTransportDemuxer(stack), - primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), + primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), + packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint), stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -99,7 +114,23 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback Bytes: &tcpip.StatCounter{}, }, }, + ndp: ndpState{ + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + }, + } + nic.ndp.nic = nic + + // Register supported packet endpoint protocols. + for _, netProto := range header.Ethertypes { + nic.packetEPs[netProto] = []PacketEndpoint{} } + for _, netProto := range stack.networkProtocols { + nic.packetEPs[netProto.Number()] = []PacketEndpoint{} + } + + return nic } // enable enables the NIC. enable will attach the link to its LinkEndpoint and @@ -107,6 +138,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback func (n *NIC) enable() *tcpip.Error { n.attachLinkEndpoint() + // Create an endpoint to receive broadcast packets on this interface. + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + if err := n.AddAddress(tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, + }, NeverPrimaryEndpoint); err != nil { + return err + } + } + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -114,11 +155,50 @@ func (n *NIC) enable() *tcpip.Error { // when we perform Duplicate Address Detection, or Router Advertisement // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 // section 4.2 for more information. - if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { - return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + // + // Also auto-generate an IPv6 link-local address based on the NIC's + // link address if it is configured to do so. Note, each interface is + // required to have IPv6 link-local unicast address, as per RFC 4291 + // section 2.1. + _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber] + if !ok { + return nil } - return nil + n.mu.Lock() + defer n.mu.Unlock() + + if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + + if !n.stack.autoGenIPv6LinkLocal { + return nil + } + + l2addr := n.linkEP.LinkAddress() + + // Only attempt to generate the link-local address if we have a + // valid MAC address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address + // (provided by LinkEndpoint.LinkAddress) before reaching this + // point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } + + addr := header.LinkLocalAddr(l2addr) + + _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + }, + }, CanBePrimaryEndpoint) + + return err } // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it @@ -154,18 +234,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN n.mu.RLock() defer n.mu.RUnlock() - list := n.primary[protocol] - if list == nil { - return nil - } - - for e := list.Front(); e != nil; e = e.Next() { - r := e.(*referencedNetworkEndpoint) - // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set. - switch r.ep.ID().LocalAddress { - case header.IPv4Broadcast, header.IPv4Any: - continue - } + for _, r := range n.primary[protocol] { if r.isValidForOutgoing() && r.tryIncRef() { return r } @@ -275,13 +344,35 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if ref, ok := n.endpoints[id]; ok { switch ref.getKind() { - case permanent: + case permanentTentative, permanent: // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress case permanentExpired, temporary: - // Promote the endpoint to become permanent. + // Promote the endpoint to become permanent and respect + // the new peb. if ref.tryIncRef() { ref.setKind(permanent) + + refs := n.primary[ref.protocol] + for i, r := range refs { + if r == ref { + switch peb { + case CanBePrimaryEndpoint: + return ref, nil + case FirstPrimaryEndpoint: + if i == 0 { + return ref, nil + } + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + case NeverPrimaryEndpoint: + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + return ref, nil + } + } + } + + n.insertPrimaryEndpointLocked(ref, peb) + return ref, nil } // tryIncRef failing means the endpoint is scheduled to be removed once @@ -291,6 +382,7 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p n.removeEndpointLocked(ref) } } + return n.addAddressLocked(protocolAddress, peb, permanent) } @@ -314,6 +406,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if err != nil { return nil, err } + + isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) + + // If the address is an IPv6 address and it is a permanent address, + // mark it as tentative so it goes through the DAD process. + if isIPv6Unicast && kind == permanent { + kind = permanentTentative + } + ref := &referencedNetworkEndpoint{ refs: 1, ep: ep, @@ -331,7 +432,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar // If we are adding an IPv6 unicast address, join the solicited-node // multicast address. - if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) { + if isIPv6Unicast { snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { return nil, err @@ -340,17 +441,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar n.endpoints[id] = ref - l, ok := n.primary[protocolAddress.Protocol] - if !ok { - l = &ilist.List{} - n.primary[protocolAddress.Protocol] = l - } + n.insertPrimaryEndpointLocked(ref, peb) - switch peb { - case CanBePrimaryEndpoint: - l.PushBack(ref) - case FirstPrimaryEndpoint: - l.PushFront(ref) + // If we are adding a tentative IPv6 address, start DAD. + if isIPv6Unicast && kind == permanentTentative { + if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { + return nil, err + } } return ref, nil @@ -375,10 +472,12 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { - // Don't include expired or tempory endpoints to avoid confusion and - // prevent the caller from using those. + // Don't include tentative, expired or temporary endpoints to + // avoid confusion and prevent the caller from using those. switch ref.getKind() { - case permanentExpired, temporary: + case permanentTentative, permanentExpired, temporary: + // TODO(b/140898488): Should tentative addresses be + // returned? continue } addrs = append(addrs, tcpip.ProtocolAddress{ @@ -399,12 +498,12 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for proto, list := range n.primary { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) - // Don't include expired or tempory endpoints to avoid confusion and - // prevent the caller from using those. + for _, ref := range list { + // Don't include tentative, expired or tempory endpoints + // to avoid confusion and prevent the caller from using + // those. switch ref.getKind() { - case permanentExpired, temporary: + case permanentTentative, permanentExpired, temporary: continue } @@ -464,6 +563,19 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { return append(sns, n.addressRanges...) } +// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required +// by peb. +// +// n MUST be locked. +func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { + switch peb { + case CanBePrimaryEndpoint: + n.primary[r.protocol] = append(n.primary[r.protocol], r) + case FirstPrimaryEndpoint: + n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...) + } +} + func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() @@ -481,9 +593,12 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { } delete(n.endpoints, id) - wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front() - if wasInList { - n.primary[r.protocol].Remove(r) + refs := n.primary[r.protocol] + for i, ref := range refs { + if ref == r { + n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + break + } } r.ep.Close() @@ -497,10 +612,22 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { r, ok := n.endpoints[NetworkEndpointID{addr}] - if !ok || r.getKind() != permanent { + if !ok { + return tcpip.ErrBadLocalAddress + } + + kind := r.getKind() + if kind != permanent && kind != permanentTentative { return tcpip.ErrBadLocalAddress } + isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) + + // If we are removing a tentative IPv6 unicast address, stop DAD. + if isIPv6Unicast && kind == permanentTentative { + n.ndp.stopDuplicateAddressDetection(addr) + } + r.setKind(permanentExpired) if !r.decRefLocked() { // The endpoint still has references to it. @@ -511,7 +638,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // If we are removing an IPv6 unicast address, leave the solicited-node // multicast address. - if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) { + if isIPv6Unicast { snmc := header.SolicitedNodeAddr(addr) if err := n.leaveGroupLocked(snmc); err != nil { return err @@ -541,6 +668,11 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address // exists yet. Otherwise it just increments its count. n MUST be locked before // joinGroupLocked is called. func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + // TODO(b/143102137): When implementing MLD, make sure MLD packets are + // not sent unless a valid link-local address is available for use on n + // as an MLD packet's source address must be a link-local address as + // outlined in RFC 3810 section 5. + id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -591,10 +723,10 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return nil } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } @@ -604,9 +736,9 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) netProto, ok := n.stack.networkProtocols[protocol] if !ok { @@ -614,36 +746,39 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr return } + // If no local link layer address is provided, assume it was sent + // directly to this NIC. + if local == "" { + local = n.linkEP.LinkAddress() + } + + // Are any packet sockets listening for this network protocol? + n.mu.RLock() + packetEPs := n.packetEPs[protocol] + // Check whether there are packet sockets listening for every protocol. + // If we received a packet with protocol EthernetProtocolAll, then the + // previous for loop will have handled it. + if protocol != header.EthernetProtocolAll { + packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...) + } + n.mu.RUnlock() + for _, ep := range packetEPs { + ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + } + if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { n.stack.stats.IP.PacketsReceived.Increment() } - if len(vv.First()) < netProto.MinimumPacketSize() { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(vv.First()) - - n.stack.AddLinkAddress(n.id, src, remote) - - // If the packet is destined to the IPv4 Broadcast address, then make a - // route to each IPv4 network endpoint and let each endpoint handle the - // packet. - if dst == header.IPv4Broadcast { - // n.endpoints is mutex protected so acquire lock. - n.mu.RLock() - for _, ref := range n.endpoints { - if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) - } - } - n.mu.RUnlock() - return - } + src, dst := netProto.ParseAddresses(pkt.Data.First()) if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return } @@ -671,31 +806,34 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(vv.First()) - vv.RemoveFirst() + hdr := buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Data.RemoveFirst() // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { + if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, pkt.Data, protocol); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + pkt.Data.Size())) } } return } - n.stack.stats.IP.InvalidAddressesReceived.Increment() + // If a packet socket handled the packet, don't treat it as invalid. + if len(packetEPs) == 0 { + n.stack.stats.IP.InvalidAddressesReceived.Increment() + } } // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -707,46 +845,41 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) { - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) - } + n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(vv.First()) < transProto.MinimumPacketSize() { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { - return - } - if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, pkt, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, netHeader, vv) { + if state.defaultHandler(r, id, pkt) { return } } // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { n.stack.stats.MalformedRcvdPackets.Increment() } } // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -757,20 +890,17 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(vv.First()) < 8 { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { - return - } - if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { return } } @@ -785,14 +915,78 @@ func (n *NIC) Stack() *Stack { return n.stack } +// isAddrTentative returns true if addr is tentative on n. +// +// Note that if addr is not associated with n, then this function will return +// false. It will only return true if the address is associated with the NIC +// AND it is tentative. +func (n *NIC) isAddrTentative(addr tcpip.Address) bool { + ref, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok { + return false + } + + return ref.getKind() == permanentTentative +} + +// dupTentativeAddrDetected attempts to inform n that a tentative addr +// is a duplicate on a link. +// +// dupTentativeAddrDetected will delete the tentative address if it exists. +func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + ref, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok { + return tcpip.ErrBadAddress + } + + if ref.getKind() != permanentTentative { + return tcpip.ErrInvalidEndpointState + } + + return n.removePermanentAddressLocked(addr) +} + +// setNDPConfigs sets the NDP configurations for n. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNDPConfigs(c NDPConfigurations) { + c.validate() + + n.mu.Lock() + n.ndp.configs = c + n.mu.Unlock() +} + +// handleNDPRA handles an NDP Router Advertisement message that arrived on n. +func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.handleRA(ip, ra) +} + type networkEndpointKind int32 const ( + // A permanentTentative endpoint is a permanent address that is not yet + // considered to be fully bound to an interface in the traditional + // sense. That is, the address is associated with a NIC, but packets + // destined to the address MUST NOT be accepted and MUST be silently + // dropped, and the address MUST NOT be used as a source address for + // outgoing packets. For IPv6, addresses will be of this kind until + // NDP's Duplicate Address Detection has resolved, or be deleted if + // the process results in detecting a duplicate address. + permanentTentative networkEndpointKind = iota + // A permanent endpoint is created by adding a permanent address (vs. a // temporary one) to the NIC. Its reference count is biased by 1 to avoid // removal when no route holds a reference to it. It is removed by explicitly // removing the permanent address from the NIC. - permanent networkEndpointKind = iota + permanent // An expired permanent endoint is a permanent endoint that had its address // removed from the NIC, and it is waiting to be removed once no more routes @@ -810,8 +1004,37 @@ const ( temporary ) +func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return tcpip.ErrNotSupported + } + n.packetEPs[netProto] = append(eps, ep) + + return nil +} + +func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + n.mu.Lock() + defer n.mu.Unlock() + + eps, ok := n.packetEPs[netProto] + if !ok { + return + } + + for i, epOther := range eps { + if epOther == ep { + n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) + return + } + } +} + type referencedNetworkEndpoint struct { - ilist.Entry ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 80101d4bb..c0026f5a3 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -60,24 +60,64 @@ const ( // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { + // UniqueID returns an unique ID for this transport endpoint. + UniqueID() uint64 + // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + // this transport endpoint. It sets pkt.TransportHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g., + // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + // HandleControlPacket takes ownership of pkt. + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + + // Close puts the endpoint in a closed state and frees all resources + // associated with it. This cleanup may happen asynchronously. Wait can + // be used to block on this asynchronous cleanup. + Close() + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // An endpoint can be requested to stop its worker goroutines by calling + // its Close method. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // RawTransportEndpoint is the interface that needs to be implemented by raw // transport protocol endpoints. RawTransportEndpoints receive the entire -// packet - including the link, network, and transport headers - as delivered -// to netstack. +// packet - including the network and transport headers - as delivered to +// netstack. type RawTransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. The packet contains all data from the link // layer up. - HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) +} + +// PacketEndpoint is the interface that needs to be implemented by packet +// transport protocol endpoints. These endpoints receive link layer headers in +// addition to whatever they contain (usually network and transport layer +// headers and a payload). +type PacketEndpoint interface { + // HandlePacket is called by the stack when new packets arrive that + // match the endpoint. + // + // Implementers should treat packet as immutable and should copy it + // before before modification. + // + // linkHeader may have a length of 0, in which case the PacketEndpoint + // should construct its own ethernet header for applications. + // + // HandlePacket takes ownership of pkt. + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -107,7 +147,9 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + // + // HandleUnknownDestinationPacket takes ownership of pkt. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -125,13 +167,21 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. It also returns the network layer - // header for the enpoint to inspect or pass up the stack. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) + // transport protocol endpoint. + // + // pkt.NetworkHeader must be set before calling DeliverTransportPacket. + // + // DeliverTransportPacket takes ownership of pkt. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) + // + // pkt.NetworkHeader must be set before calling + // DeliverTransportControlPacket. + // + // DeliverTransportControlPacket takes ownership of pkt. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -146,6 +196,19 @@ const ( PacketLoop ) +// NetworkHeaderParams are the header parameters given as input by the +// transport endpoint to the network. +type NetworkHeaderParams struct { + // Protocol refers to the transport protocol number. + Protocol tcpip.TransportProtocolNumber + + // TTL refers to Time To Live field of the IP-header. + TTL uint8 + + // TOS refers to TypeOfService or TrafficClass field of the IP-header. + TOS uint8 +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { @@ -170,7 +233,11 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error + WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error + + // WritePackets writes packets to the given destination address and + // protocol. + WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. @@ -186,8 +253,10 @@ type NetworkEndpoint interface { NICID() tcpip.NICID // HandlePacket is called by the link layer when new packets arrive to - // this network endpoint. - HandlePacket(r *Route, vv buffer.VectorisedView) + // this network endpoint. It sets pkt.NetworkHeader. + // + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -212,7 +281,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -229,9 +298,15 @@ type NetworkProtocol interface { // packets to the appropriate network endpoint after it has been handled by // the data link layer. type NetworkDispatcher interface { - // DeliverNetworkPacket finds the appropriate network protocol - // endpoint and hands the packet over for further processing. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + // DeliverNetworkPacket finds the appropriate network protocol endpoint + // and hands the packet over for further processing. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverNetworkPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverNetworkPacket takes ownership of pkt. + DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -253,12 +328,18 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityGSO + CapabilityHardwareGSO + + // CapabilitySoftwareGSO indicates the link endpoint supports of sending + // multiple packets using a single call (LinkEndpoint.WritePackets). + CapabilitySoftwareGSO ) // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. +// out through the implementer's data link endpoint. When a link header exists, +// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the +// stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a @@ -288,6 +369,18 @@ type LinkEndpoint interface { // r.LocalLinkAddress if it is provided. WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + // WritePackets writes packets with the given protocol through the + // given route. + // + // Right now, WritePackets is used only when the software segmentation + // offload is enabled. If it will be used for something else, it may + // require to change syscall filters. + WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + + // WriteRawPacket writes a packet directly to the link. The packet + // should already have an ethernet header. + WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error + // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. Attach(dispatcher NetworkDispatcher) @@ -311,13 +404,14 @@ type LinkEndpoint interface { type InjectableLinkEndpoint interface { LinkEndpoint - // Inject injects an inbound packet. - Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + // InjectInbound injects an inbound packet. + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) - // WriteRawPacket writes a fully formed outbound packet directly to the link. + // InjectOutbound writes a fully formed outbound packet directly to the + // link. // // dest is used by endpoints with multiple raw destinations. - WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error + InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error } // A LinkAddressResolver is an extension to a NetworkProtocol that @@ -346,10 +440,10 @@ type LinkAddressResolver interface { type LinkAddressCache interface { // CheckLocalAddress determines if the given local address exists, and if it // does not exist. - CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID + CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver @@ -360,17 +454,22 @@ type LinkAddressCache interface { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). - GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } -// UnassociatedEndpointFactory produces endpoints for writing packets not -// associated with a particular transport protocol. Such endpoints can be used -// to write arbitrary packets that include the IP header. -type UnassociatedEndpointFactory interface { - NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) +// RawFactory produces endpoints for writing various types of raw packets. +type RawFactory interface { + // NewUnassociatedEndpoint produces endpoints for writing packets not + // associated with a particular transport protocol. Such endpoints can + // be used to write arbitrary packets that include the network header. + NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + + // NewPacketEndpoint produces endpoints for reading and writing packets + // that include network and (when cooked is false) link layer headers. + NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) } // GSOType is the type of GSO segments. @@ -381,8 +480,14 @@ type GSOType int // Types of gso segments. const ( GSONone GSOType = iota + + // Hardware GSO types: GSOTCPv4 GSOTCPv6 + + // GSOSW is used for software GSO segments which have to be sent by + // endpoint.WritePackets. + GSOSW ) // GSO contains generic segmentation offload properties. @@ -410,3 +515,7 @@ type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 } + +// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. +// This isn't a hard limit, because it is never set into packet headers. +const SoftwareGSOMaxSize = (1 << 16) diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 5c8b7977a..1a0a51b57 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -47,8 +47,8 @@ type Route struct { // starts. ref *referencedNetworkEndpoint - // loop controls where WritePacket should send packets. - loop PacketLooping + // Loop controls where WritePacket should send packets. + Loop PacketLooping } // makeRoute initializes a new route. It takes ownership of the provided @@ -59,6 +59,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop = PacketLoop } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop } return Route{ @@ -67,7 +69,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip LocalLinkAddress: localLinkAddr, RemoteAddress: remoteAddr, ref: ref, - loop: loop, + Loop: loop, } } @@ -152,12 +154,12 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) + err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.Loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { @@ -167,6 +169,44 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec return err } +// PacketDescriptor is a packet descriptor which contains a packet header and +// offset and size of packet data in a payload view. +type PacketDescriptor struct { + Hdr buffer.Prependable + Off int + Size int +} + +// NewPacketDescriptors allocates a set of packet descriptors. +func NewPacketDescriptors(n int, hdrSize int) []PacketDescriptor { + buf := make([]byte, n*hdrSize) + hdrs := make([]PacketDescriptor, n) + for i := range hdrs { + hdrs[i].Hdr = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) + } + return hdrs +} + +// WritePackets writes the set of packets through the given route. +func (r *Route) WritePackets(gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams) (int, *tcpip.Error) { + if !r.ref.isValidForOutgoing() { + return 0, tcpip.ErrInvalidEndpointState + } + + n, err := r.ref.ep.WritePackets(r, gso, hdrs, payload, params, r.Loop) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(hdrs) - n)) + } + r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) + payloadSize := 0 + for i := 0; i < n; i++ { + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdrs[i].Hdr.UsedLength())) + payloadSize += hdrs[i].Size + } + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + return n, err +} + // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { @@ -174,7 +214,7 @@ func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip. return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.Loop); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } @@ -208,10 +248,17 @@ func (r *Route) Clone() Route { return *r } -// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast. +// MakeLoopedRoute duplicates the given route with special handling for routes +// used for sending multicast or broadcast packets. In those cases the +// multicast/broadcast address is the remote address when sending out, but for +// incoming (looped) packets it becomes the local address. Similarly, the local +// interface address that was the local address going out becomes the remote +// address coming in. This is different to unicast routes where local and +// remote addresses remain the same as they identify location (local vs remote) +// not direction (source vs destination). func (r *Route) MakeLoopedRoute() Route { l := r.Clone() - if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { + if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress l.RemoteLinkAddress = l.LocalLinkAddress } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 18d1704a5..2f8d8e822 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,10 +20,13 @@ package stack import ( + "encoding/binary" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -41,11 +44,14 @@ const ( resolutionTimeout = 1 * time.Second // resolutionAttempts is set to the same ARP retries used in Linux. resolutionAttempts = 3 + + // DefaultTOS is the default type of service value for network endpoints. + DefaultTOS = 0 ) type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -339,6 +345,13 @@ type ResumableEndpoint interface { Resume(*Stack) } +// uniqueIDGenerator is a default unique ID generator. +type uniqueIDGenerator uint64 + +func (u *uniqueIDGenerator) UniqueID() uint64 { + return atomic.AddUint64((*uint64)(u), 1) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -346,10 +359,9 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver - // unassociatedFactory creates unassociated endpoints. If nil, raw - // endpoints are disabled. It is set during Stack creation and is - // immutable. - unassociatedFactory UnassociatedEndpointFactory + // rawFactory creates raw endpoints. If nil, raw endpoints are + // disabled. It is set during Stack creation and is immutable. + rawFactory RawFactory demux *transportDemuxer @@ -357,9 +369,10 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + forwarding bool + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -388,6 +401,32 @@ type Stack struct { // icmpRateLimiter is a global rate limiter for all ICMP messages generated // by the stack. icmpRateLimiter *ICMPRateLimiter + + // seed is a one-time random value initialized at stack startup + // and is used to seed the TCP port picking on active connections + // + // TODO(gvisor.dev/issue/940): S/R this field. + seed uint32 + + // ndpConfigs is the default NDP configurations used by interfaces. + ndpConfigs NDPConfigurations + + // autoGenIPv6LinkLocal determines whether or not the stack will attempt + // to auto-generate an IPv6 link-local address for newly enabled NICs. + // See the AutoGenIPv6LinkLocal field of Options for more details. + autoGenIPv6LinkLocal bool + + // ndpDisp is the NDP event dispatcher that is used to send the netstack + // integrator NDP related events. + ndpDisp NDPDispatcher + + // uniqueIDGenerator is a generator of unique identifiers. + uniqueIDGenerator UniqueID +} + +// UniqueID is an abstract generator of unique identifiers. +type UniqueID interface { + UniqueID() uint64 } // Options contains optional Stack configuration. @@ -411,14 +450,71 @@ type Options struct { // stack (false). HandleLocal bool - // UnassociatedFactory produces unassociated endpoints raw endpoints. - // Raw endpoints are enabled only if this is non-nil. - UnassociatedFactory UnassociatedEndpointFactory + // UniqueID is an optional generator of unique identifiers. + UniqueID UniqueID + + // NDPConfigs is the default NDP configurations used by interfaces. + // + // By default, NDPConfigs will have a zero value for its + // DupAddrDetectTransmits field, implying that DAD will not be performed + // before assigning an address to a NIC. + NDPConfigs NDPConfigurations + + // AutoGenIPv6LinkLocal determins whether or not the stack will attempt + // to auto-generate an IPv6 link-local address for newly enabled NICs. + // Note, setting this to true does not mean that a link-local address + // will be assigned right away, or at all. If Duplicate Address + // Detection is enabled, an address will only be assigned if it + // successfully resolves. If it fails, no further attempt will be made + // to auto-generate an IPv6 link-local address. + // + // The generated link-local address will follow RFC 4291 Appendix A + // guidelines. + AutoGenIPv6LinkLocal bool + + // NDPDisp is the NDP event dispatcher that an integrator can provide to + // receive NDP related events. + NDPDisp NDPDispatcher + + // RawFactory produces raw endpoints. Raw endpoints are enabled only if + // this is non-nil. + RawFactory RawFactory +} + +// TransportEndpointInfo holds useful information about a transport endpoint +// which can be queried by monitoring tools. +// +// +stateify savable +type TransportEndpointInfo struct { + // The following fields are initialized at creation time and are + // immutable. + + NetProto tcpip.NetworkProtocolNumber + TransProto tcpip.TransportProtocolNumber + + // The following fields are protected by endpoint mu. + + ID TransportEndpointID + // BindNICID and bindAddr are set via calls to Bind(). They are used to + // reject attempts to send data or connect via a different NIC or + // address + BindNICID tcpip.NICID + BindAddr tcpip.Address + // RegisterNICID is the default NICID registered as a side-effect of + // connect or datagram write. + RegisterNICID tcpip.NICID } +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*TransportEndpointInfo) IsEndpointInfo() {} + // New allocates a new networking stack with only the requested networking and // transport protocols configured with default options. // +// Note, NDPConfigurations will be fixed before being used by the Stack. That +// is, if an invalid value was provided, it will be reset to the default value. +// // Protocol options can be changed by calling the // SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the // stack. Please refer to individual protocol implementations as to what options @@ -429,17 +525,30 @@ func New(opts Options) *Stack { clock = &tcpip.StdClock{} } + if opts.UniqueID == nil { + opts.UniqueID = new(uniqueIDGenerator) + } + + // Make sure opts.NDPConfigs contains valid values only. + opts.NDPConfigs.validate() + s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), - nics: make(map[tcpip.NICID]*NIC), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - icmpRateLimiter: NewICMPRateLimiter(), + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + ndpConfigs: opts.NDPConfigs, + autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + uniqueIDGenerator: opts.UniqueID, + ndpDisp: opts.NDPDisp, } // Add specified network protocols. @@ -457,8 +566,8 @@ func New(opts Options) *Stack { } } - // Add the factory for unassociated endpoints, if present. - s.unassociatedFactory = opts.UnassociatedFactory + // Add the factory for raw endpoints, if present. + s.rawFactory = opts.RawFactory // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -466,6 +575,11 @@ func New(opts Options) *Stack { return s } +// UniqueID returns a unique identifier. +func (s *Stack) UniqueID() uint64 { + return s.uniqueIDGenerator.UniqueID() +} + // SetNetworkProtocolOption allows configuring individual protocol level // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value @@ -527,7 +641,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -593,12 +707,12 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if s.unassociatedFactory == nil { + if s.rawFactory == nil { return nil, tcpip.ErrNotPermitted } if !associated { - return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue) + return s.rawFactory.NewUnassociatedEndpoint(s, network, transport, waiterQueue) } t, ok := s.transportProtocols[transport] @@ -609,6 +723,16 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network return t.proto.NewRawEndpoint(s, network, waiterQueue) } +// NewPacketEndpoint creates a new packet endpoint listening for the given +// netProto. +func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if s.rawFactory == nil { + return nil, tcpip.ErrNotPermitted + } + + return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) +} + // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { @@ -893,7 +1017,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } else { for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) { + if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } if nic, ok := s.nics[route.NIC]; ok { @@ -931,13 +1055,13 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool // CheckLocalAddress determines if the given local address exists, and if it // does, returns the id of the NIC it's bound to. Returns 0 if the address // does not exist. -func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { +func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { s.mu.RLock() defer s.mu.RUnlock() // If a NIC is specified, we try to find the address there only. - if nicid != 0 { - nic := s.nics[nicid] + if nicID != 0 { + nic := s.nics[nicID] if nic == nil { return 0 } @@ -996,35 +1120,35 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { } // AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.add(fullAddr, linkAddr) // TODO: provide a way for a transport endpoint to receive a signal // that AddLinkAddress for a particular address has been called. } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() - nic := s.nics[nicid] + nic := s.nics[nicID] if nic == nil { s.mu.RUnlock() return "", nil, tcpip.ErrUnknownNICID } s.mu.RUnlock() - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } // RemoveWaker implements LinkAddressCache.RemoveWaker. -func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { +func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { s.mu.RLock() defer s.mu.RUnlock() - if nic := s.nics[nicid]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + if nic := s.nics[nicID]; nic == nil { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.removeWaker(fullAddr, waker) } } @@ -1033,73 +1157,52 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { - if nicID == 0 { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { - if nicID == 0 { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep) - return - } +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +} - s.mu.RLock() - defer s.mu.RUnlock() +// StartTransportEndpointCleanup removes the endpoint with the given id from +// the stack transport dispatcher. It also transitions it to the cleanup stage. +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.mu.Lock() + defer s.mu.Unlock() - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterEndpoint(netProtos, protocol, id, ep) - } + s.cleanupEndpoints[ep] = struct{}{} + + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +} + +// CompleteTransportEndpointCleanup removes the endpoint from the cleanup +// stage. +func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { + s.mu.Lock() + delete(s.cleanupEndpoints, ep) + s.mu.Unlock() +} + +// FindTransportEndpoint finds an endpoint that most closely matches the provided +// id. If no endpoint is found it returns nil. +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, r) } // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { - if nicID == 0 { - return s.demux.registerRawEndpoint(netProto, transProto, ep) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerRawEndpoint(netProto, transProto, ep) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { - if nicID == 0 { - s.demux.unregisterRawEndpoint(netProto, transProto, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterRawEndpoint(netProto, transProto, ep) - } + s.demux.unregisterRawEndpoint(netProto, transProto, ep) } // RegisterRestoredEndpoint records e as an endpoint that has been restored on @@ -1110,6 +1213,69 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { s.mu.Unlock() } +// RegisteredEndpoints returns all endpoints which are currently registered. +func (s *Stack) RegisteredEndpoints() []TransportEndpoint { + s.mu.Lock() + defer s.mu.Unlock() + var es []TransportEndpoint + for _, e := range s.demux.protocol { + es = append(es, e.transportEndpoints()...) + } + return es +} + +// CleanupEndpoints returns endpoints currently in the cleanup state. +func (s *Stack) CleanupEndpoints() []TransportEndpoint { + s.mu.Lock() + es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) + for e := range s.cleanupEndpoints { + es = append(es, e) + } + s.mu.Unlock() + return es +} + +// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful +// for restoring a stack after a save. +func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { + s.mu.Lock() + for _, e := range es { + s.cleanupEndpoints[e] = struct{}{} + } + s.mu.Unlock() +} + +// Close closes all currently registered transport endpoints. +// +// Endpoints created or modified during this call may not get closed. +func (s *Stack) Close() { + for _, e := range s.RegisteredEndpoints() { + e.Close() + } +} + +// Wait waits for all transport and link endpoints to halt their worker +// goroutines. +// +// Endpoints created or modified during this call may not get waited on. +// +// Note that link endpoints must be stopped via an implementation specific +// mechanism. +func (s *Stack) Wait() { + for _, e := range s.RegisteredEndpoints() { + e.Wait() + } + for _, e := range s.CleanupEndpoints() { + e.Wait() + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, n := range s.nics { + n.linkEP.Wait() + } +} + // Resume restarts the stack after a restore. This must be called after the // entire system has been restored. func (s *Stack) Resume() { @@ -1124,6 +1290,109 @@ func (s *Stack) Resume() { } } +// RegisterPacketEndpoint registers ep with the stack, causing it to receive +// all traffic of the specified netProto on the given NIC. If nicID is 0, it +// receives traffic from every NIC. +func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + // If no NIC is specified, capture on all devices. + if nicID == 0 { + // Register with each NIC. + for _, nic := range s.nics { + if err := nic.registerPacketEndpoint(netProto, ep); err != nil { + s.unregisterPacketEndpointLocked(0, netProto, ep) + return err + } + } + return nil + } + + // Capture on a specific device. + nic, ok := s.nics[nicID] + if !ok { + return tcpip.ErrUnknownNICID + } + if err := nic.registerPacketEndpoint(netProto, ep); err != nil { + return err + } + + return nil +} + +// UnregisterPacketEndpoint unregisters ep for packets of the specified +// netProto from the specified NIC. If nicID is 0, ep is unregistered from all +// NICs. +func (s *Stack) UnregisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + s.mu.Lock() + defer s.mu.Unlock() + s.unregisterPacketEndpointLocked(nicID, netProto, ep) +} + +func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { + // If no NIC is specified, unregister on all devices. + if nicID == 0 { + // Unregister with each NIC. + for _, nic := range s.nics { + nic.unregisterPacketEndpoint(netProto, ep) + } + return + } + + // Unregister in a single device. + nic, ok := s.nics[nicID] + if !ok { + return + } + nic.unregisterPacketEndpoint(netProto, ep) +} + +// WritePacket writes data directly to the specified NIC. It adds an ethernet +// header based on the arguments. +func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { + s.mu.Lock() + nic, ok := s.nics[nicID] + s.mu.Unlock() + if !ok { + return tcpip.ErrUnknownDevice + } + + // Add our own fake ethernet header. + ethFields := header.EthernetFields{ + SrcAddr: nic.linkEP.LinkAddress(), + DstAddr: dst, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + ethHeader := buffer.View(fakeHeader).ToVectorisedView() + ethHeader.Append(payload) + + if err := nic.linkEP.WriteRawPacket(ethHeader); err != nil { + return err + } + + return nil +} + +// WriteRawPacket writes data directly to the specified NIC without adding any +// headers. +func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { + s.mu.Lock() + nic, ok := s.nics[nicID] + s.mu.Unlock() + if !ok { + return tcpip.ErrUnknownDevice + } + + if err := nic.linkEP.WriteRawPacket(payload); err != nil { + return err + } + + return nil +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. @@ -1243,3 +1512,85 @@ func (s *Stack) SetICMPBurst(burst int) { func (s *Stack) AllowICMPMessage() bool { return s.icmpRateLimiter.Allow() } + +// IsAddrTentative returns true if addr is tentative on the NIC with ID id. +// +// Note that if addr is not associated with a NIC with id ID, then this +// function will return false. It will only return true if the address is +// associated with the NIC AND it is tentative. +func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] + if !ok { + return false, tcpip.ErrUnknownNICID + } + + return nic.isAddrTentative(addr), nil +} + +// DupTentativeAddrDetected attempts to inform the NIC with ID id that a +// tentative addr on it is a duplicate on a link. +func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.dupTentativeAddrDetected(addr) +} + +// SetNDPConfigurations sets the per-interface NDP configurations on the NIC +// with ID id to c. +// +// Note, if c contains invalid NDP configuration values, it will be fixed to +// use default values for the erroneous values. +func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.setNDPConfigs(c) + + return nil +} + +// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement +// message that it needs to handle. +func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.handleNDPRA(ip, ra) + + return nil +} + +// Seed returns a 32 bit value that can be used as a seed value for port +// picking, ISN generation etc. +// +// NOTE: The seed is generated once during stack initialization only. +func (s *Stack) Seed() uint32 { + return s.seed +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return binary.LittleEndian.Uint32(b) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d2dede8a9..bf1d6974c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -24,11 +24,14 @@ import ( "sort" "strings" "testing" + "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -55,7 +58,7 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int proto *fakeNetworkProtocol @@ -68,7 +71,7 @@ func (f *fakeNetworkEndpoint) MTU() uint32 { } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicid + return f.nicID } func (f *fakeNetworkEndpoint) PrefixLen() int { @@ -83,28 +86,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := vv.First() - vv.TrimFront(fakeNetHeaderLen) + b := pkt.Data.First() + pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := vv.First() + nb := pkt.Data.First() if len(nb) < fakeNetHeaderLen { return } - vv.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + pkt.Data.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -119,7 +122,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -128,14 +131,16 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu b := hdr.Prepend(fakeNetHeaderLen) b[0] = r.RemoteAddress[0] b[1] = f.id.LocalAddress[0] - b[2] = byte(protocol) + b[2] = byte(params.Protocol) if loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - f.HandlePacket(r, vv) + f.HandlePacket(r, tcpip.PacketBuffer{ + Data: vv, + }) } if loop&stack.PacketOut == 0 { return nil @@ -144,6 +149,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + panic("not implemented") +} + func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { return tcpip.ErrNotSupported } @@ -189,9 +199,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, @@ -251,7 +261,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -261,7 +273,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -271,7 +285,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -280,7 +296,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -290,7 +308,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -310,7 +330,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) + return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -365,7 +385,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -660,11 +682,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } } -func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { +func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { t.Helper() - info, ok := s.NICInfo()[nicid] + info, ok := s.NICInfo()[nicID] if !ok { - t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + t.Fatalf("NICInfo() failed to find nicID=%d", nicID) } if len(addr) == 0 { // No address given, verify that there is no address assigned to the NIC. @@ -697,7 +719,7 @@ func TestEndpointExpiration(t *testing.T) { localAddrByte byte = 0x01 remoteAddr tcpip.Address = "\x03" noAddr tcpip.Address = "" - nicid tcpip.NICID = 1 + nicID tcpip.NICID = 1 ) localAddr := tcpip.Address([]byte{localAddrByte}) @@ -709,7 +731,7 @@ func TestEndpointExpiration(t *testing.T) { }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -726,13 +748,13 @@ func TestEndpointExpiration(t *testing.T) { buf[0] = localAddrByte if promiscuous { - if err := s.SetPromiscuousMode(nicid, true); err != nil { + if err := s.SetPromiscuousMode(nicID, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } } if spoofing { - if err := s.SetSpoofing(nicid, true); err != nil { + if err := s.SetSpoofing(nicID, true); err != nil { t.Fatal("SetSpoofing failed:", err) } } @@ -740,7 +762,7 @@ func TestEndpointExpiration(t *testing.T) { // 1. No Address yet, send should only work for spoofing, receive for // promiscuous mode. //----------------------- - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -755,20 +777,20 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -783,10 +805,10 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -804,10 +826,10 @@ func TestEndpointExpiration(t *testing.T) { // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -823,10 +845,10 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) testSend(t, r, ep, nil) @@ -834,17 +856,17 @@ func TestEndpointExpiration(t *testing.T) { // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -952,10 +974,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. testSendTo(t, s, dstAddr, ep, nil) @@ -967,10 +989,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != localAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. testSend(t, r, ep, nil) @@ -1016,17 +1038,33 @@ func TestSpoofingNoAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } -func TestBroadcastNeedsNoRoute(t *testing.T) { +func verifyRoute(gotRoute, wantRoute stack.Route) error { + if gotRoute.LocalAddress != wantRoute.LocalAddress { + return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) + } + if gotRoute.RemoteAddress != wantRoute.RemoteAddress { + return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) + } + if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { + return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + } + if gotRoute.NextHop != wantRoute.NextHop { + return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) + } + return nil +} + +func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) @@ -1039,28 +1077,99 @@ func TestBroadcastNeedsNoRoute(t *testing.T) { // If there is no endpoint, it won't work. if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } - if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err) + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + if err := s.AddProtocolAddress(1, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + } + + // If the NIC doesn't exist, it won't work. + if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { + t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + } +} + +func TestOutgoingBroadcastWithRouteTable(t *testing.T) { + defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. + nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. + nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") + + // Create a new stack with two NICs. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + if err := s.CreateNIC(2, ep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) } - if r.LocalAddress != header.IPv4Any { - t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any) + nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { + t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) } - if r.RemoteAddress != header.IPv4Broadcast { - t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast) + // Set the initial route table. + rt := []tcpip.Route{ + {Destination: nic1Addr.Subnet(), NIC: 1}, + {Destination: nic2Addr.Subnet(), NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2}, + {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1}, } + s.SetRouteTable(rt) - // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + // When an interface is given, the route for a broadcast goes through it. + r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + } + + // When an interface is not given, it consults the route table. + // 1. Case: Using the default route. + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) + } + + // 2. Case: Having an explicit route for broadcast will select that one. + rt = append( + []tcpip.Route{ + {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + }, + rt..., + ) + s.SetRouteTable(rt) + r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) + } + if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -1550,12 +1659,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } func TestAddAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1563,7 +1672,7 @@ func TestAddAddress(t *testing.T) { expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) for _, addrLen := range []int{4, 16} { address := addrGen.next(addrLen) - if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { t.Fatalf("AddAddress(address=%s) failed: %s", address, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1572,17 +1681,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1599,24 +1708,24 @@ func TestAddProtocolAddress(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil { + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) } expectedAddresses = append(expectedAddresses, protocolAddress) } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1627,7 +1736,7 @@ func TestAddAddressWithOptions(t *testing.T) { for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil { + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1637,17 +1746,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1666,7 +1775,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil { + if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) } expectedAddresses = append(expectedAddresses, protocolAddress) @@ -1674,7 +1783,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -1700,7 +1809,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1760,7 +1871,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) select { case <-ep2.C: @@ -1777,3 +1890,297 @@ func TestNICForwarding(t *testing.T) { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } + +// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses +// (or lack there-of if disabled (default)). Note, DAD will be disabled in +// these tests. +func TestNICAutoGenAddr(t *testing.T) { + tests := []struct { + name string + autoGen bool + linkAddr tcpip.LinkAddress + shouldGen bool + }{ + { + "Disabled", + false, + linkAddr1, + false, + }, + { + "Enabled", + true, + linkAddr1, + true, + }, + { + "Nil MAC", + true, + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty MAC", + true, + tcpip.LinkAddress(""), + false, + }, + { + "Invalid MAC", + true, + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Multicast MAC", + true, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + false, + }, + { + "Unspecified MAC", + true, + tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + } + + if test.autoGen { + // Only set opts.AutoGenIPv6LinkLocal when + // test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by + // default. + opts.AutoGenIPv6LinkLocal = true + } + + e := channel.New(10, 1280, test.linkAddr) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + + if test.shouldGen { + // Should have auto-generated an address and + // resolved immediately (DAD is disabled). + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } + } else { + // Should not have auto-generated an address. + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + } + }) + } +} + +// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 +// link-local addresses will only be assigned after the DAD process resolves. +func TestNICAutoGenAddrDoesDAD(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + } + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Address should not be considered bound to the + // NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + linkLocalAddr := header.LinkLocalAddr(linkAddr1) + + // Wait for DAD to resolve. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // We should get a resolution event after 1s (default time to + // resolve as per default NDP configurations). Waiting for that + // resolution time + an extra 1s without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != linkLocalAddr { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } +} + +// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected +// when an address's kind gets "promoted" to permanent from permanentExpired. +func TestNewPEBOnPromotionToPermanent(t *testing.T) { + pebs := []stack.PrimaryEndpointBehavior{ + stack.NeverPrimaryEndpoint, + stack.CanBePrimaryEndpoint, + stack.FirstPrimaryEndpoint, + } + + for _, pi := range pebs { + for _, ps := range pebs { + t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + // Add a permanent address with initial + // PrimaryEndpointBehavior (peb), pi. If pi is + // NeverPrimaryEndpoint, the address should not + // be returned by a call to GetMainNICAddress; + // else, it should. + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + } + addr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("s.GetMainNICAddress failed:", err) + } + if pi == stack.NeverPrimaryEndpoint { + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + + } + } else if addr.Address != "\x01" { + t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatalf("NewSubnet failed:", err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + // Take a route through the address so its ref + // count gets incremented and does not actually + // get deleted when RemoveAddress is called + // below. This is because we want to test that a + // new peb is respected when an address gets + // "promoted" to permanent from a + // permanentExpired kind. + r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + defer r.Release() + if err := s.RemoveAddress(1, "\x01"); err != nil { + t.Fatalf("RemoveAddress failed:", err) + } + + // + // At this point, the address should still be + // known by the NIC, but have its + // kind = permanentExpired. + // + + // Add some other address with peb set to + // FirstPrimaryEndpoint. + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + + } + + // Add back the address we removed earlier and + // make sure the new peb was respected. + // (The address should just be promoted now). + if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { + t.Fatal("AddAddressWithOptions failed:", err) + } + var primaryAddrs []tcpip.Address + for _, pa := range s.NICInfo()[1].ProtocolAddresses { + primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address) + } + var expectedList []tcpip.Address + switch ps { + case stack.FirstPrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x01", + "\x03", + } + case stack.CanBePrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x03", + "\x01", + } + case stack.NeverPrimaryEndpoint: + expectedList = []tcpip.Address{ + "\x03", + } + } + if !cmp.Equal(primaryAddrs, expectedList) { + t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList) + } + + // Once we remove the other address, if the new + // peb, ps, was NeverPrimaryEndpoint, no address + // should be returned by a call to + // GetMainNICAddress; else, our original address + // should be returned. + if err := s.RemoveAddress(1, "\x03"); err != nil { + t.Fatalf("RemoveAddress failed:", err) + } + addr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("s.GetMainNICAddress failed:", err) + } + if ps == stack.NeverPrimaryEndpoint { + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want) + + } + } else { + if addr.Address != "\x01" { + t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address) + } + } + }) + } + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index cf8a6d129..cb805522b 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -17,10 +17,10 @@ package stack import ( "fmt" "math/rand" + "sort" "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -35,7 +35,7 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]TransportEndpoint + endpoints map[TransportEndpointID]*endpointsByNic // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint @@ -43,19 +43,122 @@ type transportEndpoints struct { // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - e, ok := eps.endpoints[id] + epsByNic, ok := eps.endpoints[id] if !ok { return } - if multiPortEp, ok := e.(*multiPortEndpoint); ok { - if !multiPortEp.unregisterEndpoint(ep) { + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return + } + delete(eps.endpoints, id) +} + +func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { + eps.mu.RLock() + defer eps.mu.RUnlock() + es := make([]TransportEndpoint, 0, len(eps.endpoints)) + for _, e := range eps.endpoints { + es = append(es, e.transportEndpoints()...) + } + return es +} + +type endpointsByNic struct { + mu sync.RWMutex + endpoints map[tcpip.NICID]*multiPortEndpoint + // seed is a random secret for a jenkins hash. + seed uint32 +} + +func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + var eps []TransportEndpoint + for _, ep := range epsByNic.endpoints { + eps = append(eps, ep.transportEndpoints()...) + } + return eps +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { + epsByNic.mu.RLock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return } } - delete(eps.endpoints, id) + + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints bound to the right device. + if isMulticastOrBroadcast(id.LocalAddress) { + mpep.handlePacketAll(r, id, pkt) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[n.ID()] + if !ok { + mpep, ok = epsByNic.endpoints[0] + } + if !ok { + return + } + + // TODO(eyalsoha): Why don't we look at id to see if this packet needs to + // broadcast like we are doing with handlePacket above? + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) +} + +// registerEndpoint returns true if it succeeds. It fails and returns +// false if ep already has an element with the same key. +func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + + if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { + // There was already a bind. + return multiPortEp.singleRegisterEndpoint(t, reusePort) + } + + // This is a new binding. + multiPortEp := &multiPortEndpoint{} + multiPortEp.endpointsMap = make(map[TransportEndpoint]int) + multiPortEp.reuse = reusePort + epsByNic.endpoints[bindToDevice] = multiPortEp + return multiPortEp.singleRegisterEndpoint(t, reusePort) +} + +// unregisterEndpoint returns true if endpointsByNic has to be unregistered. +func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + multiPortEp, ok := epsByNic.endpoints[bindToDevice] + if !ok { + return false + } + if multiPortEp.unregisterEndpoint(t) { + delete(epsByNic.endpoints, bindToDevice) + } + return len(epsByNic.endpoints) == 0 } // transportDemuxer demultiplexes packets targeted at a transport endpoint @@ -75,7 +178,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]TransportEndpoint), + endpoints: make(map[TransportEndpointID]*endpointsByNic), } } } @@ -85,10 +188,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) return err } } @@ -97,13 +200,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // multiPortEndpoint is a container for TransportEndpoints which are bound to -// the same pair of address and port. +// the same pair of address and port. endpointsArr always has at least one +// element. +// +// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save +// this to ensure that the underlying endpoints get saved/restored, but not not +// use the restored copy. +// +// +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int - // seed is a random secret for a jenkins hash. - seed uint32 + // reuse indicates if more than one endpoint is allowed. + reuse bool +} + +func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + return eps } // reciprocalScale scales a value into range [0, n). @@ -117,9 +234,10 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint { - ep.mu.RLock() - defer ep.mu.RUnlock() +func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { + if len(mpep.endpointsArr) == 1 { + return mpep.endpointsArr[0] + } payload := []byte{ byte(id.LocalPort), @@ -128,51 +246,77 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd byte(id.RemotePort >> 8), } - h := jenkins.Sum32(ep.seed) + h := jenkins.Sum32(seed) h.Write(payload) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(ep.endpointsArr))) - return ep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) + return mpep.endpointsArr[idx] } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast or multicast datagram, deliver the datagram to all - // endpoints managed by ep. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { - for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy. - if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) - break - } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { + ep.mu.RLock() + for i, endpoint := range ep.endpointsArr { + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. + if i == len(ep.endpointsArr)-1 { + endpoint.HandlePacket(r, id, pkt) + break } - } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + endpoint.HandlePacket(r, id, pkt.Clone()) } + ep.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { - ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv) +// Close implements stack.TransportEndpoint.Close. +func (ep *multiPortEndpoint) Close() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Close() + } } -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) { +// Wait implements stack.TransportEndpoint.Wait. +func (ep *multiPortEndpoint) Wait() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Wait() + } +} + +// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint +// list. The list might be empty already. +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - // A new endpoint is added into endpointsArr and its index there is - // saved in endpointsMap. This will allows to remove endpoint from - // the array fast. + if len(ep.endpointsArr) > 0 { + // If it was previously bound, we need to check if we can bind again. + if !ep.reuse || !reusePort { + return tcpip.ErrPortInUse + } + } + + // A new endpoint is added into endpointsArr and its index there is saved in + // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + + // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints + // can be restored in the same order. + sort.Slice(ep.endpointsArr, func(i, j int) bool { + return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() + }) + for i, e := range ep.endpointsArr { + ep.endpointsMap[e] = i + } + return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. @@ -197,53 +341,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { return true } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { + // TODO(eyalsoha): Why? reusePort = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return nil + return tcpip.ErrUnknownProtocol } eps.mu.Lock() defer eps.mu.Unlock() - var multiPortEp *multiPortEndpoint - if _, ok := eps.endpoints[id]; ok { - if !reusePort { - return tcpip.ErrPortInUse - } - multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint) - if !ok { - return tcpip.ErrPortInUse - } + if epsByNic, ok := eps.endpoints[id]; ok { + // There was already a binding. + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } - if reusePort { - if multiPortEp == nil { - multiPortEp = &multiPortEndpoint{} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.seed = rand.Uint32() - eps.endpoints[id] = multiPortEp - } - - multiPortEp.singleRegisterEndpoint(ep) - - return nil + // This is a new binding. + epsByNic := &endpointsByNic{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), } - eps.endpoints[id] = ep + eps.endpoints[id] = epsByNic - return nil + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep) + eps.unregisterEndpoint(id, ep, bindToDevice) } } } @@ -259,30 +391,21 @@ var loopbackSubnet = func() tcpip.Subnet { // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if it // found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false } - // If a sender bound to the Loopback interface sends a broadcast, - // that broadcast must not be delivered to the sender. - if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort { - return false - } - - // If the packet is a broadcast, then find all matching transport endpoints. - // Otherwise, try to find a single matching transport endpoint. - destEps := make([]TransportEndpoint, 0, 1) eps.mu.RLock() - if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { - for epID, endpoint := range eps.endpoints { - if epID.LocalPort == id.LocalPort { - destEps = append(destEps, endpoint) - } - } - } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { + // Determine which transport endpoint or endpoints to deliver this packet to. + // If the packet is a broadcast or multicast, then find all matching + // transport endpoints. + var destEps []*endpointsByNic + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, id) + } else if ep := d.findEndpointLocked(eps, id); ep != nil { destEps = append(destEps, ep) } @@ -297,17 +420,19 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // Deliver the packet. - for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + // HandlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEps[:len(destEps)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEps[len(destEps)-1].handlePacket(r, id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -321,7 +446,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + rawEP.HandlePacket(r, pkt) foundRaw = true } eps.mu.RUnlock() @@ -331,7 +456,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -339,7 +464,7 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, vv, id) + ep := d.findEndpointLocked(eps, id) eps.mu.RUnlock() // Fail if we didn't find one. @@ -348,15 +473,16 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, } // Deliver the packet. - ep.HandleControlPacket(id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, pkt) return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the local address. @@ -364,7 +490,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with the id minus the remote part. @@ -372,15 +498,54 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - return ep + matchedEPs = append(matchedEPs, ep) + } + return matchedEPs +} + +// findTransportEndpoint find a single endpoint that most closely matches the provided id. +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + return nil + } + // Try to find the endpoint. + eps.mu.RLock() + epsByNic := d.findEndpointLocked(eps, id) + // Fail if we didn't find one. + if epsByNic == nil { + eps.mu.RUnlock() + return nil + } + + epsByNic.mu.RLock() + eps.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return nil + } } + ep := selectEndpoint(id, mpep, epsByNic.seed) + epsByNic.mu.RUnlock() + return ep +} + +// findEndpointLocked returns the endpoint that most closely matches the given +// id. +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { + return matchedEPs[0] + } return nil } @@ -391,7 +556,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { - return nil + return tcpip.ErrNotSupported } eps.mu.Lock() @@ -418,3 +583,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } } } + +func isMulticastOrBroadcast(addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +} diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go new file mode 100644 index 000000000..3b28b06d0 --- /dev/null +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -0,0 +1,354 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testPort = 4096 +) + +type testContext struct { + t *testing.T + linkEPs map[string]*channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createV6Endpoint(v6only bool) { + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.ep.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } +} + +// newDualTestContextMultiNic creates the testing context and also linkEpNames +// named NICs. +func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + linkEPs := make(map[string]*channel.Endpoint) + for i, linkEpName := range linkEpNames { + channelEP := channel.New(256, mtu, "") + nicID := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + linkEPs[linkEpName] = channelEP + + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress IPv4 failed: %v", err) + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress IPv6 failed: %v", err) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEPs: linkEPs, + } +} + +type headers struct { + srcPort uint16 + dstPort uint16 +} + +func newPayload() []byte { + b := make([]byte, 30+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: testV6Addr, + DstAddr: stackV6Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +func TestTransportDemuxerRegister(t *testing.T) { + for _, test := range []struct { + name string + proto tcpip.NetworkProtocolNumber + want *tcpip.Error + }{ + {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"success", ipv4.ProtocolNumber, nil}, + } { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { + t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + } + }) + } +} + +// TestReuseBindToDevice injects varied packets on input devices and checks that +// the distribution of packets received matches expectations. +func TestDistribution(t *testing.T) { + type endpointSockopts struct { + reuse int + bindToDevice string + } + for _, test := range []struct { + name string + // endpoints will received the inject packets. + endpoints []endpointSockopts + // wantedDistribution is the wanted ratio of packets received on each + // endpoint for each NIC on which packets are injected. + wantedDistributions map[string][]float64 + }{ + { + "BindPortReuse", + // 5 endpoints that all have reuse set. + []endpointSockopts{ + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed evenly. + "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + }, + }, + { + "BindToDevice", + // 3 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{0, "dev0"}, + endpointSockopts{0, "dev1"}, + endpointSockopts{0, "dev2"}, + }, + map[string][]float64{ + // Injected packets on dev0 go only to the endpoint bound to dev0. + "dev0": []float64{1, 0, 0}, + // Injected packets on dev1 go only to the endpoint bound to dev1. + "dev1": []float64{0, 1, 0}, + // Injected packets on dev2 go only to the endpoint bound to dev2. + "dev2": []float64{0, 0, 1}, + }, + }, + { + "ReuseAndBindToDevice", + // 6 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed among endpoints bound to + // dev0. + "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + // Injected packets on dev1 get distributed among endpoints bound to + // dev1 or unbound. + "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + // Injected packets on dev999 go only to the unbound. + "dev999": []float64{0, 0, 0, 0, 0, 1}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + for device, wantedDistribution := range test.wantedDistributions { + t.Run(device, func(t *testing.T) { + var devices []string + for d := range test.wantedDistributions { + devices = append(devices, d) + } + c := newDualTestContextMultiNic(t, defaultMTU, devices) + defer c.cleanup() + + c.createV6Endpoint(false) + + eps := make(map[tcpip.Endpoint]int) + + pollChannel := make(chan tcpip.Endpoint) + for i, endpoint := range test.endpoints { + // Try to receive the data. + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + var err *tcpip.Error + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + eps[ep] = i + + go func(ep tcpip.Endpoint) { + for range ch { + pollChannel <- ep + } + }(ep) + + defer ep.Close() + reusePortOption := tcpip.ReusePortOption(endpoint.reuse) + if err := ep.SetSockOpt(reusePortOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + } + bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) + if err := ep.SetSockOpt(bindToDeviceOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + } + if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + } + } + + npackets := 100000 + nports := 10000 + if got, want := len(test.endpoints), len(wantedDistribution); got != want { + t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) + } + ports := make(map[uint16]tcpip.Endpoint) + stats := make(map[tcpip.Endpoint]int) + for i := 0; i < npackets; i++ { + // Send a packet. + port := uint16(i % nports) + payload := newPayload() + c.sendV6Packet(payload, + &headers{ + srcPort: testPort + port, + dstPort: stackPort}, + device) + + var addr tcpip.FullAddress + ep := <-pollChannel + _, _, err := ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + } + stats[ep]++ + if i < nports { + ports[uint16(i)] = ep + } else { + // Check that all packets from one client are handled by the same + // socket. + if want, got := ports[port], ep; want != got { + t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) + } + } + } + + // Check that a packet distribution is as expected. + for ep, i := range eps { + wantedRatio := wantedDistribution[i] + wantedRecv := wantedRatio * float64(npackets) + actualRecv := stats[ep] + actualRatio := float64(stats[ep]) / float64(npackets) + // The deviation is less than 10%. + if math.Abs(actualRatio-wantedRatio) > 0.05 { + t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 56e8a5d9b..2cacea99a 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -38,19 +38,27 @@ const ( // Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't // use it. type fakeTransportEndpoint struct { - id stack.TransportEndpointID + stack.TransportEndpointInfo stack *stack.Stack - netProto tcpip.NetworkProtocolNumber proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + uniqueID uint64 // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint } -func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto} +func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { + return &f.TransportEndpointInfo +} + +func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { + return nil +} + +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Close() { @@ -75,7 +83,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { + if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil { return 0, nil, err } @@ -126,8 +134,8 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer r.Release() // Try to register so that we can start receiving packets. - f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false) + f.ID.RemoteAddress = addr.Addr + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) if err != nil { return err } @@ -137,6 +145,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } +func (f *fakeTransportEndpoint) UniqueID() uint64 { + return f.uniqueID +} + func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -168,7 +180,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, + false, /* reuse */ + 0, /* bindtoDevice */ ); err != nil { return err } @@ -184,14 +197,16 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - id: id, - stack: f.stack, - netProto: f.netProto, + stack: f.stack, + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, proto: f.proto, peerAddr: r.RemoteAddress, route: r.Clone(), @@ -199,7 +214,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -208,15 +223,15 @@ func (f *fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { -} +func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } -func (f *fakeTransportEndpoint) Resume(*stack.Stack) { -} +func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} + +func (f *fakeTransportEndpoint) Wait() {} type fakeTransportGoodOption bool @@ -241,7 +256,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil + return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -256,7 +271,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } @@ -327,7 +342,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -336,7 +353,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -345,7 +364,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -398,7 +419,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -407,7 +430,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -416,7 +441,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -569,7 +596,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: req.ToVectorisedView(), + }) aep, _, err := ep.Accept() if err != nil || aep == nil { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c021c67ac..bd5eb89ca 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -57,6 +57,9 @@ type Error struct { // String implements fmt.Stringer.String. func (e *Error) String() string { + if e == nil { + return "<nil>" + } return e.msg } @@ -228,6 +231,13 @@ func (s *Subnet) Broadcast() Address { return Address(addr) } +// Equal returns true if s equals o. +// +// Needed to use cmp.Equal on Subnet as its fields are unexported. +func (s Subnet) Equal(o Subnet) bool { + return s == o +} + // NICID is a number that uniquely identifies a NIC. type NICID int32 @@ -252,7 +262,7 @@ type FullAddress struct { // This may not be used by all endpoint types. NIC NICID - // Addr is the network address. + // Addr is the network or link layer address. Addr Address // Port is the transport port. @@ -426,6 +436,26 @@ type Endpoint interface { // IPTables returns the iptables for this endpoint's stack. IPTables() (iptables.IPTables, error) + + // Info returns a copy to the transport endpoint info. + Info() EndpointInfo + + // Stats returns a reference to the endpoint stats. + Stats() EndpointStats +} + +// EndpointInfo is the interface implemented by each endpoint info struct. +type EndpointInfo interface { + // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo + // marker interface. + IsEndpointInfo() +} + +// EndpointStats is the interface implemented by each endpoint stats struct. +type EndpointStats interface { + // IsEndpointStats is an empty method to implement the tcpip.EndpointStats + // marker interface. + IsEndpointStats() } // WriteOptions contains options for Endpoint.Write. @@ -466,6 +496,11 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption + // DelayOption is used by SetSockOpt/GetSockOpt to specify if data + // should be sent out immediately by the transport protocol. For TCP, + // it determines if the Nagle algorithm is on or off. + DelayOption + // TODO(b/137664753): convert all int socket options to be handled via // GetSockOptInt. ) @@ -478,11 +513,6 @@ type ErrorOption struct{} // socket is to be restricted to sending and receiving IPv6 packets only. type V6OnlyOption int -// DelayOption is used by SetSockOpt/GetSockOpt to specify if data should be -// sent out immediately by the transport protocol. For TCP, it determines if the -// Nagle algorithm is on or off. -type DelayOption int - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be // held until segments are full by the TCP transport protocol. type CorkOption int @@ -495,6 +525,10 @@ type ReuseAddressOption int // to be bound to an identical socket address. type ReusePortOption int +// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets +// should bind only on a specific NIC. +type BindToDeviceOption string + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int @@ -546,6 +580,22 @@ type ModerateReceiveBufferOption bool // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. type MaxSegOption int +// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop +// limit value for unicast messages. The default is protocol specific. +// +// A zero value indicates the default. +type TTLOption uint8 + +// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state +// before being marked closed. +type TCPLingerTimeoutOption time.Duration + +// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TIME_WAIT state +// before being marked closed. +type TCPTimeWaitTimeoutOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 @@ -587,6 +637,18 @@ type OutOfBandInlineOption int // datagram sockets are allowed to send packets to a broadcast address. type BroadcastOption int +// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify +// a default TTL. +type DefaultTTLOption uint8 + +// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS +// for all subsequent outgoing IPv4 packets from the endpoint. +type IPv4TOSOption uint8 + +// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS +// for all subsequent outgoing IPv6 packets from the endpoint. +type IPv6TrafficClassOption uint8 + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. @@ -628,6 +690,11 @@ func (s *StatCounter) Increment() { s.IncrementBy(1) } +// Decrement minuses one to the counter. +func (s *StatCounter) Decrement() { + s.IncrementBy(^uint64(0)) +} + // Value returns the current value of the counter. func (s *StatCounter) Value() uint64 { return atomic.LoadUint64(&s.count) @@ -816,6 +883,14 @@ type IPStats struct { // OutgoingPacketErrors is the total number of IP packets which failed // to write to a link-layer endpoint. OutgoingPacketErrors *StatCounter + + // MalformedPacketsReceived is the total number of IP Packets that were + // dropped due to the IP packet header failing validation checks. + MalformedPacketsReceived *StatCounter + + // MalformedFragmentsReceived is the total number of IP Fragments that were + // dropped due to the fragment failing validation checks. + MalformedFragmentsReceived *StatCounter } // TCPStats collects TCP-specific stats. @@ -828,6 +903,15 @@ type TCPStats struct { // successfully via Listen. PassiveConnectionOpenings *StatCounter + // CurrentEstablished is the number of TCP connections for which the + // current state is either ESTABLISHED or CLOSE-WAIT. + CurrentEstablished *StatCounter + + // EstablishedResets is the number of times TCP connections have made + // a direct transition to the CLOSED state from either the + // ESTABLISHED state or the CLOSE-WAIT state. + EstablishedResets *StatCounter + // ListenOverflowSynDrop is the number of times the listen queue overflowed // and a SYN was dropped. ListenOverflowSynDrop *StatCounter @@ -862,6 +946,9 @@ type TCPStats struct { // SegmentsSent is the number of TCP segments sent. SegmentsSent *StatCounter + // SegmentSendErrors is the number of TCP segments failed to be sent. + SegmentSendErrors *StatCounter + // ResetsSent is the number of TCP resets sent. ResetsSent *StatCounter @@ -914,6 +1001,9 @@ type UDPStats struct { // PacketsSent is the number of UDP datagrams sent via sendUDP. PacketsSent *StatCounter + + // PacketSendErrors is the number of datagrams failed to be sent. + PacketSendErrors *StatCounter } // Stats holds statistics about the networking stack. @@ -924,7 +1014,7 @@ type Stats struct { // stack that were for an unknown or unsupported protocol. UnknownProtocolRcvdPackets *StatCounter - // MalformedRcvPackets is the number of packets received by the stack + // MalformedRcvdPackets is the number of packets received by the stack // that were deemed malformed. MalformedRcvdPackets *StatCounter @@ -944,18 +1034,95 @@ type Stats struct { UDP UDPStats } +// ReceiveErrors collects packet receive errors within transport endpoint. +type ReceiveErrors struct { + // ReceiveBufferOverflow is the number of received packets dropped + // due to the receive buffer being full. + ReceiveBufferOverflow StatCounter + + // MalformedPacketsReceived is the number of incoming packets + // dropped due to the packet header being in a malformed state. + MalformedPacketsReceived StatCounter + + // ClosedReceiver is the number of received packets dropped because + // of receiving endpoint state being closed. + ClosedReceiver StatCounter +} + +// SendErrors collects packet send errors within the transport layer for +// an endpoint. +type SendErrors struct { + // SendToNetworkFailed is the number of packets failed to be written to + // the network endpoint. + SendToNetworkFailed StatCounter + + // NoRoute is the number of times we failed to resolve IP route. + NoRoute StatCounter + + // NoLinkAddr is the number of times we failed to resolve ARP. + NoLinkAddr StatCounter +} + +// ReadErrors collects segment read errors from an endpoint read call. +type ReadErrors struct { + // ReadClosed is the number of received packet drops because the endpoint + // was shutdown for read. + ReadClosed StatCounter + + // InvalidEndpointState is the number of times we found the endpoint state + // to be unexpected. + InvalidEndpointState StatCounter +} + +// WriteErrors collects packet write errors from an endpoint write call. +type WriteErrors struct { + // WriteClosed is the number of packet drops because the endpoint + // was shutdown for write. + WriteClosed StatCounter + + // InvalidEndpointState is the number of times we found the endpoint state + // to be unexpected. + InvalidEndpointState StatCounter + + // InvalidArgs is the number of times invalid input arguments were + // provided for endpoint Write call. + InvalidArgs StatCounter +} + +// TransportEndpointStats collects statistics about the endpoint. +type TransportEndpointStats struct { + // PacketsReceived is the number of successful packet receives. + PacketsReceived StatCounter + + // PacketsSent is the number of successful packet sends. + PacketsSent StatCounter + + // ReceiveErrors collects packet receive errors within transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects packet read errors from an endpoint read call. + ReadErrors ReadErrors + + // SendErrors collects packet send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects packet write errors from an endpoint write call. + WriteErrors WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*TransportEndpointStats) IsEndpointStats() {} + func fillIn(v reflect.Value) { for i := 0; i < v.NumField(); i++ { v := v.Field(i) - switch v.Kind() { - case reflect.Ptr: - if s := v.Addr().Interface().(**StatCounter); *s == nil { - *s = &StatCounter{} + if s, ok := v.Addr().Interface().(**StatCounter); ok { + if *s == nil { + *s = new(StatCounter) } - case reflect.Struct: + } else { fillIn(v) - default: - panic(fmt.Sprintf("unexpected type %s", v.Type())) } } } @@ -966,6 +1133,26 @@ func (s Stats) FillIn() Stats { return s } +// Clone returns a copy of the TransportEndpointStats by atomically reading +// each field. +func (src *TransportEndpointStats) Clone() TransportEndpointStats { + var dst TransportEndpointStats + clone(reflect.ValueOf(&dst).Elem(), reflect.ValueOf(src).Elem()) + return dst +} + +func clone(dst reflect.Value, src reflect.Value) { + for i := 0; i < dst.NumField(); i++ { + d := dst.Field(i) + s := src.Field(i) + if c, ok := s.Addr().Interface().(*StatCounter); ok { + d.Addr().Interface().(*StatCounter).IncrementBy(c.Value()) + } else { + clone(d, s) + } + } +} + // String implements the fmt.Stringer interface. func (a Address) String() string { switch len(a) { @@ -1091,6 +1278,47 @@ func (a AddressWithPrefix) String() string { return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen) } +// Subnet converts the address and prefix into a Subnet value and returns it. +func (a AddressWithPrefix) Subnet() Subnet { + addrLen := len(a.Address) + if a.PrefixLen <= 0 { + return Subnet{ + address: Address(strings.Repeat("\x00", addrLen)), + mask: AddressMask(strings.Repeat("\x00", addrLen)), + } + } + if a.PrefixLen >= addrLen*8 { + return Subnet{ + address: a.Address, + mask: AddressMask(strings.Repeat("\xff", addrLen)), + } + } + + sa := make([]byte, addrLen) + sm := make([]byte, addrLen) + n := uint(a.PrefixLen) + for i := 0; i < addrLen; i++ { + if n >= 8 { + sa[i] = a.Address[i] + sm[i] = 0xff + n -= 8 + continue + } + sm[i] = ^byte(0xff >> n) + sa[i] = a.Address[i] & sm[i] + n = 0 + } + + // For extra caution, call NewSubnet rather than directly creating the Subnet + // value. If that fails it indicates a serious bug in this code, so panic is + // in order. + s, err := NewSubnet(Address(sa), AddressMask(sm)) + if err != nil { + panic("invalid subnet: " + err.Error()) + } + return s +} + // ProtocolAddress is an address and the network protocol it is associated // with. type ProtocolAddress struct { @@ -1111,8 +1339,8 @@ var ( // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { - es := make([]Endpoint, 0, len(danglingEndpoints)) danglingEndpointsMu.Lock() + es := make([]Endpoint, 0, len(danglingEndpoints)) for e := range danglingEndpoints { es = append(es, e) } diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index fb3a0a5ee..8c0aacffa 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -195,3 +195,34 @@ func TestStatsString(t *testing.T) { t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) } } + +func TestAddressWithPrefixSubnet(t *testing.T) { + tests := []struct { + addr Address + prefixLen int + subnetAddr Address + subnetMask AddressMask + }{ + {"\xaa\x55\x33\x42", -1, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 0, "\x00\x00\x00\x00", "\x00\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 1, "\x80\x00\x00\x00", "\x80\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 7, "\xaa\x00\x00\x00", "\xfe\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 8, "\xaa\x00\x00\x00", "\xff\x00\x00\x00"}, + {"\xaa\x55\x33\x42", 24, "\xaa\x55\x33\x00", "\xff\xff\xff\x00"}, + {"\xaa\x55\x33\x42", 31, "\xaa\x55\x33\x42", "\xff\xff\xff\xfe"}, + {"\xaa\x55\x33\x42", 32, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + {"\xaa\x55\x33\x42", 33, "\xaa\x55\x33\x42", "\xff\xff\xff\xff"}, + } + for _, tt := range tests { + ap := AddressWithPrefix{Address: tt.addr, PrefixLen: tt.prefixLen} + gotSubnet := ap.Subnet() + wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) + if err != nil { + t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + continue + } + if gotSubnet != wantSubnet { + t.Errorf("got subnet = %q, want = %q", gotSubnet, wantSubnet) + } + } +} diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index a52262e87..48764b978 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.9 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index d78a162b8..9254c3dea 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "icmp_packet_list", out = "icmp_packet_list.go", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a111fdb2a..70e008d36 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -31,9 +31,6 @@ type icmpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } type endpointState int @@ -52,12 +49,13 @@ const ( // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber - transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -73,30 +71,32 @@ type endpoint struct { sndBufSize int // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags - id stack.TransportEndpointID state endpointState - // bindNICID and bindAddr are set via calls to Bind(). They are used to - // reject attempts to send data or connect via a different NIC or - // address - bindNICID tcpip.NICID - bindAddr tcpip.Address - // regNICID is the default NIC to be used when callers don't specify a - // NIC. - regNICID tcpip.NICID - route stack.Route `state:"manual"` -} - -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + route stack.Route `state:"manual"` + ttl uint8 + stats tcpip.TransportEndpointStats `state:"nosave"` +} + +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return &endpoint{ - stack: stack, - netProto: netProto, - transProto: transProto, + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + state: stateInitial, + uniqueID: s.UniqueID(), }, nil } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -104,7 +104,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -143,6 +143,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess if e.rcvList.Empty() { err := tcpip.ErrWouldBlock if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() @@ -205,6 +206,29 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -254,13 +278,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC - if e.bindNICID != 0 { - if nicid != 0 && nicid != e.bindNICID { + nicID := to.NIC + if e.BindNICID != 0 { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.bindNICID + nicID = e.BindNICID } toCopy := *to @@ -271,7 +295,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -294,12 +318,12 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, err } - switch e.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - err = send4(route, e.id.LocalPort, v) + err = send4(route, e.ID.LocalPort, v, e.ttl) case header.IPv6ProtocolNumber: - err = send6(route, e.id.LocalPort, v) + err = send6(route, e.ID.LocalPort, v, e.ttl) } if err != nil { @@ -314,8 +338,15 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(o) + e.mu.Unlock() + } + return nil } @@ -362,12 +393,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = 0 return nil + case *tcpip.TTLOption: + e.rcvMu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.rcvMu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } } -func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } @@ -389,10 +426,13 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) } -func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return tcpip.ErrInvalidEndpointState } @@ -409,21 +449,24 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - icmpv6.SetChecksum(0) - icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) + dataVV := data.ToVectorisedView() + icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, hdr, dataVV, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if header.IsV4MappedAddress(addr.Addr) { return 0, tcpip.ErrNoRoute } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -440,20 +483,20 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC localPort := uint16(0) switch e.state { case stateBound, stateConnected: - localPort = e.id.LocalPort - if e.bindNICID == 0 { + localPort = e.ID.LocalPort + if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.bindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.bindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -464,7 +507,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -481,14 +524,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } - e.id = id + e.ID = id e.route = r.Clone() - e.regNICID = nicid + e.RegisterNICID = nicID e.state = stateConnected @@ -539,18 +582,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil @@ -597,8 +640,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return err } - e.id = id - e.regNICID = addr.NIC + e.ID = id + e.RegisterNICID = addr.NIC // Mark endpoint as bound. e.state = stateBound @@ -621,8 +664,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { return err } - e.bindNICID = addr.NIC - e.bindAddr = addr.Addr + e.BindNICID = addr.NIC + e.BindAddr = addr.Addr return nil } @@ -633,9 +676,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + NIC: e.RegisterNICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, nil } @@ -649,9 +692,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, nil } @@ -675,19 +718,21 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Only accept echo replies. - switch e.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(vv.First()) + h := header.ICMPv4(pkt.Data.First()) if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(vv.First()) + h := header.ICMPv6(pkt.Data.First()) if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } } @@ -695,31 +740,39 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv e.rcvMu.Lock() // Drop the packet if our buffer is currently full. - if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &icmpPacket{ + packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, }, } - pkt.data = vv.Clone(pkt.views[:]) + packet.data = pkt.Data - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() - + e.stats.PacketsReceived.Increment() // Notify any waiters that there's data to be read now. if wasEmpty { e.waiterQueue.Notify(waiter.EventIn) @@ -727,7 +780,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't @@ -735,3 +788,20 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C func (e *endpoint) State() uint32 { return 0 } + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index c587b96b6..9d263c0ec 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -76,19 +76,19 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) if err != nil { panic(err) } - e.id.LocalAddress = e.route.LocalAddress - } else if len(e.id.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { + e.ID.LocalAddress = e.route.LocalAddress + } else if len(e.ID.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } } - e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) + e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index bfb16f7c3..9ce500e80 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD new file mode 100644 index 000000000..8ea2e6ee5 --- /dev/null +++ b/pkg/tcpip/transport/packet/BUILD @@ -0,0 +1,46 @@ +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_template_instance( + name = "packet_list", + out = "packet_list.go", + package = "packet", + prefix = "packet", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*packet", + "Linker": "*packet", + }, +) + +go_library( + name = "packet", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "packet_list.go", + ], + importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", + imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/sleep", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/iptables", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) + +filegroup( + name = "autogen", + srcs = [ + "packet_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go new file mode 100644 index 000000000..0010b5e5f --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -0,0 +1,362 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packet provides the implementation of packet sockets (see +// packet(7)). Packet sockets allow applications to: +// +// * manually write and inspect link, network, and transport headers +// * receive all traffic of a given network protocol, or all protocols +// +// Packet sockets are similar to raw sockets, but provide even more power to +// users, letting them effectively talk directly to the network device. +// +// Packet sockets skip the input and output iptables chains. +package packet + +import ( + "sync" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// +stateify savable +type packet struct { + packetEntry + // data holds the actual packet data, including any headers and + // payload. + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // timestampNS is the unix time at which the packet was received. + timestampNS int64 + // senderAddr is the network address of the sender. + senderAddr tcpip.FullAddress +} + +// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal +// to have goroutines make concurrent calls into the endpoint. +// +// Lock order: +// endpoint.mu +// endpoint.rcvMu +// +// +stateify savable +type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + cooked bool + + // The following fields are used to manage the receive queue and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` +} + +// NewEndpoint returns a new packet endpoint. +func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + }, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } + + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { + return nil, err + } + return ep, nil +} + +// Close implements tcpip.Endpoint.Close. +func (ep *endpoint) Close() { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return + } + + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + + // Clear the receive list. + ep.rcvClosed = true + ep.rcvBufSize = 0 + for !ep.rcvList.Empty() { + ep.rcvList.Remove(ep.rcvList.Front()) + } + + ep.closed = true + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (ep *endpoint) ModerateRecvBuf(copied int) {} + +// IPTables implements tcpip.Endpoint.IPTables. +func (ep *endpoint) IPTables() (iptables.IPTables, error) { + return ep.stack.IPTables(), nil +} + +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + ep.rcvMu.Lock() + + // If there's no data to read, return that read would block or that the + // endpoint is closed. + if ep.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if ep.rcvClosed { + ep.stats.ReadErrors.ReadClosed.Increment() + err = tcpip.ErrClosedForReceive + } + ep.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + packet := ep.rcvList.Front() + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + + ep.rcvMu.Unlock() + + if addr != nil { + *addr = packet.senderAddr + } + + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil +} + +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // TODO(b/129292371): Implement. + return 0, nil, tcpip.ErrInvalidOptionValue +} + +// Peek implements tcpip.Endpoint.Peek. +func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be +// disconnected, and this function always returns tpcip.ErrNotSupported. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be +// connected, and this function always returnes tcpip.ErrNotSupported. +func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used +// with Shutdown, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with +// Listen, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Listen(backlog int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with +// Accept, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// Bind implements tcpip.Endpoint.Bind. +func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + // TODO(gvisor.dev/issue/173): Add Bind support. + + // "By default, all packets of the specified protocol type are passed + // to a packet socket. To get packets only from a specific interface + // use bind(2) specifying an address in a struct sockaddr_ll to bind + // the packet socket to an interface. Fields used for binding are + // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." + // - packet(7). + + return tcpip.ErrNotSupported +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, tcpip.ErrNotSupported +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + // Even a connected socket doesn't return a remote address. + return tcpip.FullAddress{}, tcpip.ErrNotConnected +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine whether the endpoint is readable. + if (mask & waiter.EventIn) != 0 { + ep.rcvMu.Lock() + if !ep.rcvList.Empty() || ep.rcvClosed { + result |= waiter.EventIn + } + ep.rcvMu.Unlock() + } + + return result +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be +// used with SetSockOpt, and this function always returns +// tcpip.ErrNotSupported. +func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// HandlePacket implements stack.PacketEndpoint.HandlePacket. +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + ep.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if ep.rcvClosed { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if ep.rcvBufSize >= ep.rcvBufSizeMax { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + wasEmpty := ep.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + var packet packet + // TODO(b/129292371): Return network protocol. + if len(pkt.LinkHeader) > 0 { + // Get info directly from the ethernet header. + hdr := header.Ethernet(pkt.LinkHeader) + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(hdr.SourceAddress()), + } + } else { + // Guess the would-be ethernet header. + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(localAddr), + } + } + + if ep.cooked { + // Cooked packets can simply be queued. + packet.data = pkt.Data + } else { + // Raw packets need their ethernet headers prepended before + // queueing. + var linkHeader buffer.View + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + } + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } + packet.timestampNS = ep.stack.NowNanoseconds() + + ep.rcvList.PushBack(&packet) + ep.rcvBufSize += packet.data.Size() + + ep.rcvMu.Unlock() + ep.stats.PacketsReceived.Increment() + // Notify waiters that there's data to be read. + if wasEmpty { + ep.waiterQueue.Notify(waiter.EventIn) + } +} + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} + +// Info returns a copy of the endpoint info. +func (ep *endpoint) Info() tcpip.EndpointInfo { + ep.mu.RLock() + // Make a copy of the endpoint info. + ret := ep.TransportEndpointInfo + ep.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (ep *endpoint) Stats() tcpip.EndpointStats { + return &ep.stats +} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go new file mode 100644 index 000000000..9b88f17e4 --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -0,0 +1,72 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// saveData saves packet.data field. +func (p *packet) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads packet.data field. +func (p *packet) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (ep *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after saveRcvBufSizeMax(), which would have + // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + ep.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) saveRcvBufSizeMax() int { + max := ep.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + ep.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + ep.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) loadRcvBufSizeMax(max int) { + ep.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (ep *endpoint) afterLoad() { + // StackFromEnv is a stack used specifically for save/restore. + ep.stack = stack.StackFromEnv + + // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. + if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 7241f6c19..4af49218c 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -1,17 +1,17 @@ -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( - name = "packet_list", - out = "packet_list.go", + name = "raw_packet_list", + out = "raw_packet_list.go", package = "raw", - prefix = "packet", + prefix = "rawPacket", template = "//pkg/ilist:generic_list", types = { - "Element": "*packet", - "Linker": "*packet", + "Element": "*rawPacket", + "Linker": "*rawPacket", }, ) @@ -20,8 +20,8 @@ go_library( srcs = [ "endpoint.go", "endpoint_state.go", - "packet_list.go", "protocol.go", + "raw_packet_list.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], @@ -34,6 +34,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/packet", "//pkg/waiter", ], ) @@ -41,7 +42,7 @@ go_library( filegroup( name = "autogen", srcs = [ - "packet_list.go", + "raw_packet_list.go", ], visibility = ["//:sandbox"], ) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index a02731a5d..230a1537a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -17,8 +17,7 @@ // // * manually write and inspect transport layer headers and payloads // * receive all traffic of a given transport protocol (e.g. ICMP or UDP) -// * optionally write and inspect network layer and link layer headers for -// packets +// * optionally write and inspect network layer headers of packets // // Raw sockets don't have any notion of ports, and incoming packets are // demultiplexed solely by protocol number. Thus, a raw UDP endpoint will @@ -38,15 +37,11 @@ import ( ) // +stateify savable -type packet struct { - packetEntry +type rawPacket struct { + rawPacketEntry // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -62,18 +57,17 @@ type packet struct { // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber - transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue associated bool // The following fields are used to manage the receive queue and are // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` - rcvList packetList + rcvList rawPacketList rcvBufSizeMax int `state:".(int)"` rcvBufSize int rcvClosed bool @@ -84,35 +78,28 @@ type endpoint struct { closed bool connected bool bound bool - // registeredNIC is the NIC to which th endpoint is explicitly - // registered. Is set when Connect or Bind are used to specify a NIC. - registeredNIC tcpip.NICID - // boundNIC and boundAddr are set on calls to Bind(). When callers - // attempt actions that would invalidate the binding data (e.g. sending - // data via a NIC other than boundNIC), the endpoint will return an - // error. - boundNIC tcpip.NICID - boundAddr tcpip.Address // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. - route stack.Route `state:"manual"` + route stack.Route `state:"manual"` + stats tcpip.TransportEndpointStats `state:"nosave"` } // NewEndpoint returns a raw endpoint for the given protocols. -// TODO(b/129292371): IP_HDRINCL and AF_PACKET. func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if netProto != header.IPv4ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } - ep := &endpoint{ - stack: stack, - netProto: netProto, - transProto: transProto, + e := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, @@ -123,114 +110,139 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - ep.rcvBufSizeMax = 0 - ep.waiterQueue = nil - return ep, nil + e.rcvBufSizeMax = 0 + e.waiterQueue = nil + return e, nil } - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { return nil, err } - return ep, nil + return e, nil } // Close implements tcpip.Endpoint.Close. -func (ep *endpoint) Close() { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() - if ep.closed || !ep.associated { + if e.closed || !e.associated { return } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) - ep.rcvMu.Lock() - defer ep.rcvMu.Unlock() + e.rcvMu.Lock() + defer e.rcvMu.Unlock() // Clear the receive list. - ep.rcvClosed = true - ep.rcvBufSize = 0 - for !ep.rcvList.Empty() { - ep.rcvList.Remove(ep.rcvList.Front()) + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + e.rcvList.Remove(e.rcvList.Front()) } - if ep.connected { - ep.route.Release() - ep.connected = false + if e.connected { + e.route.Release() + e.connected = false } - ep.closed = true + e.closed = true - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (iptables.IPTables, error) { - return ep.stack.IPTables(), nil +func (e *endpoint) IPTables() (iptables.IPTables, error) { + return e.stack.IPTables(), nil } // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !ep.associated { +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if !e.associated { return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue } - ep.rcvMu.Lock() + e.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. - if ep.rcvList.Empty() { + if e.rcvList.Empty() { err := tcpip.ErrWouldBlock - if ep.rcvClosed { + if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return buffer.View{}, tcpip.ControlMessages{}, err } - packet := ep.rcvList.Front() - ep.rcvList.Remove(packet) - ep.rcvBufSize -= packet.data.Size() + pkt := e.rcvList.Front() + e.rcvList.Remove(pkt) + e.rcvBufSize -= pkt.data.Size() - ep.rcvMu.Unlock() + e.rcvMu.Unlock() if addr != nil { - *addr = packet.senderAddr + *addr = pkt.senderAddr } - return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil + return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil } // Write implements tcpip.Endpoint.Write. -func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - ep.mu.RLock() + e.mu.RLock() - if ep.closed { - ep.mu.RUnlock() + if e.closed { + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { + e.mu.RUnlock() return 0, nil, err } // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !ep.associated { + if !e.associated { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -251,66 +263,66 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <- if opts.To == nil { // If the user doesn't specify a destination, they should have // connected to another address. - if !ep.connected { - ep.mu.RUnlock() + if !e.connected { + e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if ep.route.IsResolutionRequired() { - savedRoute := &ep.route + if e.route.IsResolutionRequired() { + savedRoute := &e.route // Promote lock to exclusive if using a shared route, // given that it may need to change in finishWrite. - ep.mu.RUnlock() - ep.mu.Lock() + e.mu.RUnlock() + e.mu.Lock() // Make sure that the route didn't change during the // time we didn't hold the lock. - if !ep.connected || savedRoute != &ep.route { - ep.mu.Unlock() + if !e.connected || savedRoute != &e.route { + e.mu.Unlock() return 0, nil, tcpip.ErrInvalidEndpointState } - n, ch, err := ep.finishWrite(payloadBytes, savedRoute) - ep.mu.Unlock() + n, ch, err := e.finishWrite(payloadBytes, savedRoute) + e.mu.Unlock() return n, ch, err } - n, ch, err := ep.finishWrite(payloadBytes, &ep.route) - ep.mu.RUnlock() + n, ch, err := e.finishWrite(payloadBytes, &e.route) + e.mu.RUnlock() return n, ch, err } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC - if ep.bound && nic != 0 && nic != ep.boundNIC { - ep.mu.RUnlock() + if e.bound && nic != 0 && nic != e.BindNICID { + e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } // We don't support IPv6 yet, so this has to be an IPv4 address. if len(opts.To.Addr) != header.IPv4AddressSize { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } - // Find the route to the destination. If boundAddress is 0, + // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. - route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false) + route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - ep.mu.RUnlock() + e.mu.RUnlock() return 0, nil, err } - n, ch, err := ep.finishWrite(payloadBytes, &route) + n, ch, err := e.finishWrite(payloadBytes, &route) route.Release() - ep.mu.RUnlock() + e.mu.RUnlock() return n, ch, err } // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -323,16 +335,16 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - switch ep.netProto { + switch e.NetProto { case header.IPv4ProtocolNumber: - if !ep.associated { + if !e.associated { if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { return 0, nil, err } break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil { + if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { return 0, nil, err } @@ -344,7 +356,7 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -354,11 +366,11 @@ func (*endpoint) Disconnect() *tcpip.Error { } // Connect implements tcpip.Endpoint.Connect. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() - if ep.closed { + if e.closed { return tcpip.ErrInvalidEndpointState } @@ -368,15 +380,15 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } nic := addr.NIC - if ep.bound { - if ep.boundNIC == 0 { + if e.bound { + if e.BindNICID == 0 { // If we're bound, but not to a specific NIC, the NIC // in addr will be used. Nothing to do here. } else if addr.NIC == 0 { // If we're bound to a specific NIC, but addr doesn't // specify a NIC, use the bound NIC. - nic = ep.boundNIC - } else if addr.NIC != ep.boundNIC { + nic = e.BindNICID + } else if addr.NIC != e.BindNICID { // We're bound and addr specifies a NIC. They must be // the same. return tcpip.ErrInvalidEndpointState @@ -384,53 +396,53 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the destination. - route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false) + route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) if err != nil { return err } defer route.Release() - if ep.associated { + if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { return err } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = nic + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = nic } // Save the route we've connected via. - ep.route = route.Clone() - ep.connected = true + e.route = route.Clone() + e.connected = true return nil } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() - if !ep.connected { + if !e.connected { return tcpip.ErrNotConnected } return nil } // Listen implements tcpip.Endpoint.Listen. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } // Bind implements tcpip.Endpoint.Bind. -func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { - ep.mu.Lock() - defer ep.mu.Unlock() +func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() // Callers must provide an IPv4 address or no network address (for // binding to a NIC, but not an address). @@ -439,56 +451,56 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // If a local address was specified, verify that it's valid. - if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 { + if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } - if ep.associated { + if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { return err } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = addr.NIC - ep.boundNIC = addr.NIC + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = addr.NIC + e.BindNICID = addr.NIC } - ep.boundAddr = addr.Addr - ep.bound = true + e.BindAddr = addr.Addr + e.bound = true return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } // Readiness implements tcpip.Endpoint.Readiness. -func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // The endpoint is always writable. result := waiter.EventOut & mask // Determine whether the endpoint is readable. if (mask & waiter.EventIn) != 0 { - ep.rcvMu.Lock() - if !ep.rcvList.Empty() || ep.rcvClosed { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { result |= waiter.EventIn } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() } return result } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -498,28 +510,28 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 - ep.rcvMu.Lock() - if !ep.rcvList.Empty() { - p := ep.rcvList.Front() + e.rcvMu.Lock() + if !e.rcvList.Empty() { + p := e.rcvList.Front() v = p.data.Size() } - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return v, nil case tcpip.SendBufferSizeOption: - ep.mu.Lock() - v := ep.sndBufSize - ep.mu.Unlock() + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() return v, nil case tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - v := ep.rcvBufSizeMax - ep.rcvMu.Unlock() + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() return v, nil } @@ -528,7 +540,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: return nil @@ -543,63 +555,89 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { - ep.rcvMu.Lock() +func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. - if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax { - ep.stack.Stats().DroppedPackets.Increment() - ep.rcvMu.Unlock() + if e.rcvClosed { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() return } - if ep.bound { + if e.rcvBufSize >= e.rcvBufSizeMax { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + if e.bound { // If bound to a NIC, only accept data for that NIC. - if ep.boundNIC != 0 && ep.boundNIC != route.NICID() { - ep.rcvMu.Unlock() + if e.BindNICID != 0 && e.BindNICID != route.NICID() { + e.rcvMu.Unlock() return } // If bound to an address, only accept data for that address. - if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress { - ep.rcvMu.Unlock() + if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + e.rcvMu.Unlock() return } } // If connected, only accept packets from the remote address we // connected to. - if ep.connected && ep.route.RemoteAddress != route.RemoteAddress { - ep.rcvMu.Unlock() + if e.connected && e.route.RemoteAddress != route.RemoteAddress { + e.rcvMu.Unlock() return } - wasEmpty := ep.rcvBufSize == 0 + wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - packet := &packet{ + packet := &rawPacket{ senderAddr: tcpip.FullAddress{ NIC: route.NICID(), Addr: route.RemoteAddress, }, } - combinedVV := netHeader.ToVectorisedView() - combinedVV.Append(vv) - packet.data = combinedVV.Clone(packet.views[:]) - packet.timestampNS = ep.stack.NowNanoseconds() - - ep.rcvList.PushBack(packet) - ep.rcvBufSize += packet.data.Size() + networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) + combinedVV := networkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + packet.timestampNS = e.stack.NowNanoseconds() - ep.rcvMu.Unlock() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() + e.rcvMu.Unlock() + e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. if wasEmpty { - ep.waiterQueue.Notify(waiter.EventIn) + e.waiterQueue.Notify(waiter.EventIn) } } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (e *endpoint) State() uint32 { return 0 } + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 168953dec..33bfb56cd 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -20,15 +20,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// saveData saves packet.data field. -func (p *packet) saveData() buffer.VectorisedView { +// saveData saves rawPacket.data field. +func (p *rawPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, // which is not allowed by state framework (in-struct pointer). return p.data.Clone(nil) } -// loadData loads packet.data field. -func (p *packet) loadData(data buffer.VectorisedView) { +// loadData loads rawPacket.data field. +func (p *rawPacket) loadData(data buffer.VectorisedView) { // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization // here because data.views is not guaranteed to be loaded by now. Plus, // data.views will be allocated anyway so there really is little point @@ -73,7 +73,7 @@ func (ep *endpoint) Resume(s *stack.Stack) { // If the endpoint is connected, re-connect. if ep.connected { var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) if err != nil { panic(err) } @@ -81,12 +81,14 @@ func (ep *endpoint) Resume(s *stack.Stack) { // If the endpoint is bound, re-bind. if ep.bound { - if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { panic(tcpip.ErrBadLocalAddress) } } - if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { - panic(err) + if ep.associated { + if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + panic(err) + } } } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index a2512d666..f30aa2a4a 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -17,13 +17,19 @@ package raw import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/packet" "gvisor.dev/gvisor/pkg/waiter" ) -// EndpointFactory implements stack.UnassociatedEndpointFactory. +// EndpointFactory implements stack.RawFactory. type EndpointFactory struct{} -// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. -func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. +func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } + +// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. +func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 39a839ab7..3f47b328d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "tcp_segment_list", out = "tcp_segment_list.go", @@ -45,10 +44,12 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/seqnum", @@ -70,7 +71,7 @@ filegroup( go_test( name = "tcp_test", - size = "small", + size = "medium", srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 0802e984e..f24b51b91 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -229,7 +230,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n := newEndpoint(l.stack, netProto, nil) n.v6only = l.v6only - n.id = s.id + n.ID = s.id n.boundNICID = s.route.NICID() n.route = s.route.Clone() n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} @@ -242,7 +243,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil { n.Close() return nil, err } @@ -268,8 +269,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) + isn := generateSecureISN(s.id, l.stack.Seed()) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts) if err != nil { return nil, err } @@ -288,9 +289,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - h.resetToSynRcvd(cookie, irs, opts) + h.resetToSynRcvd(isn, irs, opts) if err := h.execute(); err != nil { - ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.Close() if l.listenEP != nil { l.removePendingEndpoint(ep) @@ -298,7 +298,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() + ep.stack.Stats().TCP.CurrentEstablished.Increment() ep.state = StateEstablished + ep.isConnectNotified = true ep.mu.Unlock() // Update the receive window scaling. We can't do it before the @@ -311,14 +313,14 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.id] = n + l.pendingEndpoints[n.ID] = n l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.id) + delete(l.pendingEndpoints, n.ID) l.pending.Done() l.pendingMu.Unlock() } @@ -360,12 +362,19 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer decSynRcvdCount() defer e.decSynRcvdCount() defer s.decRef() + n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } @@ -399,6 +408,17 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + if s.flagsAreSet(header.TCPFlagSyn | header.TCPFlagAck) { + // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST + // must be sent in response to a SYN-ACK while in the listen + // state to prevent completing a handshake from an old SYN. + e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil) + return + } + + // TODO(b/143300739): Use the userMSS of the listening socket + // for accepted sockets. + switch s.flags { case header.TCPFlagSyn: opts := parseSynSegmentOptions(s) @@ -414,6 +434,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { } decSynRcvdCount() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } else { @@ -421,6 +442,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // is full then drop the syn. if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } @@ -431,15 +453,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // // Enable Timestamp option if the original syn did have // the timestamp option specified. - mss := mssForRoute(&s.route) synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, TSVal: tcpTimeStamp(timeStampOffset()), TSEcr: opts.TSVal, - MSS: uint16(mss), + MSS: mssForRoute(&s.route), } - sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } @@ -451,6 +472,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // complete the connection at the time of retransmit if // the backlog has space. e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return } @@ -505,6 +527,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() return } @@ -515,7 +538,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished + n.isConnectNotified = true // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -526,6 +551,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } @@ -536,7 +566,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Lock() v6only := e.v6only e.mu.Unlock() - ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto) defer func() { // Mark endpoint as closed. This will prevent goroutines running diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 21038a65a..a114c06c1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "sync" "time" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -78,9 +80,6 @@ type handshake struct { // mss is the maximum segment size received from the peer. mss uint16 - // amss is the maximum segment size advertised by us to the peer. - amss uint16 - // sndWndScale is the send window scale, as defined in RFC 1323. A // negative value means no scaling is supported by the peer. sndWndScale int @@ -142,7 +141,32 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) } // effectiveRcvWndScale returns the effective receive window scale to be used. @@ -238,6 +262,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.state = handshakeSynRcvd h.ep.mu.Lock() h.ep.state = StateSynRecv + ttl := h.ep.ttl h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), @@ -251,8 +276,10 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { SACKPermitted: rcvSynOpts.SACKPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) - + if ttl == 0 { + ttl = s.route.DefaultTTL() + } + h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -296,7 +323,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -383,6 +410,11 @@ func (h *handshake) resolveRoute() *tcpip.Error { switch index { case wakerForResolution: if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + if err == tcpip.ErrNoLinkAddress { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + } else if err != nil { + h.ep.stats.SendErrors.NoRoute.Increment() + } // Either success (err == nil) or failure. return err } @@ -437,7 +469,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. - h.ep.amss = mssForRoute(&h.ep.route) + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -460,7 +492,8 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + for h.state != handshakeCompleted { switch index, _ := s.Fetch(true); index { case wakerForResend: @@ -469,7 +502,7 @@ func (h *handshake) execute() *tcpip.Error { return tcpip.ErrTimeout } rt.Reset(timeOut) - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) case wakerForNotification: n := h.ep.fetchNotifications() @@ -579,24 +612,30 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { options := makeSynOptions(opts) - err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil) + // We ignore SYN send errors and let the callers re-attempt send. + if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil { + e.stats.SendErrors.SynSendToNetworkFailed.Increment() + } putOptions(options) - return err + return nil } -// sendTCP sends a TCP segment with the provided options via the provided -// network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - optLen := len(opts) - // Allocate a buffer for the TCP header. - hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) - - if rcvWnd > 0xffff { - rcvWnd = 0xffff +func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil { + e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() + return err } + e.stats.SegmentsSent.Increment() + return nil +} +func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDescriptor, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { + optLen := len(opts) + hdr := &d.Hdr + packetSize := d.Size + off := d.Off // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) tcp.Encode(&header.TCPFields{ @@ -610,7 +649,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise }) copy(tcp[header.TCPMinimumSize:], opts) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(hdr.UsedLength() + packetSize) xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { @@ -620,16 +659,79 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVV(data, xsum) + xsum = header.ChecksumVVWithOffset(data, xsum, off, packetSize) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } +} + +func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + mss := int(gso.MSS) + n := (data.Size() + mss - 1) / mss + + hdrs := stack.NewPacketDescriptors(n, header.TCPMinimumSize+int(r.MaxHeaderLength())+optLen) + + size := data.Size() + off := 0 + for i := 0; i < n; i++ { + packetSize := mss + if packetSize > size { + packetSize = size + } + size -= packetSize + hdrs[i].Off = off + hdrs[i].Size = packetSize + buildTCPHdr(r, id, &hdrs[i], data, flags, seq, ack, rcvWnd, opts, gso) + off += packetSize + seq = seq.Add(seqnum.Size(packetSize)) + } + if ttl == 0 { + ttl = r.DefaultTTL() + } + sent, err := r.WritePackets(gso, hdrs, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) + if err != nil { + r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) + } + r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent)) + return err +} + +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) + } + + d := &stack.PacketDescriptor{ + Hdr: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + Off: 0, + Size: data.Size(), + } + buildTCPHdr(r, id, d, data, flags, seq, ack, rcvWnd, opts, gso) + + if ttl == 0 { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(gso, d.Hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + r.Stats().TCP.SegmentSendErrors.Increment() + return err + } r.Stats().TCP.SegmentsSent.Increment() if (flags & header.TCPFlagRst) != 0 { r.Stats().TCP.ResetsSent.Increment() } - - return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl) + return nil } // makeOptions makes an options slice. @@ -678,7 +780,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso) + err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso) putOptions(options) return err } @@ -727,10 +829,26 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. + if e.state == StateEstablished || e.state == StateCloseWait { + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } e.state = StateError - e.hardError = err + e.HardError = err if err != tcpip.ErrConnectionReset { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { + resetSeqNum = e.snd.sndNxt + } + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) } } @@ -744,6 +862,51 @@ func (e *endpoint) completeWorkerLocked() { } } +func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + s.decRef() + e.mu.Lock() + switch e.state { + // In case of a RST in CLOSE-WAIT linux moves + // the socket to closed state with an error set + // to indicate EPIPE. + // + // Technically this seems to be at odds w/ RFC. + // As per https://tools.ietf.org/html/rfc793#section-2.7 + // page 69 the behavior for a segment arriving + // w/ RST bit set in CLOSE-WAIT is inlined below. + // + // ESTABLISHED + // FIN-WAIT-1 + // FIN-WAIT-2 + // CLOSE-WAIT + + // If the RST bit is set then, any outstanding RECEIVEs and + // SEND should receive "reset" responses. All segment queues + // should be flushed. Users should also receive an unsolicited + // general "connection reset" signal. Enter the CLOSED state, + // delete the TCB, and return. + case StateCloseWait: + e.state = StateClose + e.HardError = tcpip.ErrAborted + // We need to set this explicitly here because otherwise + // the port registrations will not be released till the + // endpoint is actively closed by the application. + e.workerCleanup = true + e.mu.Unlock() + return false, nil + default: + e.mu.Unlock() + return false, tcpip.ErrConnectionReset + } + } + return true, nil +} + // handleSegments pulls segments from the queue and processes them. It returns // no error if the protocol loop should continue, an error otherwise. func (e *endpoint) handleSegments() *tcpip.Error { @@ -761,14 +924,34 @@ func (e *endpoint) handleSegments() *tcpip.Error { } if s.flagIsSet(header.TCPFlagRst) { - if e.rcv.acceptable(s.sequenceNumber, 0) { - // RFC 793, page 37 states that "in all states - // except SYN-SENT, all reset (RST) segments are - // validated by checking their SEQ-fields." So - // we only process it if it's acceptable. - s.decRef() - return tcpip.ErrConnectionReset + if ok, err := e.handleReset(s); !ok { + return err } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK> + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. @@ -777,7 +960,15 @@ func (e *endpoint) handleSegments() *tcpip.Error { // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment // information." - e.rcv.handleRcvdSegment(s) + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + s.decRef() + return err + } + if drop { + s.decRef() + continue + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -876,7 +1067,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() - // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -897,8 +1087,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateError - e.hardError = err + e.HardError = err // Lock released below. epilogue() @@ -920,6 +1112,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // RTT itself. e.rcvAutoParams.prevCopied = initialRcvWnd e.rcvListMu.Unlock() + e.stack.Stats().TCP.CurrentEstablished.Increment() + e.mu.Lock() + e.state = StateEstablished + e.mu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -927,7 +1123,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -958,7 +1153,13 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { { w: &closeWaker, f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + // This means the socket is being closed due + // to the TCP_FIN_WAIT2 timeout was hit. Just + // mark the socket as closed. + e.mu.Lock() + e.state = StateClose + e.mu.Unlock() + return nil }, }, { @@ -1001,17 +1202,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() } + if n¬ifyClose != 0 && closeTimer == nil { - // Reset the connection 3 seconds after - // the endpoint has been closed. - // - // The timer could fire in background - // when the endpoint is drained. That's - // OK as the loop here will not honor - // the firing until the undrain arrives. - closeTimer = time.AfterFunc(3*time.Second, func() { - closeWaker.Assert() - }) + e.mu.Lock() + if e.state == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { + closeWaker.Assert() + }) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1033,6 +1235,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } } + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + return nil }, }, @@ -1059,15 +1267,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - e.mu.RLock() + e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } - e.mu.RUnlock() // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + + for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() @@ -1083,15 +1292,156 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil } + e.mu.Lock() + } + + state := e.state + e.mu.Unlock() + var reuseTW func() + if state == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() if e.state != StateError { + e.stack.Stats().TCP.EstablishedResets.Increment() + e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateClose } + // Lock released below. epilogue() + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + return nil } + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + tcpEP.enqueueSegment(s) + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + <-e.undrain + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index c54610a87..dfaa4a559 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -42,7 +42,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { } } -func testV4Connect(t *testing.T, c *context.Context) { +func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { // Start connection attempt. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventOut) @@ -55,12 +55,11 @@ func testV4Connect(t *testing.T, c *context.Context) { // Receive SYN packet. b := c.GetPacket() - checker.IPv4(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv4(t, b, synCheckers...) tcp := header.TCP(header.IPv4(b).Payload()) c.IRS = seqnum.Value(tcp.SequenceNumber()) @@ -76,14 +75,13 @@ func testV4Connect(t *testing.T, c *context.Context) { }) // Receive ACK packet. - checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv4(t, c.GetPacket(), ackCheckers...) // Wait for connection to be established. select { @@ -152,7 +150,7 @@ func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { testV4Connect(t, c) } -func testV6Connect(t *testing.T, c *context.Context) { +func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { // Start connection attempt to IPv6 address. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventOut) @@ -165,12 +163,11 @@ func testV6Connect(t *testing.T, c *context.Context) { // Receive SYN packet. b := c.GetV6Packet() - checker.IPv6(t, b, - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagSyn), - ), - ) + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv6(t, b, synCheckers...) tcp := header.TCP(header.IPv6(b).Payload()) c.IRS = seqnum.Value(tcp.SequenceNumber()) @@ -186,14 +183,13 @@ func testV6Connect(t *testing.T, c *context.Context) { }) // Receive ACK packet. - checker.IPv6(t, c.GetV6Packet(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPFlags(header.TCPFlagAck), - checker.SeqNum(uint32(c.IRS)+1), - checker.AckNum(uint32(iss)+1), - ), - ) + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv6(t, c.GetV6Packet(), ackCheckers...) // Wait for connection to be established. select { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 35b489c68..04c92c04c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "fmt" "math" "strings" @@ -26,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -119,6 +121,11 @@ const ( notifyReset notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -170,6 +177,101 @@ type rcvBufAutoTuneParams struct { disabled bool } +// ReceiveErrors collect segment receive errors within transport layer. +type ReceiveErrors struct { + tcpip.ReceiveErrors + + // SegmentQueueDropped is the number of segments dropped due to + // a full segment queue. + SegmentQueueDropped tcpip.StatCounter + + // ChecksumErrors is the number of segments dropped due to bad checksums. + ChecksumErrors tcpip.StatCounter + + // ListenOverflowSynDrop is the number of times the listen queue overflowed + // and a SYN was dropped. + ListenOverflowSynDrop tcpip.StatCounter + + // ListenOverflowAckDrop is the number of times the final ACK + // in the handshake was dropped due to overflow. + ListenOverflowAckDrop tcpip.StatCounter + + // ZeroRcvWindowState is the number of times we advertised + // a zero receive window when rcvList is full. + ZeroRcvWindowState tcpip.StatCounter +} + +// SendErrors collect segment send errors within the transport layer. +type SendErrors struct { + tcpip.SendErrors + + // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent + // to the network endpoint. + SegmentSendToNetworkFailed tcpip.StatCounter + + // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent + // to the network endpoint. + SynSendToNetworkFailed tcpip.StatCounter + + // Retransmits is the number of TCP segments retransmitted. + Retransmits tcpip.StatCounter + + // FastRetransmit is the number of segments retransmitted in fast + // recovery. + FastRetransmit tcpip.StatCounter + + // Timeouts is the number of times the RTO expired. + Timeouts tcpip.StatCounter +} + +// Stats holds statistics about the endpoint. +type Stats struct { + // SegmentsReceived is the number of TCP segments received that + // the transport layer successfully parsed. + SegmentsReceived tcpip.StatCounter + + // SegmentsSent is the number of TCP segments sent. + SegmentsSent tcpip.StatCounter + + // FailedConnectionAttempts is the number of times we saw Connect and + // Accept errors. + FailedConnectionAttempts tcpip.StatCounter + + // ReceiveErrors collects segment receive errors within the + // transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects segment read errors from an endpoint read call. + ReadErrors tcpip.ReadErrors + + // SendErrors collects segment send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects segment write errors from an endpoint write call. + WriteErrors tcpip.WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*Stats) IsEndpointStats() {} + +// EndpointInfo holds useful information about a transport endpoint which +// can be queried by monitoring tools. +// +// +stateify savable +type EndpointInfo struct { + stack.TransportEndpointInfo + + // HardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. HardError is protected by endpoint mu. + HardError *tcpip.Error `state:".(string)"` +} + +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*EndpointInfo) IsEndpointInfo() {} + // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -178,6 +280,8 @@ type rcvBufAutoTuneParams struct { // // +stateify savable type endpoint struct { + EndpointInfo + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -186,9 +290,9 @@ type endpoint struct { // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. - stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber + stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -218,14 +322,19 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` + ttl uint8 v6only bool isConnectNotified bool // TCP should never broadcast but Linux nevertheless supports enabling/ @@ -240,11 +349,6 @@ type endpoint struct { // address). effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` - // hardError is meaningful only when state is stateError, it stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. hardError is protected by mu. - hardError *tcpip.Error `state:".(string)"` - // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -280,6 +384,9 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. + bindToDevice tcpip.NICID + // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -315,7 +422,7 @@ type endpoint struct { // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. - userMSS int + userMSS uint16 // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the @@ -393,13 +500,49 @@ type endpoint struct { probe stack.TCPProbeFunc `state:"nosave"` // The following are only used to assist the restore run to re-connect. - bindAddress tcpip.Address connectingAddress tcpip.Address // amss is the advertised MSS to the peer by this endpoint. amss uint16 + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + gso *stack.GSO + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + maxMSS := mssForRoute(&r) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS } // StopWork halts packet processing. Only to be used in tests. @@ -427,10 +570,15 @@ type keepalive struct { waker sleep.Waker `state:"nosave"` } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + EndpointInfo: EndpointInfo{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + }, waiterQueue: waiterQueue, state: StateInitial, rcvBufSize: DefaultReceiveBufferSize, @@ -443,29 +591,40 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), } var ss SendBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } var rs ReceiveBufferSizeOption - if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } var cs tcpip.CongestionControlOption - if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } var mrb tcpip.ModerateReceiveBufferOption - if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { + if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - if p := stack.GetTCPProbe(); p != nil { + var de DelayEnabled + if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { + e.SetSockOptInt(tcpip.DelayOption, 1) + } + + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -552,6 +711,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.mu.Lock() + closed := e.closed + e.mu.Unlock() + if closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -564,14 +730,16 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) e.isPortReserved = false } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) @@ -597,9 +765,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { go func() { defer close(done) for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() + n.notifyProtocolGoroutine(notifyReset) n.Close() } }() @@ -625,16 +791,17 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) e.isPortReserved = false } e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -645,7 +812,9 @@ func (e *endpoint) initialReceiveWindow() int { if rcvWnd > math.MaxUint16 { rcvWnd = math.MaxUint16 } - routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2 + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 if rcvWnd > routeWnd { rcvWnd = routeWnd } @@ -731,11 +900,12 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, bufUsed := e.rcvBufUsed if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() - he := e.hardError + he := e.HardError e.mu.RUnlock() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } + e.stats.ReadErrors.InvalidEndpointState.Increment() return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -744,6 +914,9 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RUnlock() + if err == tcpip.ErrClosedForReceive { + e.stats.ReadErrors.ReadClosed.Increment() + } return v, tcpip.ControlMessages{}, err } @@ -787,7 +960,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { if !e.state.connected() { switch e.state { case StateError: - return 0, e.hardError + return 0, e.HardError default: return 0, tcpip.ErrClosedForSend } @@ -818,6 +991,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if err != nil { e.sndBufMu.Unlock() e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -852,6 +1026,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if err != nil { e.sndBufMu.Unlock() e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() return 0, nil, err } @@ -863,7 +1038,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Add data to the send queue. - s := newSegmentFromView(&e.route, e.id, v) + s := newSegmentFromView(&e.route, e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) @@ -895,8 +1070,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.state; !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.hardError + return 0, tcpip.ControlMessages{}, e.HardError } + e.stats.ReadErrors.InvalidEndpointState.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState } @@ -905,6 +1081,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro if e.rcvBufUsed == 0 { if e.rcvClosed || !e.state.connected() { + e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -1018,14 +1195,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { e.sndBufMu.Unlock() return nil - default: - return nil - } -} - -// SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { case tcpip.DelayOption: if v == 0 { atomic.StoreUint32(&e.delay, 0) @@ -1037,6 +1206,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + default: + return nil + } +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch v := opt.(type) { case tcpip.CorkOption: if v == 0 { atomic.StoreUint32(&e.cork, 0) @@ -1060,6 +1239,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicID, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicID + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.QuickAckOption: if v == 0 { atomic.StoreUint32(&e.slowAck, 1) @@ -1074,14 +1268,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidOptionValue } e.mu.Lock() - e.userMSS = int(userMSS) + e.userMSS = uint16(userMSS) e.mu.Unlock() e.notifyProtocolGoroutine(notifyMSSChanged) return nil case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -1096,6 +1290,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.v6only = v != 0 return nil + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + return nil + case tcpip.KeepaliveEnabledOption: e.keepalive.Lock() e.keepalive.enabled = v != 0 @@ -1164,6 +1364,45 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Linux returns ENOENT when an invalid congestion // control algorithm is specified. return tcpip.ErrNoSuchFile + + case tcpip.IPv4TOSOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + + case tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1190,6 +1429,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() + case tcpip.SendBufferSizeOption: e.sndBufMu.Lock() v := e.sndBufSize @@ -1202,8 +1442,16 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil + case tcpip.DelayOption: + var o int + if v := atomic.LoadUint32(&e.delay); v != 0 { + o = 1 + } + return o, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -1224,13 +1472,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = header.TCPDefaultMSS return nil - case *tcpip.DelayOption: - *o = 0 - if v := atomic.LoadUint32(&e.delay); v != 0 { - *o = 1 - } - return nil - case *tcpip.CorkOption: *o = 0 if v := atomic.LoadUint32(&e.cork); v != 0 { @@ -1260,6 +1501,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = "" + return nil + case *tcpip.QuickAckOption: *o = 1 if v := atomic.LoadUint32(&e.slowAck); v != 0 { @@ -1269,7 +1520,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -1283,6 +1534,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.TTLOption: + e.mu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.mu.Unlock() + return nil + case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.mu.RLock() @@ -1347,13 +1604,31 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.IPv4TOSOption: + e.mu.RLock() + *o = tcpip.IPv4TOSOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.IPv6TrafficClassOption: + e.mu.RLock() + *o = tcpip.IPv6TrafficClassOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if header.IsV4MappedAddress(addr.Addr) { // Fail if using a v4 mapped address on a v6only endpoint. if e.v6only { @@ -1369,7 +1644,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -1383,7 +1658,12 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - return e.connect(addr, true, true) + err := e.connect(addr, true, true) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err } // connect connects the endpoint to its peer. In the normal non-S/R case, the @@ -1392,14 +1672,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() connectingAddr := addr.Addr @@ -1419,7 +1694,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return tcpip.ErrAlreadyConnected } - nicid := addr.NIC + nicID := addr.NIC switch e.state { case StateBound: // If we're already bound to a NIC but the caller is requesting @@ -1428,11 +1703,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er break } - if nicid != 0 && nicid != e.boundNICID { + if nicID != 0 && nicID != e.boundNICID { return tcpip.ErrNoRoute } - nicid = e.boundNICID + nicID = e.boundNICID case StateInitial: // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) @@ -1444,29 +1719,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return tcpip.ErrAlreadyConnecting case StateError: - return e.hardError + return e.HardError default: return tcpip.ErrInvalidEndpointState } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() - origID := e.id + origID := e.ID netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.id.LocalAddress = r.LocalAddress - e.id.RemoteAddress = r.RemoteAddress - e.id.RemotePort = addr.Port + e.ID.LocalAddress = r.LocalAddress + e.ID.RemoteAddress = r.RemoteAddress + e.ID.RemotePort = addr.Port - if e.id.LocalPort != 0 { + if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1475,20 +1750,35 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // one. Make sure that it isn't one that will result in the same // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). - sameAddr := e.id.LocalAddress == e.id.RemoteAddress - if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - if sameAddr && p == e.id.RemotePort { + sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress + + // Calculate a port offset based on the destination IP/port and + // src IP to ensure that for a given tuple (srcIP, destIP, + // destPort) the offset used as a starting point is the same to + // ensure that we can cycle through the port space effectively. + h := jenkins.Sum32(e.stack.Seed()) + h.Write([]byte(e.ID.LocalAddress)) + h.Write([]byte(e.ID.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) + h.Write(portBuf) + portOffset := h.Sum32() + + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { + if sameAddr && p == e.ID.RemotePort { return false, nil } - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { + // reusePort is false below because connect cannot reuse a port even if + // reusePort was set. + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) { return false, nil } - id := e.id + id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: - e.id = id + e.ID = id return true, nil case tcpip.ErrPortInUse: return false, nil @@ -1504,14 +1794,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) e.isPortReserved = false } e.isRegistered = true e.state = StateConnecting e.route = r.Clone() - e.boundNICID = nicid + e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -1523,7 +1813,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.segmentQueue.mu.Lock() for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { - s.id = e.id + s.id = e.ID s.route = r.Clone() e.sndWaker.Assert() } @@ -1531,6 +1821,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) e.state = StateEstablished + e.stack.Stats().TCP.CurrentEstablished.Increment() } if run { @@ -1551,9 +1842,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() - defer e.mu.Unlock() e.shutdownFlags |= flags - + finQueued := false switch { case e.state.connected(): // Close for read. @@ -1568,6 +1858,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { e.notifyProtocolGoroutine(notifyReset) + e.mu.Unlock() return nil } } @@ -1583,17 +1874,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.id, nil) + s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - + finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() } case e.state == StateListen: @@ -1601,24 +1889,37 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } - default: + e.mu.Unlock() return tcpip.ErrNotConnected } - + e.mu.Unlock() + if finQueued { + if e.workMu.TryLock() { + e.handleClose() + e.workMu.Unlock() + } else { + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + } return nil } // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { +func (e *endpoint) Listen(backlog int) *tcpip.Error { + err := e.listen(backlog) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err +} + +func (e *endpoint) listen(backlog int) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - defer func() { - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() // Allow the backlog to be adjusted if the endpoint is not shutting down. // When the endpoint shuts down, it sets workerCleanup to true, and from @@ -1644,11 +1945,12 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // Endpoint must be bound before it can transition to listen mode. if e.state != StateBound { + e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil { return err } @@ -1692,12 +1994,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrWouldBlock } - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. @@ -1712,7 +2009,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { return tcpip.ErrAlreadyBound } - e.bindAddress = addr.Addr + e.BindAddr = addr.Addr netProto, err := e.checkV4Mapped(&addr) if err != nil { return err @@ -1729,26 +2026,26 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) if err != nil { return err } e.isPortReserved = true e.effectiveNetProtos = netProtos - e.id.LocalPort = port + e.ID.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func() { + defer func(bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil - e.id.LocalPort = 0 - e.id.LocalAddress = "" + e.ID.LocalPort = 0 + e.ID.LocalAddress = "" e.boundNICID = 0 } - }() + }(e.bindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. @@ -1759,7 +2056,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } e.boundNICID = nic - e.id.LocalAddress = addr.Addr + e.ID.LocalAddress = addr.Addr } // Mark endpoint as bound. @@ -1774,8 +2071,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, NIC: e.boundNICID, }, nil } @@ -1790,19 +2087,20 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, NIC: e.boundNICID, }, nil } // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + s := newSegment(r, id, pkt) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() return } @@ -1810,27 +2108,34 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv if !s.csumValid { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() return } e.stack.Stats().TCP.ValidSegmentsReceived.Increment() + e.stats.SegmentsReceived.Increment() if (s.flags & header.TCPFlagRst) != 0 { e.stack.Stats().TCP.ResetsReceived.Increment() } + e.enqueueSegment(s) +} + +func (e *endpoint) enqueueSegment(s *segment) { // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() } else { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.SegmentQueueDropped.Increment() s.decRef() } } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -1874,6 +2179,7 @@ func (e *endpoint) readyToRead(s *segment) { // that a subsequent read of the segment will correctly trigger // a non-zero notification. if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() e.zeroWindow = true } e.rcvList.PushBack(s) @@ -2026,7 +2332,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Copy EndpointID. e.mu.Lock() - s.ID = stack.TCPEndpointID(e.id) + s.ID = stack.TCPEndpointID(e.ID) e.mu.Unlock() // Copy endpoint rcv state. @@ -2119,11 +2425,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { return s } -func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityGSO == 0 { - return - } - +func (e *endpoint) initHardwareGSO() { gso := &stack.GSO{} switch e.route.NetProto { case header.IPv4ProtocolNumber: @@ -2133,7 +2435,7 @@ func (e *endpoint) initGSO() { gso.Type = stack.GSOTCPv6 gso.L3HdrLen = header.IPv6MinimumSize default: - panic(fmt.Sprintf("Unknown netProto: %v", e.netProto)) + panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto)) } gso.NeedsCsum = true gso.CsumOffset = header.TCPChecksumOffset @@ -2141,6 +2443,18 @@ func (e *endpoint) initGSO() { e.gso = gso } +func (e *endpoint) initGSO() { + if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + e.initHardwareGSO() + } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + e.gso = &stack.GSO{ + MaxSize: e.route.GSOMaxSize(), + Type: stack.GSOSW, + NeedsCsum: false, + } + } +} + // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { @@ -2149,6 +2463,37 @@ func (e *endpoint) State() uint32 { return uint32(e.state) } +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.EndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.mu.Lock() + running := e.workerRunning + e.mu.Unlock() + if !running { + break + } + <-notifyCh + } +} + func mssForRoute(r *stack.Route) uint16 { + // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 831389ec7..7aa4c3f0e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -55,7 +55,7 @@ func (e *endpoint) beforeSave() { case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { - panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() @@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() { } fallthrough case StateError, StateClose: - for e.state == StateError && e.workerRunning { + for (e.state == StateError || e.state == StateClose) && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { + // Freeze segment queue before registering to prevent any segments + // from being delivered while it is being restored. + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() + state := e.origEndpointState - state := e.state switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -189,12 +195,13 @@ func (e *endpoint) Resume(s *stack.Stack) { } bind := func() { - e.state = StateInitial - if len(e.bindAddress) == 0 { - e.bindAddress = e.id.LocalAddress + if len(e.BindAddr) == 0 { + e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) + addr := e.BindAddr + port := e.ID.LocalPort + if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { + panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) } } @@ -202,21 +209,31 @@ func (e *endpoint) Resume(s *stack.Stack) { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { - e.connectingAddress = e.id.RemoteAddress + e.connectingAddress = e.ID.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. - if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if state == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } connectedLoading.Done() case StateListen: tcpip.AsyncLoading.Add(1) @@ -236,7 +253,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } connectingLoading.Done() @@ -263,8 +280,12 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() } - fallthrough + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) case StateError: + e.state = StateError + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } @@ -288,21 +309,21 @@ func (e *endpoint) loadLastError(s string) { } // saveHardError is invoked by stateify. -func (e *endpoint) saveHardError() string { - if e.hardError == nil { +func (e *EndpointInfo) saveHardError() string { + if e.HardError == nil { return "" } - return e.hardError.String() + return e.HardError.String() } // loadHardError is invoked by stateify. -func (e *endpoint) loadHardError(s string) { +func (e *EndpointInfo) loadHardError(s string) { if s == "" { return } - e.hardError = loadError(s) + e.HardError = loadError(s) } var messageToError map[string]*tcpip.Error diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 63666f0b3..4983bca81 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index d5d8ab96a..89b965c23 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -23,6 +23,7 @@ package tcp import ( "strings" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -54,12 +55,23 @@ const ( // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second ) // SACKEnabled option can be used to enable SACK support in the TCP // protocol. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool +// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol. +type DelayEnabled bool + // SendBufferSizeOption allows the default, min and max send buffer sizes for // TCP endpoints to be queried or configured. type SendBufferSizeOption struct { @@ -84,11 +96,14 @@ const ( type protocol struct { mu sync.Mutex sackEnabled bool + delayEnabled bool sendBufferSize SendBufferSizeOption recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool + tcpLingerTimeout time.Duration + tcpTimeWaitTimeout time.Duration } // Number returns the tcp protocol number. @@ -97,7 +112,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return newEndpoint(stack, netProto, waiterQueue), nil } @@ -126,8 +141,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { @@ -153,7 +168,7 @@ func replyWithReset(s *segment) { ack := s.sequenceNumber.Add(s.logicalLen()) - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. @@ -165,6 +180,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case DelayEnabled: + p.mu.Lock() + p.delayEnabled = bool(v) + p.mu.Unlock() + return nil + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue @@ -202,6 +223,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpLingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpTimeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -216,6 +255,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *DelayEnabled: + p.mu.Lock() + *v = DelayEnabled(p.delayEnabled) + p.mu.Unlock() + return nil + case *SendBufferSizeOption: p.mu.Lock() *v = p.sendBufferSize @@ -246,6 +291,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *tcpip.TCPLingerTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + p.mu.Unlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -258,5 +315,7 @@ func NewProtocol() stack.TransportProtocol { recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, + tcpLingerTimeout: DefaultTCPLingerTimeout, + tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index e90f9a7d9..068b90fb6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -18,6 +18,7 @@ import ( "container/heap" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -209,6 +210,11 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum switch r.ep.state { case StateFinWait1: r.ep.state = StateFinWait2 + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: r.ep.state = StateTimeWait case StateLastAck: @@ -253,23 +259,105 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -// handleRcvdSegment handles TCP segments directed at the connection managed by -// r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait, StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + s.decRef() + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + } + // We don't care about receive processing anymore if the receive side // is closed. - if r.closed { - return + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil + } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + r.ep.mu.RLock() + state := r.ep.state + closed := r.ep.closed + r.ep.mu.RUnlock() + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } } segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber // If the sequence number range is outside the acceptable range, just - // send an ACK. This is according to RFC 793, page 37. + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { r.ep.snd.sendAck() - return + return true, nil } // Defer segment processing if it can't be consumed now. @@ -288,7 +376,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { // have to retransmit. r.ep.snd.sendAck() } - return + return false, nil } // Since we consumed a segment update the receiver's RTT estimate @@ -315,4 +403,67 @@ func (r *receiver) handleRcvdSegment(s *segment) { r.pendingBufUsed -= s.logicalLen() s.decRef() } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index ea725d513..1c10da5ca 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -60,13 +61,13 @@ type segment struct { xmitTime time.Time `state:".(unixTime)"` } -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } - s.data = vv.Clone(s.views[:]) + s.data = pkt.Data.Clone(s.views[:]) s.rcvdTime = time.Now() return s } @@ -99,8 +100,14 @@ func (s *segment) clone() *segment { return t } -func (s *segment) flagIsSet(flag uint8) bool { - return (s.flags & flag) != 0 +// flagIsSet checks if at least one flag in flags is set in s.flags. +func (s *segment) flagIsSet(flags uint8) bool { + return s.flags&flags != 0 +} + +// flagsAreSet checks if all flags in flags are set in s.flags. +func (s *segment) flagsAreSet(flags uint8) bool { + return s.flags&flags == flags } func (s *segment) decRef() { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 735edfe55..d3f7c9125 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -417,6 +417,7 @@ func (s *sender) resendSegment() { s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() + s.ep.stats.SendErrors.FastRetransmit.Increment() // Run SetPipe() as per RFC 6675 section 5 Step 4.4 s.SetPipe() @@ -435,6 +436,7 @@ func (s *sender) retransmitTimerExpired() bool { } s.ep.stack.Stats().TCP.Timeouts.Increment() + s.ep.stats.SendErrors.Timeouts.Increment() // Give up if we've waited more than a minute since the last resend. if s.rto >= 60*time.Second { @@ -672,6 +674,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se default: s.ep.state = StateFinWait1 } + s.ep.stack.Stats().TCP.CurrentEstablished.Decrement() s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. @@ -1188,6 +1191,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { func (s *sender) sendSegment(seg *segment) *tcpip.Error { if !seg.xmitTime.IsZero() { s.ep.stack.Stats().TCP.Retransmits.Increment() + s.ep.stats.SendErrors.Retransmits.Increment() if s.sndCwnd < s.sndSsthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index 9fa97528b..782d7b42c 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -500,6 +500,14 @@ func TestRetransmit(t *testing.T) { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 4e7f1a740..afea124ec 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -520,10 +520,18 @@ func TestSACKRecovery(t *testing.T) { t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { + t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want) } + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { + t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want) + } + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) // Acknowledge all pending data to recover point. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 2be094876..84579ce52 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -99,7 +99,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { - t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want) } } @@ -122,6 +125,9 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want) + } } func TestTCPSegmentsSentIncrement(t *testing.T) { @@ -136,6 +142,9 @@ func TestTCPSegmentsSentIncrement(t *testing.T) { if got := stats.TCP.SegmentsSent.Value(); got != want { t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { + t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want) + } } func TestTCPResetsSentIncrement(t *testing.T) { @@ -197,17 +206,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -247,7 +257,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -255,6 +265,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + } + c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), @@ -276,6 +293,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Get the ACK to the FIN we just sent. c.GetPacket() + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + // Now resend the same ACK, this ACK should generate a RST as there // should be no endpoint in SYN-RCVD state and we are not using // syn-cookies yet. The reason we send the same ACK is we need a valid @@ -367,6 +389,13 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -387,12 +416,24 @@ func TestConnectResetAfterClose(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), RcvWnd: 30000, }) - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) for { b := c.GetPacket() tcpHdr := header.TCP(header.IPv4(b).Payload()) @@ -404,7 +445,7 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), ), @@ -465,6 +506,347 @@ func TestSimpleReceive(t *testing.T) { ) } +// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when +// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV4(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.Create(-1) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv4 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when +// creating a new active IPv6 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV6(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Set the MSS socket option. + opt := tcpip.MaxSegOption(test.setMSS) + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv6 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +func TestSendRstOnListenerRxSynAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxSynAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestTOSV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + + const tos = 0xC0 + if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { + t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + } + + var v tcpip.IPv4TOSOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv4TOSOption(tos); v != want { + t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testV4Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestTrafficClassV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + const tos = 0xC0 + if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { + t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) + } + + var v tcpip.IPv6TrafficClassOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv6TrafficClassOption(tos); v != want { + t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + // Test the connection request. + testV6Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetV6Packet() + checker.IPv6(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestConnectBindToDevice(t *testing.T) { + for _, test := range []struct { + name string + device string + want tcp.EndpointState + }{ + {"RightDevice", "nic1", tcp.StateEstablished}, + {"WrongDevice", "nic2", tcp.StateSynSent}, + {"AnyDevice", "", tcp.StateEstablished}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + bindToDevice := tcpip.BindToDeviceOption(test.device) + c.EP.SetSockOpt(bindToDevice) + // Start connection attempt. + waitEntry, _ := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + c.GetPacket() + if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + }) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -760,8 +1142,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -797,6 +1178,10 @@ func TestShutdownRead(t *testing.T) { if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { + t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want) + } } func TestFullWindowReceive(t *testing.T) { @@ -853,6 +1238,11 @@ func TestFullWindowReceive(t *testing.T) { t.Fatalf("got data = %v, want = %v", v, data) } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { + t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want) + } + // Check that we get an ACK for the newly non-zero window. checker.IPv4(t, c.GetPacket(), checker.TCP( @@ -1444,7 +1834,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -1492,7 +1882,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOpt(tcpip.DelayOption(1)) + c.EP.SetSockOptInt(tcpip.DelayOption, 1) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -1525,7 +1915,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOpt(tcpip.DelayOption(0)) + c.EP.SetSockOptInt(tcpip.DelayOption, 0) // Check that data is received. second := c.GetPacket() @@ -1562,7 +1952,7 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.DelayOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, } @@ -1692,6 +2082,34 @@ func TestSendGreaterThanMTU(t *testing.T) { testBrokenUpWrite(t, c, maxPayload) } +func TestSetTTL(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := context.New(t, 65535) + defer c.Cleanup() + + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { + t.Fatalf("SetSockOpt failed: %v", err) + } + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + + checker.IPv4(t, b, checker.TTL(wantTTL)) + }) + } +} + func TestActiveSendMSSLessThanMTU(t *testing.T) { const maxPayload = 100 c := context.New(t, 65535) @@ -1997,6 +2415,13 @@ loop: t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) } } + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { + t.Fatalf("got EP state is not StateError") + } } func TestSendOnResetConnection(t *testing.T) { @@ -2568,6 +2993,17 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { + t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want) + } + // Ensure there were no errors during handshake. If these stats have + // incremented, then the connection should not have been established. + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0) + } + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0) + } } func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { @@ -2592,6 +3028,9 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } } func TestReceivedIncorrectChecksumIncrement(t *testing.T) { @@ -2618,6 +3057,9 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { if got := stats.TCP.ChecksumErrors.Value(); got != want { t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) + } } func TestReceivedSegmentQueuing(t *testing.T) { @@ -2674,6 +3116,13 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) @@ -2681,12 +3130,12 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -2724,10 +3173,9 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -2744,7 +3192,7 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Peek failed: %v", err) + t.Fatalf("Peek failed: %s", err) } peekBuf = peekBuf[:n] @@ -2755,7 +3203,7 @@ func TestReadAfterClosedState(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -2765,11 +3213,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -2970,6 +3418,62 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) } +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + defer ep.Close() + + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } + + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) + } + + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s + } + + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) + } +} + func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ @@ -3880,7 +4384,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: irs, @@ -4014,6 +4519,9 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want) } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want) + } we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -4052,6 +4560,14 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + // Expect InvalidEndpointState errors on a read at this point. + if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + } + if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { + t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) + } + if err := ep.Listen(10); err != nil { t.Fatalf("Listen failed: %v", err) } @@ -4083,6 +4599,9 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } + if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { + t.Errorf("Unexpected error attempting to call connect on an established endpoint, got: %v, want: %v", err, tcpip.ErrAlreadyConnected) + } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -4370,3 +4889,590 @@ func TestReceiveBufferAutoTuning(t *testing.T) { payloadSize *= 2 } } + +func TestDelayEnabled(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + checkDelayOption(t, c, false, 0) // Delay is disabled by default. + + for _, v := range []struct { + delayEnabled tcp.DelayEnabled + wantDelayOption int + }{ + {delayEnabled: false, wantDelayOption: 0}, + {delayEnabled: true, wantDelayOption: 1}, + } { + c := context.New(t, defaultMTU) + defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %v", v.delayEnabled, err) + } + checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + } +} + +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { + t.Helper() + + var gotDelayEnabled tcp.DelayEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { + t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %v", err) + } + if gotDelayEnabled != wantDelayEnabled { + t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) + } + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) + if err != nil { + t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) + } + gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) + if err != nil { + t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) + } + if gotDelayOption != wantDelayOption { + t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) + } +} + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(uint32(ackHeaders.SeqNum)), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index d3f1d2cdf..4854e719d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -158,7 +158,14 @@ func New(t *testing.T, mtu uint32) *Context { if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { + if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + if testing.Verbose() { + wep2 = sniffer.New(channel.New(1000, mtu, "")) + } + if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -295,7 +302,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -343,13 +352,17 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.Inject(ipv4.ProtocolNumber, s) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: s, + }) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegment(payload, h), + }) } // SendAck sends an ACK packet. @@ -511,7 +524,9 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // CreateConnected creates a connected TCP endpoint. @@ -588,12 +603,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.Port = tcpHdr.SourcePort() } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { +// Create creates a TCP endpoint. +func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -606,6 +617,15 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("SetSockOpt failed failed: %v", err) } } +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { + c.Create(epRcvBuf) c.Connect(iss, rcvWnd, options) } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c1ca22b35..c9460aa0d 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "udp_packet_list", out = "udp_packet_list.go", @@ -52,6 +51,7 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0bec7e62d..5270f24df 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -31,9 +31,6 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } // EndpointState represents the state of a UDP endpoint. @@ -49,6 +46,22 @@ const ( StateClosed ) +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnected: + return "CONNECTING" + case StateClosed: + return "CLOSED" + default: + return "UNKNOWN" + } +} + // endpoint represents a UDP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -58,11 +71,13 @@ const ( // // +stateify savable type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -76,20 +91,23 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` sndBufSize int - id stack.TransportEndpointID state EndpointState - bindNICID tcpip.NICID - regNICID tcpip.NICID route stack.Route `state:"manual"` dstPort uint16 v6only bool + ttl uint8 multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID multicastLoop bool reusePort bool + bindToDevice tcpip.NICID broadcast bool + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -104,6 +122,9 @@ type endpoint struct { // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). effectiveNetProtos []tcpip.NetworkProtocolNumber + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats tcpip.TransportEndpointStats `state:"nosave"` } // +stateify savable @@ -112,10 +133,13 @@ type multicastMembership struct { multicastAddr tcpip.Address } -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { return &endpoint{ - stack: stack, - netProto: netProto, + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.UDPProtocolNumber, + }, waiterQueue: waiterQueue, // RFC 1075 section 5.4 recommends a TTL of 1 for membership // requests. @@ -133,9 +157,16 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + state: StateInitial, + uniqueID: s.UniqueID(), } } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -144,12 +175,12 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) } for _, mem := range e.multicastMemberships { - e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } e.multicastMemberships = nil @@ -189,6 +220,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess if e.rcvList.Empty() { err := tcpip.ErrWouldBlock if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() @@ -250,33 +282,56 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { - localAddr := e.id.LocalAddress +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { + localAddr := e.ID.LocalAddress if isBroadcastOrMulticast(localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" } if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { - if nicid == 0 { - nicid = e.multicastNICID + if nicID == 0 { + nicID = e.multicastNICID } - if localAddr == "" && nicid == 0 { + if localAddr == "" && nicID == 0 { localAddr = e.multicastAddr } } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { return stack.Route{}, 0, err } - return r, nicid, nil + return r, nicID, nil } // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -327,13 +382,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC - if e.bindNICID != 0 { - if nicid != 0 && nicid != e.bindNICID { + nicID := to.NIC + if e.BindNICID != 0 { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.bindNICID + nicID = e.BindNICID } if to.Addr == header.IPv4Broadcast && !e.broadcast { @@ -345,7 +400,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, err } - r, _, err := e.connectRoute(nicid, *to, netProto) + r, _, err := e.connectRoute(nicID, *to, netProto) if err != nil { return 0, nil, err } @@ -373,12 +428,16 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrMessageTooLong } - ttl := route.DefaultTTL() + ttl := e.ttl + useDefaultTTL := ttl == 0 + if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { ttl = e.multicastTTL + // Multicast allows a 0 TTL. + useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -399,7 +458,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrInvalidEndpointState } @@ -413,6 +472,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.v6only = v != 0 + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -447,7 +511,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if e.bindNICID != 0 && e.bindNICID != nic { + if e.BindNICID != 0 && e.BindNICID != nic { return tcpip.ErrInvalidEndpointState } @@ -474,7 +538,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -491,7 +555,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } - if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -512,7 +576,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } } } else { - nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { return tcpip.ErrUnknownDevice @@ -534,7 +598,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrBadLocalAddress } - if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } @@ -551,12 +615,39 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.reusePort = v != 0 e.mu.Unlock() + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicID, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicID + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 e.mu.Unlock() return nil + + case tcpip.IPv4TOSOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + return nil + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + return nil } return nil } @@ -598,7 +689,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. - if e.netProto != header.IPv6ProtocolNumber { + if e.NetProto != header.IPv6ProtocolNumber { return tcpip.ErrUnknownProtocolOption } @@ -612,6 +703,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.TTLOption: + e.mu.Lock() + *o = tcpip.TTLOption(e.ttl) + e.mu.Unlock() + return nil + case *tcpip.MulticastTTLOption: e.mu.Lock() *o = tcpip.MulticastTTLOption(e.multicastTTL) @@ -646,6 +743,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = tcpip.BindToDeviceOption("") + return nil + case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -661,6 +768,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.IPv4TOSOption: + e.mu.RLock() + *o = tcpip.IPv4TOSOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.IPv6TrafficClassOption: + e.mu.RLock() + *o = tcpip.IPv6TrafficClassOption(e.sendTOS) + e.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -668,7 +787,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -691,14 +810,21 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u udp.SetChecksum(^udp.CalculateChecksum(xsum)) } + if useDefaultTTL { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + r.Stats().UDP.PacketSendErrors.Increment() + return err + } + // Track count of packets sent. r.Stats().UDP.PacketsSent.Increment() - - return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl) + return nil } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto + netProto := e.NetProto if len(addr.Addr) == 0 { return netProto, nil } @@ -715,14 +841,14 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t } // Fail if we are bound to an IPv6 address. - if !allowMismatch && len(e.id.LocalAddress) == 16 { + if !allowMismatch && len(e.ID.LocalAddress) == 16 { return 0, tcpip.ErrNetworkUnreachable } } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) { + if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -739,27 +865,27 @@ func (e *endpoint) Disconnect() *tcpip.Error { } id := stack.TransportEndpointID{} // Exclude ephemerally bound endpoints. - if e.bindNICID != 0 || e.id.LocalAddress == "" { + if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error id = stack.TransportEndpointID{ - LocalPort: e.id.LocalPort, - LocalAddress: e.id.LocalAddress, + LocalPort: e.ID.LocalPort, + LocalAddress: e.ID.LocalAddress, } - id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { return err } e.state = StateBound } else { - if e.id.LocalPort != 0 { + if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice) } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) - e.id = id + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.ID = id e.route.Release() e.route = stack.Route{} e.dstPort = 0 @@ -781,33 +907,33 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC var localPort uint16 switch e.state { case StateInitial: case StateBound, StateConnected: - localPort = e.id.LocalPort - if e.bindNICID == 0 { + localPort = e.ID.LocalPort + if e.BindNICID == 0 { break } - if nicid != 0 && nicid != e.bindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.bindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } - r, nicid, err := e.connectRoute(nicid, addr, netProto) + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err } defer r.Release() id := stack.TransportEndpointID{ - LocalAddress: e.id.LocalAddress, + LocalAddress: e.ID.LocalAddress, LocalPort: localPort, RemotePort: addr.Port, RemoteAddress: r.RemoteAddress, @@ -828,20 +954,20 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } // Remove the old registration. - if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + if e.ID.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) } - e.id = id + e.ID = id e.route = r.Clone() e.dstPort = addr.Port - e.regNICID = nicid + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos e.state = StateConnected @@ -896,18 +1022,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { - if e.id.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort) +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if e.ID.LocalPort == 0 { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) if err != nil { return id, err } id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) } return id, err } @@ -935,11 +1061,11 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } } - nicid := addr.NIC + nicID := addr.NIC if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { // A local unicast address was specified, verify that it's valid. - nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nicid == 0 { + nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nicID == 0 { return tcpip.ErrBadLocalAddress } } @@ -948,13 +1074,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } - e.id = id - e.regNICID = nicid + e.ID = id + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos // Mark endpoint as bound. @@ -979,7 +1105,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // Save the effective NICID generated by bindLocked. - e.bindNICID = e.regNICID + e.BindNICID = e.RegisterNICID return nil } @@ -990,9 +1116,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, + NIC: e.RegisterNICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, }, nil } @@ -1006,9 +1132,9 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } return tcpip.FullAddress{ - NIC: e.regNICID, - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, }, nil } @@ -1032,42 +1158,52 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - vv.TrimFront(header.UDPMinimumSize) + pkt.Data.TrimFront(header.UDPMinimumSize) e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() + e.stats.PacketsReceived.Increment() // Drop the packet if our buffer is currently full. - if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return } wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &udpPacket{ + packet := &udpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, Port: hdr.SourcePort(), }, } - pkt.data = vv.Clone(pkt.views[:]) - e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + packet.data = pkt.Data + e.rcvList.PushBack(packet) + e.rcvBufSize += pkt.Data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() @@ -1078,7 +1214,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. @@ -1088,6 +1224,23 @@ func (e *endpoint) State() uint32 { return uint32(e.state) } +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements tcpip.Endpoint.Wait. +func (*endpoint) Wait() {} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index be46e6d4e..b227e353b 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -72,7 +72,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s for _, m := range e.multicastMemberships { - if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } } @@ -93,13 +93,13 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == StateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) if err != nil { panic(err) } - } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. - if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } } @@ -107,9 +107,9 @@ func (e *endpoint) Resume(s *stack.Stack) { // Our saved state had a port, but we don't actually have a // reservation. We need to remove the port from our state, but still // pass it to the reservation machinery. - id := e.id - e.id.LocalPort = 0 - e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + id := e.ID + e.ID.LocalPort = 0 + e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a9edc2c8d..fc706ede2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -16,7 +16,6 @@ package udp import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, id: id, - vv: vv, + pkt: pkt, }) return true @@ -62,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - vv buffer.VectorisedView + pkt tcpip.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that @@ -74,15 +73,15 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { ep.Close() return nil, err } - ep.id = r.id + ep.ID = r.id ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort - ep.regNICID = r.route.NICID() + ep.RegisterNICID = r.route.NICID() ep.state = StateConnected @@ -90,7 +89,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.vv) + ep.HandlePacket(r.route, r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index f5cc932dd..43f11b700 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,10 +66,10 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -116,13 +116,18 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + // The buffers used by pkt may be used elsewhere in the system. + // For example, a raw or packet socket may use what UDP + // considers an unreachable destination. Thus we deep copy pkt + // to prevent multiple ownership and SR errors. + newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) + payload := newNetHeader.ToVectorisedView() + payload.Append(pkt.Data.ToView().ToVectorisedView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -130,7 +135,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL()) + r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -151,12 +156,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload.Append(pkt.Data) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -164,7 +169,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) } return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 2ec27be4d..30ee9801b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -17,7 +17,6 @@ package udp_test import ( "bytes" "fmt" - "math" "math/rand" "testing" "time" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -384,18 +384,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h) + c.injectV4Packet(payload, &h, true /* valid */) } else { - c.injectV6Packet(payload, &h) + c.injectV6Packet(payload, &h, true /* valid */) } } // injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv6(buf) @@ -409,10 +412,16 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) + l := uint16(header.UDPMinimumSize + len(payload)) + if !valid { + // Change the UDP payload length to corrupt the header + // as requested by the caller. + l++ + } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: uint16(header.UDPMinimumSize + len(payload)), + Length: l, }) // Calculate the UDP pseudo-header checksum. @@ -423,15 +432,22 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) { u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } -// injectV6Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { +// injectV4Packet creates a V4 test packet with the given payload and header +// values, and injects it into the link endpoint. valid indicates if the +// caller intends to inject a packet with a valid or an invalid UDP header. +// We can invalidate the header by corrupting the UDP payload length. +func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv4(buf) @@ -461,7 +477,12 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) { u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } func newPayload() []byte { @@ -476,94 +497,67 @@ func newMinPayload(minSize int) []byte { return b } -func TestBindPortReuse(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - var eps [5]tcpip.Endpoint - reusePortOpt := tcpip.ReusePortOption(1) - - pollChannel := make(chan tcpip.Endpoint) - for i := 0; i < len(eps); i++ { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err *tcpip.Error - eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(eps[i]) +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - defer eps[i].Close() - if err := eps[i].SetSockOpt(reusePortOpt); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) } + defer ep.Close() - npackets := 100000 - nports := 10000 - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - c.injectV6Packet(payload, &header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - }) + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } - var addr tcpip.FullAddress - ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled - // by the same socket. - if ports[port] != ep { - t.Fatalf("Port mismatch") - } - } + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) } - if len(stats) != len(eps) { - t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps)) + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s } - // Check that a packet distribution is fair between sockets. - for _, c := range stats { - n := float64(npackets) / float64(len(eps)) - // The deviation is less than 10%. - if math.Abs(float64(c)-n) > n/10 { - t.Fatal(c, n) - } + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) } } -// testRead sends a packet of the given test flow into the stack by injecting it -// into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// testReadInternal sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then attempts to read it from the +// UDP endpoint and depending on if this was expected to succeed verifies its +// correctness. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { c.t.Helper() payload := newPayload() @@ -574,6 +568,9 @@ func testRead(c *testContext, flow testFlow) { c.wq.EventRegister(&we, waiter.EventIn) defer c.wq.EventUnregister(&we) + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + var addr tcpip.FullAddress v, _, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { @@ -581,25 +578,55 @@ func testRead(c *testContext, flow testFlow) { select { case <-ch: v, _, err = c.ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - case <-time.After(1 * time.Second): - c.t.Fatalf("Timed out waiting for data") + case <-time.After(300 * time.Millisecond): + if packetShouldBeDropped { + return // expected to time out + } + c.t.Fatal("timed out waiting for data") } } + if expectReadError && err != nil { + c.checkEndpointReadStats(1, epstats, err) + return + } + + if err != nil { + c.t.Fatal("Read failed:", err) + } + + if packetShouldBeDropped { + c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) + } + // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) } // Check the payload. if !bytes.Equal(payload, v) { - c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + c.checkEndpointReadStats(1, epstats, err) +} + +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness. +func testRead(c *testContext, flow testFlow) { + c.t.Helper() + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) +} + +// testFailingRead sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then tries to read it from the UDP +// endpoint and expects this to fail. +func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { + c.t.Helper() + testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) } func TestBindEphemeralPort(t *testing.T) { @@ -771,13 +798,17 @@ func TestReadOnBoundToMulticast(t *testing.T) { c.t.Fatal("SetSockOpt failed:", err) } + // Check that we receive multicast packets but not unicast or broadcast + // ones. testRead(c, flow) + testFailingRead(c, broadcast, false /* expectReadError */) + testFailingRead(c, unicastV4, false /* expectReadError */) }) } } // TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast -// address and receive broadcast data on it. +// address and can receive only broadcast data. func TestV4ReadOnBoundToBroadcast(t *testing.T) { for _, flow := range []testFlow{broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -792,8 +823,31 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - // Test acceptance. + // Check that we receive broadcast packets but not unicast ones. + testRead(c, flow) + testFailingRead(c, unicastV4, false /* expectReadError */) + }) + } +} + +// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY +// and receive broadcast and unicast data. +func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s (", err) + } + + // Check that we receive both broadcast and unicast packets. testRead(c, flow) + testRead(c, unicastV4) }) } } @@ -802,7 +856,8 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { // and verifies it fails with the provided error code. func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { c.t.Helper() - + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) @@ -810,6 +865,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) + c.checkEndpointWriteStats(1, epstats, gotErr) if gotErr != wantErr { c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) } @@ -835,6 +891,8 @@ func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...chec func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() writeOpts := tcpip.WriteOptions{} if setDest { @@ -852,7 +910,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - + c.checkEndpointWriteStats(1, epstats, err) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -921,6 +979,10 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Write to V4 mapped address. testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + const want = 1 + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { + c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) + } } func TestDualWriteConnectedToV4Mapped(t *testing.T) { @@ -1183,6 +1245,109 @@ func TestTTL(t *testing.T) { } } +func TestSetTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() + } + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + ep.Close() + + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } + }) + } +} + +func TestTOSV4(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + const tos = 0xC0 + var v tcpip.IPv4TOSOption + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + // Test for expected default value. + if v != 0 { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + } + + if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { + c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + } + + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv4TOSOption(tos); v != want { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testWrite(c, flow, checker.TOS(tos, 0)) + }) + } +} + +func TestTOSV6(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + const tos = 0xC0 + var v tcpip.IPv6TrafficClassOption + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + // Test for expected default value. + if v != 0 { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + } + + if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { + c.t.Errorf("SetSockOpt failed: %s", err) + } + + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.IPv6TrafficClassOption(tos); v != want { + c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + } + + testWrite(c, flow, checker.TOS(tos, 0)) + }) + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1396,3 +1561,117 @@ func TestV6UnknownDestination(t *testing.T) { }) } } + +// TestIncrementMalformedPacketsReceived verifies if the malformed received +// global and endpoint stats get incremented. +func TestIncrementMalformedPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + payload := newPayload() + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + c.injectV6Packet(payload, &h, false /* !valid */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + } +} + +// TestShutdownRead verifies endpoint read shutdown and error +// stats increment on packet receive. +func TestShutdownRead(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingRead(c, unicastV6, true /* expectReadError */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { + t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) + } +} + +// TestShutdownWrite verifies endpoint write shutdown and error +// stats increment on packet write. +func TestShutdownWrite(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) +} + +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil: + want.PacketsSent.IncrementBy(incr) + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + want.WriteErrors.InvalidArgs.IncrementBy(incr) + case tcpip.ErrClosedForSend: + want.WriteErrors.WriteClosed.IncrementBy(incr) + case tcpip.ErrInvalidEndpointState: + want.WriteErrors.InvalidEndpointState.IncrementBy(incr) + case tcpip.ErrNoLinkAddress: + want.SendErrors.NoLinkAddr.IncrementBy(incr) + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + want.SendErrors.NoRoute.IncrementBy(incr) + default: + want.SendErrors.SendToNetworkFailed.IncrementBy(incr) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} + +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil, tcpip.ErrWouldBlock: + case tcpip.ErrClosedForReceive: + want.ReadErrors.ReadClosed.IncrementBy(incr) + default: + c.t.Errorf("Endpoint error missing stats update err %v", err) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 8dc88becb..1f7efb064 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -1,10 +1,9 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") - -package(licenses = ["notice"]) - load("//tools/go_generics:defs.bzl", "go_template_instance") load("//tools/go_stateify:defs.bzl", "go_library") +package(licenses = ["notice"]) + go_template_instance( name = "waiter_list", out = "waiter_list.go", diff --git a/runsc/BUILD b/runsc/BUILD index 5e7dacb87..e5587421d 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -1,7 +1,7 @@ package(licenses = ["notice"]) # Apache 2.0 load("@io_bazel_rules_go//go:def.bzl", "go_binary") -load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_deb", "pkg_tar") +load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar") go_binary( name = "runsc", @@ -76,26 +76,29 @@ pkg_tar( genrule( name = "deb-version", + # Note that runsc must appear in the srcs parameter and not the tools + # parameter, otherwise it will not be stamped. This is reasonable, as tools + # may be encoded differently in the build graph (cached more aggressively + # because they are assumes to be hermetic). + srcs = [":runsc"], outs = ["version.txt"], cmd = "$(location :runsc) -version | grep 'runsc version' | sed 's/^[^0-9]*//' > $@", stamp = 1, - tools = [":runsc"], ) pkg_deb( name = "runsc-debian", architecture = "amd64", data = ":debian-data", + # Note that the description_file will be flatten (all newlines removed), + # and therefore it is kept to a simple one-line description. The expected + # format for debian packages is "short summary\nLonger explanation of + # tool." and this is impossible with the flattening. description_file = "debian/description", homepage = "https://gvisor.dev/", maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>", package = "runsc", postinst = "debian/postinst.sh", - tags = [ - # TODO(b/135475885): pkg_deb requires python2: - # https://github.com/bazelbuild/bazel/issues/8443 - "manual", - ], version_file = ":version.txt", visibility = [ "//visibility:public", diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index d90381c0f..58e86ae7f 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -57,10 +57,11 @@ go_library( "//pkg/sentry/pgalloc", "//pkg/sentry/platform", "//pkg/sentry/sighandling", - "//pkg/sentry/socket/epsocket", "//pkg/sentry/socket/hostinet", "//pkg/sentry/socket/netlink", "//pkg/sentry/socket/netlink/route", + "//pkg/sentry/socket/netlink/uevent", + "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix", "//pkg/sentry/state", "//pkg/sentry/strace", diff --git a/runsc/boot/config.go b/runsc/boot/config.go index 38278d0a2..72a33534f 100644 --- a/runsc/boot/config.go +++ b/runsc/boot/config.go @@ -178,8 +178,11 @@ type Config struct { // capabilities. EnableRaw bool - // GSO indicates that generic segmentation offload is enabled. - GSO bool + // HardwareGSO indicates that hardware segmentation offload is enabled. + HardwareGSO bool + + // SoftwareGSO indicates that software segmentation offload is enabled. + SoftwareGSO bool // LogPackets indicates that all network packets should be logged. LogPackets bool @@ -231,6 +234,10 @@ type Config struct { // ReferenceLeakMode sets reference leak check mode ReferenceLeakMode refs.LeakMode + // OverlayfsStaleRead causes cached FDs to reopen after a file is opened for + // write to workaround overlayfs limitation on kernels before 4.19. + OverlayfsStaleRead bool + // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in // tests. It allows runsc to start the sandbox process as the current // user, and without chrooting the sandbox process. This can be @@ -271,6 +278,9 @@ func (c *Config) ToFlags() []string { "--rootless=" + strconv.FormatBool(c.Rootless), "--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr), "--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode), + "--gso=" + strconv.FormatBool(c.HardwareGSO), + "--software-gso=" + strconv.FormatBool(c.SoftwareGSO), + "--overlayfs-stale-read=" + strconv.FormatBool(c.OverlayfsStaleRead), } // Only include these if set since it is never to be used by users. if c.TestOnlyAllowRunAsCurrentUserWithoutChroot { diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 72cbabd16..f62be4c59 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -18,7 +18,6 @@ import ( "errors" "fmt" "os" - "path" "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -27,12 +26,13 @@ import ( "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/socket/epsocket" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/state" "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" + "gvisor.dev/gvisor/runsc/specutils" ) const ( @@ -51,7 +51,7 @@ const ( ContainerEvent = "containerManager.Event" // ContainerExecuteAsync is the URPC endpoint for executing a command in a - // container.. + // container. ContainerExecuteAsync = "containerManager.ExecuteAsync" // ContainerPause pauses the container. @@ -142,7 +142,7 @@ func newController(fd int, l *Loader) (*controller, error) { } srv.Register(manager) - if eps, ok := l.k.NetworkStack().(*epsocket.Stack); ok { + if eps, ok := l.k.NetworkStack().(*netstack.Stack); ok { net := &Network{ Stack: eps.Stack, } @@ -161,7 +161,7 @@ func newController(fd int, l *Loader) (*controller, error) { }, nil } -// containerManager manages sandboes containers. +// containerManager manages sandbox containers. type containerManager struct { // startChan is used to signal when the root container process should // be started. @@ -234,17 +234,13 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error { if args.CID == "" { return errors.New("start argument missing container ID") } - // Prevent CIDs containing ".." from confusing the sentry when creating - // /containers/<cid> directory. - // TODO(b/129293409): Once we have multiple independent roots, this - // check won't be necessary. - if path.Clean(args.CID) != args.CID { - return fmt.Errorf("container ID shouldn't contain directory traversals such as \"..\": %q", args.CID) - } if len(args.FilePayload.Files) < 4 { return fmt.Errorf("start arguments must contain stdin, stderr, and stdout followed by at least one file for the container root gofer") } + // All validation passed, logs the spec for debugging. + specutils.LogSpec(args.Spec) + err := cm.l.startContainer(args.Spec, args.Conf, args.CID, args.FilePayload.Files) if err != nil { log.Debugf("containerManager.Start failed %q: %+v: %v", args.CID, args, err) @@ -355,7 +351,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { fs.SetRestoreEnvironment(*renv) // Prepare to load from the state file. - if eps, ok := networkStack.(*epsocket.Stack); ok { + if eps, ok := networkStack.(*netstack.Stack); ok { stack.StackFromEnv = eps.Stack // FIXME(b/36201077) } info, err := specFile.Stat() @@ -384,7 +380,9 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { } // Since we have a new kernel we also must make a new watchdog. - dog := watchdog.New(k, watchdog.DefaultTimeout, cm.l.conf.WatchdogAction) + dogOpts := watchdog.DefaultOpts + dogOpts.TaskTimeoutAction = cm.l.conf.WatchdogAction + dog := watchdog.New(k, dogOpts) // Change the loader fields to reflect the changes made when restoring. cm.l.k = k diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index a2ecc6bcb..5ad108261 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -44,6 +44,7 @@ var allowedSyscalls = seccomp.SyscallRules{ }, syscall.SYS_CLOSE: {}, syscall.SYS_DUP: {}, + syscall.SYS_DUP2: {}, syscall.SYS_EPOLL_CREATE1: {}, syscall.SYS_EPOLL_CTL: {}, syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{ @@ -242,6 +243,15 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.AllowValue(0), }, }, + unix.SYS_SENDMMSG: []seccomp.Rule{ + { + seccomp.AllowAny{}, + seccomp.AllowAny{}, + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.MSG_DONTWAIT), + seccomp.AllowValue(0), + }, + }, syscall.SYS_RESTART_SYSCALL: {}, syscall.SYS_RT_SIGACTION: {}, syscall.SYS_RT_SIGPROCMASK: {}, diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 34c674840..76036c147 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -64,6 +64,9 @@ const ( nonefs = "none" ) +// tmpfs has some extra supported options that we must pass through. +var tmpfsAllowedOptions = []string{"mode", "uid", "gid"} + func addOverlay(ctx context.Context, conf *Config, lower *fs.Inode, name string, lowerFlags fs.MountSourceFlags) (*fs.Inode, error) { // Upper layer uses the same flags as lower, but it must be read-write. upperFlags := lowerFlags @@ -172,27 +175,25 @@ func p9MountOptions(fd int, fa FileAccessType) []string { func parseAndFilterOptions(opts []string, allowedKeys ...string) ([]string, error) { var out []string for _, o := range opts { - kv := strings.Split(o, "=") - switch len(kv) { - case 1: - if specutils.ContainsStr(allowedKeys, o) { - out = append(out, o) - continue - } - log.Warningf("ignoring unsupported key %q", kv) - case 2: - if specutils.ContainsStr(allowedKeys, kv[0]) { - out = append(out, o) - continue - } - log.Warningf("ignoring unsupported key %q", kv[0]) - default: - return nil, fmt.Errorf("invalid option %q", o) + ok, err := parseMountOption(o, allowedKeys...) + if err != nil { + return nil, err + } + if ok { + out = append(out, o) } } return out, nil } +func parseMountOption(opt string, allowedKeys ...string) (bool, error) { + kv := strings.SplitN(opt, "=", 3) + if len(kv) > 2 { + return false, fmt.Errorf("invalid option %q", opt) + } + return specutils.ContainsStr(allowedKeys, kv[0]), nil +} + // mountDevice returns a device string based on the fs type and target // of the mount. func mountDevice(m specs.Mount) string { @@ -207,6 +208,8 @@ func mountDevice(m specs.Mount) string { func mountFlags(opts []string) fs.MountSourceFlags { mf := fs.MountSourceFlags{} + // Note: changes to supported options must be reflected in + // isSupportedMountFlag() as well. for _, o := range opts { switch o { case "rw": @@ -224,6 +227,18 @@ func mountFlags(opts []string) fs.MountSourceFlags { return mf } +func isSupportedMountFlag(fstype, opt string) bool { + switch opt { + case "rw", "ro", "noatime", "noexec": + return true + } + if fstype == tmpfs { + ok, err := parseMountOption(opt, tmpfsAllowedOptions...) + return ok && err == nil + } + return false +} + func mustFindFilesystem(name string) fs.Filesystem { fs, ok := fs.FindFilesystem(name) if !ok { @@ -427,6 +442,39 @@ func (m *mountHint) isSupported() bool { return m.mount.Type == tmpfs && m.share == pod } +// checkCompatible verifies that shared mount is compatible with master. +// For now enforce that all options are the same. Once bind mount is properly +// supported, then we should ensure the master is less restrictive than the +// container, e.g. master can be 'rw' while container mounts as 'ro'. +func (m *mountHint) checkCompatible(mount specs.Mount) error { + // Remove options that don't affect to mount's behavior. + masterOpts := filterUnsupportedOptions(m.mount) + slaveOpts := filterUnsupportedOptions(mount) + + if len(masterOpts) != len(slaveOpts) { + return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts) + } + + sort.Strings(masterOpts) + sort.Strings(slaveOpts) + for i, opt := range masterOpts { + if opt != slaveOpts[i] { + return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts) + } + } + return nil +} + +func filterUnsupportedOptions(mount specs.Mount) []string { + rv := make([]string, 0, len(mount.Options)) + for _, o := range mount.Options { + if isSupportedMountFlag(mount.Type, o) { + rv = append(rv, o) + } + } + return rv +} + // podMountHints contains a collection of mountHints for the pod. type podMountHints struct { mounts map[string]*mountHint @@ -655,6 +703,14 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (* log.Infof("Mounting root over 9P, ioFD: %d", fd) p9FS := mustFindFilesystem("9p") opts := p9MountOptions(fd, conf.FileAccess) + + if conf.OverlayfsStaleRead { + // We can't check for overlayfs here because sandbox is chroot'ed and gofer + // can only send mount options for specs.Mounts (specs.Root is missing + // Options field). So assume root is always on top of overlayfs. + opts = append(opts, "overlayfs_stale_read") + } + rootInode, err := p9FS.Mount(ctx, rootDevice, mf, strings.Join(opts, ","), nil) if err != nil { return nil, fmt.Errorf("creating root mount point: %v", err) @@ -689,7 +745,6 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) ( fsName string opts []string useOverlay bool - err error ) switch m.Type { @@ -700,8 +755,11 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) ( case tmpfs: fsName = m.Type - // tmpfs has some extra supported options that we must pass through. - opts, err = parseAndFilterOptions(m.Options, "mode", "uid", "gid") + var err error + opts, err = parseAndFilterOptions(m.Options, tmpfsAllowedOptions...) + if err != nil { + return "", nil, false, err + } case bind: fd := c.fds.remove() @@ -717,7 +775,7 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) ( // for now. log.Warningf("ignoring unknown filesystem type %q", m.Type) } - return fsName, opts, useOverlay, err + return fsName, opts, useOverlay, nil } // mountSubmount mounts volumes inside the container's root. Because mounts may @@ -786,17 +844,8 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns // mountSharedSubmount binds mount to a previously mounted volume that is shared // among containers in the same pod. func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.MountNamespace, root *fs.Dirent, mount specs.Mount, source *mountHint) error { - // For now enforce that all options are the same. Once bind mount is properly - // supported, then we should ensure the master is less restrictive than the - // container, e.g. master can be 'rw' while container mounts as 'ro'. - if len(mount.Options) != len(source.mount.Options) { - return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", source.mount.Options, mount.Options) - } - sort.Strings(mount.Options) - for i, opt := range mount.Options { - if opt != source.mount.Options[i] { - return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", source.mount.Options, mount.Options) - } + if err := source.checkCompatible(mount); err != nil { + return err } maxTraversals := uint(0) diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index adf345490..f05d5973f 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -62,10 +62,11 @@ import ( "gvisor.dev/gvisor/runsc/specutils" // Include supported socket providers. - "gvisor.dev/gvisor/pkg/sentry/socket/epsocket" "gvisor.dev/gvisor/pkg/sentry/socket/hostinet" _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink" _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route" + _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" _ "gvisor.dev/gvisor/pkg/sentry/socket/unix" ) @@ -232,7 +233,7 @@ func New(args Args) (*Loader, error) { // this point. Netns is configured before Run() is called. Netstack is // configured using a control uRPC message. Host network is configured inside // Run(). - networkStack, err := newEmptyNetworkStack(args.Conf, k) + networkStack, err := newEmptyNetworkStack(args.Conf, k, k) if err != nil { return nil, fmt.Errorf("creating network: %v", err) } @@ -300,7 +301,9 @@ func New(args Args) (*Loader, error) { } // Create a watchdog. - dog := watchdog.New(k, watchdog.DefaultTimeout, args.Conf.WatchdogAction) + dogOpts := watchdog.DefaultOpts + dogOpts.TaskTimeoutAction = args.Conf.WatchdogAction + dog := watchdog.New(k, dogOpts) procArgs, err := newProcess(args.ID, args.Spec, creds, k, k.RootPIDNamespace()) if err != nil { @@ -905,7 +908,7 @@ func (l *Loader) WaitExit() kernel.ExitStatus { return l.k.GlobalInit().ExitStatus() } -func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { +func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { switch conf.Network { case NetworkHost: return hostinet.NewStack(), nil @@ -914,15 +917,16 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { // NetworkNone sets up loopback using netstack. netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} - s := epsocket.Stack{stack.New(stack.Options{ + s := netstack.Stack{stack.New(stack.Options{ NetworkProtocols: netProtos, TransportProtocols: transProtos, Clock: clock, - Stats: epsocket.Metrics, + Stats: netstack.Metrics, HandleLocal: true, // Enable raw sockets for users with sufficient // privileges. - UnassociatedFactory: raw.EndpointFactory{}, + RawFactory: raw.EndpointFactory{}, + UniqueID: uniqueID, })} // Enable SACK Recovery. @@ -930,6 +934,10 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { return nil, fmt.Errorf("failed to enable SACK: %v", err) } + // Set default TTLs as required by socket/netstack. + s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + // Enable Receive Buffer Auto-Tuning. if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 32cba5ac1..f98c5fd36 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -50,12 +50,13 @@ type DefaultRoute struct { // FDBasedLink configures an fd-based link. type FDBasedLink struct { - Name string - MTU int - Addresses []net.IP - Routes []Route - GSOMaxSize uint32 - LinkAddress net.HardwareAddr + Name string + MTU int + Addresses []net.IP + Routes []Route + GSOMaxSize uint32 + SoftwareGSOEnabled bool + LinkAddress net.HardwareAddr // NumChannels controls how many underlying FD's are to be used to // create this endpoint. @@ -163,6 +164,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct Address: mac, PacketDispatchMode: fdbased.RecvMMsg, GSOMaxSize: link.GSOMaxSize, + SoftwareGSOEnabled: link.SoftwareGSOEnabled, RXChecksumOffload: true, }) if err != nil { diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go index bf1225e1c..d1e99243b 100644 --- a/runsc/cmd/exec.go +++ b/runsc/cmd/exec.go @@ -105,11 +105,11 @@ func (ex *Exec) SetFlags(f *flag.FlagSet) { // Execute implements subcommands.Command.Execute. It starts a process in an // already created container. func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - e, id, err := ex.parseArgs(f) + conf := args[0].(*boot.Config) + e, id, err := ex.parseArgs(f, conf.EnableRaw) if err != nil { Fatalf("parsing process spec: %v", err) } - conf := args[0].(*boot.Config) waitStatus := args[1].(*syscall.WaitStatus) c, err := container.Load(conf.RootDir, id) @@ -117,6 +117,9 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) Fatalf("loading sandbox: %v", err) } + log.Debugf("Exec arguments: %+v", e) + log.Debugf("Exec capablities: %+v", e.Capabilities) + // Replace empty settings with defaults from container. if e.WorkingDirectory == "" { e.WorkingDirectory = c.Spec.Process.Cwd @@ -129,14 +132,11 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } if e.Capabilities == nil { - // enableRaw is set to true to prevent the filtering out of - // CAP_NET_RAW. This is the opposite of Create() because exec - // requires the capability to be set explicitly, while 'docker - // run' sets it by default. - e.Capabilities, err = specutils.Capabilities(true /* enableRaw */, c.Spec.Process.Capabilities) + e.Capabilities, err = specutils.Capabilities(conf.EnableRaw, c.Spec.Process.Capabilities) if err != nil { Fatalf("creating capabilities: %v", err) } + log.Infof("Using exec capabilities from container: %+v", e.Capabilities) } // containerd expects an actual process to represent the container being @@ -283,14 +283,14 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi // parseArgs parses exec information from the command line or a JSON file // depending on whether the --process flag was used. Returns an ExecArgs and // the ID of the container to be used. -func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) { +func (ex *Exec) parseArgs(f *flag.FlagSet, enableRaw bool) (*control.ExecArgs, string, error) { if ex.processPath == "" { // Requires at least a container ID and command. if f.NArg() < 2 { f.Usage() return nil, "", fmt.Errorf("both a container-id and command are required") } - e, err := ex.argsFromCLI(f.Args()[1:]) + e, err := ex.argsFromCLI(f.Args()[1:], enableRaw) return e, f.Arg(0), err } // Requires only the container ID. @@ -298,11 +298,11 @@ func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) { f.Usage() return nil, "", fmt.Errorf("a container-id is required") } - e, err := ex.argsFromProcessFile() + e, err := ex.argsFromProcessFile(enableRaw) return e, f.Arg(0), err } -func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) { +func (ex *Exec) argsFromCLI(argv []string, enableRaw bool) (*control.ExecArgs, error) { extraKGIDs := make([]auth.KGID, 0, len(ex.extraKGIDs)) for _, s := range ex.extraKGIDs { kgid, err := strconv.Atoi(s) @@ -315,7 +315,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) { var caps *auth.TaskCapabilities if len(ex.caps) > 0 { var err error - caps, err = capabilities(ex.caps) + caps, err = capabilities(ex.caps, enableRaw) if err != nil { return nil, fmt.Errorf("capabilities error: %v", err) } @@ -333,7 +333,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) { }, nil } -func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) { +func (ex *Exec) argsFromProcessFile(enableRaw bool) (*control.ExecArgs, error) { f, err := os.Open(ex.processPath) if err != nil { return nil, fmt.Errorf("error opening process file: %s, %v", ex.processPath, err) @@ -343,21 +343,21 @@ func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) { if err := json.NewDecoder(f).Decode(&p); err != nil { return nil, fmt.Errorf("error parsing process file: %s, %v", ex.processPath, err) } - return argsFromProcess(&p) + return argsFromProcess(&p, enableRaw) } // argsFromProcess performs all the non-IO conversion from the Process struct // to ExecArgs. -func argsFromProcess(p *specs.Process) (*control.ExecArgs, error) { +func argsFromProcess(p *specs.Process, enableRaw bool) (*control.ExecArgs, error) { // Create capabilities. var caps *auth.TaskCapabilities if p.Capabilities != nil { var err error - // enableRaw is set to true to prevent the filtering out of - // CAP_NET_RAW. This is the opposite of Create() because exec - // requires the capability to be set explicitly, while 'docker - // run' sets it by default. - caps, err = specutils.Capabilities(true /* enableRaw */, p.Capabilities) + // Starting from Docker 19, capabilities are explicitly set for exec (instead + // of nil like before). So we can't distinguish 'exec' from + // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter + // CAP_NET_RAW in the same way as container start. + caps, err = specutils.Capabilities(enableRaw, p.Capabilities) if err != nil { return nil, fmt.Errorf("error creating capabilities: %v", err) } @@ -410,7 +410,7 @@ func resolveEnvs(envs ...[]string) ([]string, error) { // capabilities takes a list of capabilities as strings and returns an // auth.TaskCapabilities struct with those capabilities in every capability set. // This mimics runc's behavior. -func capabilities(cs []string) (*auth.TaskCapabilities, error) { +func capabilities(cs []string, enableRaw bool) (*auth.TaskCapabilities, error) { var specCaps specs.LinuxCapabilities for _, cap := range cs { specCaps.Ambient = append(specCaps.Ambient, cap) @@ -419,11 +419,11 @@ func capabilities(cs []string) (*auth.TaskCapabilities, error) { specCaps.Inheritable = append(specCaps.Inheritable, cap) specCaps.Permitted = append(specCaps.Permitted, cap) } - // enableRaw is set to true to prevent the filtering out of - // CAP_NET_RAW. This is the opposite of Create() because exec requires - // the capability to be set explicitly, while 'docker run' sets it by - // default. - return specutils.Capabilities(true /* enableRaw */, &specCaps) + // Starting from Docker 19, capabilities are explicitly set for exec (instead + // of nil like before). So we can't distinguish 'exec' from + // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter + // CAP_NET_RAW in the same way as container start. + return specutils.Capabilities(enableRaw, &specCaps) } // stringSlice allows a flag to be used multiple times, where each occurrence diff --git a/runsc/cmd/exec_test.go b/runsc/cmd/exec_test.go index eb38a431f..a1e980d08 100644 --- a/runsc/cmd/exec_test.go +++ b/runsc/cmd/exec_test.go @@ -91,7 +91,7 @@ func TestCLIArgs(t *testing.T) { } for _, tc := range testCases { - e, err := tc.ex.argsFromCLI(tc.argv) + e, err := tc.ex.argsFromCLI(tc.argv, true) if err != nil { t.Errorf("argsFromCLI(%+v): got error: %+v", tc.ex, err) } else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) { @@ -144,7 +144,7 @@ func TestJSONArgs(t *testing.T) { } for _, tc := range testCases { - e, err := argsFromProcess(&tc.p) + e, err := argsFromProcess(&tc.p, true) if err != nil { t.Errorf("argsFromProcess(%+v): got error: %+v", tc.p, err) } else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) { diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index fbd579fb8..4831210c0 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -27,6 +27,7 @@ import ( "flag" "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/unet" @@ -135,7 +136,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // // Note that all mount points have been mounted in the proper location in // setupRootFS(). - cleanMounts, err := resolveMounts(spec.Mounts, root) + cleanMounts, err := resolveMounts(conf, spec.Mounts, root) if err != nil { Fatalf("Failure to resolve mounts: %v", err) } @@ -380,7 +381,7 @@ func setupMounts(mounts []specs.Mount, root string) error { // Otherwise, it may follow symlinks to locations that would be overwritten // with another mount point and return the wrong location. In short, make sure // setupMounts() has been called before. -func resolveMounts(mounts []specs.Mount, root string) ([]specs.Mount, error) { +func resolveMounts(conf *boot.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) { cleanMounts := make([]specs.Mount, 0, len(mounts)) for _, m := range mounts { if m.Type != "bind" || !specutils.IsSupportedDevMount(m) { @@ -395,8 +396,15 @@ func resolveMounts(mounts []specs.Mount, root string) ([]specs.Mount, error) { if err != nil { panic(fmt.Sprintf("%q could not be made relative to %q: %v", dst, root, err)) } + + opts, err := adjustMountOptions(conf, filepath.Join(root, relDst), m.Options) + if err != nil { + return nil, err + } + cpy := m cpy.Destination = filepath.Join("/", relDst) + cpy.Options = opts cleanMounts = append(cleanMounts, cpy) } return cleanMounts, nil @@ -423,7 +431,7 @@ func resolveSymlinksImpl(root, base, rel string, followCount uint) (string, erro path := filepath.Join(base, name) if !strings.HasPrefix(path, root) { // One cannot '..' their way out of root. - path = root + base = root continue } fi, err := os.Lstat(path) @@ -453,3 +461,20 @@ func resolveSymlinksImpl(root, base, rel string, followCount uint) (string, erro } return base, nil } + +// adjustMountOptions adds 'overlayfs_stale_read' if mounting over overlayfs. +func adjustMountOptions(conf *boot.Config, path string, opts []string) ([]string, error) { + rv := make([]string, len(opts)) + copy(rv, opts) + + if conf.OverlayfsStaleRead { + statfs := syscall.Statfs_t{} + if err := syscall.Statfs(path, &statfs); err != nil { + return nil, err + } + if statfs.Type == unix.OVERLAYFS_SUPER_MAGIC { + rv = append(rv, "overlayfs_stale_read") + } + } + return rv, nil +} diff --git a/runsc/container/BUILD b/runsc/container/BUILD index bc1fa25e3..2bd12120d 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "container.go", "hook.go", + "state_file.go", "status.go", ], importpath = "gvisor.dev/gvisor/runsc/container", @@ -47,6 +48,7 @@ go_test( ], deps = [ "//pkg/abi/linux", + "//pkg/bits", "//pkg/log", "//pkg/sentry/control", "//pkg/sentry/kernel", diff --git a/runsc/container/container.go b/runsc/container/container.go index a721c1c31..68782c4be 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -17,13 +17,11 @@ package container import ( "context" - "encoding/json" "fmt" "io/ioutil" "os" "os/exec" "os/signal" - "path/filepath" "regexp" "strconv" "strings" @@ -31,7 +29,6 @@ import ( "time" "github.com/cenkalti/backoff" - "github.com/gofrs/flock" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" @@ -41,17 +38,6 @@ import ( "gvisor.dev/gvisor/runsc/specutils" ) -const ( - // metadataFilename is the name of the metadata file relative to the - // container root directory that holds sandbox metadata. - metadataFilename = "meta.json" - - // metadataLockFilename is the name of a lock file in the container - // root directory that is used to prevent concurrent modifications to - // the container state and metadata. - metadataLockFilename = "meta.lock" -) - // validateID validates the container id. func validateID(id string) error { // See libcontainer/factory_linux.go. @@ -99,11 +85,6 @@ type Container struct { // BundleDir is the directory containing the container bundle. BundleDir string `json:"bundleDir"` - // Root is the directory containing the container metadata file. If this - // container is the root container, Root and RootContainerDir will be the - // same. - Root string `json:"root"` - // CreatedAt is the time the container was created. CreatedAt time.Time `json:"createdAt"` @@ -121,21 +102,24 @@ type Container struct { // be 0 if the gofer has been killed. GoferPid int `json:"goferPid"` + // Sandbox is the sandbox this container is running in. It's set when the + // container is created and reset when the sandbox is destroyed. + Sandbox *sandbox.Sandbox `json:"sandbox"` + + // Saver handles load from/save to the state file safely from multiple + // processes. + Saver StateFile `json:"saver"` + + // + // Fields below this line are not saved in the state file and will not + // be preserved across commands. + // + // goferIsChild is set if a gofer process is a child of the current process. // // This field isn't saved to json, because only a creator of a gofer // process will have it as a child process. goferIsChild bool - - // Sandbox is the sandbox this container is running in. It's set when the - // container is created and reset when the sandbox is destroyed. - Sandbox *sandbox.Sandbox `json:"sandbox"` - - // RootContainerDir is the root directory containing the metadata file of the - // sandbox root container. It's used to lock in order to serialize creating - // and deleting this Container's metadata directory. If this container is the - // root container, this is the same as Root. - RootContainerDir string } // loadSandbox loads all containers that belong to the sandbox with the given @@ -166,43 +150,35 @@ func loadSandbox(rootDir, id string) ([]*Container, error) { return containers, nil } -// Load loads a container with the given id from a metadata file. id may be an -// abbreviation of the full container id, in which case Load loads the -// container to which id unambiguously refers to. -// Returns ErrNotExist if container doesn't exist. -func Load(rootDir, id string) (*Container, error) { - log.Debugf("Load container %q %q", rootDir, id) - if err := validateID(id); err != nil { +// Load loads a container with the given id from a metadata file. partialID may +// be an abbreviation of the full container id, in which case Load loads the +// container to which id unambiguously refers to. Returns ErrNotExist if +// container doesn't exist. +func Load(rootDir, partialID string) (*Container, error) { + log.Debugf("Load container %q %q", rootDir, partialID) + if err := validateID(partialID); err != nil { return nil, fmt.Errorf("validating id: %v", err) } - cRoot, err := findContainerRoot(rootDir, id) + id, err := findContainerID(rootDir, partialID) if err != nil { // Preserve error so that callers can distinguish 'not found' errors. return nil, err } - // Lock the container metadata to prevent other runsc instances from - // writing to it while we are reading it. - unlock, err := lockContainerMetadata(cRoot) - if err != nil { - return nil, err + state := StateFile{ + RootDir: rootDir, + ID: id, } - defer unlock() + defer state.close() - // Read the container metadata file and create a new Container from it. - metaFile := filepath.Join(cRoot, metadataFilename) - metaBytes, err := ioutil.ReadFile(metaFile) - if err != nil { + c := &Container{} + if err := state.load(c); err != nil { if os.IsNotExist(err) { // Preserve error so that callers can distinguish 'not found' errors. return nil, err } - return nil, fmt.Errorf("reading container metadata file %q: %v", metaFile, err) - } - var c Container - if err := json.Unmarshal(metaBytes, &c); err != nil { - return nil, fmt.Errorf("unmarshaling container metadata from %q: %v", metaFile, err) + return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err) } // If the status is "Running" or "Created", check that the sandbox @@ -223,57 +199,37 @@ func Load(rootDir, id string) (*Container, error) { } } - return &c, nil + return c, nil } -func findContainerRoot(rootDir, partialID string) (string, error) { +func findContainerID(rootDir, partialID string) (string, error) { // Check whether the id fully specifies an existing container. - cRoot := filepath.Join(rootDir, partialID) - if _, err := os.Stat(cRoot); err == nil { - return cRoot, nil + stateFile := buildStatePath(rootDir, partialID) + if _, err := os.Stat(stateFile); err == nil { + return partialID, nil } // Now see whether id could be an abbreviation of exactly 1 of the // container ids. If id is ambiguous (it could match more than 1 // container), it is an error. - cRoot = "" ids, err := List(rootDir) if err != nil { return "", err } + rv := "" for _, id := range ids { if strings.HasPrefix(id, partialID) { - if cRoot != "" { - return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, cRoot, id) + if rv != "" { + return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id) } - cRoot = id + rv = id } } - if cRoot == "" { + if rv == "" { return "", os.ErrNotExist } - log.Debugf("abbreviated id %q resolves to full id %q", partialID, cRoot) - return filepath.Join(rootDir, cRoot), nil -} - -// List returns all container ids in the given root directory. -func List(rootDir string) ([]string, error) { - log.Debugf("List containers %q", rootDir) - fs, err := ioutil.ReadDir(rootDir) - if err != nil { - return nil, fmt.Errorf("reading dir %q: %v", rootDir, err) - } - var out []string - for _, f := range fs { - // Filter out directories that do no belong to a container. - cid := f.Name() - if validateID(cid) == nil { - if _, err := os.Stat(filepath.Join(rootDir, cid, metadataFilename)); err == nil { - out = append(out, f.Name()) - } - } - } - return out, nil + log.Debugf("abbreviated id %q resolves to full id %q", partialID, rv) + return rv, nil } // Args is used to configure a new container. @@ -316,44 +272,34 @@ func New(conf *boot.Config, args Args) (*Container, error) { return nil, err } - unlockRoot, err := maybeLockRootContainer(args.Spec, conf.RootDir) - if err != nil { - return nil, err + if err := os.MkdirAll(conf.RootDir, 0711); err != nil { + return nil, fmt.Errorf("creating container root directory: %v", err) } - defer unlockRoot() + + c := &Container{ + ID: args.ID, + Spec: args.Spec, + ConsoleSocket: args.ConsoleSocket, + BundleDir: args.BundleDir, + Status: Creating, + CreatedAt: time.Now(), + Owner: os.Getenv("USER"), + Saver: StateFile{ + RootDir: conf.RootDir, + ID: args.ID, + }, + } + // The Cleanup object cleans up partially created containers when an error + // occurs. Any errors occurring during cleanup itself are ignored. + cu := specutils.MakeCleanup(func() { _ = c.Destroy() }) + defer cu.Clean() // Lock the container metadata file to prevent concurrent creations of // containers with the same id. - containerRoot := filepath.Join(conf.RootDir, args.ID) - unlock, err := lockContainerMetadata(containerRoot) - if err != nil { + if err := c.Saver.lockForNew(); err != nil { return nil, err } - defer unlock() - - // Check if the container already exists by looking for the metadata - // file. - if _, err := os.Stat(filepath.Join(containerRoot, metadataFilename)); err == nil { - return nil, fmt.Errorf("container with id %q already exists", args.ID) - } else if !os.IsNotExist(err) { - return nil, fmt.Errorf("looking for existing container in %q: %v", containerRoot, err) - } - - c := &Container{ - ID: args.ID, - Spec: args.Spec, - ConsoleSocket: args.ConsoleSocket, - BundleDir: args.BundleDir, - Root: containerRoot, - Status: Creating, - CreatedAt: time.Now(), - Owner: os.Getenv("USER"), - RootContainerDir: conf.RootDir, - } - // The Cleanup object cleans up partially created containers when an error occurs. - // Any errors occuring during cleanup itself are ignored. - cu := specutils.MakeCleanup(func() { _ = c.Destroy() }) - defer cu.Clean() + defer c.Saver.unlock() // If the metadata annotations indicate that this container should be // started in an existing sandbox, we must do so. The metadata will @@ -431,7 +377,7 @@ func New(conf *boot.Config, args Args) (*Container, error) { c.changeStatus(Created) // Save the metadata file. - if err := c.save(); err != nil { + if err := c.saveLocked(); err != nil { return nil, err } @@ -451,17 +397,12 @@ func New(conf *boot.Config, args Args) (*Container, error) { func (c *Container) Start(conf *boot.Config) error { log.Debugf("Start container %q", c.ID) - unlockRoot, err := maybeLockRootContainer(c.Spec, c.RootContainerDir) - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlockRoot() + unlock := specutils.MakeCleanup(func() { c.Saver.unlock() }) + defer unlock.Clean() - unlock, err := c.lock() - if err != nil { - return err - } - defer unlock() if err := c.requireStatus("start", Created); err != nil { return err } @@ -509,14 +450,15 @@ func (c *Container) Start(conf *boot.Config) error { } c.changeStatus(Running) - if err := c.save(); err != nil { + if err := c.saveLocked(); err != nil { return err } - // Adjust the oom_score_adj for sandbox. This must be done after - // save(). - err = adjustSandboxOOMScoreAdj(c.Sandbox, c.RootContainerDir, false) - if err != nil { + // Release lock before adjusting OOM score because the lock is acquired there. + unlock.Clean() + + // Adjust the oom_score_adj for sandbox. This must be done after saveLocked(). + if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Saver.RootDir, false); err != nil { return err } @@ -529,11 +471,10 @@ func (c *Container) Start(conf *boot.Config) error { // to restore a container from its state file. func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile string) error { log.Debugf("Restore container %q", c.ID) - unlock, err := c.lock() - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlock() + defer c.Saver.unlock() if err := c.requireStatus("restore", Created); err != nil { return err @@ -551,7 +492,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile str return err } c.changeStatus(Running) - return c.save() + return c.saveLocked() } // Run is a helper that calls Create + Start + Wait. @@ -711,11 +652,10 @@ func (c *Container) Checkpoint(f *os.File) error { // The call only succeeds if the container's status is created or running. func (c *Container) Pause() error { log.Debugf("Pausing container %q", c.ID) - unlock, err := c.lock() - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlock() + defer c.Saver.unlock() if c.Status != Created && c.Status != Running { return fmt.Errorf("cannot pause container %q in state %v", c.ID, c.Status) @@ -725,18 +665,17 @@ func (c *Container) Pause() error { return fmt.Errorf("pausing container: %v", err) } c.changeStatus(Paused) - return c.save() + return c.saveLocked() } // Resume unpauses the container and its kernel. // The call only succeeds if the container's status is paused. func (c *Container) Resume() error { log.Debugf("Resuming container %q", c.ID) - unlock, err := c.lock() - if err != nil { + if err := c.Saver.lock(); err != nil { return err } - defer unlock() + defer c.Saver.unlock() if c.Status != Paused { return fmt.Errorf("cannot resume container %q in state %v", c.ID, c.Status) @@ -745,7 +684,7 @@ func (c *Container) Resume() error { return fmt.Errorf("resuming container: %v", err) } c.changeStatus(Running) - return c.save() + return c.saveLocked() } // State returns the metadata of the container. @@ -773,6 +712,17 @@ func (c *Container) Processes() ([]*control.Process, error) { func (c *Container) Destroy() error { log.Debugf("Destroy container %q", c.ID) + if err := c.Saver.lock(); err != nil { + return err + } + defer func() { + c.Saver.unlock() + c.Saver.close() + }() + + // Stored for later use as stop() sets c.Sandbox to nil. + sb := c.Sandbox + // We must perform the following cleanup steps: // * stop the container and gofer processes, // * remove the container filesystem on the host, and @@ -782,48 +732,43 @@ func (c *Container) Destroy() error { // do our best to perform all of the cleanups. Hence, we keep a slice // of errors return their concatenation. var errs []string - - unlock, err := maybeLockRootContainer(c.Spec, c.RootContainerDir) - if err != nil { - return err - } - defer unlock() - - // Stored for later use as stop() sets c.Sandbox to nil. - sb := c.Sandbox - if err := c.stop(); err != nil { err = fmt.Errorf("stopping container: %v", err) log.Warningf("%v", err) errs = append(errs, err.Error()) } - if err := os.RemoveAll(c.Root); err != nil && !os.IsNotExist(err) { - err = fmt.Errorf("deleting container root directory %q: %v", c.Root, err) + if err := c.Saver.destroy(); err != nil { + err = fmt.Errorf("deleting container state files: %v", err) log.Warningf("%v", err) errs = append(errs, err.Error()) } c.changeStatus(Stopped) - // Adjust oom_score_adj for the sandbox. This must be done after the - // container is stopped and the directory at c.Root is removed. - // We must test if the sandbox is nil because Destroy should be - // idempotent. - if sb != nil { - if err := adjustSandboxOOMScoreAdj(sb, c.RootContainerDir, true); err != nil { + // Adjust oom_score_adj for the sandbox. This must be done after the container + // is stopped and the directory at c.Root is removed. Adjustment can be + // skipped if the root container is exiting, because it brings down the entire + // sandbox. + // + // Use 'sb' to tell whether it has been executed before because Destroy must + // be idempotent. + if sb != nil && !isRoot(c.Spec) { + if err := adjustSandboxOOMScoreAdj(sb, c.Saver.RootDir, true); err != nil { errs = append(errs, err.Error()) } } // "If any poststop hook fails, the runtime MUST log a warning, but the - // remaining hooks and lifecycle continue as if the hook had succeeded" -OCI spec. - // Based on the OCI, "The post-stop hooks MUST be called after the container is - // deleted but before the delete operation returns" + // remaining hooks and lifecycle continue as if the hook had + // succeeded" - OCI spec. + // + // Based on the OCI, "The post-stop hooks MUST be called after the container + // is deleted but before the delete operation returns" // Run it here to: // 1) Conform to the OCI. - // 2) Make sure it only runs once, because the root has been deleted, the container - // can't be loaded again. + // 2) Make sure it only runs once, because the root has been deleted, the + // container can't be loaded again. if c.Spec.Hooks != nil { executeHooksBestEffort(c.Spec.Hooks.Poststop, c.State()) } @@ -834,18 +779,13 @@ func (c *Container) Destroy() error { return fmt.Errorf(strings.Join(errs, "\n")) } -// save saves the container metadata to a file. +// saveLocked saves the container metadata to a file. // // Precondition: container must be locked with container.lock(). -func (c *Container) save() error { +func (c *Container) saveLocked() error { log.Debugf("Save container %q", c.ID) - metaFile := filepath.Join(c.Root, metadataFilename) - meta, err := json.Marshal(c) - if err != nil { - return fmt.Errorf("invalid container metadata: %v", err) - } - if err := ioutil.WriteFile(metaFile, meta, 0640); err != nil { - return fmt.Errorf("writing container metadata: %v", err) + if err := c.Saver.saveLocked(c); err != nil { + return fmt.Errorf("saving container metadata: %v", err) } return nil } @@ -1106,50 +1046,8 @@ func (c *Container) requireStatus(action string, statuses ...Status) error { return fmt.Errorf("cannot %s container %q in state %s", action, c.ID, c.Status) } -// lock takes a file lock on the container metadata lock file. -func (c *Container) lock() (func() error, error) { - return lockContainerMetadata(filepath.Join(c.Root, c.ID)) -} - -// lockContainerMetadata takes a file lock on the metadata lock file in the -// given container root directory. -func lockContainerMetadata(containerRootDir string) (func() error, error) { - if err := os.MkdirAll(containerRootDir, 0711); err != nil { - return nil, fmt.Errorf("creating container root directory %q: %v", containerRootDir, err) - } - f := filepath.Join(containerRootDir, metadataLockFilename) - l := flock.NewFlock(f) - if err := l.Lock(); err != nil { - return nil, fmt.Errorf("acquiring lock on container lock file %q: %v", f, err) - } - return l.Unlock, nil -} - -// maybeLockRootContainer locks the sandbox root container. It is used to -// prevent races to create and delete child container sandboxes. -func maybeLockRootContainer(spec *specs.Spec, rootDir string) (func() error, error) { - if isRoot(spec) { - return func() error { return nil }, nil - } - - sbid, ok := specutils.SandboxID(spec) - if !ok { - return nil, fmt.Errorf("no sandbox ID found when locking root container") - } - sb, err := Load(rootDir, sbid) - if err != nil { - return nil, err - } - - unlock, err := sb.lock() - if err != nil { - return nil, err - } - return unlock, nil -} - func isRoot(spec *specs.Spec) bool { - return specutils.ShouldCreateSandbox(spec) + return specutils.SpecContainerType(spec) != specutils.ContainerTypeContainer } // runInCgroup executes fn inside the specified cgroup. If cg is nil, execute @@ -1170,7 +1068,12 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error { func (c *Container) adjustGoferOOMScoreAdj() error { if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil { if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil { - return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err) + // Ignore NotExist error because it can be returned when the sandbox + // exited while OOM score was being adjusted. + if !os.IsNotExist(err) { + return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err) + } + log.Warningf("Gofer process (%d) not found setting oom_score_adj", c.GoferPid) } } @@ -1198,7 +1101,7 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) // Get the lowest score for all containers. var lowScore int scoreFound := false - if len(containers) == 1 && len(containers[0].Spec.Annotations[specutils.ContainerdContainerTypeAnnotation]) == 0 { + if len(containers) == 1 && specutils.SpecContainerType(containers[0].Spec) == specutils.ContainerTypeUnspecified { // This is a single-container sandbox. Set the oom_score_adj to // the value specified in the OCI bundle. if containers[0].Spec.Process.OOMScoreAdj != nil { @@ -1214,7 +1117,7 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) // // We will use OOMScoreAdj in the single-container case where the // containerd container-type annotation is not present. - if container.Spec.Annotations[specutils.ContainerdContainerTypeAnnotation] == specutils.ContainerdContainerTypeSandbox { + if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox { continue } @@ -1252,7 +1155,12 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) // Set the lowest of all containers oom_score_adj to the sandbox. if err := setOOMScoreAdj(s.Pid, lowScore); err != nil { - return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err) + // Ignore NotExist error because it can be returned when the sandbox + // exited while OOM score was being adjusted. + if !os.IsNotExist(err) { + return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err) + } + log.Warningf("Sandbox process (%d) not found setting oom_score_adj", s.Pid) } return nil diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 2ac12e5b6..07eacaac0 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -34,6 +34,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -1547,7 +1548,8 @@ func TestAbbreviatedIDs(t *testing.T) { } defer os.RemoveAll(rootDir) - conf := testutil.TestConfigWithRoot(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir cids := []string{ "foo-" + testutil.UniqueContainerID(), @@ -2049,6 +2051,67 @@ func TestMountSymlink(t *testing.T) { } } +// Check that --net-raw disables the CAP_NET_RAW capability. +func TestNetRaw(t *testing.T) { + capNetRaw := strconv.FormatUint(bits.MaskOf64(int(linux.CAP_NET_RAW)), 10) + app, err := testutil.FindFile("runsc/container/test_app/test_app") + if err != nil { + t.Fatal("error finding test_app:", err) + } + + for _, enableRaw := range []bool{true, false} { + conf := testutil.TestConfig() + conf.EnableRaw = enableRaw + + test := "--enabled" + if !enableRaw { + test = "--disabled" + } + + spec := testutil.NewSpecWithArgs(app, "capability", test, capNetRaw) + if err := run(spec, conf); err != nil { + t.Fatalf("Error running container: %v", err) + } + } +} + +// TestOverlayfsStaleRead most basic test that '--overlayfs-stale-read' works. +func TestOverlayfsStaleRead(t *testing.T) { + conf := testutil.TestConfig() + conf.OverlayfsStaleRead = true + + in, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.in") + if err != nil { + t.Fatalf("ioutil.TempFile() failed: %v", err) + } + defer in.Close() + if _, err := in.WriteString("stale data"); err != nil { + t.Fatalf("in.Write() failed: %v", err) + } + + out, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.out") + if err != nil { + t.Fatalf("ioutil.TempFile() failed: %v", err) + } + defer out.Close() + + const want = "foobar" + cmd := fmt.Sprintf("cat %q && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name()) + spec := testutil.NewSpecWithArgs("/bin/bash", "-c", cmd) + if err := run(spec, conf); err != nil { + t.Fatalf("Error running container: %v", err) + } + + gotBytes, err := ioutil.ReadAll(out) + if err != nil { + t.Fatalf("out.Read() failed: %v", err) + } + got := strings.TrimSpace(string(gotBytes)) + if want != got { + t.Errorf("Wrong content in out file, got: %q. want: %q", got, want) + } +} + // executeSync synchronously executes a new process. func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) { pid, err := cont.Execute(args) diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index bd45a5118..a5a62378c 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -60,13 +60,8 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { } func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*Container, func(), error) { - // Setup root dir if one hasn't been provided. if len(conf.RootDir) == 0 { - rootDir, err := testutil.SetupRootDir() - if err != nil { - return nil, nil, fmt.Errorf("error creating root dir: %v", err) - } - conf.RootDir = rootDir + panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") } var containers []*Container @@ -78,7 +73,6 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*C for _, b := range bundles { os.RemoveAll(b) } - os.RemoveAll(conf.RootDir) } for i, spec := range specs { bundleDir, err := testutil.SetupBundleDir(spec) @@ -144,6 +138,13 @@ func TestMultiContainerSanity(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep) @@ -175,6 +176,13 @@ func TestMultiPIDNS(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} testSpecs, ids := createSpecs(sleep, sleep) @@ -213,6 +221,13 @@ func TestMultiPIDNSPath(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} testSpecs, ids := createSpecs(sleep, sleep, sleep) @@ -268,13 +283,21 @@ func TestMultiPIDNSPath(t *testing.T) { } func TestMultiContainerWait(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // The first container should run the entire duration of the test. cmd1 := []string{"sleep", "100"} // We'll wait on the second container, which is much shorter lived. cmd2 := []string{"sleep", "1"} specs, ids := createSpecs(cmd1, cmd2) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -344,12 +367,14 @@ func TestExecWait(t *testing.T) { } defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir + // The first container should run the entire duration of the test. cmd1 := []string{"sleep", "100"} // We'll wait on the second container, which is much shorter lived. cmd2 := []string{"sleep", "1"} specs, ids := createSpecs(cmd1, cmd2) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -432,7 +457,15 @@ func TestMultiContainerMount(t *testing.T) { }) // Setup the containers. + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir + containers, cleanup, err := startContainers(conf, sps, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -454,6 +487,13 @@ func TestMultiContainerSignal(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep) @@ -548,6 +588,13 @@ func TestMultiContainerDestroy(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // First container will remain intact while the second container is killed. podSpecs, ids := createSpecs( []string{"sleep", "100"}, @@ -599,13 +646,21 @@ func TestMultiContainerDestroy(t *testing.T) { } func TestMultiContainerProcesses(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // Note: use curly braces to keep 'sh' process around. Otherwise, shell // will just execve into 'sleep' and both containers will look the // same. specs, ids := createSpecs( []string{"sleep", "100"}, []string{"sh", "-c", "{ sleep 100; }"}) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -650,6 +705,15 @@ func TestMultiContainerProcesses(t *testing.T) { // TestMultiContainerKillAll checks that all process that belong to a container // are killed when SIGKILL is sent to *all* processes in that container. func TestMultiContainerKillAll(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + for _, tc := range []struct { killContainer bool }{ @@ -665,7 +729,6 @@ func TestMultiContainerKillAll(t *testing.T) { specs, ids := createSpecs( []string{app, "task-tree", "--depth=2", "--width=2"}, []string{app, "task-tree", "--depth=4", "--width=2"}) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -739,19 +802,13 @@ func TestMultiContainerDestroyNotStarted(t *testing.T) { specs, ids := createSpecs( []string{"/bin/sleep", "100"}, []string{"/bin/sleep", "100"}) - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - conf := testutil.TestConfigWithRoot(rootDir) - - // Create and start root container. - rootBundleDir, err := testutil.SetupBundleDir(specs[0]) + conf := testutil.TestConfig() + rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) if err != nil { t.Fatalf("error setting up container: %v", err) } + defer os.RemoveAll(rootDir) defer os.RemoveAll(rootBundleDir) rootArgs := Args{ @@ -800,19 +857,12 @@ func TestMultiContainerDestroyStarting(t *testing.T) { } specs, ids := createSpecs(cmds...) - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - - conf := testutil.TestConfigWithRoot(rootDir) - - // Create and start root container. - rootBundleDir, err := testutil.SetupBundleDir(specs[0]) + conf := testutil.TestConfig() + rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf) if err != nil { t.Fatalf("error setting up container: %v", err) } + defer os.RemoveAll(rootDir) defer os.RemoveAll(rootBundleDir) rootArgs := Args{ @@ -886,9 +936,17 @@ func TestMultiContainerDifferentFilesystems(t *testing.T) { script := fmt.Sprintf("if [ -f %q ]; then exit 1; else touch %q; fi", filename, filename) cmd := []string{"sh", "-c", script} + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // Make sure overlay is enabled, and none of the root filesystems are // read-only, otherwise we won't be able to create the file. - conf := testutil.TestConfig() conf.Overlay = true specs, ids := createSpecs(cmdRoot, cmd, cmd) for _, s := range specs { @@ -941,26 +999,21 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) { } allSpecs, allIDs := createSpecs(cmds...) - rootDir, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer os.RemoveAll(rootDir) - // Split up the specs and IDs. rootSpec := allSpecs[0] rootID := allIDs[0] childrenSpecs := allSpecs[1:] childrenIDs := allIDs[1:] - bundleDir, err := testutil.SetupBundleDir(rootSpec) + conf := testutil.TestConfig() + rootDir, bundleDir, err := testutil.SetupContainer(rootSpec, conf) if err != nil { - t.Fatalf("error setting up bundle dir: %v", err) + t.Fatalf("error setting up container: %v", err) } + defer os.RemoveAll(rootDir) defer os.RemoveAll(bundleDir) // Start root container. - conf := testutil.TestConfigWithRoot(rootDir) rootArgs := Args{ ID: rootID, Spec: rootSpec, @@ -1029,6 +1082,13 @@ func TestMultiContainerSharedMount(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} podSpec, ids := createSpecs(sleep, sleep) @@ -1137,6 +1197,13 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} podSpec, ids := createSpecs(sleep, sleep) @@ -1197,6 +1264,13 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { for _, conf := range configs(all...) { t.Logf("Running test with conf: %+v", conf) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf.RootDir = rootDir + // Setup the containers. sleep := []string{"sleep", "100"} podSpec, ids := createSpecs(sleep, sleep) @@ -1297,6 +1371,59 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { } } +// Test that unsupported pod mounts options are ignored when matching master and +// slave mounts. +func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + + // Setup the containers. + sleep := []string{"/bin/sleep", "100"} + podSpec, ids := createSpecs(sleep, sleep) + mnt0 := specs.Mount{ + Destination: "/mydir/test", + Source: "/some/dir", + Type: "tmpfs", + Options: []string{"rw", "rbind", "relatime"}, + } + podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) + + mnt1 := mnt0 + mnt1.Destination = "/mydir2/test2" + mnt1.Options = []string{"rw", "nosuid"} + podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) + + createSharedMount(mnt0, "test-mount", podSpec...) + + containers, cleanup, err := startContainers(conf, podSpec, ids) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() + + execs := []execDesc{ + { + c: containers[0], + cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, + desc: "directory is mounted in container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, + desc: "directory is mounted in container1", + }, + } + if err := execMany(execs); err != nil { + t.Fatal(err.Error()) + } +} + // Test that one container can send an FD to another container, even though // they have distinct MountNamespaces. func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { @@ -1329,6 +1456,15 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { Type: "tmpfs", } + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + // Create the specs. specs, ids := createSpecs( []string{"sleep", "1000"}, @@ -1339,7 +1475,6 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { specs[1].Mounts = append(specs[2].Mounts, sharedMnt, writeableMnt) specs[2].Mounts = append(specs[1].Mounts, sharedMnt) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -1358,9 +1493,17 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) { // Test that container is destroyed when Gofer is killed. func TestMultiContainerGoferKilled(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep, sleep) - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -1436,7 +1579,15 @@ func TestMultiContainerGoferKilled(t *testing.T) { func TestMultiContainerLoadSandbox(t *testing.T) { sleep := []string{"sleep", "100"} specs, ids := createSpecs(sleep, sleep, sleep) + + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir // Create containers for the sandbox. wants, cleanup, err := startContainers(conf, specs, ids) @@ -1529,7 +1680,15 @@ func TestMultiContainerRunNonRoot(t *testing.T) { Type: "bind", }) + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + conf := testutil.TestConfig() + conf.RootDir = rootDir + pod, cleanup, err := startContainers(conf, podSpecs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go new file mode 100644 index 000000000..d95151ea5 --- /dev/null +++ b/runsc/container/state_file.go @@ -0,0 +1,185 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package container + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "sync" + + "github.com/gofrs/flock" + "gvisor.dev/gvisor/pkg/log" +) + +const stateFileExtension = ".state" + +// StateFile handles load from/save to container state safely from multiple +// processes. It uses a lock file to provide synchronization between operations. +// +// The lock file is located at: "${s.RootDir}/${s.ID}.lock". +// The state file is located at: "${s.RootDir}/${s.ID}.state". +type StateFile struct { + // RootDir is the directory containing the container metadata file. + RootDir string `json:"rootDir"` + + // ID is the container ID. + ID string `json:"id"` + + // + // Fields below this line are not saved in the state file and will not + // be preserved across commands. + // + + once sync.Once + flock *flock.Flock +} + +// List returns all container ids in the given root directory. +func List(rootDir string) ([]string, error) { + log.Debugf("List containers %q", rootDir) + list, err := filepath.Glob(filepath.Join(rootDir, "*"+stateFileExtension)) + if err != nil { + return nil, err + } + var out []string + for _, path := range list { + // Filter out files that do no belong to a container. + fileName := filepath.Base(path) + if len(fileName) < len(stateFileExtension) { + panic(fmt.Sprintf("invalid file match %q", path)) + } + // Remove the extension. + cid := fileName[:len(fileName)-len(stateFileExtension)] + if validateID(cid) == nil { + out = append(out, cid) + } + } + return out, nil +} + +// lock globally locks all locking operations for the container. +func (s *StateFile) lock() error { + s.once.Do(func() { + s.flock = flock.NewFlock(s.lockPath()) + }) + + if err := s.flock.Lock(); err != nil { + return fmt.Errorf("acquiring lock on %q: %v", s.flock, err) + } + return nil +} + +// lockForNew acquires the lock and checks if the state file doesn't exist. This +// is done to ensure that more than one creation didn't race to create +// containers with the same ID. +func (s *StateFile) lockForNew() error { + if err := s.lock(); err != nil { + return err + } + + // Checks if the container already exists by looking for the metadata file. + if _, err := os.Stat(s.statePath()); err == nil { + s.unlock() + return fmt.Errorf("container already exists") + } else if !os.IsNotExist(err) { + s.unlock() + return fmt.Errorf("looking for existing container: %v", err) + } + return nil +} + +// unlock globally unlocks all locking operations for the container. +func (s *StateFile) unlock() error { + if !s.flock.Locked() { + panic("unlock called without lock held") + } + + if err := s.flock.Unlock(); err != nil { + log.Warningf("Error to release lock on %q: %v", s.flock, err) + return fmt.Errorf("releasing lock on %q: %v", s.flock, err) + } + return nil +} + +// saveLocked saves 'v' to the state file. +// +// Preconditions: lock() must been called before. +func (s *StateFile) saveLocked(v interface{}) error { + if !s.flock.Locked() { + panic("saveLocked called without lock held") + } + + meta, err := json.Marshal(v) + if err != nil { + return err + } + if err := ioutil.WriteFile(s.statePath(), meta, 0640); err != nil { + return fmt.Errorf("writing json file: %v", err) + } + return nil +} + +func (s *StateFile) load(v interface{}) error { + if err := s.lock(); err != nil { + return err + } + defer s.unlock() + + metaBytes, err := ioutil.ReadFile(s.statePath()) + if err != nil { + return err + } + return json.Unmarshal(metaBytes, &v) +} + +func (s *StateFile) close() error { + if s.flock == nil { + return nil + } + if s.flock.Locked() { + panic("Closing locked file") + } + return s.flock.Close() +} + +func buildStatePath(rootDir, id string) string { + return filepath.Join(rootDir, id+stateFileExtension) +} + +// statePath is the full path to the state file. +func (s *StateFile) statePath() string { + return buildStatePath(s.RootDir, s.ID) +} + +// lockPath is the full path to the lock file. +func (s *StateFile) lockPath() string { + return filepath.Join(s.RootDir, s.ID+".lock") +} + +// destroy deletes all state created by the stateFile. It may be called with the +// lock file held. In that case, the lock file must still be unlocked and +// properly closed after destroy returns. +func (s *StateFile) destroy() error { + if err := os.Remove(s.statePath()); err != nil && !os.IsNotExist(err) { + return err + } + if err := os.Remove(s.lockPath()); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go index 7f735c254..913d781c6 100644 --- a/runsc/container/test_app/test_app.go +++ b/runsc/container/test_app/test_app.go @@ -19,10 +19,12 @@ package main import ( "context" "fmt" + "io/ioutil" "log" "net" "os" "os/exec" + "regexp" "strconv" sys "syscall" "time" @@ -35,6 +37,7 @@ import ( func main() { subcommands.Register(subcommands.HelpCommand(), "") subcommands.Register(subcommands.FlagsCommand(), "") + subcommands.Register(new(capability), "") subcommands.Register(new(fdReceiver), "") subcommands.Register(new(fdSender), "") subcommands.Register(new(forkBomb), "") @@ -287,3 +290,65 @@ func (s *syscall) Execute(ctx context.Context, f *flag.FlagSet, args ...interfac } return subcommands.ExitSuccess } + +type capability struct { + enabled uint64 + disabled uint64 +} + +// Name implements subcommands.Command. +func (*capability) Name() string { + return "capability" +} + +// Synopsis implements subcommands.Command. +func (*capability) Synopsis() string { + return "checks if effective capabilities are set/unset" +} + +// Usage implements subcommands.Command. +func (*capability) Usage() string { + return "capability [--enabled=number] [--disabled=number]" +} + +// SetFlags implements subcommands.Command. +func (c *capability) SetFlags(f *flag.FlagSet) { + f.Uint64Var(&c.enabled, "enabled", 0, "") + f.Uint64Var(&c.disabled, "disabled", 0, "") +} + +// Execute implements subcommands.Command. +func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if c.enabled == 0 && c.disabled == 0 { + fmt.Println("One of the flags must be set") + return subcommands.ExitUsageError + } + + status, err := ioutil.ReadFile("/proc/self/status") + if err != nil { + fmt.Printf("Error reading %q: %v\n", "proc/self/status", err) + return subcommands.ExitFailure + } + re := regexp.MustCompile("CapEff:\t([0-9a-f]+)\n") + matches := re.FindStringSubmatch(string(status)) + if matches == nil || len(matches) != 2 { + fmt.Printf("Effective capabilities not found in\n%s\n", status) + return subcommands.ExitFailure + } + caps, err := strconv.ParseUint(matches[1], 16, 64) + if err != nil { + fmt.Printf("failed to convert capabilities %q: %v\n", matches[1], err) + return subcommands.ExitFailure + } + + if c.enabled != 0 && (caps&c.enabled) != c.enabled { + fmt.Printf("Missing capabilities, want: %#x: got: %#x\n", c.enabled, caps) + return subcommands.ExitFailure + } + if c.disabled != 0 && (caps&c.disabled) != 0 { + fmt.Printf("Extra capabilities found, dont_want: %#x: got: %#x\n", c.disabled, caps) + return subcommands.ExitFailure + } + + return subcommands.ExitSuccess +} diff --git a/runsc/criutil/criutil.go b/runsc/criutil/criutil.go index c8ddf5a9a..773f5a1c4 100644 --- a/runsc/criutil/criutil.go +++ b/runsc/criutil/criutil.go @@ -157,13 +157,55 @@ func (cc *Crictl) RmPod(podID string) error { return err } -// StartPodAndContainer pulls an image, then starts a sandbox and container in -// that sandbox. It returns the pod ID and container ID. -func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) { +// StartContainer pulls the given image ands starts the container in the +// sandbox with the given podID. +func (cc *Crictl) StartContainer(podID, image, sbSpec, contSpec string) (string, error) { + // Write the specs to files that can be read by crictl. + sbSpecFile, err := testutil.WriteTmpFile("sbSpec", sbSpec) + if err != nil { + return "", fmt.Errorf("failed to write sandbox spec: %v", err) + } + contSpecFile, err := testutil.WriteTmpFile("contSpec", contSpec) + if err != nil { + return "", fmt.Errorf("failed to write container spec: %v", err) + } + + return cc.startContainer(podID, image, sbSpecFile, contSpecFile) +} + +func (cc *Crictl) startContainer(podID, image, sbSpecFile, contSpecFile string) (string, error) { if err := cc.Pull(image); err != nil { - return "", "", fmt.Errorf("failed to pull %s: %v", image, err) + return "", fmt.Errorf("failed to pull %s: %v", image, err) + } + + contID, err := cc.Create(podID, contSpecFile, sbSpecFile) + if err != nil { + return "", fmt.Errorf("failed to create container in pod %q: %v", podID, err) + } + + if _, err := cc.Start(contID); err != nil { + return "", fmt.Errorf("failed to start container %q in pod %q: %v", contID, podID, err) + } + + return contID, nil +} + +// StopContainer stops and deletes the container with the given container ID. +func (cc *Crictl) StopContainer(contID string) error { + if err := cc.Stop(contID); err != nil { + return fmt.Errorf("failed to stop container %q: %v", contID, err) + } + + if err := cc.Rm(contID); err != nil { + return fmt.Errorf("failed to remove container %q: %v", contID, err) } + return nil +} + +// StartPodAndContainer pulls an image, then starts a sandbox and container in +// that sandbox. It returns the pod ID and container ID. +func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) { // Write the specs to files that can be read by crictl. sbSpecFile, err := testutil.WriteTmpFile("sbSpec", sbSpec) if err != nil { @@ -179,28 +221,17 @@ func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, return "", "", err } - contID, err := cc.Create(podID, contSpecFile, sbSpecFile) - if err != nil { - return "", "", fmt.Errorf("failed to create container in pod %q: %v", podID, err) - } + contID, err := cc.startContainer(podID, image, sbSpecFile, contSpecFile) - if _, err := cc.Start(contID); err != nil { - return "", "", fmt.Errorf("failed to start container %q in pod %q: %v", contID, podID, err) - } - - return podID, contID, nil + return podID, contID, err } // StopPodAndContainer stops a container and pod. func (cc *Crictl) StopPodAndContainer(podID, contID string) error { - if err := cc.Stop(contID); err != nil { + if err := cc.StopContainer(contID); err != nil { return fmt.Errorf("failed to stop container %q in pod %q: %v", contID, podID, err) } - if err := cc.Rm(contID); err != nil { - return fmt.Errorf("failed to remove container %q in pod %q: %v", contID, podID, err) - } - if err := cc.StopPod(podID); err != nil { return fmt.Errorf("failed to stop pod %q: %v", podID, err) } diff --git a/runsc/debian/description b/runsc/debian/description index 6e3b1b2c0..9e8e08805 100644 --- a/runsc/debian/description +++ b/runsc/debian/description @@ -1,5 +1 @@ -gVisor is a user-space kernel, written in Go, that implements a substantial -portion of the Linux system surface. It includes an Open Container Initiative -(OCI) runtime called runsc that provides an isolation boundary between the -application and the host kernel. The runsc runtime integrates with Docker and -Kubernetes, making it simple to run sandboxed containers. +gVisor container sandbox runtime diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go index e37ec0ffd..57f6ae8de 100644 --- a/runsc/dockerutil/dockerutil.go +++ b/runsc/dockerutil/dockerutil.go @@ -282,7 +282,14 @@ func (d *Docker) Logs() (string, error) { // Exec calls 'docker exec' with the arguments provided. func (d *Docker) Exec(args ...string) (string, error) { - a := []string{"exec", d.Name} + return d.ExecWithFlags(nil, args...) +} + +// ExecWithFlags calls 'docker exec <flags> name <args>'. +func (d *Docker) ExecWithFlags(flags []string, args ...string) (string, error) { + a := []string{"exec"} + a = append(a, flags...) + a = append(a, d.Name) a = append(a, args...) return do(a...) } diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index 80a4aa2fe..afcb41801 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -6,6 +6,8 @@ go_library( name = "fsgofer", srcs = [ "fsgofer.go", + "fsgofer_amd64_unsafe.go", + "fsgofer_arm64_unsafe.go", "fsgofer_unsafe.go", ], importpath = "gvisor.dev/gvisor/runsc/fsgofer", diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD index 02168ad1b..bac73f89d 100644 --- a/runsc/fsgofer/filter/BUILD +++ b/runsc/fsgofer/filter/BUILD @@ -6,6 +6,8 @@ go_library( name = "filter", srcs = [ "config.go", + "config_amd64.go", + "config_arm64.go", "extra_filters.go", "extra_filters_msan.go", "extra_filters_race.go", diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index c7922b54f..a1792330f 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -25,11 +25,7 @@ import ( // allowedSyscalls is the set of syscalls executed by the gofer. var allowedSyscalls = seccomp.SyscallRules{ - syscall.SYS_ACCEPT: {}, - syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_GET_FS)}, - {seccomp.AllowValue(linux.ARCH_SET_FS)}, - }, + syscall.SYS_ACCEPT: {}, syscall.SYS_CLOCK_GETTIME: {}, syscall.SYS_CLONE: []seccomp.Rule{ { @@ -155,7 +151,6 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_MPROTECT: {}, syscall.SYS_MUNMAP: {}, syscall.SYS_NANOSLEEP: {}, - syscall.SYS_NEWFSTATAT: {}, syscall.SYS_OPENAT: {}, syscall.SYS_PPOLL: {}, syscall.SYS_PREAD64: {}, @@ -177,6 +172,7 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_RENAMEAT: {}, syscall.SYS_RESTART_SYSCALL: {}, syscall.SYS_RT_SIGPROCMASK: {}, + syscall.SYS_RT_SIGRETURN: {}, syscall.SYS_SCHED_YIELD: {}, syscall.SYS_SENDMSG: []seccomp.Rule{ // Used by fdchannel.Endpoint.SendFD(). @@ -219,6 +215,18 @@ var udsSyscalls = seccomp.SyscallRules{ syscall.SYS_SOCKET: []seccomp.Rule{ { seccomp.AllowValue(syscall.AF_UNIX), + seccomp.AllowValue(syscall.SOCK_STREAM), + seccomp.AllowValue(0), + }, + { + seccomp.AllowValue(syscall.AF_UNIX), + seccomp.AllowValue(syscall.SOCK_DGRAM), + seccomp.AllowValue(0), + }, + { + seccomp.AllowValue(syscall.AF_UNIX), + seccomp.AllowValue(syscall.SOCK_SEQPACKET), + seccomp.AllowValue(0), }, }, syscall.SYS_CONNECT: []seccomp.Rule{ diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go new file mode 100644 index 000000000..a4b28cb8b --- /dev/null +++ b/runsc/fsgofer/filter/config_amd64.go @@ -0,0 +1,33 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package filter + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/seccomp" +) + +func init() { + allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{ + {seccomp.AllowValue(linux.ARCH_GET_FS)}, + {seccomp.AllowValue(linux.ARCH_SET_FS)}, + } + + allowedSyscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{} +} diff --git a/runsc/fsgofer/filter/config_arm64.go b/runsc/fsgofer/filter/config_arm64.go new file mode 100644 index 000000000..d2697deb7 --- /dev/null +++ b/runsc/fsgofer/filter/config_arm64.go @@ -0,0 +1,27 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package filter + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/seccomp" +) + +func init() { + allowedSyscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{} +} diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index a570f1a41..9117d9616 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -21,7 +21,6 @@ package fsgofer import ( - "errors" "fmt" "io" "math" @@ -126,63 +125,31 @@ func NewAttachPoint(prefix string, c Config) (p9.Attacher, error) { // Attach implements p9.Attacher. func (a *attachPoint) Attach() (p9.File, error) { - // dirFD (1st argument) is ignored because 'prefix' is always absolute. - stat, err := statAt(-1, a.prefix) - if err != nil { - return nil, fmt.Errorf("stat file %q, err: %v", a.prefix, err) - } - - // Acquire the attach point lock. a.attachedMu.Lock() defer a.attachedMu.Unlock() - // Hold the file descriptor we are converting into a p9.File. - var f *fd.FD - - // Apply the S_IFMT bitmask so we can detect file type appropriately. - switch fmtStat := stat.Mode & syscall.S_IFMT; fmtStat { - case syscall.S_IFSOCK: - // Check to see if the CLI option has been set to allow the UDS mount. - if !a.conf.HostUDS { - return nil, errors.New("host UDS support is disabled") - } - - // Attempt to open a connection. Bubble up the failures. - f, err = fd.DialUnix(a.prefix) - if err != nil { - return nil, err - } - - default: - // Default to Read/Write permissions. - mode := syscall.O_RDWR - - // If the configuration is Read Only or the mount point is a directory, - // set the mode to Read Only. - if a.conf.ROMount || fmtStat == syscall.S_IFDIR { - mode = syscall.O_RDONLY - } + if a.attached { + return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix) + } - // Open the mount point & capture the FD. - f, err = fd.Open(a.prefix, openFlags|mode, 0) - if err != nil { - return nil, fmt.Errorf("unable to open file %q, err: %v", a.prefix, err) - } + f, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) { + return fd.Open(a.prefix, openFlags|mode, 0) + }) + if err != nil { + return nil, fmt.Errorf("unable to open %q: %v", a.prefix, err) } - // Close the connection if already attached. - if a.attached { - f.Close() - return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix) + stat, err := stat(f.FD()) + if err != nil { + return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err) } - // Return a localFile object to the caller with the UDS FD included. - rv, err := newLocalFile(a, f, a.prefix, stat) + lf, err := newLocalFile(a, f, a.prefix, stat) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err) } a.attached = true - return rv, nil + return lf, nil } // makeQID returns a unique QID for the given stat buffer. @@ -298,10 +265,10 @@ func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, erro // actual file open and is customizable by the caller. func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error) { // Attempt to open file in the following mode in order: - // 1. RDONLY | NONBLOCK: for all files, works for directories and ro mounts too. - // Use non-blocking to prevent getting stuck inside open(2) for FIFOs. This option - // has no effect on regular files. - // 2. PATH: for symlinks + // 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs. + // Use non-blocking to prevent getting stuck inside open(2) for + // FIFOs. This option has no effect on regular files. + // 2. PATH: for symlinks, sockets. modes := []int{syscall.O_RDONLY | syscall.O_NONBLOCK, unix.O_PATH} var err error @@ -330,7 +297,7 @@ func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error) return file, nil } -func getSupportedFileType(stat syscall.Stat_t) (fileType, error) { +func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) { var ft fileType switch stat.Mode & syscall.S_IFMT { case syscall.S_IFREG: @@ -340,6 +307,9 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) { case syscall.S_IFLNK: ft = symlink case syscall.S_IFSOCK: + if !permitSocket { + return unknown, syscall.EPERM + } ft = socket default: return unknown, syscall.EPERM @@ -348,7 +318,7 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) { } func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) { - ft, err := getSupportedFileType(stat) + ft, err := getSupportedFileType(stat, a.conf.HostUDS) if err != nil { return nil, err } @@ -396,23 +366,24 @@ func fchown(fd int, uid p9.UID, gid p9.GID) error { } // Open implements p9.File. -func (l *localFile) Open(mode p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { +func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { if l.isOpen() { panic(fmt.Sprintf("attempting to open already opened file: %q", l.hostPath)) } // Check if control file can be used or if a new open must be created. var newFile *fd.FD - if mode == p9.ReadOnly { - log.Debugf("Open reusing control file, mode: %v, %q", mode, l.hostPath) + if flags == p9.ReadOnly { + log.Debugf("Open reusing control file, flags: %v, %q", flags, l.hostPath) newFile = l.file } else { // Ideally reopen would call name_to_handle_at (with empty name) and // open_by_handle_at to reopen the file without using 'hostPath'. However, // name_to_handle_at and open_by_handle_at aren't supported by overlay2. - log.Debugf("Open reopening file, mode: %v, %q", mode, l.hostPath) + log.Debugf("Open reopening file, flags: %v, %q", flags, l.hostPath) var err error - newFile, err = reopenProcFd(l.file, openFlags|mode.OSFlags()) + // Constrain open flags to the open mode and O_TRUNC. + newFile, err = reopenProcFd(l.file, openFlags|(flags.OSFlags()&(syscall.O_ACCMODE|syscall.O_TRUNC))) if err != nil { return nil, p9.QID{}, 0, extractErrno(err) } @@ -439,7 +410,7 @@ func (l *localFile) Open(mode p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { } l.file = newFile } - l.mode = mode + l.mode = flags & p9.OpenFlagsModeMask return fd, l.attachPoint.makeQID(stat), 0, nil } @@ -631,7 +602,7 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) Mode: p9.FileMode(stat.Mode), UID: p9.UID(stat.Uid), GID: p9.GID(stat.Gid), - NLink: stat.Nlink, + NLink: uint64(stat.Nlink), RDev: stat.Rdev, Size: uint64(stat.Size), BlockSize: uint64(stat.Blksize), @@ -1062,12 +1033,48 @@ func (l *localFile) Flush() error { } // Connect implements p9.File. -func (l *localFile) Connect(p9.ConnectFlags) (*fd.FD, error) { - // Check to see if the CLI option has been set to allow the UDS mount. +func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) { if !l.attachPoint.conf.HostUDS { - return nil, errors.New("host UDS support is disabled") + return nil, syscall.ECONNREFUSED } - return fd.DialUnix(l.hostPath) + + // TODO(gvisor.dev/issue/1003): Due to different app vs replacement + // mappings, the app path may have fit in the sockaddr, but we can't + // fit f.path in our sockaddr. We'd need to redirect through a shorter + // path in order to actually connect to this socket. + if len(l.hostPath) > linux.UnixPathMax { + return nil, syscall.ECONNREFUSED + } + + var stype int + switch flags { + case p9.StreamSocket: + stype = syscall.SOCK_STREAM + case p9.DgramSocket: + stype = syscall.SOCK_DGRAM + case p9.SeqpacketSocket: + stype = syscall.SOCK_SEQPACKET + default: + return nil, syscall.ENXIO + } + + f, err := syscall.Socket(syscall.AF_UNIX, stype, 0) + if err != nil { + return nil, err + } + + if err := syscall.SetNonblock(f, true); err != nil { + syscall.Close(f) + return nil, err + } + + sa := syscall.SockaddrUnix{Name: l.hostPath} + if err := syscall.Connect(f, &sa); err != nil { + syscall.Close(f) + return nil, err + } + + return fd.New(f), nil } // Close implements p9.File. diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe.go b/runsc/fsgofer/fsgofer_amd64_unsafe.go new file mode 100644 index 000000000..5d4aab597 --- /dev/null +++ b/runsc/fsgofer/fsgofer_amd64_unsafe.go @@ -0,0 +1,49 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package fsgofer + +import ( + "syscall" + "unsafe" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/syserr" +) + +func statAt(dirFd int, name string) (syscall.Stat_t, error) { + nameBytes, err := syscall.BytePtrFromString(name) + if err != nil { + return syscall.Stat_t{}, err + } + namePtr := unsafe.Pointer(nameBytes) + + var stat syscall.Stat_t + statPtr := unsafe.Pointer(&stat) + + if _, _, errno := syscall.Syscall6( + syscall.SYS_NEWFSTATAT, + uintptr(dirFd), + uintptr(namePtr), + uintptr(statPtr), + linux.AT_SYMLINK_NOFOLLOW, + 0, + 0); errno != 0 { + + return syscall.Stat_t{}, syserr.FromHost(errno).ToError() + } + return stat, nil +} diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe.go b/runsc/fsgofer/fsgofer_arm64_unsafe.go new file mode 100644 index 000000000..8041fd352 --- /dev/null +++ b/runsc/fsgofer/fsgofer_arm64_unsafe.go @@ -0,0 +1,49 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package fsgofer + +import ( + "syscall" + "unsafe" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/syserr" +) + +func statAt(dirFd int, name string) (syscall.Stat_t, error) { + nameBytes, err := syscall.BytePtrFromString(name) + if err != nil { + return syscall.Stat_t{}, err + } + namePtr := unsafe.Pointer(nameBytes) + + var stat syscall.Stat_t + statPtr := unsafe.Pointer(&stat) + + if _, _, errno := syscall.Syscall6( + syscall.SYS_FSTATAT, + uintptr(dirFd), + uintptr(namePtr), + uintptr(statPtr), + linux.AT_SYMLINK_NOFOLLOW, + 0, + 0); errno != 0 { + + return syscall.Stat_t{}, syserr.FromHost(errno).ToError() + } + return stat, nil +} diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index cbbe71019..05af7e397 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -665,7 +665,7 @@ func TestAttachInvalidType(t *testing.T) { } f, err := a.Attach() if f != nil || err == nil { - t.Fatalf("Attach should have failed, got (%v, nil)", f) + t.Fatalf("Attach should have failed, got (%v, %v)", f, err) } }) } diff --git a/runsc/fsgofer/fsgofer_unsafe.go b/runsc/fsgofer/fsgofer_unsafe.go index ff2556aee..542b54365 100644 --- a/runsc/fsgofer/fsgofer_unsafe.go +++ b/runsc/fsgofer/fsgofer_unsafe.go @@ -18,34 +18,9 @@ import ( "syscall" "unsafe" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/syserr" ) -func statAt(dirFd int, name string) (syscall.Stat_t, error) { - nameBytes, err := syscall.BytePtrFromString(name) - if err != nil { - return syscall.Stat_t{}, err - } - namePtr := unsafe.Pointer(nameBytes) - - var stat syscall.Stat_t - statPtr := unsafe.Pointer(&stat) - - if _, _, errno := syscall.Syscall6( - syscall.SYS_NEWFSTATAT, - uintptr(dirFd), - uintptr(namePtr), - uintptr(statPtr), - linux.AT_SYMLINK_NOFOLLOW, - 0, - 0); errno != 0 { - - return syscall.Stat_t{}, syserr.FromHost(errno).ToError() - } - return stat, nil -} - func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) error { // utimensat(2) doesn't accept empty name, instead name must be nil to make it // operate directly on 'dirFd' unlike other *at syscalls. diff --git a/runsc/main.go b/runsc/main.go index 7dce9dc00..711f60d4f 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -41,35 +41,39 @@ import ( var ( // Although these flags are not part of the OCI spec, they are used by // Docker, and thus should not be changed. - rootDir = flag.String("root", "", "root directory for storage of container state") - logFilename = flag.String("log", "", "file path where internal debug information is written, default is stdout") - logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s") - debug = flag.Bool("debug", false, "enable debug logging") - showVersion = flag.Bool("version", false, "show version and exit") + rootDir = flag.String("root", "", "root directory for storage of container state.") + logFilename = flag.String("log", "", "file path where internal debug information is written, default is stdout.") + logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s.") + debug = flag.Bool("debug", false, "enable debug logging.") + showVersion = flag.Bool("version", false, "show version and exit.") + // TODO(gvisor.dev/issue/193): support systemd cgroups + systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.") // These flags are unique to runsc, and are used to configure parts of the // system that are not covered by the runtime spec. // Debugging flags. debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") - logPackets = flag.Bool("log-packets", false, "enable network packet logging") + logPackets = flag.Bool("log-packets", false, "enable network packet logging.") logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") - debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s") - alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr") + debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") + alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.") // Debugging flags: strace related - strace = flag.Bool("strace", false, "enable strace") + strace = flag.Bool("strace", false, "enable strace.") straceSyscalls = flag.String("strace-syscalls", "", "comma-separated list of syscalls to trace. If --strace is true and this list is empty, then all syscalls will be traced.") - straceLogSize = flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs") + straceLogSize = flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs.") // Flags that control sandbox runtime behavior. - platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm") + platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.") network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") - gso = flag.Bool("gso", true, "enable generic segmenation offload") + hardwareGSO = flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.") + softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware ofload can't be enabled.") fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") - fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "Allow the gofer to mount Unix Domain Sockets.") + fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") + overlayfsStaleRead = flag.Bool("overlayfs-stale-read", false, "reopen cached FDs after a file is opened for write to workaround overlayfs limitation on kernels before 4.19.") watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.") panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.") profile = flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).") @@ -134,6 +138,12 @@ func main() { os.Exit(0) } + // TODO(gvisor.dev/issue/193): support systemd cgroups + if *systemdCgroup { + fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193") + os.Exit(1) + } + var errorLogger io.Writer if *logFD > -1 { errorLogger = os.NewFile(uintptr(*logFD), "error log file") @@ -199,7 +209,8 @@ func main() { FSGoferHostUDS: *fsGoferHostUDS, Overlay: *overlay, Network: netType, - GSO: *gso, + HardwareGSO: *hardwareGSO, + SoftwareGSO: *softwareGSO, LogPackets: *logPackets, Platform: platformType, Strace: *strace, @@ -212,6 +223,7 @@ func main() { Rootless: *rootless, AlsoLogToStderr: *alsoLogToStderr, ReferenceLeakMode: refsLeakMode, + OverlayfsStaleRead: *overlayfsStaleRead, TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, TestOnlyTestNameEnv: *testOnlyTestNameEnv, diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index 7fdceaab6..27459e6d1 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/log", "//pkg/sentry/control", "//pkg/sentry/platform", + "//pkg/tcpip/stack", "//pkg/urpc", "//runsc/boot", "//runsc/boot/platforms", diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 5634f0707..d42de0176 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -28,6 +28,7 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" @@ -61,7 +62,7 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi // Build the path to the net namespace of the sandbox process. // This is what we will copy. nsPath := filepath.Join("/proc", strconv.Itoa(pid), "ns/net") - if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.GSO, conf.NumNetworkChannels); err != nil { + if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.NumNetworkChannels); err != nil { return fmt.Errorf("creating interfaces from net namespace %q: %v", nsPath, err) } case boot.NetworkHost: @@ -136,7 +137,7 @@ func isRootNS() (bool, error) { // createInterfacesAndRoutesFromNS scrapes the interface and routes from the // net namespace with the given path, creates them in the sandbox, and removes // them from the host. -func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO bool, numNetworkChannels int) error { +func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, numNetworkChannels int) error { // Join the network namespace that we will be copying. restore, err := joinNetNS(nsPath) if err != nil { @@ -232,7 +233,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO // Create the socket for the device. for i := 0; i < link.NumChannels; i++ { log.Debugf("Creating Channel %d", i) - socketEntry, err := createSocket(iface, ifaceLink, enableGSO) + socketEntry, err := createSocket(iface, ifaceLink, hardwareGSO) if err != nil { return fmt.Errorf("failed to createSocket for %s : %v", iface.Name, err) } @@ -246,6 +247,11 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO } args.FilePayload.Files = append(args.FilePayload.Files, socketEntry.deviceFile) } + if link.GSOMaxSize == 0 && softwareGSO { + // Hardware GSO is disabled. Let's enable software GSO. + link.GSOMaxSize = stack.SoftwareGSOMaxSize + link.SoftwareGSOEnabled = true + } // Collect the addresses for the interface, enable forwarding, // and remove them from the host. diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD index fbfb8e2f8..205638803 100644 --- a/runsc/specutils/BUILD +++ b/runsc/specutils/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_library( name = "specutils", srcs = [ + "cri.go", "fs.go", "namespace.go", "specutils.go", @@ -13,6 +14,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/bits", "//pkg/log", "//pkg/sentry/kernel/auth", "@com_github_cenkalti_backoff//:go_default_library", diff --git a/runsc/specutils/cri.go b/runsc/specutils/cri.go new file mode 100644 index 000000000..9c5877cd5 --- /dev/null +++ b/runsc/specutils/cri.go @@ -0,0 +1,110 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package specutils + +import ( + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +const ( + // ContainerdContainerTypeAnnotation is the OCI annotation set by + // containerd to indicate whether the container to create should have + // its own sandbox or a container within an existing sandbox. + ContainerdContainerTypeAnnotation = "io.kubernetes.cri.container-type" + // ContainerdContainerTypeContainer is the container type value + // indicating the container should be created in an existing sandbox. + ContainerdContainerTypeContainer = "container" + // ContainerdContainerTypeSandbox is the container type value + // indicating the container should be created in a new sandbox. + ContainerdContainerTypeSandbox = "sandbox" + + // ContainerdSandboxIDAnnotation is the OCI annotation set to indicate + // which sandbox the container should be created in when the container + // is not the first container in the sandbox. + ContainerdSandboxIDAnnotation = "io.kubernetes.cri.sandbox-id" + + // CRIOContainerTypeAnnotation is the OCI annotation set by + // CRI-O to indicate whether the container to create should have + // its own sandbox or a container within an existing sandbox. + CRIOContainerTypeAnnotation = "io.kubernetes.cri-o.ContainerType" + + // CRIOContainerTypeContainer is the container type value + // indicating the container should be created in an existing sandbox. + CRIOContainerTypeContainer = "container" + // CRIOContainerTypeSandbox is the container type value + // indicating the container should be created in a new sandbox. + CRIOContainerTypeSandbox = "sandbox" + + // CRIOSandboxIDAnnotation is the OCI annotation set to indicate + // which sandbox the container should be created in when the container + // is not the first container in the sandbox. + CRIOSandboxIDAnnotation = "io.kubernetes.cri-o.SandboxID" +) + +// ContainerType represents the type of container requested by the calling container manager. +type ContainerType int + +const ( + // ContainerTypeUnspecified indicates that no known container type + // annotation was found in the spec. + ContainerTypeUnspecified ContainerType = iota + // ContainerTypeUnknown indicates that a container type was specified + // but is unknown to us. + ContainerTypeUnknown + // ContainerTypeSandbox indicates that the container should be run in a + // new sandbox. + ContainerTypeSandbox + // ContainerTypeContainer indicates that the container should be run in + // an existing sandbox. + ContainerTypeContainer +) + +// SpecContainerType tries to determine the type of container specified by the +// container manager using well-known container annotations. +func SpecContainerType(spec *specs.Spec) ContainerType { + if t, ok := spec.Annotations[ContainerdContainerTypeAnnotation]; ok { + switch t { + case ContainerdContainerTypeSandbox: + return ContainerTypeSandbox + case ContainerdContainerTypeContainer: + return ContainerTypeContainer + default: + return ContainerTypeUnknown + } + } + if t, ok := spec.Annotations[CRIOContainerTypeAnnotation]; ok { + switch t { + case CRIOContainerTypeSandbox: + return ContainerTypeSandbox + case CRIOContainerTypeContainer: + return ContainerTypeContainer + default: + return ContainerTypeUnknown + } + } + return ContainerTypeUnspecified +} + +// SandboxID returns the ID of the sandbox to join and whether an ID was found +// in the spec. +func SandboxID(spec *specs.Spec) (string, bool) { + if id, ok := spec.Annotations[ContainerdSandboxIDAnnotation]; ok { + return id, true + } + if id, ok := spec.Annotations[CRIOSandboxIDAnnotation]; ok { + return id, true + } + return "", false +} diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go index d441419cb..c7dd3051c 100644 --- a/runsc/specutils/namespace.go +++ b/runsc/specutils/namespace.go @@ -33,19 +33,19 @@ import ( func nsCloneFlag(nst specs.LinuxNamespaceType) uintptr { switch nst { case specs.IPCNamespace: - return syscall.CLONE_NEWIPC + return unix.CLONE_NEWIPC case specs.MountNamespace: - return syscall.CLONE_NEWNS + return unix.CLONE_NEWNS case specs.NetworkNamespace: - return syscall.CLONE_NEWNET + return unix.CLONE_NEWNET case specs.PIDNamespace: - return syscall.CLONE_NEWPID + return unix.CLONE_NEWPID case specs.UTSNamespace: - return syscall.CLONE_NEWUTS + return unix.CLONE_NEWUTS case specs.UserNamespace: - return syscall.CLONE_NEWUSER + return unix.CLONE_NEWUSER case specs.CgroupNamespace: - panic("cgroup namespace has no associated clone flag") + return unix.CLONE_NEWCGROUP default: panic(fmt.Sprintf("unknown namespace %v", nst)) } diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index cb9e58dfb..d3c2e4e78 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -31,6 +31,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) @@ -91,7 +92,7 @@ func ValidateSpec(spec *specs.Spec) error { log.Warningf("AppArmor profile %q is being ignored", spec.Process.ApparmorProfile) } - // TODO(b/72226747): Apply seccomp to application inside sandbox. + // TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox. if spec.Linux != nil && spec.Linux.Seccomp != nil { log.Warningf("Seccomp spec is being ignored") } @@ -107,23 +108,18 @@ func ValidateSpec(spec *specs.Spec) error { } } - // Two annotations are use by containerd to support multi-container pods. - // "io.kubernetes.cri.container-type" - // "io.kubernetes.cri.sandbox-id" - containerType, hasContainerType := spec.Annotations[ContainerdContainerTypeAnnotation] - _, hasSandboxID := spec.Annotations[ContainerdSandboxIDAnnotation] - switch { - // Non-containerd use won't set a container type. - case !hasContainerType: - case containerType == ContainerdContainerTypeSandbox: - // When starting a container in an existing sandbox, the sandbox ID - // must be set. - case containerType == ContainerdContainerTypeContainer: - if !hasSandboxID { - return fmt.Errorf("spec has container-type of %s, but no sandbox ID set", containerType) + // CRI specifies whether a container should start a new sandbox, or run + // another container in an existing sandbox. + switch SpecContainerType(spec) { + case ContainerTypeContainer: + // When starting a container in an existing sandbox, the + // sandbox ID must be set. + if _, ok := SandboxID(spec); !ok { + return fmt.Errorf("spec has container-type of container, but no sandbox ID set") } + case ContainerTypeUnknown: + return fmt.Errorf("unknown container-type") default: - return fmt.Errorf("unknown container-type: %s", containerType) } return nil @@ -241,6 +237,15 @@ func AllCapabilities() *specs.LinuxCapabilities { } } +// AllCapabilitiesUint64 returns a bitmask containing all capabilities set. +func AllCapabilitiesUint64() uint64 { + var rv uint64 + for _, cap := range capFromName { + rv |= bits.MaskOf64(int(cap)) + } + return rv +} + var capFromName = map[string]linux.Capability{ "CAP_CHOWN": linux.CAP_CHOWN, "CAP_DAC_OVERRIDE": linux.CAP_DAC_OVERRIDE, @@ -328,39 +333,6 @@ func IsSupportedDevMount(m specs.Mount) bool { return true } -const ( - // ContainerdContainerTypeAnnotation is the OCI annotation set by - // containerd to indicate whether the container to create should have - // its own sandbox or a container within an existing sandbox. - ContainerdContainerTypeAnnotation = "io.kubernetes.cri.container-type" - // ContainerdContainerTypeContainer is the container type value - // indicating the container should be created in an existing sandbox. - ContainerdContainerTypeContainer = "container" - // ContainerdContainerTypeSandbox is the container type value - // indicating the container should be created in a new sandbox. - ContainerdContainerTypeSandbox = "sandbox" - - // ContainerdSandboxIDAnnotation is the OCI annotation set to indicate - // which sandbox the container should be created in when the container - // is not the first container in the sandbox. - ContainerdSandboxIDAnnotation = "io.kubernetes.cri.sandbox-id" -) - -// ShouldCreateSandbox returns true if the spec indicates that a new sandbox -// should be created for the container. If false, the container should be -// started in an existing sandbox. -func ShouldCreateSandbox(spec *specs.Spec) bool { - t, ok := spec.Annotations[ContainerdContainerTypeAnnotation] - return !ok || t == ContainerdContainerTypeSandbox -} - -// SandboxID returns the ID of the sandbox to join and whether an ID was found -// in the spec. -func SandboxID(spec *specs.Spec) (string, bool) { - id, ok := spec.Annotations[ContainerdSandboxIDAnnotation] - return id, ok -} - // WaitForReady waits for a process to become ready. The process is ready when // the 'ready' function returns true. It continues to wait if 'ready' returns // false. It returns error on timeout, if the process stops or if 'ready' fails. diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index d44ebc906..c96ca2eb6 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/runsc/testutil", visibility = ["//:sandbox"], deps = [ + "//pkg/log", "//runsc/boot", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index edf8b126c..9632776d2 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -25,7 +25,6 @@ import ( "fmt" "io" "io/ioutil" - "log" "math" "math/rand" "net/http" @@ -42,6 +41,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" ) @@ -151,13 +151,6 @@ func TestConfig() *boot.Config { } } -// TestConfigWithRoot returns the default configuration to use in tests. -func TestConfigWithRoot(rootDir string) *boot.Config { - conf := TestConfig() - conf.RootDir = rootDir - return conf -} - // NewSpecWithArgs creates a simple spec with the given args suitable for use // in tests. func NewSpecWithArgs(args ...string) *specs.Spec { @@ -286,7 +279,7 @@ func WaitForHTTP(port int, timeout time.Duration) error { url := fmt.Sprintf("http://localhost:%d/", port) resp, err := c.Get(url) if err != nil { - log.Printf("Waiting %s: %v", url, err) + log.Infof("Waiting %s: %v", url, err) return err } resp.Body.Close() diff --git a/scripts/build.sh b/scripts/build.sh index b3a6e4e7a..8b2094cb0 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -17,63 +17,71 @@ source $(dirname $0)/common.sh # Install required packages for make_repository.sh et al. -sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils +sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils xz-utils # Build runsc. runsc=$(build -c opt //runsc) # Build packages. -pkg=$(build -c opt --host_force_python=py2 //runsc:runsc-debian) +pkgs=$(build -c opt //runsc:runsc-debian) + +# Stop here if we have no artifacts directory. +[[ -v KOKORO_ARTIFACTS_DIR ]] || exit 0 + +# install_raw installs raw artifacts. +install_raw() { + mkdir -p "$1" + cp -f "${runsc}" "$1"/runsc + sha512sum "$1"/runsc | awk '{print $1 " runsc"}' > "$1"/runsc.sha512 +} # Build a repository, if the key is available. +# +# Note that make_repository.sh script will install packages into the provided +# root, but will output to stdout a directory that can be copied arbitrarily +# into "${KOKORO_ARTIFACTS_DIR}"/dists/XXX. We do things this way because we +# will copy the same repository structure into multiple locations, below. if [[ -v KOKORO_REPO_KEY ]]; then - repo=$(tools/make_repository.sh "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" gvisor-bot@google.com main ${pkg}) + repo=$(tools/make_repository.sh \ + "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" \ + gvisor-bot@google.com \ + main \ + "${KOKORO_ARTIFACTS_DIR}" \ + ${pkgs}) fi -# Install installs artifacts. -install() { - local -r binaries_dir="$1" - local -r repo_dir="$2" - mkdir -p "${binaries_dir}" - cp -f "${runsc}" "${binaries_dir}"/runsc - sha512sum "${binaries_dir}"/runsc | awk '{print $1 " runsc"}' > "${binaries_dir}"/runsc.sha512 +# install_repo installs a repository. +# +# Note that packages are already installed, as noted above. +install_repo() { if [[ -v repo ]]; then - rm -rf "${repo_dir}" && mkdir -p "$(dirname "${repo_dir}")" - cp -a "${repo}" "${repo_dir}" + rm -rf "$1" && mkdir -p "$(dirname "$1")" && cp -a "${repo}" "$1" fi } -# Move the runsc binary into "latest" directory, and also a directory with the -# current date. If the current commit happens to correpond to a tag, then we -# will also move everything into a directory named after the given tag. -if [[ -v KOKORO_ARTIFACTS_DIR ]]; then - if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then - # The "latest" directory and current date. - stamp="$(date -Idate)" - install "${KOKORO_ARTIFACTS_DIR}/nightly/latest" \ - "${KOKORO_ARTIFACTS_DIR}/dists/nightly/latest" - install "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" \ - "${KOKORO_ARTIFACTS_DIR}/dists/nightly/${stamp}" - else - # Is it a tagged release? Build that instead. In that case, we also try to - # update the base release directory, in case this is an update. Finally, we - # update the "release" directory, which has the last released version. - tags="$(git tag --points-at HEAD)" - if ! [[ -z "${tags}" ]]; then - # Note that a given commit can match any number of tags. We have to - # iterate through all possible tags and produce associated artifacts. - for tag in ${tags}; do - name=$(echo "${tag}" | cut -d'-' -f2) - base=$(echo "${name}" | cut -d'.' -f1) - install "${KOKORO_ARTIFACTS_DIR}/release/${name}" \ - "${KOKORO_ARTIFACTS_DIR}/dists/${name}" - if [[ "${base}" != "${tag}" ]]; then - install "${KOKORO_ARTIFACTS_DIR}/release/${base}" \ - "${KOKORO_ARTIFACTS_DIR}/dists/${base}" - fi - install "${KOKORO_ARTIFACTS_DIR}/release/latest" \ - "${KOKORO_ARTIFACTS_DIR}/dists/latest" - done - fi +# If nightly, install only nightly artifacts. +if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then + # The "latest" directory and current date. + stamp="$(date -Idate)" + install_raw "${KOKORO_ARTIFACTS_DIR}/nightly/latest" + install_raw "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" + install_repo "${KOKORO_ARTIFACTS_DIR}/dists/nightly" +else + # We keep only the latest master raw release. + install_raw "${KOKORO_ARTIFACTS_DIR}/master/latest" + install_repo "${KOKORO_ARTIFACTS_DIR}/dists/master" + + # Is it a tagged release? Build that too. + tags="$(git tag --points-at HEAD)" + if ! [[ -z "${tags}" ]]; then + # Note that a given commit can match any number of tags. We have to iterate + # through all possible tags and produce associated artifacts. + for tag in ${tags}; do + name=$(echo "${tag}" | cut -d'-' -f2) + base=$(echo "${name}" | cut -d'.' -f1) + install_raw "${KOKORO_ARTIFACTS_DIR}/release/${name}" + install_repo "${KOKORO_ARTIFACTS_DIR}/dists/release" + install_repo "${KOKORO_ARTIFACTS_DIR}/dists/${base}" + done fi fi diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh index ea2291a4d..a82163297 100755 --- a/scripts/common_bazel.sh +++ b/scripts/common_bazel.sh @@ -71,6 +71,13 @@ function run_as_root() { function collect_logs() { # Zip out everything into a convenient form. if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then + # Merge results files of all shards for each test suite. + for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do + junitparser merge `find $d -name test.xml` $d/test.xml + cat $d/shard_*_of_*/test.log > $d/test.log + ls -l $d/shard_*_of_*/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/outputs.zip + done + find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf # Move test logs to Kokoro directory. tar is used to conveniently perform # renames while moving files. find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | @@ -79,9 +86,16 @@ function collect_logs() { # Collect sentry logs, if any. if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then - local -r logs=$(ls "${RUNSC_LOGS_DIR}") + # Check if the directory is empty or not (only the first line it needed). + local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1) if [[ "${logs}" ]]; then - tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${RUNTIME}.tar.gz" -C "${RUNSC_LOGS_DIR}" . + local -r archive=runsc_logs_"${RUNTIME}".tar.gz + if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then + echo "runsc logs will be uploaded to:" + echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" + echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" + fi + tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" . fi fi fi diff --git a/scripts/dev.sh b/scripts/dev.sh index ee74dcb72..c67003018 100755 --- a/scripts/dev.sh +++ b/scripts/dev.sh @@ -58,7 +58,7 @@ if [[ ${REFRESH} -eq 0 ]]; then echo echo "Runtimes ${RUNTIME} and ${RUNTIME}-d (debug enabled) setup." echo "Use --runtime="${RUNTIME}" with your Docker command." - echo " docker run --rm --runtime="${RUNTIME}" --rm hello-world" + echo " docker run --rm --runtime="${RUNTIME}" hello-world" echo echo "If you rebuild, use $0 --refresh." diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh new file mode 100755 index 000000000..9ee991e42 --- /dev/null +++ b/scripts/runtime_tests.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source $(dirname $0)/common.sh + +if [ ! -v RUNTIME_TEST_NAME ]; then + echo 'Must set $RUNTIME_TEST_NAME' >&2 + exit 1 +fi + +install_runsc_for_test runtimes +test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}_test" diff --git a/scripts/swgso_tests.sh b/scripts/swgso_tests.sh new file mode 100755 index 000000000..0de2df1d2 --- /dev/null +++ b/scripts/swgso_tests.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source $(dirname $0)/common.sh + +# Install the runtime and perform basic tests. +install_runsc_for_test swgso --software-gso=true --gso=false +test_runsc //test/image:image_test //test/e2e:integration_test diff --git a/scripts/syscall_kvm_tests.sh b/scripts/syscall_kvm_tests.sh new file mode 100755 index 000000000..de85daa5a --- /dev/null +++ b/scripts/syscall_kvm_tests.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +source $(dirname $0)/common.sh + +# TODO(b/112165693): "test --test_tag_filters=runsc_kvm" can be used +# when the "manual" tag will be removed for kvm tests. +test `bazel query "attr(tags, runsc_kvm, tests(//test/syscalls/...))"` diff --git a/test/README.md b/test/README.md index 09c36b461..97fe7ea04 100644 --- a/test/README.md +++ b/test/README.md @@ -10,9 +10,31 @@ they may need extra setup in the test machine and extra configuration to run. functionality. - **image:** basic end to end test for popular images. These require the same setup as integration tests. -- **root:** tests that require to be run as root. +- **root:** tests that require to be run as root. These require the same setup + as integration tests. - **util:** utilities library to support the tests. For the above noted cases, the relevant runtime must be installed via `runsc -install` before running. This is handled automatically by the test scripts in -the `kokoro` directory. +install` before running. Just note that they require specific configuration to +work. This is handled automatically by the test scripts in the `scripts` +directory and they can be used to run tests locally on your machine. They are +also used to run these tests in `kokoro`. + +**Example:** + +To run image and integration tests, run: + +`./scripts/docker_test.sh` + +To run root tests, run: + +`./scripts/root_test.sh` + +There are a few other interesting variations for image and integration tests: + +* overlay: sets writable overlay inside the sentry +* hostnet: configures host network pass-thru, instead of netstack +* kvm: runsc the test using the KVM platform, instead of ptrace + +The test will build runsc, configure it with your local docker, restart +`dockerd`, and run tests. The location for runsc logs is printed to the output. diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 99442cffb..4fe03a220 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -19,7 +19,9 @@ go_test( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/bits", "//runsc/dockerutil", + "//runsc/specutils", "//runsc/testutil", ], ) diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go index 7238c2afe..4074d2285 100644 --- a/test/e2e/exec_test.go +++ b/test/e2e/exec_test.go @@ -30,14 +30,17 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/runsc/dockerutil" + "gvisor.dev/gvisor/runsc/specutils" ) +// Test that exec uses the exact same capability set as the container. func TestExecCapabilities(t *testing.T) { if err := dockerutil.Pull("alpine"); err != nil { t.Fatalf("docker pull failed: %v", err) } - d := dockerutil.MakeDocker("exec-test") + d := dockerutil.MakeDocker("exec-capabilities-test") // Start the container. if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { @@ -52,29 +55,61 @@ func TestExecCapabilities(t *testing.T) { if len(matches) != 2 { t.Fatalf("There should be a match for the whole line and the capability bitmask") } - capString := matches[1] - t.Log("Root capabilities:", capString) + want := fmt.Sprintf("CapEff:\t%s\n", matches[1]) + t.Log("Root capabilities:", want) - // CAP_NET_RAW was in the capability set for the container, but was - // removed. However, `exec` does not remove it. Verify that it's not - // set in the container, then re-add it for comparison. - caps, err := strconv.ParseUint(capString, 16, 64) + // Now check that exec'd process capabilities match the root. + got, err := d.Exec("grep", "CapEff:", "/proc/self/status") if err != nil { - t.Fatalf("failed to convert capabilities %q: %v", capString, err) + t.Fatalf("docker exec failed: %v", err) } - if caps&(1<<uint64(linux.CAP_NET_RAW)) != 0 { - t.Fatalf("CAP_NET_RAW should be filtered, but is set in the container: %x", caps) + t.Logf("CapEff: %v", got) + if got != want { + t.Errorf("wrong capabilities, got: %q, want: %q", got, want) } - caps |= 1 << uint64(linux.CAP_NET_RAW) - want := fmt.Sprintf("CapEff:\t%016x\n", caps) +} - // Now check that exec'd process capabilities match the root. - got, err := d.Exec("grep", "CapEff:", "/proc/self/status") +// Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW +// which is removed from the container when --net-raw=false. +func TestExecPrivileged(t *testing.T) { + if err := dockerutil.Pull("alpine"); err != nil { + t.Fatalf("docker pull failed: %v", err) + } + d := dockerutil.MakeDocker("exec-privileged-test") + + // Start the container with all capabilities dropped. + if err := d.Run("--cap-drop=all", "alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { + t.Fatalf("docker run failed: %v", err) + } + defer d.CleanUp() + + // Check that all capabilities where dropped from container. + matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) + if err != nil { + t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) + } + if len(matches) != 2 { + t.Fatalf("There should be a match for the whole line and the capability bitmask") + } + containerCaps, err := strconv.ParseUint(matches[1], 16, 64) + if err != nil { + t.Fatalf("failed to convert capabilities %q: %v", matches[1], err) + } + t.Logf("Container capabilities: %#x", containerCaps) + if containerCaps != 0 { + t.Fatalf("Container should have no capabilities: %x", containerCaps) + } + + // Check that 'exec --privileged' adds all capabilities, except + // for CAP_NET_RAW. + got, err := d.ExecWithFlags([]string{"--privileged"}, "grep", "CapEff:", "/proc/self/status") if err != nil { t.Fatalf("docker exec failed: %v", err) } + t.Logf("Exec CapEff: %v", got) + want := fmt.Sprintf("CapEff:\t%016x\n", specutils.AllCapabilitiesUint64()&^bits.MaskOf64(int(linux.CAP_NET_RAW))) if got != want { - t.Errorf("wrong capabilities, got: %q, want: %q", got, want) + t.Errorf("Wrong capabilities, got: %q, want: %q. Make sure runsc is not using '--net-raw'", got, want) } } @@ -173,8 +208,27 @@ func TestExecEnv(t *testing.T) { if err != nil { t.Fatalf("docker exec failed: %v", err) } - if want := "BAR"; !strings.Contains(got, want) { - t.Errorf("wanted exec output to contain %q, got %q", want, got) + if got, want := strings.TrimSpace(got), "BAR"; got != want { + t.Errorf("bad output from 'docker exec'. Got %q; Want %q.", got, want) + } +} + +// TestRunEnvHasHome tests that run always has HOME environment set. +func TestRunEnvHasHome(t *testing.T) { + // Base alpine image does not have any environment variables set. + if err := dockerutil.Pull("alpine"); err != nil { + t.Fatalf("docker pull failed: %v", err) + } + d := dockerutil.MakeDocker("run-env-test") + + // Exec "echo $HOME". The 'bin' user's home dir is '/bin'. + got, err := d.RunFg("--user", "bin", "alpine", "/bin/sh", "-c", "echo $HOME") + if err != nil { + t.Fatalf("docker run failed: %v", err) + } + defer d.CleanUp() + if got, want := strings.TrimSpace(got), "/bin"; got != want { + t.Errorf("bad output from 'docker run'. Got %q; Want %q.", got, want) } } @@ -184,7 +238,7 @@ func TestExecEnvHasHome(t *testing.T) { if err := dockerutil.Pull("alpine"); err != nil { t.Fatalf("docker pull failed: %v", err) } - d := dockerutil.MakeDocker("exec-env-test") + d := dockerutil.MakeDocker("exec-env-home-test") // We will check that HOME is set for root user, and also for a new // non-root user we will create. diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 7cc0de129..28064e557 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -175,6 +175,9 @@ func TestCheckpointRestore(t *testing.T) { t.Fatal(err) } + // TODO(b/143498576): Remove after github.com/moby/moby/issues/38963 is fixed. + time.Sleep(1 * time.Second) + if err := d.Restore("test"); err != nil { t.Fatal("docker restore failed:", err) } diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index 76f1e4f2a..4038661cb 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -24,6 +24,7 @@ import ( "strconv" "strings" "testing" + "time" "gvisor.dev/gvisor/runsc/cgroup" "gvisor.dev/gvisor/runsc/dockerutil" @@ -56,6 +57,59 @@ func verifyPid(pid int, path string) error { } // TestCgroup sets cgroup options and checks that cgroup was properly configured. +func TestMemCGroup(t *testing.T) { + allocMemSize := 128 << 20 + if err := dockerutil.Pull("python"); err != nil { + t.Fatal("docker pull failed:", err) + } + d := dockerutil.MakeDocker("memusage-test") + + // Start a new container and allocate the specified about of memory. + args := []string{ + "--memory=256MB", + "python", + "python", + "-c", + fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize), + } + if err := d.Run(args...); err != nil { + t.Fatal("docker create failed:", err) + } + defer d.CleanUp() + + gid, err := d.ID() + if err != nil { + t.Fatalf("Docker.ID() failed: %v", err) + } + t.Logf("cgroup ID: %s", gid) + + path := filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.usage_in_bytes") + memUsage := 0 + + // Wait when the container will allocate memory. + start := time.Now() + for time.Now().Sub(start) < 30*time.Second { + outRaw, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("failed to read %q: %v", path, err) + } + out := strings.TrimSpace(string(outRaw)) + memUsage, err = strconv.Atoi(out) + if err != nil { + t.Fatalf("Atoi(%v): %v", out, err) + } + + if memUsage > allocMemSize { + return + } + + time.Sleep(100 * time.Millisecond) + } + + t.Fatalf("%vMB is less than %vMB: %v", memUsage>>20, allocMemSize>>20) +} + +// TestCgroup sets cgroup options and checks that cgroup was properly configured. func TestCgroup(t *testing.T) { if err := dockerutil.Pull("alpine"); err != nil { t.Fatal("docker pull failed:", err) diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go index d597664f5..3f90c4c6a 100644 --- a/test/root/crictl_test.go +++ b/test/root/crictl_test.go @@ -126,6 +126,59 @@ func TestMountOverSymlinks(t *testing.T) { } } +// TestHomeDir tests that the HOME environment variable is set for +// multi-containers. +func TestHomeDir(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + contSpec := testdata.SimpleSpec("root", "k8s.gcr.io/busybox", []string{"sleep", "1000"}) + podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, contSpec) + if err != nil { + t.Fatal(err) + } + + t.Run("root container", func(t *testing.T) { + out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME") + if err != nil { + t.Fatal(err) + } + if got, want := strings.TrimSpace(string(out)), "/root"; got != want { + t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) + } + }) + + t.Run("sub-container", func(t *testing.T) { + // Create a sub container in the same pod. + subContSpec := testdata.SimpleSpec("subcontainer", "k8s.gcr.io/busybox", []string{"sleep", "1000"}) + subContID, err := crictl.StartContainer(podID, "k8s.gcr.io/busybox", testdata.Sandbox, subContSpec) + if err != nil { + t.Fatal(err) + } + + out, err := crictl.Exec(subContID, "sh", "-c", "echo $HOME") + if err != nil { + t.Fatal(err) + } + if got, want := strings.TrimSpace(string(out)), "/root"; got != want { + t.Fatalf("Home directory invalid. Got %q, Want: %q", got, want) + } + + if err := crictl.StopContainer(subContID); err != nil { + t.Fatal(err) + } + }) + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatal(err) + } + +} + // setup sets up before a test. Specifically it: // * Creates directories and a socket for containerd to utilize. // * Runs containerd and waits for it to reach a "ready" state for testing. diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go index 6cd378a1b..126f0975a 100644 --- a/test/root/oom_score_adj_test.go +++ b/test/root/oom_score_adj_test.go @@ -40,6 +40,15 @@ var ( // TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a // single container sandbox. func TestOOMScoreAdjSingle(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -84,7 +93,6 @@ func TestOOMScoreAdjSingle(t *testing.T) { s := testutil.NewSpecWithArgs("sleep", "1000") s.Process.OOMScoreAdj = testCase.OOMScoreAdj - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id}) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -123,6 +131,15 @@ func TestOOMScoreAdjSingle(t *testing.T) { // TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a // multi-container sandbox. func TestOOMScoreAdjMulti(t *testing.T) { + rootDir, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer os.RemoveAll(rootDir) + + conf := testutil.TestConfig() + conf.RootDir = rootDir + ppid, err := specutils.GetParentPid(os.Getpid()) if err != nil { t.Fatalf("getting parent pid: %v", err) @@ -240,7 +257,6 @@ func TestOOMScoreAdjMulti(t *testing.T) { } } - conf := testutil.TestConfig() containers, cleanup, err := startContainers(conf, specs, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -327,13 +343,8 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { } func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) { - // Setup root dir if one hasn't been provided. if len(conf.RootDir) == 0 { - rootDir, err := testutil.SetupRootDir() - if err != nil { - return nil, nil, fmt.Errorf("error creating root dir: %v", err) - } - conf.RootDir = rootDir + panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") } var containers []*container.Container @@ -345,7 +356,6 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*c for _, b := range bundles { os.RemoveAll(b) } - os.RemoveAll(conf.RootDir) } for i, spec := range specs { bundleDir, err := testutil.SetupBundleDir(spec) diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD index 14c19ef1e..125633680 100644 --- a/test/root/testdata/BUILD +++ b/test/root/testdata/BUILD @@ -10,6 +10,7 @@ go_library( "httpd.go", "httpd_mount_paths.go", "sandbox.go", + "simple.go", ], importpath = "gvisor.dev/gvisor/test/root/testdata", visibility = [ diff --git a/test/root/testdata/simple.go b/test/root/testdata/simple.go new file mode 100644 index 000000000..1cca53f0c --- /dev/null +++ b/test/root/testdata/simple.go @@ -0,0 +1,41 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testdata + +import ( + "encoding/json" + "fmt" +) + +// SimpleSpec returns a JSON config for a simple container that runs the +// specified command in the specified image. +func SimpleSpec(name, image string, cmd []string) string { + cmds, err := json.Marshal(cmd) + if err != nil { + // This shouldn't happen. + panic(err) + } + return fmt.Sprintf(` +{ + "metadata": { + "name": %q + }, + "image": { + "image": %q + }, + "command": %s + } +`, name, image, cmds) +} diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index dfb4e2a97..367295206 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -1,6 +1,6 @@ # These packages are used to run language runtime tests inside gVisor sandboxes. -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") load("//test/runtimes:build_defs.bzl", "runtime_test") package(licenses = ["notice"]) @@ -16,26 +16,38 @@ go_binary( ) runtime_test( - image = "gcr.io/gvisor-presubmit/go1.12", + name = "go1.12", + blacklist_file = "blacklist_go1.12.csv", lang = "go", ) runtime_test( - image = "gcr.io/gvisor-presubmit/java11", + name = "java11", + blacklist_file = "blacklist_java11.csv", lang = "java", ) runtime_test( - image = "gcr.io/gvisor-presubmit/nodejs12.4.0", + name = "nodejs12.4.0", + blacklist_file = "blacklist_nodejs12.4.0.csv", lang = "nodejs", ) runtime_test( - image = "gcr.io/gvisor-presubmit/php7.3.6", + name = "php7.3.6", + blacklist_file = "blacklist_php7.3.6.csv", lang = "php", ) runtime_test( - image = "gcr.io/gvisor-presubmit/python3.7.3", + name = "python3.7.3", + blacklist_file = "blacklist_python3.7.3.csv", lang = "python", ) + +go_test( + name = "blacklist_test", + size = "small", + srcs = ["blacklist_test.go"], + embed = [":runner"], +) diff --git a/test/runtimes/blacklist_go1.12.csv b/test/runtimes/blacklist_go1.12.csv new file mode 100644 index 000000000..8c8ae0c5d --- /dev/null +++ b/test/runtimes/blacklist_go1.12.csv @@ -0,0 +1,16 @@ +test name,bug id,comment +cgo_errors,,FLAKY +cgo_test,,FLAKY +go_test:cmd/go,,FLAKY +go_test:cmd/vendor/golang.org/x/sys/unix,b/118783622,/dev devices missing +go_test:net,b/118784196,socket: invalid argument. Works as intended: see bug. +go_test:os,b/118780122,we have a pollable filesystem but that's a surprise +go_test:os/signal,b/118780860,/dev/pts not properly supported +go_test:runtime,b/118782341,sigtrap not reported or caught or something +go_test:syscall,b/118781998,bad bytes -- bad mem addr +race,b/118782931,thread sanitizer. Works as intended: b/62219744. +runtime:cpu124,b/118778254,segmentation fault +test:0_1,,FLAKY +testasan,, +testcarchive,b/118782924,no sigpipe +testshared,,FLAKY diff --git a/test/runtimes/blacklist_java11.csv b/test/runtimes/blacklist_java11.csv new file mode 100644 index 000000000..c012e5a56 --- /dev/null +++ b/test/runtimes/blacklist_java11.csv @@ -0,0 +1,126 @@ +test name,bug id,comment +com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker +com/sun/jdi/NashornPopFrameTest.java,, +com/sun/jdi/ProcessAttachTest.java,, +com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker +com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,, +com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,, +com/sun/tools/attach/AttachSelf.java,, +com/sun/tools/attach/BasicTests.java,, +com/sun/tools/attach/PermissionTest.java,, +com/sun/tools/attach/StartManagementAgent.java,, +com/sun/tools/attach/TempDirTest.java,, +com/sun/tools/attach/modules/Driver.java,, +java/lang/Character/CheckScript.java,,Fails in Docker +java/lang/Character/CheckUnicode.java,,Fails in Docker +java/lang/Class/GetPackageBootLoaderChildLayer.java,, +java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker +java/lang/String/nativeEncoding/StringPlatformChars.java,, +java/net/DatagramSocket/ReuseAddressTest.java,, +java/net/DatagramSocket/SendDatagramToBadAddress.java,b/78473345, +java/net/Inet4Address/PingThis.java,, +java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103, +java/net/MulticastSocket/MulticastTTL.java,, +java/net/MulticastSocket/Promiscuous.java,, +java/net/MulticastSocket/SetLoopbackMode.java,, +java/net/MulticastSocket/SetTTLAndGetTTL.java,, +java/net/MulticastSocket/Test.java,, +java/net/MulticastSocket/TestDefaults.java,, +java/net/MulticastSocket/TimeToLive.java,, +java/net/NetworkInterface/NetworkInterfaceStreamTest.java,, +java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported +java/net/Socket/TrafficClass.java,b/78527818,Not supported on gVisor +java/net/Socket/UrgentDataTest.java,b/111515323, +java/net/Socket/setReuseAddress/Basic.java,b/78519214,SO_REUSEADDR enabled by default +java/net/SocketOption/OptionsTest.java,,Fails in Docker +java/net/SocketOption/TcpKeepAliveTest.java,, +java/net/SocketPermission/SocketPermissionTest.java,, +java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker +java/net/httpclient/RequestBuilderTest.java,,Fails in Docker +java/net/httpclient/ShortResponseBody.java,, +java/net/httpclient/ShortResponseBodyWithRetry.java,, +java/nio/channels/AsyncCloseAndInterrupt.java,, +java/nio/channels/AsynchronousServerSocketChannel/Basic.java,, +java/nio/channels/AsynchronousSocketChannel/Basic.java,b/77921528,SO_KEEPALIVE is not settable +java/nio/channels/DatagramChannel/BasicMulticastTests.java,, +java/nio/channels/DatagramChannel/SocketOptionTests.java,,Fails in Docker +java/nio/channels/DatagramChannel/UseDGWithIPv6.java,, +java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker +java/nio/channels/Selector/OutOfBand.java,, +java/nio/channels/Selector/SelectWithConsumer.java,,Flaky +java/nio/channels/ServerSocketChannel/SocketOptionTests.java,, +java/nio/channels/SocketChannel/LingerOnClose.java,, +java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901, +java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker +java/rmi/activation/Activatable/extLoadedImpl/ext.sh,, +java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,, +java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker +java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker +java/util/Calendar/JapaneseEraNameTest.java,, +java/util/Currency/CurrencyTest.java,,Fails in Docker +java/util/Currency/ValidateISO4217.java,,Fails in Docker +java/util/Locale/LSRDataTest.java,, +java/util/concurrent/locks/Lock/TimedAcquireLeak.java,, +java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker +java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,, +java/util/logging/TestLoggerWeakRefLeak.java,, +javax/imageio/AppletResourceTest.java,, +javax/management/security/HashedPasswordFileTest.java,, +javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker +javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,, +jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,, +jdk/jfr/event/runtime/TestThreadParkEvent.java,, +jdk/jfr/event/sampling/TestNative.java,, +jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,, +jdk/jfr/jcmd/TestJcmdConfigure.java,, +jdk/jfr/jcmd/TestJcmdDump.java,, +jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,, +jdk/jfr/jcmd/TestJcmdDumpLimited.java,, +jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,, +jdk/jfr/jcmd/TestJcmdLegacy.java,, +jdk/jfr/jcmd/TestJcmdSaveToFile.java,, +jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,, +jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,, +jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,, +jdk/jfr/jcmd/TestJcmdStartStopDefault.java,, +jdk/jfr/jcmd/TestJcmdStartWithOptions.java,, +jdk/jfr/jcmd/TestJcmdStartWithSettings.java,, +jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,, +jdk/jfr/jvm/TestJfrJavaBase.java,, +jdk/jfr/startupargs/TestStartRecording.java,, +jdk/modules/incubator/ImageModules.java,, +jdk/net/Sockets/ExtOptionTest.java,, +jdk/net/Sockets/QuickAckTest.java,, +lib/security/cacerts/VerifyCACerts.java,, +sun/management/jmxremote/bootstrap/CustomLauncherTest.java,, +sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,, +sun/management/jmxremote/bootstrap/LocalManagementTest.java,, +sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,, +sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,, +sun/management/jmxremote/startstop/JMXStartStopTest.java,, +sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,, +sun/management/jmxremote/startstop/JMXStatusTest.java,, +sun/text/resources/LocaleDataTest.java,, +sun/tools/jcmd/TestJcmdSanity.java,, +sun/tools/jhsdb/AlternateHashingTest.java,, +sun/tools/jhsdb/BasicLauncherTest.java,, +sun/tools/jhsdb/HeapDumpTest.java,, +sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,, +sun/tools/jinfo/BasicJInfoTest.java,, +sun/tools/jinfo/JInfoTest.java,, +sun/tools/jmap/BasicJMapTest.java,, +sun/tools/jstack/BasicJStackTest.java,, +sun/tools/jstack/DeadlockDetectionTest.java,, +sun/tools/jstatd/TestJstatdExternalRegistry.java,, +sun/tools/jstatd/TestJstatdPort.java,,Flaky +sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky +sun/util/calendar/zi/TestZoneInfo310.java,, +tools/jar/modularJar/Basic.java,, +tools/jar/multiRelease/Basic.java,, +tools/jimage/JImageExtractTest.java,, +tools/jimage/JImageTest.java,, +tools/jlink/JLinkTest.java,, +tools/jlink/plugins/IncludeLocalesPluginTest.java,, +tools/jmod/hashes/HashesTest.java,, +tools/launcher/BigJar.java,b/111611473, +tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,, diff --git a/test/runtimes/blacklist_nodejs12.4.0.csv b/test/runtimes/blacklist_nodejs12.4.0.csv new file mode 100644 index 000000000..4ab4e2927 --- /dev/null +++ b/test/runtimes/blacklist_nodejs12.4.0.csv @@ -0,0 +1,47 @@ +test name,bug id,comment +benchmark/test-benchmark-fs.js,, +benchmark/test-benchmark-module.js,, +benchmark/test-benchmark-napi.js,, +doctool/test-make-doc.js,b/68848110,Expected to fail. +fixtures/test-error-first-line-offset.js,, +fixtures/test-fs-readfile-error.js,, +fixtures/test-fs-stat-sync-overflow.js,, +internet/test-dgram-broadcast-multi-process.js,, +internet/test-dgram-multicast-multi-process.js,, +internet/test-dgram-multicast-set-interface-lo.js,, +parallel/test-cluster-dgram-reuse.js,b/64024294, +parallel/test-dgram-bind-fd.js,b/132447356, +parallel/test-dgram-create-socket-handle-fd.js,b/132447238, +parallel/test-dgram-createSocket-type.js,b/68847739, +parallel/test-dgram-socket-buffer-size.js,b/68847921, +parallel/test-fs-access.js,, +parallel/test-fs-write-stream-double-close.js,, +parallel/test-fs-write-stream-throw-type-error.js,b/110226209, +parallel/test-fs-write-stream.js,, +parallel/test-http2-respond-file-error-pipe-offset.js,, +parallel/test-os.js,, +parallel/test-process-uid-gid.js,, +pseudo-tty/test-assert-colors.js,, +pseudo-tty/test-assert-no-color.js,, +pseudo-tty/test-assert-position-indicator.js,, +pseudo-tty/test-async-wrap-getasyncid-tty.js,, +pseudo-tty/test-fatal-error.js,, +pseudo-tty/test-handle-wrap-isrefed-tty.js,, +pseudo-tty/test-readable-tty-keepalive.js,, +pseudo-tty/test-set-raw-mode-reset-process-exit.js,, +pseudo-tty/test-set-raw-mode-reset-signal.js,, +pseudo-tty/test-set-raw-mode-reset.js,, +pseudo-tty/test-stderr-stdout-handle-sigwinch.js,, +pseudo-tty/test-stdout-read.js,, +pseudo-tty/test-tty-color-support.js,, +pseudo-tty/test-tty-isatty.js,, +pseudo-tty/test-tty-stdin-call-end.js,, +pseudo-tty/test-tty-stdin-end.js,, +pseudo-tty/test-stdin-write.js,, +pseudo-tty/test-tty-stdout-end.js,, +pseudo-tty/test-tty-stdout-resize.js,, +pseudo-tty/test-tty-stream-constructors.js,, +pseudo-tty/test-tty-window-size.js,, +pseudo-tty/test-tty-wrap.js,, +pummel/test-net-pingpong.js,, +pummel/test-vm-memleak.js,, diff --git a/test/runtimes/blacklist_php7.3.6.csv b/test/runtimes/blacklist_php7.3.6.csv new file mode 100644 index 000000000..456bf7487 --- /dev/null +++ b/test/runtimes/blacklist_php7.3.6.csv @@ -0,0 +1,29 @@ +test name,bug id,comment +ext/intl/tests/bug77895.phpt,, +ext/intl/tests/dateformat_bug65683_2.phpt,, +ext/mbstring/tests/bug76319.phpt,, +ext/mbstring/tests/bug76958.phpt,, +ext/mbstring/tests/bug77025.phpt,, +ext/mbstring/tests/bug77165.phpt,, +ext/mbstring/tests/bug77454.phpt,, +ext/mbstring/tests/mb_convert_encoding_leak.phpt,, +ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,, +ext/standard/tests/file/filetype_variation.phpt,, +ext/standard/tests/file/fopen_variation19.phpt,, +ext/standard/tests/file/php_fd_wrapper_01.phpt,, +ext/standard/tests/file/php_fd_wrapper_02.phpt,, +ext/standard/tests/file/php_fd_wrapper_03.phpt,, +ext/standard/tests/file/php_fd_wrapper_04.phpt,, +ext/standard/tests/file/realpath_bug77484.phpt,, +ext/standard/tests/file/rename_variation.phpt,b/68717309, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,, +ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, +ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, +ext/standard/tests/network/bug20134.phpt,, +tests/output/stream_isatty_err.phpt,b/68720279, +tests/output/stream_isatty_in-err.phpt,b/68720282, +tests/output/stream_isatty_in-out-err.phpt,, +tests/output/stream_isatty_in-out.phpt,b/68720299, +tests/output/stream_isatty_out-err.phpt,b/68720311, +tests/output/stream_isatty_out.phpt,b/68720325, diff --git a/test/runtimes/blacklist_python3.7.3.csv b/test/runtimes/blacklist_python3.7.3.csv new file mode 100644 index 000000000..2b9947212 --- /dev/null +++ b/test/runtimes/blacklist_python3.7.3.csv @@ -0,0 +1,27 @@ +test name,bug id,comment +test_asynchat,b/76031995,SO_REUSEADDR +test_asyncio,,Fails on Docker. +test_asyncore,b/76031995,SO_REUSEADDR +test_epoll,, +test_fcntl,,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. +test_ftplib,,Fails in Docker +test_httplib,b/76031995,SO_REUSEADDR +test_imaplib,, +test_logging,, +test_multiprocessing_fork,,Flaky. Sometimes times out. +test_multiprocessing_forkserver,,Flaky. Sometimes times out. +test_multiprocessing_main_handling,,Flaky. Sometimes times out. +test_multiprocessing_spawn,,Flaky. Sometimes times out. +test_nntplib,b/76031995,tests should not set SO_REUSEADDR +test_poplib,,Fails on Docker +test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted +test_pty,b/76157709,out of pty devices +test_readline,b/76157709,out of pty devices +test_resource,b/76174079, +test_selectors,b/76116849,OSError not raised with epoll +test_smtplib,b/76031995,SO_REUSEADDR and unclosed sockets +test_socket,b/75983380, +test_ssl,b/76031995,SO_REUSEADDR +test_subprocess,, +test_support,b/76031995,SO_REUSEADDR +test_telnetlib,b/76031995,SO_REUSEADDR diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/blacklist_test.go new file mode 100644 index 000000000..52f49b984 --- /dev/null +++ b/test/runtimes/blacklist_test.go @@ -0,0 +1,37 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "flag" + "os" + "testing" +) + +func TestMain(m *testing.M) { + flag.Parse() + os.Exit(m.Run()) +} + +// Test that the blacklist parses without error. +func TestBlacklists(t *testing.T) { + bl, err := getBlacklist() + if err != nil { + t.Fatalf("error parsing blacklist: %v", err) + } + if *blacklistFile != "" && len(bl) == 0 { + t.Errorf("got empty blacklist for file %q", blacklistFile) + } +} diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl index 5e3065342..6f84ca852 100644 --- a/test/runtimes/build_defs.bzl +++ b/test/runtimes/build_defs.bzl @@ -1,33 +1,72 @@ -"""Defines a rule for runsc test targets.""" +"""Defines a rule for runtime test targets.""" + +load("@io_bazel_rules_go//go:def.bzl", "go_test") -# runtime_test is a macro that will create targets to run the given test target -# with different runtime options. def runtime_test( + name, lang, - image, + image_repo = "gcr.io/gvisor-presubmit", + image_name = None, + blacklist_file = None, shard_count = 50, size = "enormous"): + """Generates sh_test and blacklist test targets for a given runtime. + + Args: + name: The name of the runtime being tested. Typically, the lang + version. + This is used in the names of the generated test targets. + lang: The language being tested. + image_repo: The docker repository containing the proctor image to run. + i.e., the prefix to the fully qualified docker image id. + image_name: The name of the image in the image_repo. + Defaults to the test name. + blacklist_file: A test blacklist to pass to the runtime test's runner. + shard_count: See Bazel common test attributes. + size: See Bazel common test attributes. + """ + if image_name == None: + image_name = name + args = [ + "--lang", + lang, + "--image", + "/".join([image_repo, image_name]), + ] + data = [ + ":runner", + ] + if blacklist_file: + args += ["--blacklist_file", "test/runtimes/" + blacklist_file] + data += [blacklist_file] + + # Add a test that the blacklist parses correctly. + blacklist_test(name, blacklist_file) + sh_test( - name = lang + "_test", + name = name + "_test", srcs = ["runner.sh"], - args = [ - "--lang", - lang, - "--image", - image, - ], - data = [ - ":runner", - ], + args = args, + data = data, size = size, shard_count = shard_count, tags = [ # Requires docker and runsc to be configured before the test runs. - "manual", "local", + # Don't include test target in wildcard target patterns. + "manual", ], ) +def blacklist_test(name, blacklist_file): + """Test that a blacklist parses correctly.""" + go_test( + name = name + "_blacklist_test", + embed = [":runner"], + srcs = ["blacklist_test.go"], + args = ["--blacklist_file", "test/runtimes/" + blacklist_file], + data = [blacklist_file], + ) + def sh_test(**kwargs): """Wraps the standard sh_test.""" native.sh_test( diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go index 3a15f59a7..bec37c69d 100644 --- a/test/runtimes/runner.go +++ b/test/runtimes/runner.go @@ -16,6 +16,7 @@ package main import ( + "encoding/csv" "flag" "fmt" "io" @@ -31,8 +32,9 @@ import ( ) var ( - lang = flag.String("lang", "", "language runtime to test") - image = flag.String("image", "", "docker image with runtime tests") + lang = flag.String("lang", "", "language runtime to test") + image = flag.String("image", "", "docker image with runtime tests") + blacklistFile = flag.String("blacklist_file", "", "file containing blacklist of tests to exclude, in CSV format with fields: test name, bug id, comment") ) // Wait time for each test to run. @@ -52,6 +54,13 @@ func main() { // defered functions before exiting. It returns an exit code that should be // passed to os.Exit. func runTests() int { + // Get tests to blacklist. + blacklist, err := getBlacklist() + if err != nil { + fmt.Fprintf(os.Stderr, "Error getting blacklist: %s\n", err.Error()) + return 1 + } + // Create a single docker container that will be used for all tests. d := dockerutil.MakeDocker("gvisor-" + *lang) defer d.CleanUp() @@ -59,7 +68,7 @@ func runTests() int { // Get a slice of tests to run. This will also start a single Docker // container that will be used to run each test. The final test will // stop the Docker container. - tests, err := getTests(d) + tests, err := getTests(d, blacklist) if err != nil { fmt.Fprintf(os.Stderr, "%s\n", err.Error()) return 1 @@ -71,7 +80,7 @@ func runTests() int { // getTests returns a slice of tests to run, subject to the shard size and // index. -func getTests(d dockerutil.Docker) ([]testing.InternalTest, error) { +func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.InternalTest, error) { // Pull the image. if err := dockerutil.Pull(*image); err != nil { return nil, fmt.Errorf("docker pull %q failed: %v", *image, err) @@ -106,12 +115,18 @@ func getTests(d dockerutil.Docker) ([]testing.InternalTest, error) { itests = append(itests, testing.InternalTest{ Name: tc, F: func(t *testing.T) { + // Is the test blacklisted? + if _, ok := blacklist[tc]; ok { + t.Skip("SKIP: blacklisted test %q", tc) + } + var ( now = time.Now() done = make(chan struct{}) output string err error ) + go func() { fmt.Printf("RUNNING %s...\n", tc) output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc) @@ -134,6 +149,43 @@ func getTests(d dockerutil.Docker) ([]testing.InternalTest, error) { return itests, nil } +// getBlacklist reads the blacklist file and returns a set of test names to +// exclude. +func getBlacklist() (map[string]struct{}, error) { + blacklist := make(map[string]struct{}) + if *blacklistFile == "" { + return blacklist, nil + } + file, err := testutil.FindFile(*blacklistFile) + if err != nil { + return nil, err + } + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + r := csv.NewReader(f) + + // First line is header. Skip it. + if _, err := r.Read(); err != nil { + return nil, err + } + + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + blacklist[record[0]] = struct{}{} + } + return blacklist, nil +} + // testDeps implements testing.testDeps (an unexported interface), and is // required to use testing.MainStart. type testDeps struct{} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 341e6b252..722d14b53 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -9,7 +9,7 @@ syscall_test(test = "//test/syscalls/linux:accept_bind_stream_test") syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:accept_bind_test", ) @@ -79,6 +79,12 @@ syscall_test(test = "//test/syscalls/linux:clock_nanosleep_test") syscall_test(test = "//test/syscalls/linux:concurrency_test") syscall_test( + add_uds_tree = True, + test = "//test/syscalls/linux:connect_external_test", + use_tmpfs = True, +) + +syscall_test( add_overlay = True, test = "//test/syscalls/linux:creat_test", ) @@ -185,7 +191,7 @@ syscall_test( ) syscall_test( - size = "medium", + size = "large", shard_count = 5, test = "//test/syscalls/linux:itimer_test", ) @@ -428,7 +434,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_abstract_test", ) @@ -439,7 +445,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_domain_test", ) @@ -452,19 +458,19 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", ) @@ -475,13 +481,13 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", ) syscall_test( size = "medium", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", ) @@ -492,7 +498,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_udp_loopback_test", ) @@ -501,10 +507,16 @@ syscall_test( test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test", ) +syscall_test(test = "//test/syscalls/linux:socket_ip_unbound_test") + syscall_test(test = "//test/syscalls/linux:socket_netdevice_test") +syscall_test(test = "//test/syscalls/linux:socket_netlink_test") + syscall_test(test = "//test/syscalls/linux:socket_netlink_route_test") +syscall_test(test = "//test/syscalls/linux:socket_netlink_uevent_test") + syscall_test(test = "//test/syscalls/linux:socket_blocking_local_test") syscall_test(test = "//test/syscalls/linux:socket_blocking_ip_test") @@ -548,7 +560,7 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_pair_test", ) @@ -587,7 +599,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", ) @@ -714,6 +726,7 @@ go_binary( "//runsc/specutils", "//runsc/testutil", "//test/syscalls/gtest", + "//test/uds", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index e94ef5602..dcf5b73ed 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -8,6 +8,7 @@ def syscall_test( size = "small", use_tmpfs = False, add_overlay = False, + add_uds_tree = False, tags = None): _syscall_test( test = test, @@ -15,6 +16,7 @@ def syscall_test( size = size, platform = "native", use_tmpfs = False, + add_uds_tree = add_uds_tree, tags = tags, ) @@ -24,6 +26,7 @@ def syscall_test( size = size, platform = "kvm", use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, tags = tags, ) @@ -33,6 +36,7 @@ def syscall_test( size = size, platform = "ptrace", use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, tags = tags, ) @@ -43,6 +47,7 @@ def syscall_test( size = size, platform = "ptrace", use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. + add_uds_tree = add_uds_tree, tags = tags, overlay = True, ) @@ -55,6 +60,7 @@ def syscall_test( size = size, platform = "ptrace", use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, tags = tags, file_access = "shared", ) @@ -67,7 +73,8 @@ def _syscall_test( use_tmpfs, tags, file_access = "exclusive", - overlay = False): + overlay = False, + add_uds_tree = False): test_name = test.split(":")[1] # Prepend "runsc" to non-native platform names. @@ -103,6 +110,7 @@ def _syscall_test( "--use-tmpfs=" + str(use_tmpfs), "--file-access=" + file_access, "--overlay=" + str(overlay), + "--add-uds-tree=" + str(add_uds_tree), ] sh_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 28b23ce58..f8b8cb724 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -333,6 +333,7 @@ cc_binary( linkstatic = 1, deps = [ ":socket_test_util", + "//test/util:file_descriptor", "//test/util:test_main", "//test/util:test_util", "@com_google_googletest//:gtest", @@ -479,6 +480,21 @@ cc_binary( ) cc_binary( + name = "connect_external_test", + testonly = 1, + srcs = ["connect_external.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "creat_test", testonly = 1, srcs = ["creat.cc"], @@ -654,6 +670,7 @@ cc_binary( "//test/util:thread_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest", ], ) @@ -1551,11 +1568,14 @@ cc_binary( srcs = ["proc_net.cc"], linkstatic = 1, deps = [ + ":socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) @@ -1904,6 +1924,7 @@ cc_binary( srcs = ["sendfile.cc"], linkstatic = 1, deps = [ + "//test/util:eventfd_util", "//test/util:file_descriptor", "//test/util:temp_path", "//test/util:test_main", @@ -2120,6 +2141,7 @@ cc_library( deps = [ ":socket_test_util", "//test/util:test_util", + "//test/util:thread_util", "@com_google_googletest//:gtest", ], alwayslink = 1, @@ -2464,6 +2486,63 @@ cc_binary( ) cc_binary( + name = "socket_bind_to_device_test", + testonly = 1, + srcs = [ + "socket_bind_to_device.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "socket_bind_to_device_sequence_test", + testonly = 1, + srcs = [ + "socket_bind_to_device_sequence.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "socket_bind_to_device_distribution_test", + testonly = 1, + srcs = [ + "socket_bind_to_device_distribution.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "socket_ip_udp_loopback_non_blocking_test", testonly = 1, srcs = [ @@ -2496,6 +2575,22 @@ cc_binary( ) cc_binary( + name = "socket_ip_unbound_test", + testonly = 1, + srcs = [ + "socket_ip_unbound.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_test_util", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "socket_domain_test", testonly = 1, srcs = [ @@ -2582,6 +2677,20 @@ cc_binary( ) cc_binary( + name = "socket_netlink_test", + testonly = 1, + srcs = ["socket_netlink.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "socket_netlink_route_test", testonly = 1, srcs = ["socket_netlink_route.cc"], @@ -2598,6 +2707,21 @@ cc_binary( ], ) +cc_binary( + name = "socket_netlink_uevent_test", + testonly = 1, + srcs = ["socket_netlink_uevent.cc"], + linkstatic = 1, + deps = [ + ":socket_netlink_util", + ":socket_test_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_googletest//:gtest", + ], +) + # These socket tests are in a library because the test cases are shared # across several test build targets. cc_library( @@ -2740,6 +2864,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "socket_bind_to_device_util", + testonly = 1, + srcs = [ + "socket_bind_to_device_util.cc", + ], + hdrs = [ + "socket_bind_to_device_util.h", + ], + deps = [ + "//test/util:test_util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + cc_binary( name = "socket_stream_local_test", testonly = 1, @@ -3163,8 +3304,6 @@ cc_binary( testonly = 1, srcs = ["timers.cc"], linkstatic = 1, - # FIXME(b/136599201) - tags = ["flaky"], deps = [ "//test/util:cleanup", "//test/util:logging", @@ -3253,6 +3392,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:uid_util", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc index 1122ea240..427c42ede 100644 --- a/test/syscalls/linux/accept_bind.cc +++ b/test/syscalls/linux/accept_bind.cc @@ -17,7 +17,6 @@ #include <algorithm> #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/file_descriptor.h" @@ -140,6 +139,18 @@ TEST_P(AllSocketPairTest, Connect) { SyscallSucceeds()); } +TEST_P(AllSocketPairTest, ConnectNonListening) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallFailsWithErrno(ECONNREFUSED)); +} + TEST_P(AllSocketPairTest, ConnectToFilePath) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/accept_bind_stream.cc b/test/syscalls/linux/accept_bind_stream.cc index b6cdb3f4f..7bcd91e9e 100644 --- a/test/syscalls/linux/accept_bind_stream.cc +++ b/test/syscalls/linux/accept_bind_stream.cc @@ -17,7 +17,6 @@ #include <algorithm> #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/bind.cc b/test/syscalls/linux/bind.cc index de8cca53b..9547c4ab2 100644 --- a/test/syscalls/linux/bind.cc +++ b/test/syscalls/linux/bind.cc @@ -17,7 +17,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc index 498c45f16..de1611c21 100644 --- a/test/syscalls/linux/chroot.cc +++ b/test/syscalls/linux/chroot.cc @@ -24,7 +24,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" diff --git a/test/syscalls/linux/clock_nanosleep.cc b/test/syscalls/linux/clock_nanosleep.cc index 52a69d230..b55cddc52 100644 --- a/test/syscalls/linux/clock_nanosleep.cc +++ b/test/syscalls/linux/clock_nanosleep.cc @@ -43,7 +43,7 @@ int sys_clock_nanosleep(clockid_t clkid, int flags, PosixErrorOr<absl::Time> GetTime(clockid_t clk) { struct timespec ts = {}; - int rc = clock_gettime(clk, &ts); + const int rc = clock_gettime(clk, &ts); MaybeSave(); if (rc < 0) { return PosixError(errno, "clock_gettime"); @@ -67,31 +67,32 @@ TEST_P(WallClockNanosleepTest, InvalidValues) { } TEST_P(WallClockNanosleepTest, SleepOneSecond) { - absl::Duration const duration = absl::Seconds(1); - struct timespec dur = absl::ToTimespec(duration); + constexpr absl::Duration kSleepDuration = absl::Seconds(1); + struct timespec duration = absl::ToTimespec(kSleepDuration); - absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &dur, &dur), - SyscallSucceeds()); - absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); + const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); + EXPECT_THAT( + RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &duration, &duration), + SyscallSucceeds()); + const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - EXPECT_GE(after - before, duration); + EXPECT_GE(after - before, kSleepDuration); } TEST_P(WallClockNanosleepTest, InterruptedNanosleep) { - absl::Duration const duration = absl::Seconds(60); - struct timespec dur = absl::ToTimespec(duration); + constexpr absl::Duration kSleepDuration = absl::Seconds(60); + struct timespec duration = absl::ToTimespec(kSleepDuration); // Install no-op signal handler for SIGALRM. struct sigaction sa = {}; sigfillset(&sa.sa_mask); sa.sa_handler = +[](int signo) {}; - auto const cleanup_sa = + const auto cleanup_sa = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa)); // Measure time since setting the alarm, since the alarm will interrupt the // sleep and hence determine how long we sleep. - absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); + const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); // Set an alarm to go off while sleeping. struct itimerval timer = {}; @@ -99,26 +100,51 @@ TEST_P(WallClockNanosleepTest, InterruptedNanosleep) { timer.it_value.tv_usec = 0; timer.it_interval.tv_sec = 1; timer.it_interval.tv_usec = 0; - auto const cleanup = + const auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, timer)); - EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &dur, &dur), + EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &duration, &duration), SyscallFailsWithErrno(EINTR)); - absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); + const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - absl::Duration const remaining = absl::DurationFromTimespec(dur); - EXPECT_GE(after - before + remaining, duration); + // Remaining time updated. + const absl::Duration remaining = absl::DurationFromTimespec(duration); + EXPECT_GE(after - before + remaining, kSleepDuration); +} + +// Remaining time is *not* updated if nanosleep completes uninterrupted. +TEST_P(WallClockNanosleepTest, UninterruptedNanosleep) { + constexpr absl::Duration kSleepDuration = absl::Milliseconds(10); + const struct timespec duration = absl::ToTimespec(kSleepDuration); + + while (true) { + constexpr int kRemainingMagic = 42; + struct timespec remaining; + remaining.tv_sec = kRemainingMagic; + remaining.tv_nsec = kRemainingMagic; + + int ret = sys_clock_nanosleep(GetParam(), 0, &duration, &remaining); + if (ret == EINTR) { + // Retry from beginning. We want a single uninterrupted call. + continue; + } + + EXPECT_THAT(ret, SyscallSucceeds()); + EXPECT_EQ(remaining.tv_sec, kRemainingMagic); + EXPECT_EQ(remaining.tv_nsec, kRemainingMagic); + break; + } } TEST_P(WallClockNanosleepTest, SleepUntil) { - absl::Time const now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); - absl::Time const until = now + absl::Seconds(2); - struct timespec ts = absl::ToTimespec(until); + const absl::Time now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); + const absl::Time until = now + absl::Seconds(2); + const struct timespec ts = absl::ToTimespec(until); EXPECT_THAT( RetryEINTR(sys_clock_nanosleep)(GetParam(), TIMER_ABSTIME, &ts, nullptr), SyscallSucceeds()); - absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); + const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam())); EXPECT_GE(after, until); } @@ -127,8 +153,8 @@ INSTANTIATE_TEST_SUITE_P(Sleepers, WallClockNanosleepTest, ::testing::Values(CLOCK_REALTIME, CLOCK_MONOTONIC)); TEST(ClockNanosleepProcessTest, SleepFiveSeconds) { - absl::Duration const kDuration = absl::Seconds(5); - struct timespec dur = absl::ToTimespec(kDuration); + const absl::Duration kSleepDuration = absl::Seconds(5); + struct timespec duration = absl::ToTimespec(kSleepDuration); // Ensure that CLOCK_PROCESS_CPUTIME_ID advances. std::atomic<bool> done(false); @@ -136,16 +162,16 @@ TEST(ClockNanosleepProcessTest, SleepFiveSeconds) { while (!done.load()) { } }); - auto const cleanup_done = Cleanup([&] { done.store(true); }); + const auto cleanup_done = Cleanup([&] { done.store(true); }); - absl::Time const before = + const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID)); - EXPECT_THAT( - RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0, &dur, &dur), - SyscallSucceeds()); - absl::Time const after = + EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0, + &duration, &duration), + SyscallSucceeds()); + const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID)); - EXPECT_GE(after - before, kDuration); + EXPECT_GE(after - before, kSleepDuration); } } // namespace diff --git a/test/syscalls/linux/connect_external.cc b/test/syscalls/linux/connect_external.cc new file mode 100644 index 000000000..bfe1da82e --- /dev/null +++ b/test/syscalls/linux/connect_external.cc @@ -0,0 +1,163 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <stdlib.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <string> +#include <tuple> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/test_util.h" + +// This file contains tests specific to connecting to host UDS managed outside +// the sandbox / test. +// +// A set of ultity sockets will be created externally in $TEST_UDS_TREE and +// $TEST_UDS_ATTACH_TREE for these tests to interact with. + +namespace gvisor { +namespace testing { + +namespace { + +struct ProtocolSocket { + int protocol; + std::string name; +}; + +// Parameter is (socket root dir, ProtocolSocket). +using GoferStreamSeqpacketTest = + ::testing::TestWithParam<std::tuple<std::string, ProtocolSocket>>; + +// Connect to a socket and verify that write/read work. +// +// An "echo" socket doesn't work for dgram sockets because our socket is +// unnamed. The server thus has no way to reply to us. +TEST_P(GoferStreamSeqpacketTest, Echo) { + std::string env; + ProtocolSocket proto; + std::tie(env, proto) = GetParam(); + + char *val = getenv(env.c_str()); + ASSERT_NE(val, nullptr); + std::string root(val); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0)); + + std::string socket_path = JoinPath(root, proto.name, "echo"); + + struct sockaddr_un addr = {}; + addr.sun_family = AF_UNIX; + memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); + + ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + constexpr int kBufferSize = 64; + char send_buffer[kBufferSize]; + memset(send_buffer, 'a', sizeof(send_buffer)); + + ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)), + SyscallSucceedsWithValue(sizeof(send_buffer))); + + char recv_buffer[kBufferSize]; + ASSERT_THAT(ReadFd(sock.get(), recv_buffer, sizeof(recv_buffer)), + SyscallSucceedsWithValue(sizeof(recv_buffer))); + ASSERT_EQ(0, memcmp(send_buffer, recv_buffer, sizeof(send_buffer))); +} + +// It is not possible to connect to a bound but non-listening socket. +TEST_P(GoferStreamSeqpacketTest, NonListening) { + std::string env; + ProtocolSocket proto; + std::tie(env, proto) = GetParam(); + + char *val = getenv(env.c_str()); + ASSERT_NE(val, nullptr); + std::string root(val); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0)); + + std::string socket_path = JoinPath(root, proto.name, "nonlistening"); + + struct sockaddr_un addr = {}; + addr.sun_family = AF_UNIX; + memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); + + ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr), + sizeof(addr)), + SyscallFailsWithErrno(ECONNREFUSED)); +} + +INSTANTIATE_TEST_SUITE_P( + StreamSeqpacket, GoferStreamSeqpacketTest, + ::testing::Combine( + // Test access via standard path and attach point. + ::testing::Values("TEST_UDS_TREE", "TEST_UDS_ATTACH_TREE"), + ::testing::Values(ProtocolSocket{SOCK_STREAM, "stream"}, + ProtocolSocket{SOCK_SEQPACKET, "seqpacket"}))); + +// Parameter is socket root dir. +using GoferDgramTest = ::testing::TestWithParam<std::string>; + +// Connect to a socket and verify that write works. +// +// An "echo" socket doesn't work for dgram sockets because our socket is +// unnamed. The server thus has no way to reply to us. +TEST_P(GoferDgramTest, Null) { + std::string env = GetParam(); + char *val = getenv(env.c_str()); + ASSERT_NE(val, nullptr); + std::string root(val); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_DGRAM, 0)); + + std::string socket_path = JoinPath(root, "dgram/null"); + + struct sockaddr_un addr = {}; + addr.sun_family = AF_UNIX; + memcpy(addr.sun_path, socket_path.c_str(), socket_path.length()); + + ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + constexpr int kBufferSize = 64; + char send_buffer[kBufferSize]; + memset(send_buffer, 'a', sizeof(send_buffer)); + + ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)), + SyscallSucceedsWithValue(sizeof(send_buffer))); +} + +INSTANTIATE_TEST_SUITE_P(Dgram, GoferDgramTest, + // Test access via standard path and attach point. + ::testing::Values("TEST_UDS_TREE", + "TEST_UDS_ATTACH_TREE")); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index 4c7c95321..581f03533 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -33,6 +33,7 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/multiprocess_util.h" @@ -68,11 +69,12 @@ constexpr char kExit42[] = "--exec_exit_42"; constexpr char kExecWithThread[] = "--exec_exec_with_thread"; constexpr char kExecFromThread[] = "--exec_exec_from_thread"; -// Runs filename with argv and checks that the exit status is expect_status and -// that stderr contains expect_stderr. -void CheckOutput(const std::string& filename, const ExecveArray& argv, - const ExecveArray& envv, int expect_status, - const std::string& expect_stderr) { +// Runs file specified by dirfd and pathname with argv and checks that the exit +// status is expect_status and that stderr contains expect_stderr. +void CheckExecHelper(const absl::optional<int32_t> dirfd, + const std::string& pathname, const ExecveArray& argv, + const ExecveArray& envv, const int flags, + int expect_status, const std::string& expect_stderr) { int pipe_fds[2]; ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds()); @@ -110,8 +112,15 @@ void CheckOutput(const std::string& filename, const ExecveArray& argv, // CloexecEventfd depend on that not happening. }; - auto kill = ASSERT_NO_ERRNO_AND_VALUE( - ForkAndExec(filename, argv, envv, remap_stderr, &child, &execve_errno)); + Cleanup kill; + if (dirfd.has_value()) { + kill = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(*dirfd, pathname, argv, + envv, flags, remap_stderr, + &child, &execve_errno)); + } else { + kill = ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(pathname, argv, envv, remap_stderr, &child, &execve_errno)); + } ASSERT_EQ(0, execve_errno); @@ -140,57 +149,71 @@ void CheckOutput(const std::string& filename, const ExecveArray& argv, EXPECT_TRUE(absl::StrContains(output, expect_stderr)) << output; } -TEST(ExecDeathTest, EmptyPath) { +void CheckExec(const std::string& filename, const ExecveArray& argv, + const ExecveArray& envv, int expect_status, + const std::string& expect_stderr) { + CheckExecHelper(/*dirfd=*/absl::optional<int32_t>(), filename, argv, envv, + /*flags=*/0, expect_status, expect_stderr); +} + +void CheckExecveat(const int32_t dirfd, const std::string& pathname, + const ExecveArray& argv, const ExecveArray& envv, + const int flags, int expect_status, + const std::string& expect_stderr) { + CheckExecHelper(absl::optional<int32_t>(dirfd), pathname, argv, envv, flags, + expect_status, expect_stderr); +} + +TEST(ExecTest, EmptyPath) { int execve_errno; ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec("", {}, {}, nullptr, &execve_errno)); EXPECT_EQ(execve_errno, ENOENT); } -TEST(ExecDeathTest, Basic) { - CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {}, - ArgEnvExitStatus(0, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n")); +TEST(ExecTest, Basic) { + CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {}, + ArgEnvExitStatus(0, 0), + absl::StrCat(WorkloadPath(kBasicWorkload), "\n")); } -TEST(ExecDeathTest, OneArg) { - CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "1"}, - {}, ArgEnvExitStatus(1, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); +TEST(ExecTest, OneArg) { + CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "1"}, + {}, ArgEnvExitStatus(1, 0), + absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); } -TEST(ExecDeathTest, FiveArg) { - CheckOutput(WorkloadPath(kBasicWorkload), - {WorkloadPath(kBasicWorkload), "1", "2", "3", "4", "5"}, {}, - ArgEnvExitStatus(5, 0), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); +TEST(ExecTest, FiveArg) { + CheckExec(WorkloadPath(kBasicWorkload), + {WorkloadPath(kBasicWorkload), "1", "2", "3", "4", "5"}, {}, + ArgEnvExitStatus(5, 0), + absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); } -TEST(ExecDeathTest, OneEnv) { - CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, - {"1"}, ArgEnvExitStatus(0, 1), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); +TEST(ExecTest, OneEnv) { + CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {"1"}, + ArgEnvExitStatus(0, 1), + absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n")); } -TEST(ExecDeathTest, FiveEnv) { - CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, - {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5), - absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); +TEST(ExecTest, FiveEnv) { + CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, + {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5), + absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n")); } -TEST(ExecDeathTest, OneArgOneEnv) { - CheckOutput(WorkloadPath(kBasicWorkload), - {WorkloadPath(kBasicWorkload), "arg"}, {"env"}, - ArgEnvExitStatus(1, 1), - absl::StrCat(WorkloadPath(kBasicWorkload), "\narg\nenv\n")); +TEST(ExecTest, OneArgOneEnv) { + CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "arg"}, + {"env"}, ArgEnvExitStatus(1, 1), + absl::StrCat(WorkloadPath(kBasicWorkload), "\narg\nenv\n")); } -TEST(ExecDeathTest, InterpreterScript) { - CheckOutput(WorkloadPath(kExitScript), {WorkloadPath(kExitScript), "25"}, {}, - ArgEnvExitStatus(25, 0), ""); +TEST(ExecTest, InterpreterScript) { + CheckExec(WorkloadPath(kExitScript), {WorkloadPath(kExitScript), "25"}, {}, + ArgEnvExitStatus(25, 0), ""); } // Everything after the path in the interpreter script is a single argument. -TEST(ExecDeathTest, InterpreterScriptArgSplit) { +TEST(ExecTest, InterpreterScriptArgSplit) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); @@ -199,12 +222,12 @@ TEST(ExecDeathTest, InterpreterScriptArgSplit) { GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo bar"), 0755)); - CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), - absl::StrCat(link.path(), "\nfoo bar\n", script.path(), "\n")); + CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), + absl::StrCat(link.path(), "\nfoo bar\n", script.path(), "\n")); } // Original argv[0] is replaced with the script path. -TEST(ExecDeathTest, InterpreterScriptArgvZero) { +TEST(ExecTest, InterpreterScriptArgvZero) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); @@ -212,13 +235,13 @@ TEST(ExecDeathTest, InterpreterScriptArgvZero) { TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); - CheckOutput(script.path(), {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script.path(), "\n")); + CheckExec(script.path(), {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), + absl::StrCat(link.path(), "\n", script.path(), "\n")); } // Original argv[0] is replaced with the script path, exactly as passed to // execve. -TEST(ExecDeathTest, InterpreterScriptArgvZeroRelative) { +TEST(ExecTest, InterpreterScriptArgvZeroRelative) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); @@ -230,12 +253,12 @@ TEST(ExecDeathTest, InterpreterScriptArgvZeroRelative) { auto script_relative = ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path())); - CheckOutput(script_relative, {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script_relative, "\n")); + CheckExec(script_relative, {"REPLACED"}, {}, ArgEnvExitStatus(1, 0), + absl::StrCat(link.path(), "\n", script_relative, "\n")); } // argv[0] is added as the script path, even if there was none. -TEST(ExecDeathTest, InterpreterScriptArgvZeroAdded) { +TEST(ExecTest, InterpreterScriptArgvZeroAdded) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); @@ -243,12 +266,12 @@ TEST(ExecDeathTest, InterpreterScriptArgvZeroAdded) { TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755)); - CheckOutput(script.path(), {}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script.path(), "\n")); + CheckExec(script.path(), {}, {}, ArgEnvExitStatus(1, 0), + absl::StrCat(link.path(), "\n", script.path(), "\n")); } // A NUL byte in the script line ends parsing. -TEST(ExecDeathTest, InterpreterScriptArgNUL) { +TEST(ExecTest, InterpreterScriptArgNUL) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); @@ -258,12 +281,12 @@ TEST(ExecDeathTest, InterpreterScriptArgNUL) { absl::StrCat("#!", link.path(), " foo", std::string(1, '\0'), "bar"), 0755)); - CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), - absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); + CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), + absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); } // Trailing whitespace following interpreter path is ignored. -TEST(ExecDeathTest, InterpreterScriptTrailingWhitespace) { +TEST(ExecTest, InterpreterScriptTrailingWhitespace) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); @@ -271,12 +294,12 @@ TEST(ExecDeathTest, InterpreterScriptTrailingWhitespace) { TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " "), 0755)); - CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(1, 0), - absl::StrCat(link.path(), "\n", script.path(), "\n")); + CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(1, 0), + absl::StrCat(link.path(), "\n", script.path(), "\n")); } // Multiple whitespace characters between interpreter and arg allowed. -TEST(ExecDeathTest, InterpreterScriptArgWhitespace) { +TEST(ExecTest, InterpreterScriptArgWhitespace) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); @@ -284,11 +307,11 @@ TEST(ExecDeathTest, InterpreterScriptArgWhitespace) { TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo"), 0755)); - CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), - absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); + CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0), + absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n")); } -TEST(ExecDeathTest, InterpreterScriptNoPath) { +TEST(ExecTest, InterpreterScriptNoPath) { TempPath script = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "#!", 0755)); @@ -299,7 +322,7 @@ TEST(ExecDeathTest, InterpreterScriptNoPath) { } // AT_EXECFN is the path passed to execve. -TEST(ExecDeathTest, ExecFn) { +TEST(ExecTest, ExecFn) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload))); @@ -314,18 +337,18 @@ TEST(ExecDeathTest, ExecFn) { auto script_relative = ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path())); - CheckOutput(script_relative, {script_relative}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(script_relative, "\n")); + CheckExec(script_relative, {script_relative}, {}, ArgEnvExitStatus(0, 0), + absl::StrCat(script_relative, "\n")); } -TEST(ExecDeathTest, ExecName) { +TEST(ExecTest, ExecName) { std::string path = WorkloadPath(kStateWorkload); - CheckOutput(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(Basename(path).substr(0, 15), "\n")); + CheckExec(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0), + absl::StrCat(Basename(path).substr(0, 15), "\n")); } -TEST(ExecDeathTest, ExecNameScript) { +TEST(ExecTest, ExecNameScript) { // Symlink through /tmp to ensure the path is short enough. TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload))); @@ -336,21 +359,21 @@ TEST(ExecDeathTest, ExecNameScript) { std::string script_path = script.path(); - CheckOutput(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), - absl::StrCat(Basename(script_path).substr(0, 15), "\n")); + CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), + absl::StrCat(Basename(script_path).substr(0, 15), "\n")); } // execve may be called by a multithreaded process. -TEST(ExecDeathTest, WithSiblingThread) { - CheckOutput("/proc/self/exe", {"/proc/self/exe", kExecWithThread}, {}, - W_EXITCODE(42, 0), ""); +TEST(ExecTest, WithSiblingThread) { + CheckExec("/proc/self/exe", {"/proc/self/exe", kExecWithThread}, {}, + W_EXITCODE(42, 0), ""); } // execve may be called from a thread other than the leader of a multithreaded // process. -TEST(ExecDeathTest, FromSiblingThread) { - CheckOutput("/proc/self/exe", {"/proc/self/exe", kExecFromThread}, {}, - W_EXITCODE(42, 0), ""); +TEST(ExecTest, FromSiblingThread) { + CheckExec("/proc/self/exe", {"/proc/self/exe", kExecFromThread}, {}, + W_EXITCODE(42, 0), ""); } TEST(ExecTest, NotFound) { @@ -376,7 +399,7 @@ void SignalHandler(int signo) { // Signal handlers are reset on execve(2), unless they have default or ignored // disposition. -TEST(ExecStateDeathTest, HandlerReset) { +TEST(ExecStateTest, HandlerReset) { struct sigaction sa; sa.sa_handler = SignalHandler; ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); @@ -388,11 +411,11 @@ TEST(ExecStateDeathTest, HandlerReset) { absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_DFL))), }; - CheckOutput(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // Ignored signal dispositions are not reset. -TEST(ExecStateDeathTest, IgnorePreserved) { +TEST(ExecStateTest, IgnorePreserved) { struct sigaction sa; sa.sa_handler = SIG_IGN; ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); @@ -404,11 +427,11 @@ TEST(ExecStateDeathTest, IgnorePreserved) { absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_IGN))), }; - CheckOutput(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // Signal masks are not reset on exec -TEST(ExecStateDeathTest, SignalMask) { +TEST(ExecStateTest, SignalMask) { sigset_t s; sigemptyset(&s); sigaddset(&s, SIGUSR1); @@ -420,12 +443,12 @@ TEST(ExecStateDeathTest, SignalMask) { absl::StrCat(SIGUSR1), }; - CheckOutput(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); + CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), ""); } // itimers persist across execve. // N.B. Timers created with timer_create(2) should not be preserved! -TEST(ExecStateDeathTest, ItimerPreserved) { +TEST(ExecStateTest, ItimerPreserved) { // The fork in ForkAndExec clears itimers, so only set them up after fork. auto setup_itimer = [] { // Ignore SIGALRM, as we don't actually care about timer @@ -472,10 +495,10 @@ TEST(ExecStateDeathTest, ItimerPreserved) { TEST(ProcSelfExe, ChangesAcrossExecve) { // See exec_proc_exe_workload for more details. We simply // assert that the /proc/self/exe link changes across execve. - CheckOutput(WorkloadPath(kProcExeWorkload), - {WorkloadPath(kProcExeWorkload), - ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))}, - {}, W_EXITCODE(0, 0), ""); + CheckExec(WorkloadPath(kProcExeWorkload), + {WorkloadPath(kProcExeWorkload), + ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))}, + {}, W_EXITCODE(0, 0), ""); } TEST(ExecTest, CloexecNormalFile) { @@ -484,20 +507,20 @@ TEST(ExecTest, CloexecNormalFile) { const FileDescriptor fd_closed_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC)); - CheckOutput(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), - absl::StrCat(fd_closed_on_exec.get())}, - {}, W_EXITCODE(0, 0), ""); + CheckExec(WorkloadPath(kAssertClosedWorkload), + {WorkloadPath(kAssertClosedWorkload), + absl::StrCat(fd_closed_on_exec.get())}, + {}, W_EXITCODE(0, 0), ""); // The assert closed workload exits with code 2 if the file still exists. We // can use this to do a negative test. const FileDescriptor fd_open_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY)); - CheckOutput(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), - absl::StrCat(fd_open_on_exec.get())}, - {}, W_EXITCODE(2, 0), ""); + CheckExec(WorkloadPath(kAssertClosedWorkload), + {WorkloadPath(kAssertClosedWorkload), + absl::StrCat(fd_open_on_exec.get())}, + {}, W_EXITCODE(2, 0), ""); } TEST(ExecTest, CloexecEventfd) { @@ -505,9 +528,239 @@ TEST(ExecTest, CloexecEventfd) { ASSERT_THAT(efd = eventfd(0, EFD_CLOEXEC), SyscallSucceeds()); FileDescriptor fd(efd); - CheckOutput(WorkloadPath(kAssertClosedWorkload), - {WorkloadPath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {}, - W_EXITCODE(0, 0), ""); + CheckExec(WorkloadPath(kAssertClosedWorkload), + {WorkloadPath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {}, + W_EXITCODE(0, 0), ""); +} + +constexpr int kLinuxMaxSymlinks = 40; + +TEST(ExecTest, SymlinkLimitExceeded) { + std::string path = WorkloadPath(kBasicWorkload); + + // Hold onto TempPath objects so they are not destructed prematurely. + std::vector<TempPath> symlinks; + for (int i = 0; i < kLinuxMaxSymlinks + 1; i++) { + symlinks.push_back( + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", path))); + path = symlinks[i].path(); + } + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE( + ForkAndExec(path, {path}, {}, /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) { + std::string tmp_dir = "/tmp"; + std::string interpreter_path = "/bin/echo"; + TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + tmp_dir, absl::StrCat("#!", interpreter_path), 0755)); + std::string script_path = script.path(); + + // Hold onto TempPath objects so they are not destructed prematurely. + std::vector<TempPath> interpreter_symlinks; + std::vector<TempPath> script_symlinks; + for (int i = 0; i < kLinuxMaxSymlinks; i++) { + interpreter_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(tmp_dir, interpreter_path))); + interpreter_path = interpreter_symlinks[i].path(); + script_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(tmp_dir, script_path))); + script_path = script_symlinks[i].path(); + } + + CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), ""); +} + +TEST(ExecveatTest, BasicWithFDCWD) { + std::string path = WorkloadPath(kBasicWorkload); + CheckExecveat(AT_FDCWD, path, {path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0), + absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, Basic) { + std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string parent_dir = std::string(Dirname(absolute_path)); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + + CheckExecveat(dirfd.get(), base, {absolute_path}, {}, /*flags=*/0, + ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n")); +} + +TEST(ExecveatTest, FDNotADirectory) { + std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(absolute_path, 0)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), base, {absolute_path}, {}, + /*flags=*/0, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOTDIR); +} + +TEST(ExecveatTest, AbsolutePathWithFDCWD) { + std::string path = WorkloadPath(kBasicWorkload); + CheckExecveat(AT_FDCWD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, + absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, AbsolutePath) { + std::string path = WorkloadPath(kBasicWorkload); + // File descriptor should be ignored when an absolute path is given. + const int32_t badFD = -1; + CheckExecveat(badFD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0, + absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, EmptyPathBasic) { + std::string path = WorkloadPath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH, ArgEnvExitStatus(0, 0), + absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, EmptyPathWithDirFD) { + std::string path = WorkloadPath(kBasicWorkload); + std::string parent_dir = std::string(Dirname(path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), "", {path}, {}, + AT_EMPTY_PATH, + /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, EACCES); +} + +TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) { + std::string path = WorkloadPath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat( + fd.get(), "", {path}, {}, /*flags=*/0, /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) { + std::string path = WorkloadPath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH)); + + CheckExecveat(fd.get(), path, {path}, {}, AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, RelativePathWithEmptyPathFlag) { + std::string absolute_path = WorkloadPath(kBasicWorkload); + std::string parent_dir = std::string(Dirname(absolute_path)); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + + CheckExecveat(dirfd.get(), base, {absolute_path}, {}, AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n")); +} + +TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) { + std::string parent_dir = "/tmp"; + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload))); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY)); + std::string base = std::string(Basename(link.path())); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {}, + AT_SYMLINK_NOFOLLOW, + /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) { + std::string parent_dir = "/tmp"; + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload))); + std::string path = link.path(); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(AT_FDCWD, path, {path}, {}, + AT_SYMLINK_NOFOLLOW, + /*child=*/nullptr, &execve_errno)); + EXPECT_EQ(execve_errno, ELOOP); +} + +TEST(ExecveatTest, SymlinkNoFollowAndEmptyPath) { + TempPath link = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload))); + std::string path = link.path(); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, 0)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH | AT_SYMLINK_NOFOLLOW, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, SymlinkNoFollowIgnoreSymlinkAncestor) { + TempPath parent_link = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", "/bin")); + std::string path_with_symlink = JoinPath(parent_link.path(), "echo"); + + CheckExecveat(AT_FDCWD, path_with_symlink, {path_with_symlink}, {}, + AT_SYMLINK_NOFOLLOW, ArgEnvExitStatus(0, 0), ""); +} + +TEST(ExecveatTest, SymlinkNoFollowWithNormalFile) { + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/bin", O_DIRECTORY)); + + CheckExecveat(dirfd.get(), "echo", {"echo"}, {}, AT_SYMLINK_NOFOLLOW, + ArgEnvExitStatus(0, 0), ""); +} + +TEST(ExecveatTest, BasicWithCloexecFD) { + std::string path = WorkloadPath(kBasicWorkload); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); + + CheckExecveat(fd.get(), "", {path}, {}, AT_SYMLINK_NOFOLLOW | AT_EMPTY_PATH, + ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n")); +} + +TEST(ExecveatTest, InterpreterScriptWithCloexecFD) { + std::string path = WorkloadPath(kExitScript); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), "", {path}, {}, + AT_EMPTY_PATH, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, InterpreterScriptWithCloexecDirFD) { + std::string absolute_path = WorkloadPath(kExitScript); + std::string parent_dir = std::string(Dirname(absolute_path)); + std::string base = std::string(Basename(absolute_path)); + const FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_CLOEXEC | O_DIRECTORY)); + + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {}, + /*flags=*/0, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, ENOENT); +} + +TEST(ExecveatTest, InvalidFlags) { + int execve_errno; + ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat( + /*dirfd=*/-1, "", {}, {}, /*flags=*/0xFFFF, /*child=*/nullptr, + &execve_errno)); + EXPECT_EQ(execve_errno, EINVAL); } // Priority consistent across calls to execve() @@ -522,9 +775,8 @@ TEST(GetpriorityTest, ExecveMaintainsPriority) { // Program run (priority_execve) will exit(X) where // X=getpriority(PRIO_PROCESS,0). Check that this exit value is prio. - CheckOutput(WorkloadPath(kPriorityWorkload), - {WorkloadPath(kPriorityWorkload)}, {}, - W_EXITCODE(expected_exit_code, 0), ""); + CheckExec(WorkloadPath(kPriorityWorkload), {WorkloadPath(kPriorityWorkload)}, + {}, W_EXITCODE(expected_exit_code, 0), ""); } void ExecWithThread() { diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 91b55015c..0a3931e5a 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -401,12 +401,17 @@ TEST(ElfTest, DataSegment) { }))); } -// Additonal pages beyond filesz are always RW. +// Additonal pages beyond filesz honor (only) execute protections. // -// N.B. Linux uses set_brk -> vm_brk to additional pages beyond filesz (even -// though start_brk itself will always be beyond memsz). As a result, the -// segment permissions don't apply; the mapping is always RW. +// N.B. Linux changed this in 4.11 (16e72e9b30986 "powerpc: do not make the +// entire heap executable"). Previously, extra pages were always RW. TEST(ElfTest, ExtraMemPages) { + // gVisor has the newer behavior. + if (!IsRunningOnGvisor()) { + auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); + SKIP_IF(version.major < 4 || (version.major == 4 && version.minor < 11)); + } + ElfBinary<64> elf = StandardElf(); // Create a standard ELF, but extend to 1.5 pages. The second page will be the @@ -415,7 +420,7 @@ TEST(ElfTest, ExtraMemPages) { decltype(elf)::ElfPhdr phdr = {}; phdr.p_type = PT_LOAD; - // RWX segment. The extra anon page will be RW anyways. + // RWX segment. The extra anon page will also be RWX. // // N.B. Linux uses clear_user to clear the end of the file-mapped page, which // respects the mapping protections. Thus if we map this RO with memsz > @@ -454,7 +459,7 @@ TEST(ElfTest, ExtraMemPages) { {0x41000, 0x42000, true, true, true, true, kPageSize, 0, 0, 0, file.path().c_str()}, // extra page from anon. - {0x42000, 0x43000, true, true, false, true, 0, 0, 0, 0, ""}, + {0x42000, 0x43000, true, true, true, true, 0, 0, 0, 0, ""}, }))); } @@ -469,7 +474,7 @@ TEST(ElfTest, AnonOnlySegment) { phdr.p_offset = 0; phdr.p_vaddr = 0x41000; phdr.p_filesz = 0; - phdr.p_memsz = kPageSize - 0xe8; + phdr.p_memsz = kPageSize; elf.phdrs.push_back(phdr); elf.UpdateOffsets(); @@ -854,6 +859,11 @@ TEST(ElfTest, ELFInterpreter) { // The first segment really needs to start at 0 for a normal PIE binary, and // thus includes the headers. uint64_t const offset = interpreter.phdrs[1].p_offset; + // N.B. Since Linux 4.10 (0036d1f7eb95b "binfmt_elf: fix calculations for bss + // padding"), Linux unconditionally zeroes the remainder of the highest mapped + // page in an interpreter, failing if the protections don't allow write. Thus + // we must mark this writeable. + interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; interpreter.phdrs[1].p_offset = 0x0; interpreter.phdrs[1].p_vaddr = 0x0; interpreter.phdrs[1].p_filesz += offset; @@ -903,15 +913,15 @@ TEST(ElfTest, ELFInterpreter) { const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); - EXPECT_THAT(child, - ContainsMappings(std::vector<ProcMapsEntry>({ - // Main binary - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, - binary_file.path().c_str()}, - // Interpreter - {interp_load_addr, interp_load_addr + 0x1000, true, false, - true, true, 0, 0, 0, 0, interpreter_file.path().c_str()}, - }))); + EXPECT_THAT( + child, ContainsMappings(std::vector<ProcMapsEntry>({ + // Main binary + {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, + binary_file.path().c_str()}, + // Interpreter + {interp_load_addr, interp_load_addr + 0x1000, true, true, true, + true, 0, 0, 0, 0, interpreter_file.path().c_str()}, + }))); } // Test parameter to ElfInterpterStaticTest cases. The first item is a suffix to @@ -928,6 +938,8 @@ TEST_P(ElfInterpreterStaticTest, Test) { const int expected_errno = std::get<1>(GetParam()); ElfBinary<64> interpreter = StandardElf(); + // See comment in ElfTest.ELFInterpreter. + interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; interpreter.UpdateOffsets(); TempPath interpreter_file = ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter)); @@ -957,7 +969,7 @@ TEST_P(ElfInterpreterStaticTest, Test) { EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({ // Interpreter. - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, + {0x40000, 0x41000, true, true, true, true, 0, 0, 0, 0, interpreter_file.path().c_str()}, }))); } @@ -1035,6 +1047,8 @@ TEST(ElfTest, ELFInterpreterRelative) { // The first segment really needs to start at 0 for a normal PIE binary, and // thus includes the headers. uint64_t const offset = interpreter.phdrs[1].p_offset; + // See comment in ElfTest.ELFInterpreter. + interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; interpreter.phdrs[1].p_offset = 0x0; interpreter.phdrs[1].p_vaddr = 0x0; interpreter.phdrs[1].p_filesz += offset; @@ -1073,15 +1087,15 @@ TEST(ElfTest, ELFInterpreterRelative) { const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1); - EXPECT_THAT(child, - ContainsMappings(std::vector<ProcMapsEntry>({ - // Main binary - {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, - binary_file.path().c_str()}, - // Interpreter - {interp_load_addr, interp_load_addr + 0x1000, true, false, - true, true, 0, 0, 0, 0, interpreter_file.path().c_str()}, - }))); + EXPECT_THAT( + child, ContainsMappings(std::vector<ProcMapsEntry>({ + // Main binary + {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0, + binary_file.path().c_str()}, + // Interpreter + {interp_load_addr, interp_load_addr + 0x1000, true, true, true, + true, 0, 0, 0, 0, interpreter_file.path().c_str()}, + }))); } // ELF interpreter architecture doesn't match the binary. @@ -1095,6 +1109,8 @@ TEST(ElfTest, ELFInterpreterWrongArch) { // The first segment really needs to start at 0 for a normal PIE binary, and // thus includes the headers. uint64_t const offset = interpreter.phdrs[1].p_offset; + // See comment in ElfTest.ELFInterpreter. + interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; interpreter.phdrs[1].p_offset = 0x0; interpreter.phdrs[1].p_vaddr = 0x0; interpreter.phdrs[1].p_filesz += offset; @@ -1174,6 +1190,8 @@ TEST(ElfTest, ElfInterpreterNoExecute) { // The first segment really needs to start at 0 for a normal PIE binary, and // thus includes the headers. uint64_t const offset = interpreter.phdrs[1].p_offset; + // See comment in ElfTest.ELFInterpreter. + interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X; interpreter.phdrs[1].p_offset = 0x0; interpreter.phdrs[1].p_vaddr = 0x0; interpreter.phdrs[1].p_filesz += offset; diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h index 36efabcae..4d155b618 100644 --- a/test/syscalls/linux/file_base.h +++ b/test/syscalls/linux/file_base.h @@ -32,7 +32,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc index 4948a76f0..c4f8bff08 100644 --- a/test/syscalls/linux/ioctl.cc +++ b/test/syscalls/linux/ioctl.cc @@ -25,7 +25,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index 410b42a47..57e99596f 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -122,6 +122,14 @@ SocketKind IPv4UDPUnboundSocket(int type) { UnboundSocketCreator(AF_INET, type | SOCK_DGRAM, IPPROTO_UDP)}; } +SocketKind IPv6UDPUnboundSocket(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "IPv6 UDP socket"); + return SocketKind{ + description, AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP, + UnboundSocketCreator(AF_INET6, type | SOCK_DGRAM, IPPROTO_UDP)}; +} + SocketKind IPv4TCPUnboundSocket(int type) { std::string description = absl::StrCat(DescribeSocketType(type), "IPv4 TCP socket"); @@ -130,6 +138,14 @@ SocketKind IPv4TCPUnboundSocket(int type) { UnboundSocketCreator(AF_INET, type | SOCK_STREAM, IPPROTO_TCP)}; } +SocketKind IPv6TCPUnboundSocket(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "IPv6 TCP socket"); + return SocketKind{ + description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, + UnboundSocketCreator(AF_INET6, type | SOCK_STREAM, IPPROTO_TCP)}; +} + PosixError IfAddrHelper::Load() { Release(); RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_)); diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 3d36b9620..072230d85 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -92,10 +92,18 @@ SocketPairKind IPv4UDPUnboundSocketPair(int type); // a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type. SocketKind IPv4UDPUnboundSocket(int type); +// IPv6UDPUnboundSocketPair returns a SocketKind that represents +// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type. +SocketKind IPv6UDPUnboundSocket(int type); + // IPv4TCPUnboundSocketPair returns a SocketKind that represents // a SimpleSocket created with AF_INET, SOCK_STREAM and the given type. SocketKind IPv4TCPUnboundSocket(int type); +// IPv6TCPUnboundSocketPair returns a SocketKind that represents +// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type. +SocketKind IPv6TCPUnboundSocket(int type); + // IfAddrHelper is a helper class that determines the local interfaces present // and provides functions to obtain their names, index numbers, and IP address. class IfAddrHelper { diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc index 51ce323b9..b77e4cbd1 100644 --- a/test/syscalls/linux/itimer.cc +++ b/test/syscalls/linux/itimer.cc @@ -267,6 +267,9 @@ int TestSIGPROFFairness(absl::Duration sleep) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { + // TODO(b/143247272): CPU time accounting is inaccurate for the KVM platform. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + pid_t child; int execve_errno; auto kill = ASSERT_NO_ERRNO_AND_VALUE( @@ -288,6 +291,9 @@ TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle_NoRandomSave) { + // TODO(b/143247272): CPU time accounting is inaccurate for the KVM platform. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + pid_t child; int execve_errno; auto kill = ASSERT_NO_ERRNO_AND_VALUE( @@ -336,7 +342,9 @@ int main(int argc, char** argv) { } if (arg == gvisor::testing::kSIGPROFFairnessIdle) { MaskSIGPIPE(); - return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(10)); + // Sleep time > ClockTick (10ms) exercises sleeping gVisor's + // kernel.cpuClockTicker. + return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(25)); } } diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc index 08ff4052c..7fd0ea20c 100644 --- a/test/syscalls/linux/madvise.cc +++ b/test/syscalls/linux/madvise.cc @@ -25,7 +25,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/file_descriptor.h" #include "test/util/logging.h" #include "test/util/memory_util.h" diff --git a/test/syscalls/linux/memory_accounting.cc b/test/syscalls/linux/memory_accounting.cc index a6e20f9c3..ff2f49863 100644 --- a/test/syscalls/linux/memory_accounting.cc +++ b/test/syscalls/linux/memory_accounting.cc @@ -16,7 +16,6 @@ #include <map> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 7a3379b9e..92ae55eec 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -61,6 +61,9 @@ namespace testing { namespace { +using ::testing::AnyOf; +using ::testing::Eq; + constexpr char kMessage[] = "soweoneul malhaebwa"; constexpr in_port_t kPort = 0x409c; // htons(40000) @@ -83,11 +86,14 @@ void SendUDPMessage(int sock) { // Send an IP packet and make sure ETH_P_<something else> doesn't pick it up. TEST(BasicCookedPacketTest, WrongType) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP)); + FileDescriptor sock = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_PACKET, SOCK_DGRAM, htons(ETH_P_PUP))); // Let's use a simple IP payload: a UDP datagram. FileDescriptor udp_sock = @@ -118,18 +124,35 @@ class CookedPacketTest : public ::testing::TestWithParam<int> { }; void CookedPacketTest::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } + + if (!IsRunningOnGvisor()) { + FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); + FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); + char enabled; + ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())), SyscallSucceeds()); } void CookedPacketTest::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); - - EXPECT_THAT(close(socket_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(socket_), SyscallSucceeds()); + } } int CookedPacketTest::GetLoopbackIndex() { @@ -142,9 +165,6 @@ int CookedPacketTest::GetLoopbackIndex() { // Receive via a packet socket. TEST_P(CookedPacketTest, Receive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); - // Let's use a simple IP payload: a UDP datagram. FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); @@ -165,13 +185,16 @@ TEST_P(CookedPacketTest, Receive) { ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_len), SyscallSucceedsWithValue(packet_size)); - ASSERT_EQ(src_len, sizeof(src)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns + // sizeof(sockaddr_ll). + ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); + // TODO(b/129292371): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); - EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); EXPECT_EQ(src.sll_halen, ETH_ALEN); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { @@ -201,7 +224,7 @@ TEST_P(CookedPacketTest, Receive) { // Send via a packet socket. TEST_P(CookedPacketTest, Send) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + // TODO(b/129292371): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 9e96460ee..d258d353c 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -26,6 +26,7 @@ #include <sys/types.h> #include <unistd.h> +#include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/base/internal/endian.h" #include "test/syscalls/linux/socket_test_util.h" @@ -61,6 +62,9 @@ namespace testing { namespace { +using ::testing::AnyOf; +using ::testing::Eq; + constexpr char kMessage[] = "soweoneul malhaebwa"; constexpr in_port_t kPort = 0x409c; // htons(40000) @@ -97,8 +101,11 @@ class RawPacketTest : public ::testing::TestWithParam<int> { }; void RawPacketTest::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } if (!IsRunningOnGvisor()) { FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( @@ -119,10 +126,10 @@ void RawPacketTest::SetUp() { } void RawPacketTest::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); - - EXPECT_THAT(close(socket_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(socket_), SyscallSucceeds()); + } } int RawPacketTest::GetLoopbackIndex() { @@ -135,9 +142,6 @@ int RawPacketTest::GetLoopbackIndex() { // Receive via a packet socket. TEST_P(RawPacketTest, Receive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - SKIP_IF(IsRunningOnGvisor()); - // Let's use a simple IP payload: a UDP datagram. FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); @@ -158,16 +162,16 @@ TEST_P(RawPacketTest, Receive) { ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_len), SyscallSucceedsWithValue(packet_size)); - // sizeof(src) is the size of a struct sockaddr_ll. sockaddr_ll ends with an 8 - // byte physical address field, but ethernet (MAC) addresses only use 6 bytes. - // Thus src_len should get modified to be 2 less than the size of sockaddr_ll. - ASSERT_EQ(src_len, sizeof(src) - 2); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns + // sizeof(sockaddr_ll). + ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); + // TODO(b/129292371): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP)); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); - EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK); EXPECT_EQ(src.sll_halen, ETH_ALEN); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { @@ -208,7 +212,7 @@ TEST_P(RawPacketTest, Receive) { // Send via a packet socket. TEST_P(RawPacketTest, Send) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + // TODO(b/129292371): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. @@ -306,7 +310,7 @@ TEST_P(RawPacketTest, Send) { } INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, - ::testing::Values(ETH_P_IP /*, ETH_P_ALL*/)); + ::testing::Values(ETH_P_IP, ETH_P_ALL)); } // namespace diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 10e2a6dfc..c0b354e65 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -20,7 +20,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/notification.h" #include "absl/time/clock.h" diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index 5e3eb1735..2cecf2e5f 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -20,7 +20,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc index eebd129f2..f7ea44054 100644 --- a/test/syscalls/linux/preadv.cc +++ b/test/syscalls/linux/preadv.cc @@ -22,7 +22,6 @@ #include <string> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc index aac960130..c9246367d 100644 --- a/test/syscalls/linux/preadv2.cc +++ b/test/syscalls/linux/preadv2.cc @@ -21,7 +21,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/memory/memory.h" #include "test/syscalls/linux/file_base.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 6f07803d9..e4c030bbb 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -440,6 +440,11 @@ TEST(ProcSelfAuxv, EntryPresence) { EXPECT_EQ(auxv_entries.count(AT_PHENT), 1); EXPECT_EQ(auxv_entries.count(AT_PHNUM), 1); EXPECT_EQ(auxv_entries.count(AT_BASE), 1); + EXPECT_EQ(auxv_entries.count(AT_UID), 1); + EXPECT_EQ(auxv_entries.count(AT_EUID), 1); + EXPECT_EQ(auxv_entries.count(AT_GID), 1); + EXPECT_EQ(auxv_entries.count(AT_EGID), 1); + EXPECT_EQ(auxv_entries.count(AT_SECURE), 1); EXPECT_EQ(auxv_entries.count(AT_CLKTCK), 1); EXPECT_EQ(auxv_entries.count(AT_RANDOM), 1); EXPECT_EQ(auxv_entries.count(AT_EXECFN), 1); diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index efdaf202b..65bad06d4 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -12,8 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <arpa/inet.h> +#include <errno.h> +#include <netinet/in.h> +#include <poll.h> +#include <sys/socket.h> +#include <sys/syscall.h> +#include <sys/types.h> + #include "gtest/gtest.h" -#include "gtest/gtest.h" +#include "absl/strings/str_split.h" +#include "absl/time/clock.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" @@ -57,6 +67,265 @@ TEST(ProcSysNetIpv4Sack, CanReadAndWrite) { EXPECT_EQ(buf, to_write); } +PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp, + const std::string &type, + const std::string &item) { + std::vector<std::string> snmp_vec = absl::StrSplit(snmp, '\n'); + + // /proc/net/snmp prints a line of headers followed by a line of metrics. + // Only search the headers. + for (unsigned i = 0; i < snmp_vec.size(); i = i + 2) { + if (!absl::StartsWith(snmp_vec[i], type)) continue; + + std::vector<std::string> fields = + absl::StrSplit(snmp_vec[i], ' ', absl::SkipWhitespace()); + + EXPECT_TRUE((i + 1) < snmp_vec.size()); + std::vector<std::string> values = + absl::StrSplit(snmp_vec[i + 1], ' ', absl::SkipWhitespace()); + + EXPECT_TRUE(!fields.empty() && fields.size() == values.size()); + + // Metrics start at the first index. + for (unsigned j = 1; j < fields.size(); j++) { + if (fields[j] == item) { + uint64_t val; + if (!absl::SimpleAtoi(values[j], &val)) { + return PosixError(EINVAL, + absl::StrCat("field is not a number: ", values[j])); + } + + return val; + } + } + } + // We should never get here. + return PosixError( + EINVAL, absl::StrCat("failed to find ", type, "/", item, " in:", snmp)); +} + +TEST(ProcNetSnmp, TcpReset_NoRandomSave) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + DisableSave ds; + + uint64_t oldAttemptFails; + uint64_t oldActiveOpens; + uint64_t oldOutRsts; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + oldOutRsts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts")); + oldAttemptFails = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails")); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); + + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = htons(1234), + }; + + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(connect(s.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallFailsWithErrno(ECONNREFUSED)); + + uint64_t newAttemptFails; + uint64_t newActiveOpens; + uint64_t newOutRsts; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + newOutRsts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts")); + newAttemptFails = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails")); + + EXPECT_EQ(oldActiveOpens, newActiveOpens - 1); + EXPECT_EQ(oldOutRsts, newOutRsts - 1); + EXPECT_EQ(oldAttemptFails, newAttemptFails - 1); +} + +TEST(ProcNetSnmp, TcpEstab_NoRandomSave) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + DisableSave ds; + + uint64_t oldEstabResets; + uint64_t oldActiveOpens; + uint64_t oldPassiveOpens; + uint64_t oldCurrEstab; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + oldPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens")); + oldCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); + oldEstabResets = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets")); + + FileDescriptor s_listen = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = 0, + }; + + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(bind(s_listen.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceeds()); + ASSERT_THAT(listen(s_listen.get(), 1), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = sizeof(sin); + ASSERT_THAT( + getsockname(s_listen.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen), + SyscallSucceeds()); + + FileDescriptor s_connect = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_THAT(connect(s_connect.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceeds()); + + auto s_accept = + ASSERT_NO_ERRNO_AND_VALUE(Accept(s_listen.get(), nullptr, nullptr)); + + uint64_t newEstabResets; + uint64_t newActiveOpens; + uint64_t newPassiveOpens; + uint64_t newCurrEstab; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens")); + newPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens")); + newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); + + EXPECT_EQ(oldActiveOpens, newActiveOpens - 1); + EXPECT_EQ(oldPassiveOpens, newPassiveOpens - 1); + EXPECT_EQ(oldCurrEstab, newCurrEstab - 2); + + // Send 1 byte from client to server. + ASSERT_THAT(send(s_connect.get(), "a", 1, 0), SyscallSucceedsWithValue(1)); + + constexpr int kPollTimeoutMs = 20000; // Wait up to 20 seconds for the data. + + // Wait until server-side fd sees the data on its side but don't read it. + struct pollfd poll_fd = {s_accept.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + // Now close server-side fd without reading the data which leads to a RST + // packet sent to client side. + s_accept.reset(-1); + + // Wait until client-side fd sees RST packet. + struct pollfd poll_fd1 = {s_connect.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd1, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + // Now close client-side fd. + s_connect.reset(-1); + + // Wait until the process of the netstack. + absl::SleepFor(absl::Seconds(1)); + + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab")); + newEstabResets = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets")); + + EXPECT_EQ(oldCurrEstab, newCurrEstab); + EXPECT_EQ(oldEstabResets, newEstabResets - 2); +} + +TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + DisableSave ds; + + uint64_t oldOutDatagrams; + uint64_t oldNoPorts; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + oldNoPorts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts")); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = htons(4444), + }; + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceedsWithValue(1)); + + uint64_t newOutDatagrams; + uint64_t newNoPorts; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + newNoPorts = + ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts")); + + EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1); + EXPECT_EQ(oldNoPorts, newNoPorts - 1); +} + +TEST(ProcNetSnmp, UdpIn) { + // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. + const DisableSave ds; + + uint64_t oldOutDatagrams; + uint64_t oldInDatagrams; + auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + oldInDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams")); + + std::cerr << "snmp: " << std::endl << snmp << std::endl; + FileDescriptor server = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_port = htons(0), + }; + ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1); + ASSERT_THAT(bind(server.get(), (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceeds()); + // Get the port bound by the server socket. + socklen_t addrlen = sizeof(sin); + ASSERT_THAT( + getsockname(server.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen), + SyscallSucceeds()); + + FileDescriptor client = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + ASSERT_THAT( + sendto(client.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)), + SyscallSucceedsWithValue(1)); + + char buf[128]; + ASSERT_THAT(recvfrom(server.get(), buf, sizeof(buf), 0, NULL, NULL), + SyscallSucceedsWithValue(1)); + + uint64_t newOutDatagrams; + uint64_t newInDatagrams; + snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp")); + std::cerr << "new snmp: " << std::endl << snmp << std::endl; + newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams")); + newInDatagrams = ASSERT_NO_ERRNO_AND_VALUE( + GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams")); + + EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1); + EXPECT_EQ(oldInDatagrams, newInDatagrams - 1); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc index f6d7ad0bb..2659f6a98 100644 --- a/test/syscalls/linux/proc_net_tcp.cc +++ b/test/syscalls/linux/proc_net_tcp.cc @@ -18,7 +18,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -250,6 +249,247 @@ TEST(ProcNetTCP, State) { EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED); } +constexpr char kProcNetTCP6Header[] = + " sl local_address remote_address" + " st tx_queue rx_queue tr tm->when retrnsmt" + " uid timeout inode"; + +// TCP6Entry represents a single entry from /proc/net/tcp6. +struct TCP6Entry { + struct in6_addr local_addr; + uint16_t local_port; + + struct in6_addr remote_addr; + uint16_t remote_port; + + uint64_t state; + uint64_t uid; + uint64_t inode; +}; + +bool IPv6AddrEqual(const struct in6_addr* a1, const struct in6_addr* a2) { + return memcmp(a1, a2, sizeof(struct in6_addr)) == 0; +} + +// Finds the first entry in 'entries' for which 'predicate' returns true. +// Returns true on match, and sets 'match' to a copy of the matching entry. If +// 'match' is null, it's ignored. +bool FindBy6(const std::vector<TCP6Entry>& entries, TCP6Entry* match, + std::function<bool(const TCP6Entry&)> predicate) { + for (const TCP6Entry& entry : entries) { + if (predicate(entry)) { + if (match != nullptr) { + *match = entry; + } + return true; + } + } + return false; +} + +const struct in6_addr* IP6FromInetSockaddr(const struct sockaddr* addr) { + auto* addr6 = reinterpret_cast<const struct sockaddr_in6*>(addr); + return &addr6->sin6_addr; +} + +bool FindByLocalAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match, + const struct sockaddr* addr) { + const struct in6_addr* local = IP6FromInetSockaddr(addr); + uint16_t port = PortFromInetSockaddr(addr); + return FindBy6(entries, match, [local, port](const TCP6Entry& e) { + return (IPv6AddrEqual(&e.local_addr, local) && e.local_port == port); + }); +} + +bool FindByRemoteAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match, + const struct sockaddr* addr) { + const struct in6_addr* remote = IP6FromInetSockaddr(addr); + uint16_t port = PortFromInetSockaddr(addr); + return FindBy6(entries, match, [remote, port](const TCP6Entry& e) { + return (IPv6AddrEqual(&e.remote_addr, remote) && e.remote_port == port); + }); +} + +void ReadIPv6Address(std::string s, struct in6_addr* addr) { + uint32_t a0, a1, a2, a3; + const char* fmt = "%08X%08X%08X%08X"; + EXPECT_EQ(sscanf(s.c_str(), fmt, &a0, &a1, &a2, &a3), 4); + + uint8_t* b = addr->s6_addr; + *((uint32_t*)&b[0]) = a0; + *((uint32_t*)&b[4]) = a1; + *((uint32_t*)&b[8]) = a2; + *((uint32_t*)&b[12]) = a3; +} + +// Returns a parsed representation of /proc/net/tcp6 entries. +PosixErrorOr<std::vector<TCP6Entry>> ProcNetTCP6Entries() { + std::string content; + RETURN_IF_ERRNO(GetContents("/proc/net/tcp6", &content)); + + bool found_header = false; + std::vector<TCP6Entry> entries; + std::vector<std::string> lines = StrSplit(content, '\n'); + std::cerr << "<contents of /proc/net/tcp6>" << std::endl; + for (const std::string& line : lines) { + std::cerr << line << std::endl; + + if (!found_header) { + EXPECT_EQ(line, kProcNetTCP6Header); + found_header = true; + continue; + } + if (line.empty()) { + continue; + } + + // Parse a single entry from /proc/net/tcp6. + // + // Example entries: + // + // clang-format off + // + // sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode + // 0: 00000000000000000000000000000000:1F90 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876340 1 ffff8803da9c9380 100 0 0 10 0 + // 1: 00000000000000000000000000000000:C350 00000000000000000000000000000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 876987 1 ffff8803ec408000 100 0 0 10 0 + // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 + // + // clang-format on + + TCP6Entry entry; + std::vector<std::string> fields = + StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty()); + + ReadIPv6Address(fields[1], &entry.local_addr); + ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16)); + ReadIPv6Address(fields[3], &entry.remote_addr); + ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16)); + ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16)); + ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11])); + ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13])); + + entries.push_back(entry); + } + std::cerr << "<end of /proc/net/tcp6>" << std::endl; + + return entries; +} + +TEST(ProcNetTCP6, Exists) { + const std::string content = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/tcp6")); + const std::string header_line = StrCat(kProcNetTCP6Header, "\n"); + if (IsRunningOnGvisor()) { + // Should be just the header since we don't have any tcp sockets yet. + EXPECT_EQ(content, header_line); + } else { + // On a general linux machine, we could have abitrary sockets on the system, + // so just check the header. + EXPECT_THAT(content, ::testing::StartsWith(header_line)); + } +} + +TEST(ProcNetTCP6, EntryUID) { + auto sockets = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create()); + std::vector<TCP6Entry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + TCP6Entry e; + + ASSERT_TRUE(FindByLocalAddr6(entries, &e, sockets->first_addr())); + EXPECT_EQ(e.uid, geteuid()); + ASSERT_TRUE(FindByRemoteAddr6(entries, &e, sockets->first_addr())); + EXPECT_EQ(e.uid, geteuid()); +} + +TEST(ProcNetTCP6, BindAcceptConnect) { + auto sockets = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create()); + std::vector<TCP6Entry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + // We can only make assertions about the total number of entries if we control + // the entire "machine". + if (IsRunningOnGvisor()) { + EXPECT_EQ(entries.size(), 2); + } + + EXPECT_TRUE(FindByLocalAddr6(entries, nullptr, sockets->first_addr())); + EXPECT_TRUE(FindByRemoteAddr6(entries, nullptr, sockets->first_addr())); +} + +TEST(ProcNetTCP6, InodeReasonable) { + auto sockets = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPAcceptBindSocketPair(0).Create()); + std::vector<TCP6Entry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + + TCP6Entry accepted_entry; + + ASSERT_TRUE( + FindByLocalAddr6(entries, &accepted_entry, sockets->first_addr())); + EXPECT_NE(accepted_entry.inode, 0); + + TCP6Entry client_entry; + ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, sockets->first_addr())); + EXPECT_NE(client_entry.inode, 0); + EXPECT_NE(accepted_entry.inode, client_entry.inode); +} + +TEST(ProcNetTCP6, State) { + std::unique_ptr<FileDescriptor> server = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create()); + + auto test_addr = V6Loopback(); + ASSERT_THAT( + bind(server->get(), reinterpret_cast<struct sockaddr*>(&test_addr.addr), + test_addr.addr_len), + SyscallSucceeds()); + + struct sockaddr_in6 addr6; + socklen_t addrlen = sizeof(struct sockaddr_in6); + auto* addr = reinterpret_cast<struct sockaddr*>(&addr6); + ASSERT_THAT(getsockname(server->get(), addr, &addrlen), SyscallSucceeds()); + ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); + + ASSERT_THAT(listen(server->get(), 10), SyscallSucceeds()); + std::vector<TCP6Entry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + TCP6Entry listen_entry; + + ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr)); + EXPECT_EQ(listen_entry.state, TCP_LISTEN); + + std::unique_ptr<FileDescriptor> client = + ASSERT_NO_ERRNO_AND_VALUE(IPv6TCPUnboundSocket(0).Create()); + ASSERT_THAT(RetryEINTR(connect)(client->get(), addr, addrlen), + SyscallSucceeds()); + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + ASSERT_TRUE(FindByLocalAddr6(entries, &listen_entry, addr)); + EXPECT_EQ(listen_entry.state, TCP_LISTEN); + TCP6Entry client_entry; + ASSERT_TRUE(FindByRemoteAddr6(entries, &client_entry, addr)); + EXPECT_EQ(client_entry.state, TCP_ESTABLISHED); + + FileDescriptor accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(server->get(), nullptr, nullptr)); + + const struct in6_addr* local = IP6FromInetSockaddr(addr); + const uint16_t accepted_local_port = PortFromInetSockaddr(addr); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries()); + TCP6Entry accepted_entry; + ASSERT_TRUE(FindBy6( + entries, &accepted_entry, + [client_entry, local, accepted_local_port](const TCP6Entry& e) { + return IPv6AddrEqual(&e.local_addr, local) && + e.local_port == accepted_local_port && + IPv6AddrEqual(&e.remote_addr, &client_entry.local_addr) && + e.remote_port == client_entry.local_port; + })); + EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc index 369df8e0e..f06f1a24b 100644 --- a/test/syscalls/linux/proc_net_udp.cc +++ b/test/syscalls/linux/proc_net_udp.cc @@ -18,7 +18,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/numbers.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 83dbd1364..66db0acaa 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index bf32efe1e..99a0df235 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -1527,27 +1527,36 @@ TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + int sync_setsid[2]; + int sync_exit[2]; + ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds()); + ASSERT_THAT(pipe(sync_exit), SyscallSucceeds()); + // Create a new process and put it in a new session. pid_t child = fork(); if (!child) { TEST_PCHECK(setsid() >= 0); // Tell the parent we're in a new session. - TEST_PCHECK(!raise(SIGSTOP)); - TEST_PCHECK(!pause()); - _exit(1); + char c = 'c'; + TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); + TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); + _exit(0); } // Wait for the child to tell us it's in a new session. - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, WUNTRACED), - SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WSTOPSIG(wstatus)); + char c = 'c'; + ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1)); // Child is in a new session, so we can't make it the foregroup process group. EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), SyscallFailsWithErrno(EPERM)); - EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds()); + EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1)); + + int wstatus; + EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); + EXPECT_TRUE(WIFEXITED(wstatus)); + EXPECT_EQ(WEXITSTATUS(wstatus), 0); } // Verify that we don't hang when creating a new session from an orphaned diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc index e1603fc2d..b48fe540d 100644 --- a/test/syscalls/linux/pwrite64.cc +++ b/test/syscalls/linux/pwrite64.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc index f6a0fc96c..1dbc0d6df 100644 --- a/test/syscalls/linux/pwritev2.cc +++ b/test/syscalls/linux/pwritev2.cc @@ -21,7 +21,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc index a070817eb..0a27506aa 100644 --- a/test/syscalls/linux/raw_socket_hdrincl.cc +++ b/test/syscalls/linux/raw_socket_hdrincl.cc @@ -63,7 +63,11 @@ class RawHDRINCL : public ::testing::Test { }; void RawHDRINCL::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_RAW), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds()); @@ -76,9 +80,10 @@ void RawHDRINCL::SetUp() { } void RawHDRINCL::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - EXPECT_THAT(close(socket_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(socket_), SyscallSucceeds()); + } } struct iphdr RawHDRINCL::LoopbackHeader() { @@ -123,8 +128,6 @@ bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port, // We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup // creates the first one, so we only have to create one more here. TEST_F(RawHDRINCL, MultipleCreation) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - int s2; ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds()); @@ -133,23 +136,17 @@ TEST_F(RawHDRINCL, MultipleCreation) { // Test that shutting down an unconnected socket fails. TEST_F(RawHDRINCL, FailShutdownWithoutConnect) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); } // Test that listen() fails. TEST_F(RawHDRINCL, FailListen) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP)); } // Test that accept() fails. TEST_F(RawHDRINCL, FailAccept) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct sockaddr saddr; socklen_t addrlen; ASSERT_THAT(accept(socket_, &saddr, &addrlen), @@ -158,8 +155,6 @@ TEST_F(RawHDRINCL, FailAccept) { // Test that the socket is writable immediately. TEST_F(RawHDRINCL, PollWritableImmediately) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct pollfd pfd = {}; pfd.fd = socket_; pfd.events = POLLOUT; @@ -168,8 +163,6 @@ TEST_F(RawHDRINCL, PollWritableImmediately) { // Test that the socket isn't readable. TEST_F(RawHDRINCL, NotReadable) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - // Try to receive data with MSG_DONTWAIT, which returns immediately if there's // nothing to be read. char buf[117]; @@ -179,16 +172,12 @@ TEST_F(RawHDRINCL, NotReadable) { // Test that we can connect() to a valid IP (loopback). TEST_F(RawHDRINCL, ConnectToLoopback) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), SyscallSucceeds()); } TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct iphdr hdr = LoopbackHeader(); ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0), SyscallSucceedsWithValue(sizeof(hdr))); @@ -197,8 +186,6 @@ TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { // HDRINCL implies write-only. Verify that we can't read a packet sent to // loopback. TEST_F(RawHDRINCL, NotReadableAfterWrite) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), SyscallSucceeds()); @@ -221,8 +208,6 @@ TEST_F(RawHDRINCL, NotReadableAfterWrite) { } TEST_F(RawHDRINCL, WriteTooSmall) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), SyscallSucceeds()); @@ -235,8 +220,6 @@ TEST_F(RawHDRINCL, WriteTooSmall) { // Bind to localhost. TEST_F(RawHDRINCL, BindToLocalhost) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - ASSERT_THAT( bind(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), SyscallSucceeds()); @@ -244,8 +227,6 @@ TEST_F(RawHDRINCL, BindToLocalhost) { // Bind to a different address. TEST_F(RawHDRINCL, BindToInvalid) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct sockaddr_in bind_addr = {}; bind_addr.sin_family = AF_INET; bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. @@ -256,8 +237,6 @@ TEST_F(RawHDRINCL, BindToInvalid) { // Send and receive a packet. TEST_F(RawHDRINCL, SendAndReceive) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( @@ -302,8 +281,6 @@ TEST_F(RawHDRINCL, SendAndReceive) { // Send and receive a packet with nonzero IP ID. TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( @@ -349,8 +326,6 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { // Send and receive a packet where the sendto address is not the same as the // provided destination. TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 971592d7d..3de898df7 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -77,7 +77,11 @@ class RawSocketICMPTest : public ::testing::Test { }; void RawSocketICMPTest::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); @@ -90,9 +94,10 @@ void RawSocketICMPTest::SetUp() { } void RawSocketICMPTest::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - EXPECT_THAT(close(s_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(s_), SyscallSucceeds()); + } } // We'll only read an echo in this case, as the kernel won't respond to the @@ -124,7 +129,7 @@ TEST_F(RawSocketICMPTest, SendAndReceiveBadChecksum) { EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } -// + // Send and receive an ICMP packet. TEST_F(RawSocketICMPTest, SendAndReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc index 352037c88..cde2f07c9 100644 --- a/test/syscalls/linux/raw_socket_ipv4.cc +++ b/test/syscalls/linux/raw_socket_ipv4.cc @@ -67,7 +67,11 @@ class RawSocketTest : public ::testing::TestWithParam<int> { }; void RawSocketTest::SetUp() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP(); + } ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); @@ -79,9 +83,10 @@ void RawSocketTest::SetUp() { } void RawSocketTest::TearDown() { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - - EXPECT_THAT(close(s_), SyscallSucceeds()); + // TearDown will be run even if we skip the test. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { + EXPECT_THAT(close(s_), SyscallSucceeds()); + } } // We should be able to create multiple raw sockets for the same protocol. diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc index f327ec3a9..4069cbc7e 100644 --- a/test/syscalls/linux/readv.cc +++ b/test/syscalls/linux/readv.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/syscalls/linux/readv_common.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc index 35d2dd9e3..9658f7d42 100644 --- a/test/syscalls/linux/readv_common.cc +++ b/test/syscalls/linux/readv_common.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/readv_socket.cc b/test/syscalls/linux/readv_socket.cc index 3c315cc02..9b6972201 100644 --- a/test/syscalls/linux/readv_socket.cc +++ b/test/syscalls/linux/readv_socket.cc @@ -19,7 +19,6 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/file_base.h" #include "test/syscalls/linux/readv_common.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index c9d76c2e2..5b474ff32 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -17,7 +17,6 @@ #include <string> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/util/capability_util.h" #include "test/util/cleanup.h" diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc index 88c010aec..e06a2666d 100644 --- a/test/syscalls/linux/select.cc +++ b/test/syscalls/linux/select.cc @@ -21,7 +21,6 @@ #include <cstdio> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/time.h" #include "test/syscalls/linux/base_poll_test.h" #include "test/util/file_descriptor.h" diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index 40c57f543..e9b131ca9 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -447,9 +447,8 @@ TEST(SemaphoreTest, SemCtlGetPidFork) { const pid_t child_pid = fork(); if (child_pid == 0) { - ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds()); - ASSERT_THAT(semctl(sem.get(), 0, GETPID), - SyscallSucceedsWithValue(getpid())); + TEST_PCHECK(semctl(sem.get(), 0, SETVAL, 1) == 0); + TEST_PCHECK(semctl(sem.get(), 0, GETPID) == getpid()); _exit(0); } diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 4502e7fb4..580ab5193 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <sys/eventfd.h> #include <sys/sendfile.h> #include <unistd.h> @@ -21,6 +22,7 @@ #include "absl/strings/string_view.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -511,6 +513,23 @@ TEST(SendFileTest, SendPipeBlocks) { SyscallSucceedsWithValue(kDataSize)); } +TEST(SendFileTest, SendToSpecialFile) { + // Create temp file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); + + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); + constexpr int kSize = 0x7ff; + ASSERT_THAT(ftruncate(inf.get(), kSize), SyscallSucceeds()); + + auto eventfd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); + + // eventfd can accept a number of bytes which is a multiple of 8. + EXPECT_THAT(sendfile(eventfd.get(), inf.get(), nullptr, 0xfffff), + SyscallSucceedsWithValue(kSize & (~7))); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc index 1c56540bc..3331288b7 100644 --- a/test/syscalls/linux/sendfile_socket.cc +++ b/test/syscalls/linux/sendfile_socket.cc @@ -185,7 +185,7 @@ TEST_P(SendFileTest, Shutdown) { // Create a socket. std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets()); const FileDescriptor client(std::get<0>(fds)); - FileDescriptor server(std::get<1>(fds)); // non-const, released below. + FileDescriptor server(std::get<1>(fds)); // non-const, reset below. // If this is a TCP socket, then turn off linger. if (GetParam() == AF_INET) { @@ -210,14 +210,14 @@ TEST_P(SendFileTest, Shutdown) { // checking the contents (other tests do that), so we just re-use the same // buffer as above. ScopedThread t([&]() { - int done = 0; + size_t done = 0; while (done < data.size()) { - int n = read(server.get(), data.data(), data.size()); + int n = RetryEINTR(read)(server.get(), data.data(), data.size()); ASSERT_THAT(n, SyscallSucceeds()); done += n; } // Close the server side socket. - ASSERT_THAT(close(server.release()), SyscallSucceeds()); + server.reset(); }); // Continuously stream from the file to the socket. Note we do not assert diff --git a/test/syscalls/linux/sigaltstack.cc b/test/syscalls/linux/sigaltstack.cc index 69b6e4f90..6fd3989a4 100644 --- a/test/syscalls/linux/sigaltstack.cc +++ b/test/syscalls/linux/sigaltstack.cc @@ -22,7 +22,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/util/cleanup.h" #include "test/util/fs_util.h" #include "test/util/multiprocess_util.h" diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc index 54c598627..09ecad34a 100644 --- a/test/syscalls/linux/signalfd.cc +++ b/test/syscalls/linux/signalfd.cc @@ -24,7 +24,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/synchronization/mutex.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -312,6 +311,23 @@ TEST(Signalfd, KillStillKills) { EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), ""); } +TEST(Signalfd, Ppoll) { + sigset_t mask; + sigemptyset(&mask); + sigaddset(&mask, SIGKILL); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC)); + + // Ensure that the given ppoll blocks. + struct pollfd pfd = {}; + pfd.fd = fd.get(); + pfd.events = POLLIN; + struct timespec timeout = {}; + timeout.tv_sec = 1; + EXPECT_THAT(RetryEINTR(ppoll)(&pfd, 1, &timeout, &mask), + SyscallSucceedsWithValue(0)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index caae215b8..3a07ac8d2 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -17,6 +17,7 @@ #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/test_util.h" namespace gvisor { @@ -57,5 +58,28 @@ TEST(SocketTest, ProtocolInet) { } } +using SocketOpenTest = ::testing::TestWithParam<int>; + +// UDS cannot be opened. +TEST_P(SocketOpenTest, Unix) { + // FIXME(b/142001530): Open incorrectly succeeds on gVisor. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor bound = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); + + struct sockaddr_un addr = + ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX)); + + ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + EXPECT_THAT(open(addr.sun_path, GetParam()), SyscallFailsWithErrno(ENXIO)); +} + +INSTANTIATE_TEST_SUITE_P(OpenModes, SocketOpenTest, + ::testing::Values(O_RDONLY, O_RDWR)); + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc new file mode 100644 index 000000000..6b27f6eab --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device.cc @@ -0,0 +1,313 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; + +// Test fixture for SO_BINDTODEVICE tests. +class BindToDeviceTest : public ::testing::TestWithParam<SocketKind> { + protected: + void SetUp() override { + printf("Testing case: %s\n", GetParam().description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + + interface_name_ = "eth1"; + auto interface_names = GetInterfaceNames(); + if (interface_names.find(interface_name_) == interface_names.end()) { + // Need a tunnel. + tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()); + interface_name_ = tunnel_->GetName(); + ASSERT_FALSE(interface_name_.empty()); + } + socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create()); + } + + string interface_name() const { return interface_name_; } + + int socket_fd() const { return socket_->get(); } + + private: + std::unique_ptr<Tunnel> tunnel_; + string interface_name_; + std::unique_ptr<FileDescriptor> socket_; +}; + +constexpr char kIllegalIfnameChar = '/'; + +// Tests getsockopt of the default value. +TEST_P(BindToDeviceTest, GetsockoptDefault) { + char name_buffer[IFNAMSIZ * 2]; + char original_name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Read the default SO_BINDTODEVICE. + memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + for (size_t i = 0; i <= sizeof(name_buffer); i++) { + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = i; + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(name_buffer_size, 0); + EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)), + 0); + } +} + +// Tests setsockopt of invalid device name. +TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Set an invalid device name. + memset(name_buffer, kIllegalIfnameChar, 5); + name_buffer_size = 5; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallFailsWithErrno(ENODEV)); +} + +// Tests setsockopt of a buffer with a valid device name but not +// null-terminated, with different sizes of buffer. +TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); + // Intentionally overwrite the null at the end. + memset(name_buffer + interface_name().size(), kIllegalIfnameChar, + sizeof(name_buffer) - interface_name().size()); + for (size_t i = 1; i <= sizeof(name_buffer); i++) { + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // It should only work if the size provided is exactly right. + if (name_buffer_size == interface_name().size()) { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallSucceeds()); + } else { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallFailsWithErrno(ENODEV)); + } + } +} + +// Tests setsockopt of a buffer with a valid device name and null-terminated, +// with different sizes of buffer. +TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); + // Don't overwrite the null at the end. + memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar, + sizeof(name_buffer) - interface_name().size() - 1); + for (size_t i = 1; i <= sizeof(name_buffer); i++) { + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // It should only work if the size provided is at least the right size. + if (name_buffer_size >= interface_name().size()) { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallSucceeds()); + } else { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallFailsWithErrno(ENODEV)); + } + } +} + +// Tests that setsockopt of an invalid device name doesn't unset the previous +// valid setsockopt. +TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Write unsuccessfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = 5; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallFailsWithErrno(ENODEV)); + + // Read it back successfully, it's unchanged. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); +} + +// Tests that setsockopt of zero-length string correctly unsets the previous +// value. +TEST_P(BindToDeviceTest, SetsockoptValidThenClear) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Clear it successfully. + name_buffer_size = 0; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallSucceeds()); + + // Read it back successfully, it's cleared. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, 0); +} + +// Tests that setsockopt of empty string correctly unsets the previous +// value. +TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Clear it successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer[0] = 0; + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallSucceeds()); + + // Read it back successfully, it's cleared. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, 0); +} + +// Tests getsockopt with different buffer sizes. +TEST_P(BindToDeviceTest, GetsockoptDevice) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back at various buffer sizes. + for (size_t i = 0; i <= sizeof(name_buffer); i++) { + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice + // for this interface name. + if (name_buffer_size >= IFNAMSIZ) { + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + } else { + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallFailsWithErrno(EINVAL)); + EXPECT_EQ(name_buffer_size, i); + } + } +} + +INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest, + ::testing::Values(IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc new file mode 100644 index 000000000..5767181a1 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -0,0 +1,380 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <atomic> +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; +using std::vector; + +struct EndpointConfig { + std::string bind_to_device; + double expected_ratio; +}; + +struct DistributionTestCase { + std::string name; + std::vector<EndpointConfig> endpoints; +}; + +struct ListenerConnector { + TestAddress listener; + TestAddress connector; +}; + +// Test fixture for SO_BINDTODEVICE tests the distribution of packets received +// with varying SO_BINDTODEVICE settings. +class BindToDeviceDistributionTest + : public ::testing::TestWithParam< + ::testing::tuple<ListenerConnector, DistributionTestCase>> { + protected: + void SetUp() override { + printf("Testing case: %s, listener=%s, connector=%s\n", + ::testing::get<1>(GetParam()).name.c_str(), + ::testing::get<0>(GetParam()).listener.description.c_str(), + ::testing::get<0>(GetParam()).connector.description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + } +}; + +PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { + switch (family) { + case AF_INET: + return static_cast<uint16_t>( + reinterpret_cast<sockaddr_in const*>(&addr)->sin_port); + case AF_INET6: + return static_cast<uint16_t>( + reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { + switch (family) { + case AF_INET: + reinterpret_cast<sockaddr_in*>(addr)->sin_port = port; + return NoError(); + case AF_INET6: + reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port; + return NoError(); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +// Binds sockets to different devices and then creates many TCP connections. +// Checks that the distribution of connections received on the sockets matches +// the expectation. +TEST_P(BindToDeviceDistributionTest, Tcp) { + auto const& [listener_connector, test] = GetParam(); + + TestAddress const& listener = listener_connector.listener; + TestAddress const& connector = listener_connector.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + + auto interface_names = GetInterfaceNames(); + + // Create the listening sockets. + std::vector<FileDescriptor> listener_fds; + std::vector<std::unique_ptr<Tunnel>> all_tunnels; + for (auto const& endpoint : test.endpoints) { + if (!endpoint.bind_to_device.empty() && + interface_names.find(endpoint.bind_to_device) == + interface_names.end()) { + all_tunnels.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); + interface_names.insert(endpoint.bind_to_device); + } + + listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP))); + int fd = listener_fds.back().get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, + endpoint.bind_to_device.c_str(), + endpoint.bind_to_device.size() + 1), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (listener_fds.size() > 1) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic<int> connects_received = ATOMIC_VAR_INIT(0); + std::vector<int> accept_counts(listener_fds.size(), 0); + std::vector<std::unique_ptr<ScopedThread>> listen_threads( + listener_fds.size()); + + for (int i = 0; i < listener_fds.size(); i++) { + listen_threads[i] = absl::make_unique<ScopedThread>( + [&listener_fds, &accept_counts, &connects_received, i, + kConnectAttempts]() { + do { + auto fd = Accept(listener_fds[i].get(), nullptr, nullptr); + if (!fd.ok()) { + // Another thread has shutdown our read side causing the accept to + // fail. + ASSERT_GE(connects_received, kConnectAttempts) + << "errno = " << fd.error(); + return; + } + // Receive some data from a socket to be sure that the connect() + // system call has been completed on another side. + int data; + EXPECT_THAT( + RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + accept_counts[i]++; + } while (++connects_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (auto const& listener_fd : listener_fds) { + shutdown(listener_fd.get(), SHUT_RDWR); + } + }); + } + + for (int i = 0; i < kConnectAttempts; i++) { + FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), + SyscallSucceedsWithValue(sizeof(i))); + } + + // Join threads to be sure that all connections have been counted. + for (auto const& listen_thread : listen_threads) { + listen_thread->Join(); + } + // Check that connections are distributed correctly among listening sockets. + for (int i = 0; i < accept_counts.size(); i++) { + EXPECT_THAT( + accept_counts[i], + EquivalentWithin(static_cast<int>(kConnectAttempts * + test.endpoints[i].expected_ratio), + 0.10)) + << "endpoint " << i << " got the wrong number of packets"; + } +} + +// Binds sockets to different devices and then sends many UDP packets. Checks +// that the distribution of packets received on the sockets matches the +// expectation. +TEST_P(BindToDeviceDistributionTest, Udp) { + auto const& [listener_connector, test] = GetParam(); + + TestAddress const& listener = listener_connector.listener; + TestAddress const& connector = listener_connector.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + + auto interface_names = GetInterfaceNames(); + + // Create the listening socket. + std::vector<FileDescriptor> listener_fds; + std::vector<std::unique_ptr<Tunnel>> all_tunnels; + for (auto const& endpoint : test.endpoints) { + if (!endpoint.bind_to_device.empty() && + interface_names.find(endpoint.bind_to_device) == + interface_names.end()) { + all_tunnels.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); + interface_names.insert(endpoint.bind_to_device); + } + + listener_fds.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0))); + int fd = listener_fds.back().get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, + endpoint.bind_to_device.c_str(), + endpoint.bind_to_device.size() + 1), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (listener_fds.size() > 1) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic<int> packets_received = ATOMIC_VAR_INIT(0); + std::vector<int> packets_per_socket(listener_fds.size(), 0); + std::vector<std::unique_ptr<ScopedThread>> receiver_threads( + listener_fds.size()); + + for (int i = 0; i < listener_fds.size(); i++) { + receiver_threads[i] = absl::make_unique<ScopedThread>( + [&listener_fds, &packets_per_socket, &packets_received, i]() { + do { + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + int data; + + auto ret = RetryEINTR(recvfrom)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast<struct sockaddr*>(&addr), &addrlen); + + if (packets_received < kConnectAttempts) { + ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); + } + + if (ret != sizeof(data)) { + // Another thread may have shutdown our read side causing the + // recvfrom to fail. + break; + } + + packets_received++; + packets_per_socket[i]++; + + // A response is required to synchronize with the main thread, + // otherwise the main thread can send more than can fit into receive + // queues. + EXPECT_THAT(RetryEINTR(sendto)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(data))); + } while (packets_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (auto const& listener_fd : listener_fds) { + shutdown(listener_fd.get(), SHUT_RDWR); + } + }); + } + + for (int i = 0; i < kConnectAttempts; i++) { + FileDescriptor const fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + int data; + EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + } + + // Join threads to be sure that all connections have been counted. + for (auto const& receiver_thread : receiver_threads) { + receiver_thread->Join(); + } + // Check that packets are distributed correctly among listening sockets. + for (int i = 0; i < packets_per_socket.size(); i++) { + EXPECT_THAT( + packets_per_socket[i], + EquivalentWithin(static_cast<int>(kConnectAttempts * + test.endpoints[i].expected_ratio), + 0.10)) + << "endpoint " << i << " got the wrong number of packets"; + } +} + +std::vector<DistributionTestCase> GetDistributionTestCases() { + return std::vector<DistributionTestCase>{ + {"Even distribution among sockets not bound to device", + {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}}, + {"Sockets bound to other interfaces get no packets", + {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}}, + {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}}, + {"Even distribution among sockets bound to device", + {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}}, + }; +} + +INSTANTIATE_TEST_SUITE_P( + BindToDeviceTest, BindToDeviceDistributionTest, + ::testing::Combine(::testing::Values( + // Listeners bound to IPv4 addresses refuse + // connections using IPv6 addresses. + ListenerConnector{V4Any(), V4Loopback()}, + ListenerConnector{V4Loopback(), V4MappedLoopback()}), + ::testing::ValuesIn(GetDistributionTestCases()))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc new file mode 100644 index 000000000..e4641c62e --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -0,0 +1,315 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; +using std::vector; + +// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket +// binding. +class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { + protected: + void SetUp() override { + printf("Testing case: %s\n", GetParam().description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + socket_factory_ = GetParam(); + + interface_names_ = GetInterfaceNames(); + } + + PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const { + return socket_factory_.Create(); + } + + // Gets a device by device_id. If the device_id has been seen before, returns + // the previously returned device. If not, finds or creates a new device. + // Returns an empty string on failure. + void GetDevice(int device_id, string *device_name) { + auto device = devices_.find(device_id); + if (device != devices_.end()) { + *device_name = device->second; + return; + } + + // Need to pick a new device. Try ethernet first. + *device_name = absl::StrCat("eth", next_unused_eth_); + if (interface_names_.find(*device_name) != interface_names_.end()) { + devices_[device_id] = *device_name; + next_unused_eth_++; + return; + } + + // Need to make a new tunnel device. gVisor tests should have enough + // ethernet devices to never reach here. + ASSERT_FALSE(IsRunningOnGvisor()); + // Need a tunnel. + tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New())); + devices_[device_id] = tunnels_.back()->GetName(); + *device_name = devices_[device_id]; + } + + // Release the socket + void ReleaseSocket(int socket_id) { + // Close the socket that was made in a previous action. The socket_id + // indicates which socket to close based on index into the list of actions. + sockets_to_close_.erase(socket_id); + } + + // Bind a socket with the reuse option and bind_to_device options. Checks + // that all steps succeed and that the bind command's error matches want. + // Sets the socket_id to uniquely identify the socket bound if it is not + // nullptr. + void BindSocket(bool reuse, int device_id = 0, int want = 0, + int *socket_id = nullptr) { + next_socket_id_++; + sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket_fd = sockets_to_close_[next_socket_id_]->get(); + if (socket_id != nullptr) { + *socket_id = next_socket_id_; + } + + // If reuse is indicated, do that. + if (reuse) { + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + } + + // If the device is non-zero, bind to that device. + if (device_id != 0) { + string device_name; + ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name)); + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, + device_name.c_str(), device_name.size() + 1), + SyscallSucceedsWithValue(0)); + char get_device[100]; + socklen_t get_device_size = 100; + EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device, + &get_device_size), + SyscallSucceedsWithValue(0)); + } + + struct sockaddr_in addr = {}; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr.sin_port = port_; + if (want == 0) { + ASSERT_THAT( + bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr), + sizeof(addr)), + SyscallSucceeds()); + } else { + ASSERT_THAT( + bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr), + sizeof(addr)), + SyscallFailsWithErrno(want)); + } + + if (port_ == 0) { + // We don't yet know what port we'll be using so we need to fetch it and + // remember it for future commands. + socklen_t addr_size = sizeof(addr); + ASSERT_THAT( + getsockname(socket_fd, reinterpret_cast<struct sockaddr *>(&addr), + &addr_size), + SyscallSucceeds()); + port_ = addr.sin_port; + } + } + + private: + SocketKind socket_factory_; + // devices maps from the device id in the test case to the name of the device. + std::unordered_map<int, string> devices_; + // These are the tunnels that were created for the test and will be destroyed + // by the destructor. + vector<std::unique_ptr<Tunnel>> tunnels_; + // A list of all interface names before the test started. + std::unordered_set<string> interface_names_; + // The next ethernet device to use when requested a device. + int next_unused_eth_ = 1; + // The port for all tests. Originally 0 (any) and later set to the port that + // all further commands will use. + in_port_t port_ = 0; + // sockets_to_close_ is a map from action index to the socket that was + // created. + std::unordered_map<int, + std::unique_ptr<gvisor::testing::FileDescriptor>> + sockets_to_close_; + int next_socket_id_ = 0; +}; + +TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 3)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindToDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 1)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 2)); +} + +TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithReuse) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0)); +} + +TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0)); +} + +TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindAndRelease) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + int to_release; + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + // Release the bind to device 0 and try again. + ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 345)); +} + +TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest, + ::testing::Values(IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc new file mode 100644 index 000000000..f4ee775bd --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_util.cc @@ -0,0 +1,75 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_bind_to_device_util.h" + +#include <arpa/inet.h> +#include <fcntl.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> +#include <unistd.h> + +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +using std::string; + +PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) { + int fd; + RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR)); + + // Using `new` to access a non-public constructor. + auto new_tunnel = absl::WrapUnique(new Tunnel(fd)); + + ifreq ifr = {}; + ifr.ifr_flags = IFF_TUN; + strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name)); + + RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr)); + new_tunnel->name_ = ifr.ifr_name; + return new_tunnel; +} + +std::unordered_set<string> GetInterfaceNames() { + struct if_nameindex* interfaces = if_nameindex(); + std::unordered_set<string> names; + if (interfaces == nullptr) { + return names; + } + for (auto interface = interfaces; + interface->if_index != 0 || interface->if_name != nullptr; interface++) { + names.insert(interface->if_name); + } + if_freenameindex(interfaces); + return names; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h new file mode 100644 index 000000000..f941ccc86 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_util.h @@ -0,0 +1,67 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ +#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ + +#include <arpa/inet.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> +#include <unistd.h> + +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +class Tunnel { + public: + static PosixErrorOr<std::unique_ptr<Tunnel>> New( + std::string tunnel_name = ""); + const std::string& GetName() const { return name_; } + + ~Tunnel() { + if (fd_ != -1) { + close(fd_); + } + } + + private: + Tunnel(int fd) : fd_(fd) {} + int fd_ = -1; + std::string name_; +}; + +std::unordered_set<std::string> GetInterfaceNames(); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ diff --git a/test/syscalls/linux/socket_blocking.cc b/test/syscalls/linux/socket_blocking.cc index 00c50d1bf..d7ce57566 100644 --- a/test/syscalls/linux/socket_blocking.cc +++ b/test/syscalls/linux/socket_blocking.cc @@ -20,7 +20,6 @@ #include <cstdio> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index 51d614639..e8f24a59e 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 322ee07ad..2eeee352e 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -13,9 +13,11 @@ // limitations under the License. #include <arpa/inet.h> +#include <linux/tcp.h> #include <netinet/in.h> #include <poll.h> #include <string.h> +#include <sys/epoll.h> #include <sys/socket.h> #include <atomic> @@ -30,6 +32,7 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" @@ -266,6 +269,340 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) { } } +// TCPFinWait2Test creates a pair of connected sockets then closes one end to +// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local +// IP/port on a new socket and tries to connect. The connect should fail w/ +// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the +// connect again with a new socket and this time it should succeed. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Lower FIN_WAIT2 state to 5 seconds for test. + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + // Now bind and connect a new socket. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Disable cooperative saves after this point. As a save between the first + // bind/connect and the second one can cause the linger timeout timer to + // be restarted causing the final bind/connect to fail. + DisableSave ds; + + // TODO(gvisor.dev/issue/1030): Portmanager does not track all 5 tuple + // reservations which causes the bind() to succeed on gVisor but connect + // correctly fails. + if (IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } else { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } + + // Sleep for a little over the linger timeout to reduce flakiness in + // save/restore tests. + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + ds.reset(); + + if (!IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + } + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPLinger2TimeoutAfterClose creates a pair of connected sockets +// then closes one end to trigger FIN_WAIT2 state for the closed endpont. +// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/ +// connecting the same address succeeds. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPResetAfterClose creates a pair of connected sockets then closes +// one end to trigger FIN_WAIT2 state for the closed endpoint verifies +// that we generate RSTs for any new data after the socket is fully +// closed. +TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + int data = 1234; + + // Now send data which should trigger a RST as the other end should + // have timed out and closed the socket. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceeds()); + // Sleep for a shortwhile to get a RST back. + absl::SleepFor(absl::Seconds(1)); + + // Try writing again and we should get an EPIPE back. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallFailsWithErrno(EPIPE)); + + // Trying to read should return zero as the other end did send + // us a FIN. We do it twice to verify that the RST does not cause an + // ECONNRESET on the read after EOF has been read by applicaiton. + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); +} + +// This test is disabled under random save as the the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // We disable saves after this point as a S/R causes the netstack seed + // to be regenerated which changes what ports/ISN is picked for a given + // tuple (src ip,src port, dst ip, dst port). This can cause the final + // SYN to use a sequence number that looks like one from the current + // connection in TIME_WAIT and will not be accepted causing the test + // to timeout. + // + // TODO(gvisor.dev/issue/940): S/R portSeed/portHint + DisableSave ds; + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the accept FD to trigger TIME_WAIT on the accepted socket which + // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of + // TIME_WAIT. + accepted.reset(); + absl::SleepFor(absl::Seconds(1)); + conn_fd.reset(); + absl::SleepFor(absl::Seconds(1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( @@ -516,6 +853,112 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + constexpr int kThreadCount = 3; + + // TODO(b/141211329): endpointsByNic.seed has to be saved/restored. + const DisableSave ds141211329; + + // Create listening sockets. + FileDescriptor listener_fds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + listener_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)); + int fd = listener_fds[i].get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (i != 0) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10; + FileDescriptor client_fds[kConnectAttempts]; + + // Do the first run without save/restore. + DisableSave ds; + for (int i = 0; i < kConnectAttempts; i++) { + client_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + ds.reset(); + + // Check that a mapping of client and server sockets has + // not been change after save/restore. + for (int i = 0; i < kConnectAttempts; i++) { + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + + int epollfd; + ASSERT_THAT(epollfd = epoll_create1(0), SyscallSucceeds()); + + for (int i = 0; i < kThreadCount; i++) { + int fd = listener_fds[i].get(); + struct epoll_event ev; + ev.data.fd = fd; + ev.events = EPOLLIN; + ASSERT_THAT(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev), SyscallSucceeds()); + } + + std::map<uint16_t, int> portToFD; + + for (int i = 0; i < kConnectAttempts * 2; i++) { + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + struct epoll_event ev; + int data, fd; + + ASSERT_THAT(epoll_wait(epollfd, &ev, 1, -1), SyscallSucceedsWithValue(1)); + + fd = ev.data.fd; + EXPECT_THAT(RetryEINTR(recvfrom)(fd, &data, sizeof(data), 0, + reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceedsWithValue(sizeof(data))); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr)); + auto prev_port = portToFD.find(port); + // Check that all packets from one client have been delivered to the same + // server socket. + if (prev_port == portToFD.end()) { + portToFD[port] = ev.data.fd; + } else { + EXPECT_EQ(portToFD[port], ev.data.fd); + } + } +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetReusePortTest, ::testing::Values( diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index bfa7943b1..a37b49447 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -24,14 +24,14 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { -TEST_P(TCPSocketPairTest, TcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, TcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -40,7 +40,7 @@ TEST_P(TCPSocketPairTest, TcpInfoSucceedes) { SyscallSucceeds()); } -TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, ShortTcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -49,7 +49,7 @@ TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) { SyscallSucceeds()); } -TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceedes) { +TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct tcp_info opt = {}; @@ -244,6 +244,31 @@ TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) { SyscallSucceedsWithValue(0)); } +// This test verifies that a shutdown(wr) by the server after sending +// data allows the client to still read() the queued data and a client +// close after sending response allows server to read the incoming +// response. +TEST_P(TCPSocketPairTest, ShutdownWrServerClientClose) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + char buf[10] = {}; + ScopedThread t([&]() { + ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(close(sockets->release_first_fd()), + SyscallSucceedsWithValue(0)); + }); + ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(shutdown)(sockets->second_fd(), SHUT_WR), + SyscallSucceedsWithValue(0)); + t.Join(); + + ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); +} + TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -697,5 +722,72 @@ TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); } +// Linux and Netstack both default to a 60s TCP_LINGER2 timeout. +constexpr int kDefaultTCPLingerTimeout = 60; + +TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + + constexpr int kNegative = -1234; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kNegative, sizeof(kNegative)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kAboveDefault = kDefaultTCPLingerTimeout + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kAboveDefault, sizeof(kAboveDefault)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kTCPLingerTimeout = kDefaultTCPLingerTimeout - 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPLingerTimeout); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc index de63f79d9..f178f1af9 100644 --- a/test/syscalls/linux/socket_ip_tcp_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_udp_generic.cc @@ -22,7 +22,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 044394ba7..2a4ed04a5 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -24,7 +24,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc new file mode 100644 index 000000000..b02872308 --- /dev/null +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -0,0 +1,378 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <netinet/in.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <cstdio> +#include <cstring> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to pairs of IP sockets. +using IPUnboundSocketTest = SimpleSocketTest; + +TEST_P(IPUnboundSocketTest, TtlDefault) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int get = -1; + socklen_t get_sz = sizeof(get); + EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, 64); + EXPECT_EQ(get_sz, sizeof(get)); +} + +TEST_P(IPUnboundSocketTest, SetTtl) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int get1 = -1; + socklen_t get1_sz = sizeof(get1); + EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get1, &get1_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get1_sz, sizeof(get1)); + + int set = 100; + if (set == get1) { + set += 1; + } + socklen_t set_sz = sizeof(set); + EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz), + SyscallSucceedsWithValue(0)); + + int get2 = -1; + socklen_t get2_sz = sizeof(get2); + EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get2_sz, sizeof(get2)); + EXPECT_EQ(get2, set); +} + +TEST_P(IPUnboundSocketTest, ResetTtlToDefault) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int get1 = -1; + socklen_t get1_sz = sizeof(get1); + EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get1, &get1_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get1_sz, sizeof(get1)); + + int set1 = 100; + if (set1 == get1) { + set1 += 1; + } + socklen_t set1_sz = sizeof(set1); + EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set1, set1_sz), + SyscallSucceedsWithValue(0)); + + int set2 = -1; + socklen_t set2_sz = sizeof(set2); + EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set2, set2_sz), + SyscallSucceedsWithValue(0)); + + int get2 = -1; + socklen_t get2_sz = sizeof(get2); + EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get2_sz, sizeof(get2)); + EXPECT_EQ(get2, get1); +} + +TEST_P(IPUnboundSocketTest, ZeroTtl) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int set = 0; + socklen_t set_sz = sizeof(set); + EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(IPUnboundSocketTest, InvalidLargeTtl) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int set = 256; + socklen_t set_sz = sizeof(set); + EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(IPUnboundSocketTest, InvalidNegativeTtl) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int set = -2; + socklen_t set_sz = sizeof(set); + EXPECT_THAT(setsockopt(socket->get(), IPPROTO_IP, IP_TTL, &set, set_sz), + SyscallFailsWithErrno(EINVAL)); +} + +struct TOSOption { + int level; + int option; +}; + +constexpr int INET_ECN_MASK = 3; + +static TOSOption GetTOSOption(int domain) { + TOSOption opt; + switch (domain) { + case AF_INET: + opt.level = IPPROTO_IP; + opt.option = IP_TOS; + break; + case AF_INET6: + opt.level = IPPROTO_IPV6; + opt.option = IPV6_TCLASS; + break; + } + return opt; +} + +TEST_P(IPUnboundSocketTest, TOSDefault) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + TOSOption t = GetTOSOption(GetParam().domain); + int get = -1; + socklen_t get_sz = sizeof(get); + constexpr int kDefaultTOS = 0; + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, sizeof(get)); + EXPECT_EQ(get, kDefaultTOS); +} + +TEST_P(IPUnboundSocketTest, SetTOS) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + int set = 0xC0; + socklen_t set_sz = sizeof(set); + TOSOption t = GetTOSOption(GetParam().domain); + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_sz = sizeof(get); + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, sizeof(get)); + EXPECT_EQ(get, set); +} + +TEST_P(IPUnboundSocketTest, ZeroTOS) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + int set = 0; + socklen_t set_sz = sizeof(set); + TOSOption t = GetTOSOption(GetParam().domain); + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_sz = sizeof(get); + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, sizeof(get)); + EXPECT_EQ(get, set); +} + +TEST_P(IPUnboundSocketTest, InvalidLargeTOS) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + // Test with exceeding the byte space. + int set = 256; + constexpr int kDefaultTOS = 0; + socklen_t set_sz = sizeof(set); + TOSOption t = GetTOSOption(GetParam().domain); + if (GetParam().domain == AF_INET) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallSucceedsWithValue(0)); + } else { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallFailsWithErrno(EINVAL)); + } + int get = -1; + socklen_t get_sz = sizeof(get); + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, sizeof(get)); + EXPECT_EQ(get, kDefaultTOS); +} + +TEST_P(IPUnboundSocketTest, CheckSkipECN) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + int set = 0xFF; + socklen_t set_sz = sizeof(set); + TOSOption t = GetTOSOption(GetParam().domain); + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallSucceedsWithValue(0)); + int expect = static_cast<uint8_t>(set); + if (GetParam().protocol == IPPROTO_TCP) { + expect &= ~INET_ECN_MASK; + } + int get = -1; + socklen_t get_sz = sizeof(get); + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, sizeof(get)); + EXPECT_EQ(get, expect); +} + +TEST_P(IPUnboundSocketTest, ZeroTOSOptionSize) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + int set = 0xC0; + socklen_t set_sz = 0; + TOSOption t = GetTOSOption(GetParam().domain); + if (GetParam().domain == AF_INET) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallSucceedsWithValue(0)); + } else { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallFailsWithErrno(EINVAL)); + } + int get = -1; + socklen_t get_sz = 0; + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, 0); + EXPECT_EQ(get, -1); +} + +TEST_P(IPUnboundSocketTest, SmallTOSOptionSize) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + int set = 0xC0; + constexpr int kDefaultTOS = 0; + TOSOption t = GetTOSOption(GetParam().domain); + for (socklen_t i = 1; i < sizeof(int); i++) { + int expect_tos; + socklen_t expect_sz; + if (GetParam().domain == AF_INET) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i), + SyscallSucceedsWithValue(0)); + expect_tos = set; + expect_sz = sizeof(uint8_t); + } else { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i), + SyscallFailsWithErrno(EINVAL)); + expect_tos = kDefaultTOS; + expect_sz = i; + } + uint get = -1; + socklen_t get_sz = i; + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, expect_sz); + // Account for partial copies by getsockopt, retrieve the lower + // bits specified by get_sz, while comparing against expect_tos. + EXPECT_EQ(get & ~(~0 << (get_sz * 8)), expect_tos); + } +} + +TEST_P(IPUnboundSocketTest, LargeTOSOptionSize) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + int set = 0xC0; + TOSOption t = GetTOSOption(GetParam().domain); + for (socklen_t i = sizeof(int); i < 10; i++) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, i), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_sz = i; + // We expect the system call handler to only copy atmost sizeof(int) bytes + // as asserted by the check below. Hence, we do not expect the copy to + // overflow in getsockopt. + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, sizeof(int)); + EXPECT_EQ(get, set); + } +} + +TEST_P(IPUnboundSocketTest, NegativeTOS) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int set = -1; + socklen_t set_sz = sizeof(set); + TOSOption t = GetTOSOption(GetParam().domain); + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallSucceedsWithValue(0)); + int expect; + if (GetParam().domain == AF_INET) { + expect = static_cast<uint8_t>(set); + if (GetParam().protocol == IPPROTO_TCP) { + expect &= ~INET_ECN_MASK; + } + } else { + // On IPv6 TCLASS, setting -1 has the effect of resetting the + // TrafficClass. + expect = 0; + } + int get = -1; + socklen_t get_sz = sizeof(get); + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, sizeof(get)); + EXPECT_EQ(get, expect); +} + +TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + int set = -2; + socklen_t set_sz = sizeof(set); + TOSOption t = GetTOSOption(GetParam().domain); + int expect; + if (GetParam().domain == AF_INET) { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallSucceedsWithValue(0)); + expect = static_cast<uint8_t>(set); + if (GetParam().protocol == IPPROTO_TCP) { + expect &= ~INET_ECN_MASK; + } + } else { + EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + SyscallFailsWithErrno(EINVAL)); + expect = 0; + } + int get = 0; + socklen_t get_sz = sizeof(get); + EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_sz, sizeof(get)); + EXPECT_EQ(get, expect); +} + +INSTANTIATE_TEST_SUITE_P( + IPUnboundSockets, IPUnboundSocketTest, + ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>( + ApplyVec<SocketKind>(IPv4UDPUnboundSocket, + AllBitwiseCombinations(List<int>{SOCK_DGRAM}, + List<int>{0, + SOCK_NONBLOCK})), + ApplyVec<SocketKind>(IPv6UDPUnboundSocket, + AllBitwiseCombinations(List<int>{SOCK_DGRAM}, + List<int>{0, + SOCK_NONBLOCK})), + ApplyVec<SocketKind>(IPv4TCPUnboundSocket, + AllBitwiseCombinations(List<int>{SOCK_STREAM}, + List<int>{0, + SOCK_NONBLOCK})), + ApplyVec<SocketKind>(IPv6TCPUnboundSocket, + AllBitwiseCombinations(List<int>{SOCK_STREAM}, + List<int>{ + 0, SOCK_NONBLOCK})))))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc index 3a068aacf..3c3712b50 100644 --- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc @@ -23,7 +23,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 67d29af0a..b828b6844 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -21,7 +21,6 @@ #include <cstdio> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index c85ae30dc..98ae414f3 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -27,7 +27,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" @@ -42,6 +41,26 @@ TestAddress V4EmptyAddress() { return t; } +constexpr char kMulticastAddress[] = "224.0.2.1"; + +TestAddress V4Multicast() { + TestAddress t("V4Multicast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + inet_addr(kMulticastAddress); + return t; +} + +TestAddress V4Broadcast() { + TestAddress t("V4Broadcast"); + t.addr.ss_family = AF_INET; + t.addr_len = sizeof(sockaddr_in); + reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = + htonl(INADDR_BROADCAST); + return t; +} + void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { got_if_infos_ = false; @@ -116,7 +135,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) { // Verifies that a broadcast UDP packet will arrive at all UDP sockets with // the destination port number. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastReceivedOnAllExpectedEndpoints) { + UDPBroadcastReceivedOnExpectedPort) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -136,51 +155,134 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); - sockaddr_in rcv_addr = {}; - socklen_t rcv_addr_sz = sizeof(rcv_addr); - rcv_addr.sin_family = AF_INET; - rcv_addr.sin_addr.s_addr = htonl(INADDR_ANY); - ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<struct sockaddr*>(&rcv_addr), - rcv_addr_sz), + // Bind the first socket to the ANY address and let the system assign a port. + auto rcv1_addr = V4Any(); + ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), + rcv1_addr.addr_len), SyscallSucceedsWithValue(0)); // Retrieve port number from first socket so that it can be bound to the // second socket. - rcv_addr = {}; + socklen_t rcv_addr_sz = rcv1_addr.addr_len; ASSERT_THAT( - getsockname(rcvr1->get(), reinterpret_cast<struct sockaddr*>(&rcv_addr), + getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), &rcv_addr_sz), SyscallSucceedsWithValue(0)); - ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<struct sockaddr*>(&rcv_addr), + EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); + auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port; + + // Bind the second socket to the same address:port as the first. + ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), rcv_addr_sz), SyscallSucceedsWithValue(0)); // Bind the non-receiving socket to an ephemeral port. - sockaddr_in norcv_addr = {}; - norcv_addr.sin_family = AF_INET; - norcv_addr.sin_addr.s_addr = htonl(INADDR_ANY); + auto norecv_addr = V4Any(); + ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), + norecv_addr.addr_len), + SyscallSucceedsWithValue(0)); + + // Broadcast a test message. + auto dst_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port; + constexpr char kTestMsg[] = "hello, world"; + EXPECT_THAT( + sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + + // Verify that the receiving sockets received the test message. + char buf[sizeof(kTestMsg)] = {}; + EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); + memset(buf, 0, sizeof(buf)); + EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); + + // Verify that the non-receiving socket did not receive the test message. + memset(buf, 0, sizeof(buf)); + EXPECT_THAT(RetryEINTR(recv)(norcv->get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// Verifies that a broadcast UDP packet will arrive at all UDP sockets bound to +// the destination port number and either INADDR_ANY or INADDR_BROADCAST, but +// not a unicast address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + UDPBroadcastReceivedOnExpectedAddresses) { + // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its + // IPv4 address on eth0. + SKIP_IF(!got_if_infos_); + + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto norcv = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Enable SO_BROADCAST on the sending socket. + ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + // Enable SO_REUSEPORT on all sockets so that they may all be bound to the + // broadcast messages destination port. + ASSERT_THAT(setsockopt(rcvr1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(setsockopt(rcvr2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(setsockopt(norcv->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + // Bind the first socket the ANY address and let the system assign a port. + auto rcv1_addr = V4Any(); + ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), + rcv1_addr.addr_len), + SyscallSucceedsWithValue(0)); + // Retrieve port number from first socket so that it can be bound to the + // second socket. + socklen_t rcv_addr_sz = rcv1_addr.addr_len; ASSERT_THAT( - bind(norcv->get(), reinterpret_cast<struct sockaddr*>(&norcv_addr), - sizeof(norcv_addr)), + getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), + &rcv_addr_sz), SyscallSucceedsWithValue(0)); + EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); + auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port; + + // Bind the second socket to the broadcast address. + auto rcv2_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&rcv2_addr.addr)->sin_port = port; + ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv2_addr.addr), + rcv2_addr.addr_len), + SyscallSucceedsWithValue(0)); + + // Bind the non-receiving socket to the unicast ethernet address. + auto norecv_addr = rcv1_addr; + reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr = + eth_if_sin_addr_; + ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), + norecv_addr.addr_len), + SyscallSucceedsWithValue(0)); // Broadcast a test message. - sockaddr_in dst_addr = {}; - dst_addr.sin_family = AF_INET; - dst_addr.sin_addr.s_addr = htonl(INADDR_BROADCAST); - dst_addr.sin_port = rcv_addr.sin_port; + auto dst_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port; constexpr char kTestMsg[] = "hello, world"; EXPECT_THAT( sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<struct sockaddr*>(&dst_addr), sizeof(dst_addr)), + reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), SyscallSucceedsWithValue(sizeof(kTestMsg))); // Verify that the receiving sockets received the test message. char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(read(rcvr1->get(), buf, sizeof(buf)), + EXPECT_THAT(recv(rcvr1->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); memset(buf, 0, sizeof(buf)); - EXPECT_THAT(read(rcvr2->get(), buf, sizeof(buf)), + EXPECT_THAT(recv(rcvr2->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); @@ -190,10 +292,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SyscallFailsWithErrno(EAGAIN)); } -// Verifies that a UDP broadcast sent via the loopback interface is not received -// by the sender. +// Verifies that a UDP broadcast can be sent and then received back on the same +// socket that is bound to the broadcast address (255.255.255.255). +// FIXME(b/141938460): This can be combined with the next test +// (UDPBroadcastSendRecvOnSocketBoundToAny). TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, - UDPBroadcastViaLoopbackFails) { + UDPBroadcastSendRecvOnSocketBoundToBroadcast) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Enable SO_BROADCAST. @@ -201,33 +305,73 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, sizeof(kSockOptOn)), SyscallSucceedsWithValue(0)); - // Bind the sender to the loopback interface. - sockaddr_in src = {}; - socklen_t src_sz = sizeof(src); - src.sin_family = AF_INET; - src.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT( - bind(sender->get(), reinterpret_cast<struct sockaddr*>(&src), src_sz), - SyscallSucceedsWithValue(0)); + // Bind the sender to the broadcast address. + auto src_addr = V4Broadcast(); + ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr), + src_addr.addr_len), + SyscallSucceedsWithValue(0)); + socklen_t src_sz = src_addr.addr_len; ASSERT_THAT(getsockname(sender->get(), - reinterpret_cast<struct sockaddr*>(&src), &src_sz), + reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz), SyscallSucceedsWithValue(0)); - ASSERT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(src_sz, src_addr.addr_len); // Send the message. - sockaddr_in dst = {}; - dst.sin_family = AF_INET; - dst.sin_addr.s_addr = htonl(INADDR_BROADCAST); - dst.sin_port = src.sin_port; + auto dst_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port; constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<struct sockaddr*>(&dst), sizeof(dst)), + EXPECT_THAT( + sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + + // Verify that the message was received. + char buf[sizeof(kTestMsg)] = {}; + EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); +} + +// Verifies that a UDP broadcast can be sent and then received back on the same +// socket that is bound to the ANY address (0.0.0.0). +// FIXME(b/141938460): This can be combined with the previous test +// (UDPBroadcastSendRecvOnSocketBoundToBroadcast). +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + UDPBroadcastSendRecvOnSocketBoundToAny) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - // Verify that the message was not received by the sender (loopback). + // Enable SO_BROADCAST. + ASSERT_THAT(setsockopt(sender->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + // Bind the sender to the ANY address. + auto src_addr = V4Any(); + ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr), + src_addr.addr_len), + SyscallSucceedsWithValue(0)); + socklen_t src_sz = src_addr.addr_len; + ASSERT_THAT(getsockname(sender->get(), + reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(src_sz, src_addr.addr_len); + + // Send the message. + auto dst_addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port; + constexpr char kTestMsg[] = "hello, world"; + EXPECT_THAT( + sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + + // Verify that the message was received. char buf[sizeof(kTestMsg)] = {}; - EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT(RetryEINTR(recv)(sender->get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_EQ(0, memcmp(buf, kTestMsg, sizeof(kTestMsg))); } // Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST @@ -237,15 +381,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) { // Broadcast a test message without having enabled SO_BROADCAST on the sending // socket. - sockaddr_in addr = {}; - socklen_t addr_sz = sizeof(addr); - addr.sin_family = AF_INET; - addr.sin_port = htons(12345); - addr.sin_addr.s_addr = htonl(INADDR_BROADCAST); + auto addr = V4Broadcast(); + reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = htons(12345); constexpr char kTestMsg[] = "hello, world"; EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<struct sockaddr*>(&addr), addr_sz), + reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len), SyscallFailsWithErrno(EACCES)); } @@ -274,21 +415,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendUnicastOnUnbound) { reinterpret_cast<struct sockaddr*>(&addr), addr_sz), SyscallSucceedsWithValue(sizeof(kTestMsg))); char buf[sizeof(kTestMsg)] = {}; - ASSERT_THAT(read(rcvr->get(), buf, sizeof(buf)), + ASSERT_THAT(recv(rcvr->get(), buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(kTestMsg))); } -constexpr char kMulticastAddress[] = "224.0.2.1"; - -TestAddress V4Multicast() { - TestAddress t("V4Multicast"); - t.addr.ss_family = AF_INET; - t.addr_len = sizeof(sockaddr_in); - reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr = - inet_addr(kMulticastAddress); - return t; -} - // Check that multicast packets won't be delivered to the sending socket with no // set interface or group membership. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, @@ -609,8 +739,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, } // Check that two sockets can join the same multicast group at the same time, -// and both will receive data on it. -TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { +// and both will receive data on it when bound to the ANY address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + TestSendMulticastToTwoBoundToAny) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -624,8 +755,72 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - // Bind the receiver to the v4 any address to ensure that we can receive the - // multicast packet. + // Bind to ANY to receive multicast packets. + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + EXPECT_EQ( + htonl(INADDR_ANY), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); + // On the first iteration, save the port we are bound to. On the second + // iteration, verify the port is the same as the one from the first + // iteration. In other words, both sockets listen on the same port. + if (bound_port == 0) { + bound_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + } else { + EXPECT_EQ(bound_port, + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); + } + + // Register to receive multicast packets. + ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + } + + // Send a multicast packet to the group and verify both receivers get it. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + for (auto& receiver : receivers) { + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT( + RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } +} + +// Check that two sockets can join the same multicast group at the same time, +// and both will receive data on it when bound to the multicast address. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + TestSendMulticastToTwoBoundToMulticastAddress) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + std::unique_ptr<FileDescriptor> receivers[2] = { + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; + + ip_mreq group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + auto receiver_addr = V4Multicast(); + int bound_port = 0; + for (auto& receiver : receivers) { + ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); ASSERT_THAT( bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), receiver_addr.addr_len), @@ -636,6 +831,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + EXPECT_EQ( + inet_addr(kMulticastAddress), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); // On the first iteration, save the port we are bound to. On the second // iteration, verify the port is the same as the one from the first // iteration. In other words, both sockets listen on the same port. @@ -643,6 +841,83 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwo) { bound_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; } else { + EXPECT_EQ( + inet_addr(kMulticastAddress), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); + EXPECT_EQ(bound_port, + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); + } + + // Register to receive multicast packets. + ASSERT_THAT(setsockopt(receiver->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + } + + // Send a multicast packet to the group and verify both receivers get it. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + for (auto& receiver : receivers) { + char recv_buf[sizeof(send_buf)] = {}; + ASSERT_THAT( + RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), 0), + SyscallSucceedsWithValue(sizeof(recv_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + } +} + +// Check that two sockets can join the same multicast group at the same time, +// and with one bound to the wildcard address and the other bound to the +// multicast address, both will receive data. +TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, + TestSendMulticastToTwoBoundToAnyAndMulticastAddress) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + std::unique_ptr<FileDescriptor> receivers[2] = { + ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), + ASSERT_NO_ERRNO_AND_VALUE(NewSocket())}; + + ip_mreq group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + // The first receiver binds to the wildcard address. + auto receiver_addr = V4Any(); + int bound_port = 0; + for (auto& receiver : receivers) { + ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + // On the first iteration, save the port we are bound to and change the + // receiver address from V4Any to V4Multicast so the second receiver binds + // to that. On the second iteration, verify the port is the same as the one + // from the first iteration but the address is different. + if (bound_port == 0) { + EXPECT_EQ( + htonl(INADDR_ANY), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); + bound_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + receiver_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port = + bound_port; + } else { + EXPECT_EQ( + inet_addr(kMulticastAddress), + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_addr.s_addr); EXPECT_EQ(bound_port, reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port); } diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 765f8e0e4..405dbbd73 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -68,7 +68,8 @@ TEST(NetdeviceTest, Netmask) { // Use a netlink socket to get the netmask, which we'll then compare to the // netmask obtained via ioctl. - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { diff --git a/test/syscalls/linux/socket_netlink.cc b/test/syscalls/linux/socket_netlink.cc new file mode 100644 index 000000000..4ec0fd4fa --- /dev/null +++ b/test/syscalls/linux/socket_netlink.cc @@ -0,0 +1,153 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/netlink.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Tests for all netlink socket protocols. + +namespace gvisor { +namespace testing { + +namespace { + +// NetlinkTest parameter is the protocol to test. +using NetlinkTest = ::testing::TestWithParam<int>; + +// Netlink sockets must be SOCK_DGRAM or SOCK_RAW. +TEST_P(NetlinkTest, Types) { + const int protocol = GetParam(); + + EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, protocol), + SyscallFailsWithErrno(ESOCKTNOSUPPORT)); + + int fd; + EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, protocol), SyscallSucceeds()); + EXPECT_THAT(close(fd), SyscallSucceeds()); + + EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, protocol), SyscallSucceeds()); + EXPECT_THAT(close(fd), SyscallSucceeds()); +} + +TEST_P(NetlinkTest, AutomaticPort) { + const int protocol = GetParam(); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); + + struct sockaddr_nl addr = {}; + addr.nl_family = AF_NETLINK; + + EXPECT_THAT( + bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallSucceeds()); + + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, sizeof(addr)); + // This is the only netlink socket in the process, so it should get the PID as + // the port id. + // + // N.B. Another process could theoretically have explicitly reserved our pid + // as a port ID, but that is very unlikely. + EXPECT_EQ(addr.nl_pid, getpid()); +} + +// Calling connect automatically binds to an automatic port. +TEST_P(NetlinkTest, ConnectBinds) { + const int protocol = GetParam(); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); + + struct sockaddr_nl addr = {}; + addr.nl_family = AF_NETLINK; + + EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, sizeof(addr)); + + // Each test is running in a pid namespace, so another process can explicitly + // reserve our pid as a port ID. In this case, a negative portid value will be + // set. + if (static_cast<pid_t>(addr.nl_pid) > 0) { + EXPECT_EQ(addr.nl_pid, getpid()); + } + + memset(&addr, 0, sizeof(addr)); + addr.nl_family = AF_NETLINK; + + // Connecting again is allowed, but keeps the same port. + EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallSucceeds()); + + addrlen = sizeof(addr); + EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, sizeof(addr)); + EXPECT_EQ(addr.nl_pid, getpid()); +} + +TEST_P(NetlinkTest, GetPeerName) { + const int protocol = GetParam(); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol)); + + struct sockaddr_nl addr = {}; + socklen_t addrlen = sizeof(addr); + + EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), + &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, sizeof(addr)); + EXPECT_EQ(addr.nl_family, AF_NETLINK); + // Peer is the kernel if we didn't connect elsewhere. + EXPECT_EQ(addr.nl_pid, 0); +} + +INSTANTIATE_TEST_SUITE_P(ProtocolTest, NetlinkTest, + ::testing::Values(NETLINK_ROUTE, + NETLINK_KOBJECT_UEVENT)); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 32fe0d6d1..ef567f512 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -41,112 +41,7 @@ namespace { using ::testing::AnyOf; using ::testing::Eq; -// Netlink sockets must be SOCK_DGRAM or SOCK_RAW. -TEST(NetlinkRouteTest, Types) { - EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, NETLINK_ROUTE), - SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - - int fd; - EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); -} - -TEST(NetlinkRouteTest, AutomaticPort) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - struct sockaddr_nl addr = {}; - addr.nl_family = AF_NETLINK; - - EXPECT_THAT( - bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallSucceeds()); - - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - // This is the only netlink socket in the process, so it should get the PID as - // the port id. - // - // N.B. Another process could theoretically have explicitly reserved our pid - // as a port ID, but that is very unlikely. - EXPECT_EQ(addr.nl_pid, getpid()); -} - -// Calling connect automatically binds to an automatic port. -TEST(NetlinkRouteTest, ConnectBinds) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - struct sockaddr_nl addr = {}; - addr.nl_family = AF_NETLINK; - - EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - - // Each test is running in a pid namespace, so another process can explicitly - // reserve our pid as a port ID. In this case, a negative portid value will be - // set. - if (static_cast<pid_t>(addr.nl_pid) > 0) { - EXPECT_EQ(addr.nl_pid, getpid()); - } - - memset(&addr, 0, sizeof(addr)); - addr.nl_family = AF_NETLINK; - - // Connecting again is allowed, but keeps the same port. - EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), - SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, sizeof(addr)); - EXPECT_EQ(addr.nl_pid, getpid()); -} - -TEST(NetlinkRouteTest, GetPeerName) { - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); - - struct sockaddr_nl addr = {}; - socklen_t addrlen = sizeof(addr); - - EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, sizeof(addr)); - EXPECT_EQ(addr.nl_family, AF_NETLINK); - // Peer is the kernel if we didn't connect elsewhere. - EXPECT_EQ(addr.nl_pid, 0); -} - -// Parameters for GetSockOpt test. They are: +// Parameters for SockOptTest. They are: // 0: Socket option to query. // 1: A predicate to run on the returned sockopt value. Should return true if // the value is considered ok. @@ -195,7 +90,8 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(SO_DOMAIN, IsEqual(AF_NETLINK), absl::StrFormat("AF_NETLINK (%d)", AF_NETLINK)), std::make_tuple(SO_PROTOCOL, IsEqual(NETLINK_ROUTE), - absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)))); + absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)), + std::make_tuple(SO_PASSCRED, IsEqual(0), "0"))); // Validates the reponses to RTM_GETLINK + NLM_F_DUMP. void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { @@ -218,7 +114,8 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { } TEST(NetlinkRouteTest, GetLinkDump) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -259,7 +156,8 @@ TEST(NetlinkRouteTest, GetLinkDump) { } TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; @@ -292,7 +190,8 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { } TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; @@ -331,7 +230,8 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { } TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; @@ -372,7 +272,8 @@ TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { } TEST(NetlinkRouteTest, ControlMessageIgnored) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -407,7 +308,8 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { } TEST(NetlinkRouteTest, GetAddrDump) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -467,7 +369,8 @@ TEST(NetlinkRouteTest, LookupAll) { // GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. TEST(NetlinkRouteTest, GetRouteDump) { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); struct request { @@ -539,6 +442,270 @@ TEST(NetlinkRouteTest, GetRouteDump) { EXPECT_TRUE(dstFound); } +// RecvmsgTrunc tests the recvmsg MSG_TRUNC flag with zero length output +// buffer. MSG_TRUNC with a zero length buffer should consume subsequent +// messages off the socket. +TEST(NetlinkRouteTest, RecvmsgTrunc) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct rtgenmsg rgm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rgm.rtgen_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + iov.iov_base = NULL; + iov.iov_len = 0; + + int trunclen, trunclen2; + + // Note: This test assumes at least two messages are returned by the + // RTM_GETADDR request. That means at least one RTM_NEWLINK message and one + // NLMSG_DONE message. We cannot read all the messages without blocking + // because we would need to read the message into a buffer and check the + // nlmsg_type for NLMSG_DONE. However, the test depends on reading into a + // zero-length buffer. + + // First, call recvmsg with MSG_TRUNC. This will read the full message from + // the socket and return it's full length. Subsequent calls to recvmsg will + // read the next messages from the socket. + ASSERT_THAT(trunclen = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC), + SyscallSucceeds()); + + // Message should always be truncated. However, While the destination iov is + // zero length, MSG_TRUNC returns the size of the next message so it should + // not be zero. + ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); + ASSERT_NE(trunclen, 0); + // Returned length is at least the header and ifaddrmsg. + EXPECT_GE(trunclen, sizeof(struct nlmsghdr) + sizeof(struct ifaddrmsg)); + + // Reset the msg_flags to make sure that the recvmsg call is setting them + // properly. + msg.msg_flags = 0; + + // Make a second recvvmsg call to get the next message. + ASSERT_THAT(trunclen2 = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_TRUNC), + SyscallSucceeds()); + ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); + ASSERT_NE(trunclen2, 0); + + // Assert that the received messages are not the same. + // + // We are calling recvmsg with a zero length buffer so we have no way to + // inspect the messages to make sure they are not equal in value. The best + // we can do is to compare their lengths. + ASSERT_NE(trunclen, trunclen2); +} + +// RecvmsgTruncPeek tests recvmsg with the combination of the MSG_TRUNC and +// MSG_PEEK flags and a zero length output buffer. This is normally used to +// read the full length of the next message on the socket without consuming +// it, so a properly sized buffer can be allocated to store the message. This +// test tests that scenario. +TEST(NetlinkRouteTest, RecvmsgTruncPeek) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct rtgenmsg rgm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rgm.rtgen_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + int type = -1; + do { + int peeklen; + int len; + + iov.iov_base = NULL; + iov.iov_len = 0; + + // Call recvmsg with MSG_PEEK and MSG_TRUNC. This will peek at the message + // and return it's full length. + // See: MSG_TRUNC http://man7.org/linux/man-pages/man2/recv.2.html + ASSERT_THAT( + peeklen = RetryEINTR(recvmsg)(fd.get(), &msg, MSG_PEEK | MSG_TRUNC), + SyscallSucceeds()); + + // Message should always be truncated. + ASSERT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); + ASSERT_NE(peeklen, 0); + + // Reset the message flags for the next call. + msg.msg_flags = 0; + + // Make the actual call to recvmsg to get the actual data. We will use + // the length returned from the peek call for the allocated buffer size.. + std::vector<char> buf(peeklen); + iov.iov_base = buf.data(); + iov.iov_len = buf.size(); + ASSERT_THAT(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0), + SyscallSucceeds()); + + // Message should not be truncated since we allocated the correct buffer + // size. + EXPECT_NE(msg.msg_flags & MSG_TRUNC, MSG_TRUNC); + + // MSG_PEEK should have left data on the socket and the subsequent call + // with should have retrieved the same data. Both calls should have + // returned the message's full length so they should be equal. + ASSERT_NE(len, 0); + ASSERT_EQ(peeklen, len); + + for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data()); + NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { + type = hdr->nlmsg_type; + } + } while (type != NLMSG_DONE && type != NLMSG_ERROR); +} + +// No SCM_CREDENTIALS are received without SO_PASSCRED set. +TEST(NetlinkRouteTest, NoPasscredNoCreds) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceeds()); + + struct request { + struct nlmsghdr hdr; + struct rtgenmsg rgm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rgm.rtgen_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + iov.iov_base = NULL; + iov.iov_len = 0; + + char control[CMSG_SPACE(sizeof(struct ucred))] = {}; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + // Note: This test assumes at least one message is returned by the + // RTM_GETADDR request. + ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + // No control messages. + EXPECT_EQ(CMSG_FIRSTHDR(&msg), nullptr); +} + +// SCM_CREDENTIALS are received with SO_PASSCRED set. +TEST(NetlinkRouteTest, PasscredCreds) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + struct request { + struct nlmsghdr hdr; + struct rtgenmsg rgm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.rgm.rtgen_family = AF_UNSPEC; + + struct iovec iov = {}; + iov.iov_base = &req; + iov.iov_len = sizeof(req); + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + iov.iov_base = NULL; + iov.iov_len = 0; + + char control[CMSG_SPACE(sizeof(struct ucred))] = {}; + msg.msg_control = control; + msg.msg_controllen = sizeof(control); + + // Note: This test assumes at least one message is returned by the + // RTM_GETADDR request. + ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds()); + + struct ucred creds; + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(creds))); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS); + + memcpy(&creds, CMSG_DATA(cmsg), sizeof(creds)); + + // The peer is the kernel, which is "PID" 0. + EXPECT_EQ(creds.pid, 0); + // The kernel identifies as root. Also allow nobody in case this test is + // running in a userns without root mapped. + EXPECT_THAT(creds.uid, AnyOf(Eq(0), Eq(65534))); + EXPECT_THAT(creds.gid, AnyOf(Eq(0), Eq(65534))); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_uevent.cc b/test/syscalls/linux/socket_netlink_uevent.cc new file mode 100644 index 000000000..da425bed4 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_uevent.cc @@ -0,0 +1,83 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/filter.h> +#include <linux/netlink.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +// Tests for NETLINK_KOBJECT_UEVENT sockets. +// +// gVisor never sends any messages on these sockets, so we don't test the events +// themselves. + +namespace gvisor { +namespace testing { + +namespace { + +// SO_PASSCRED can be enabled. Since no messages are sent in gVisor, we don't +// actually test receiving credentials. +TEST(NetlinkUeventTest, PassCred) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); + + EXPECT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); +} + +// SO_DETACH_FILTER fails without a filter already installed. +TEST(NetlinkUeventTest, DetachNoFilter) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); + + int opt; + EXPECT_THAT( + setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)), + SyscallFailsWithErrno(ENOENT)); +} + +// We can attach a BPF filter. +TEST(NetlinkUeventTest, AttachFilter) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT)); + + // Minimal BPF program: a single ret. + struct sock_filter filter = {0x6, 0, 0, 0}; + struct sock_fprog prog = {}; + prog.len = 1; + prog.filter = &filter; + + EXPECT_THAT( + setsockopt(fd.get(), SOL_SOCKET, SO_ATTACH_FILTER, &prog, sizeof(prog)), + SyscallSucceeds()); + + int opt; + EXPECT_THAT( + setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)), + SyscallSucceeds()); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index 36b6560c2..5f05bab10 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -16,7 +16,6 @@ #include <linux/if_arp.h> #include <linux/netlink.h> -#include <linux/rtnetlink.h> #include <vector> @@ -27,9 +26,9 @@ namespace gvisor { namespace testing { -PosixErrorOr<FileDescriptor> NetlinkBoundSocket() { +PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol) { FileDescriptor fd; - ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)); + ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol)); struct sockaddr_nl addr = {}; addr.nl_family = AF_NETLINK; @@ -91,6 +90,13 @@ PosixError NetlinkRequestResponse( NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { fn(hdr); type = hdr->nlmsg_type; + // Done should include an integer payload for dump_done_errno. + // See net/netlink/af_netlink.c:netlink_dump + // Some tools like the 'ip' tool check the minimum length of the + // NLMSG_DONE message. + if (type == NLMSG_DONE) { + EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int))); + } } } while (type != NLMSG_DONE && type != NLMSG_ERROR); diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index db8639a2f..da99f0d60 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -17,7 +17,6 @@ #include <linux/if_arp.h> #include <linux/netlink.h> -#include <linux/rtnetlink.h> #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -25,8 +24,8 @@ namespace gvisor { namespace testing { -// Returns a bound NETLINK_ROUTE socket. -PosixErrorOr<FileDescriptor> NetlinkBoundSocket(); +// Returns a bound netlink socket. +PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol); // Returns the port ID of the passed socket. PosixErrorOr<uint32_t> NetlinkPortID(int fd); diff --git a/test/syscalls/linux/socket_non_blocking.cc b/test/syscalls/linux/socket_non_blocking.cc index 73e6dc618..c3520cadd 100644 --- a/test/syscalls/linux/socket_non_blocking.cc +++ b/test/syscalls/linux/socket_non_blocking.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc index 3c599b6e8..d91c5ed39 100644 --- a/test/syscalls/linux/socket_non_stream.cc +++ b/test/syscalls/linux/socket_non_stream.cc @@ -19,7 +19,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc index 76127d181..62d87c1af 100644 --- a/test/syscalls/linux/socket_non_stream_blocking.cc +++ b/test/syscalls/linux/socket_non_stream_blocking.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc index 0417dd347..346443f96 100644 --- a/test/syscalls/linux/socket_stream.cc +++ b/test/syscalls/linux/socket_stream.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc index 8367460d2..e9cc082bf 100644 --- a/test/syscalls/linux/socket_stream_blocking.cc +++ b/test/syscalls/linux/socket_stream_blocking.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" diff --git a/test/syscalls/linux/socket_stream_nonblock.cc b/test/syscalls/linux/socket_stream_nonblock.cc index b00748b97..74d608741 100644 --- a/test/syscalls/linux/socket_stream_nonblock.cc +++ b/test/syscalls/linux/socket_stream_nonblock.cc @@ -20,7 +20,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 6efa8055f..be38907c2 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -30,7 +30,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/str_format.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -83,6 +82,8 @@ inline ssize_t SendFd(int fd, void* buf, size_t count, int flags) { count); } +PosixErrorOr<struct sockaddr_un> UniqueUnixAddr(bool abstract, int domain); + // A Creator<T> is a function that attempts to create and return a new T. (This // is copy/pasted from cloud/gvisor/api/sandbox_util.h and is just duplicated // here for clarity.) diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc index 875f0391f..8a28202a8 100644 --- a/test/syscalls/linux/socket_unix.cc +++ b/test/syscalls/linux/socket_unix.cc @@ -25,7 +25,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc index 1092e29b1..1159c5229 100644 --- a/test/syscalls/linux/socket_unix_cmsg.cc +++ b/test/syscalls/linux/socket_unix_cmsg.cc @@ -25,7 +25,6 @@ #include <vector> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/string_view.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc index 3e0f611d2..3245cf7c9 100644 --- a/test/syscalls/linux/socket_unix_dgram.cc +++ b/test/syscalls/linux/socket_unix_dgram.cc @@ -17,7 +17,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc index 707052af8..cd4fba25c 100644 --- a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc +++ b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc index dafe82494..276a94eb8 100644 --- a/test/syscalls/linux/socket_unix_non_stream.cc +++ b/test/syscalls/linux/socket_unix_non_stream.cc @@ -19,7 +19,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/memory_util.h" @@ -231,11 +230,21 @@ TEST_P(UnixNonStreamSocketPairTest, SendTimeout) { setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), SyscallSucceeds()); - char buf[100] = {}; + const int buf_size = 5 * kPageSize; + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &buf_size, + sizeof(buf_size)), + SyscallSucceeds()); + EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVBUF, &buf_size, + sizeof(buf_size)), + SyscallSucceeds()); + + // The buffer size should be big enough to avoid many iterations in the next + // loop. Otherwise, this will slow down cooperative_save tests. + std::vector<char> buf(kPageSize); for (;;) { int ret; ASSERT_THAT( - ret = RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), + ret = RetryEINTR(send)(sockets->first_fd(), buf.data(), buf.size(), 0), ::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN))); if (ret == -1) { break; diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc index 6f6367dd5..60fa9e38a 100644 --- a/test/syscalls/linux/socket_unix_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_seqpacket.cc @@ -17,7 +17,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 659c93945..563467365 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <poll.h> #include <stdio.h> #include <sys/un.h> -#include "gtest/gtest.h" + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" @@ -44,6 +45,50 @@ TEST_P(StreamUnixSocketPairTest, ReadOneSideClosed) { SyscallSucceedsWithValue(0)); } +TEST_P(StreamUnixSocketPairTest, RecvmsgOneSideClosed) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Set timeout so that it will not wait for ever. + struct timeval tv { + .tv_sec = 0, .tv_usec = 10 + }; + EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, + sizeof(tv)), + SyscallSucceeds()); + + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); + + char received_data[10] = {}; + struct iovec iov; + iov.iov_base = received_data; + iov.iov_len = sizeof(received_data); + struct msghdr msg = {}; + msg.msg_flags = -1; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ASSERT_THAT(recvmsg(sockets->second_fd(), &msg, MSG_WAITALL), + SyscallSucceedsWithValue(0)); +} + +TEST_P(StreamUnixSocketPairTest, ReadOneSideClosedWithUnreadData) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char buf[10] = {}; + ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); + + ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); + + ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), + SyscallFailsWithErrno(ECONNRESET)); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, StreamUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( diff --git a/test/syscalls/linux/socket_unix_unbound_abstract.cc b/test/syscalls/linux/socket_unix_unbound_abstract.cc index 4b5832de8..7f5816ace 100644 --- a/test/syscalls/linux/socket_unix_unbound_abstract.cc +++ b/test/syscalls/linux/socket_unix_unbound_abstract.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_dgram.cc b/test/syscalls/linux/socket_unix_unbound_dgram.cc index 52aef891f..907dca0f1 100644 --- a/test/syscalls/linux/socket_unix_unbound_dgram.cc +++ b/test/syscalls/linux/socket_unix_unbound_dgram.cc @@ -17,7 +17,6 @@ #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc index 8cb03c450..b14f24086 100644 --- a/test/syscalls/linux/socket_unix_unbound_filesystem.cc +++ b/test/syscalls/linux/socket_unix_unbound_filesystem.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc index 0575f2e1d..50ffa1d04 100644 --- a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc +++ b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/socket_unix_unbound_stream.cc b/test/syscalls/linux/socket_unix_unbound_stream.cc index e483d2777..344918c34 100644 --- a/test/syscalls/linux/socket_unix_unbound_stream.cc +++ b/test/syscalls/linux/socket_unix_unbound_stream.cc @@ -15,7 +15,6 @@ #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/test_util.h" diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 88ab90b5b..30de2f8ff 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -24,7 +24,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "gtest/gtest.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index bfa031bce..99863b0ed 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -130,6 +130,19 @@ void TcpSocketTest::TearDown() { } } +TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) { + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + ASSERT_THAT( + connect(s_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen), + SyscallFailsWithErrno(EISCONN)); + ASSERT_THAT( + connect(t_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen), + SyscallFailsWithErrno(EISCONN)); +} + TEST_P(TcpSocketTest, DataCoalesced) { char buf[10]; @@ -394,8 +407,15 @@ TEST_P(TcpSocketTest, PollWithFullBufferBlocks) { sizeof(tcp_nodelay_flag)), SyscallSucceeds()); + // Set a 256KB send/receive buffer. + int buf_sz = 1 << 18; + EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); + // Create a large buffer that will be used for sending. - std::vector<char> buf(10 * sendbuf_size_); + std::vector<char> buf(1 << 16); // Write until we receive an error. while (RetryEINTR(send)(s_, buf.data(), buf.size(), 0) != -1) { @@ -405,6 +425,11 @@ TEST_P(TcpSocketTest, PollWithFullBufferBlocks) { } // The last error should have been EWOULDBLOCK. ASSERT_EQ(errno, EWOULDBLOCK); + + // Now polling on the FD with a timeout should return 0 corresponding to no + // FDs ready. + struct pollfd poll_fd = {s_, POLLOUT, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), SyscallSucceedsWithValue(0)); } TEST_P(TcpSocketTest, MsgTrunc) { diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index d48453a93..6218fbce1 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -25,6 +25,7 @@ #include "test/util/posix_error.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" +#include "test/util/uid_util.h" ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID"); ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); @@ -68,30 +69,6 @@ TEST(UidGidTest, Getgroups) { // here; see the setgroups test below. } -// If the caller's real/effective/saved user/group IDs are all 0, IsRoot returns -// true. Otherwise IsRoot logs an explanatory message and returns false. -PosixErrorOr<bool> IsRoot() { - uid_t ruid, euid, suid; - int rc = getresuid(&ruid, &euid, &suid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresuid"); - } - if (ruid != 0 || euid != 0 || suid != 0) { - return false; - } - gid_t rgid, egid, sgid; - rc = getresgid(&rgid, &egid, &sgid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresgid"); - } - if (rgid != 0 || egid != 0 || sgid != 0) { - return false; - } - return true; -} - // Checks that the calling process' real/effective/saved user IDs are // ruid/euid/suid respectively. PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) { diff --git a/test/syscalls/linux/uname.cc b/test/syscalls/linux/uname.cc index 0a5d91017..d8824b171 100644 --- a/test/syscalls/linux/uname.cc +++ b/test/syscalls/linux/uname.cc @@ -41,6 +41,19 @@ TEST(UnameTest, Sanity) { TEST(UnameTest, SetNames) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + char hostname[65]; + ASSERT_THAT(sethostname("0123456789", 3), SyscallSucceeds()); + EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds()); + EXPECT_EQ(absl::string_view(hostname), "012"); + + ASSERT_THAT(sethostname("0123456789\0xxx", 11), SyscallSucceeds()); + EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds()); + EXPECT_EQ(absl::string_view(hostname), "0123456789"); + + ASSERT_THAT(sethostname("0123456789\0xxx", 12), SyscallSucceeds()); + EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds()); + EXPECT_EQ(absl::string_view(hostname), "0123456789"); + constexpr char kHostname[] = "wubbalubba"; ASSERT_THAT(sethostname(kHostname, sizeof(kHostname)), SyscallSucceeds()); @@ -54,7 +67,6 @@ TEST(UnameTest, SetNames) { EXPECT_EQ(absl::string_view(buf.domainname), kDomainname); // These should just be glibc wrappers that also call uname(2). - char hostname[65]; EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds()); EXPECT_EQ(absl::string_view(hostname), kHostname); diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go index c1e9ce22c..accf46347 100644 --- a/test/syscalls/syscall_test_runner.go +++ b/test/syscalls/syscall_test_runner.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/runsc/specutils" "gvisor.dev/gvisor/runsc/testutil" "gvisor.dev/gvisor/test/syscalls/gtest" + "gvisor.dev/gvisor/test/uds" ) // Location of syscall tests, relative to the repo root. @@ -50,6 +51,8 @@ var ( overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") parallel = flag.Bool("parallel", false, "run tests in parallel") runscPath = flag.String("runsc", "", "path to runsc binary") + + addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests") ) // runTestCaseNative runs the test case directly on the host machine. @@ -86,6 +89,19 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { // intepret them. env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) + if *addUDSTree { + socketDir, cleanup, err := uds.CreateSocketTree("/tmp") + if err != nil { + t.Fatalf("failed to create socket tree: %v", err) + } + defer cleanup() + + env = append(env, "TEST_UDS_TREE="+socketDir) + // On Linux, the concept of "attach" location doesn't exist. + // Just pass the same path to make these test identical. + env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) + } + cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName()) cmd.Env = env cmd.Stdout = os.Stdout @@ -96,101 +112,39 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { } } -// runsTestCaseRunsc runs the test case in runsc. -func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { - rootDir, err := testutil.SetupRootDir() +// runRunsc runs spec in runsc in a standard test configuration. +// +// runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR. +// +// Returns an error if the sandboxed application exits non-zero. +func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { + bundleDir, err := testutil.SetupBundleDir(spec) if err != nil { - t.Fatalf("SetupRootDir failed: %v", err) + return fmt.Errorf("SetupBundleDir failed: %v", err) } - defer os.RemoveAll(rootDir) - - // Run a new container with the test executable and filter for the - // given test suite and name. - spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName()) - - // Mark the root as writeable, as some tests attempt to - // write to the rootfs, and expect EACCES, not EROFS. - spec.Root.Readonly = false - - // Test spec comes with pre-defined mounts that we don't want. Reset it. - spec.Mounts = nil - if *useTmpfs { - // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on - // features only available in gVisor's internal tmpfs may fail. - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp", - Type: "tmpfs", - }) - } else { - // Use a gofer-backed directory as '/tmp'. - // - // Tests might be running in parallel, so make sure each has a - // unique test temp dir. - // - // Some tests (e.g., sticky) access this mount from other - // users, so make sure it is world-accessible. - tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") - if err != nil { - t.Fatalf("could not create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - if err := os.Chmod(tmpDir, 0777); err != nil { - t.Fatalf("could not chmod temp dir: %v", err) - } - - spec.Mounts = append(spec.Mounts, specs.Mount{ - Type: "bind", - Destination: "/tmp", - Source: tmpDir, - }) - } - - // Set environment variable that indicates we are - // running in gVisor and with the given platform. - platformVar := "TEST_ON_GVISOR" - env := append(os.Environ(), platformVar+"="+*platform) - - // Remove env variables that cause the gunit binary to write output - // files, since they will stomp on eachother, and on the output files - // from this go test. - env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) - - // Remove shard env variables so that the gunit binary does not try to - // intepret them. - env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) - - // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to - // be backed by tmpfs. - for i, kv := range env { - if strings.HasPrefix(kv, "TEST_TMPDIR=") { - env[i] = "TEST_TMPDIR=/tmp" - break - } - } - - spec.Process.Env = env + defer os.RemoveAll(bundleDir) - bundleDir, err := testutil.SetupBundleDir(spec) + rootDir, err := testutil.SetupRootDir() if err != nil { - t.Fatalf("SetupBundleDir failed: %v", err) + return fmt.Errorf("SetupRootDir failed: %v", err) } - defer os.RemoveAll(bundleDir) + defer os.RemoveAll(rootDir) + name := tc.FullName() id := testutil.UniqueContainerID() - log.Infof("Running test %q in container %q", tc.FullName(), id) + log.Infof("Running test %q in container %q", name, id) specutils.LogSpec(spec) args := []string{ - "-platform", *platform, "-root", rootDir, - "-file-access", *fileAccess, "-network=none", "-log-format=text", "-TESTONLY-unsafe-nonroot=true", "-net-raw=true", fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM), "-watchdog-action=panic", + "-platform", *platform, + "-file-access", *fileAccess, } if *overlay { args = append(args, "-overlay") @@ -201,14 +155,18 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { if *strace { args = append(args, "-strace") } + if *addUDSTree { + args = append(args, "-fsgofer-host-uds") + } + if outDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { - tdir := filepath.Join(outDir, strings.Replace(tc.FullName(), "/", "_", -1)) + tdir := filepath.Join(outDir, strings.Replace(name, "/", "_", -1)) if err := os.MkdirAll(tdir, 0755); err != nil { - t.Fatalf("could not create test dir: %v", err) + return fmt.Errorf("could not create test dir: %v", err) } debugLogDir, err := ioutil.TempDir(tdir, "runsc") if err != nil { - t.Fatalf("could not create temp dir: %v", err) + return fmt.Errorf("could not create temp dir: %v", err) } debugLogDir += "/" log.Infof("runsc logs: %s", debugLogDir) @@ -248,38 +206,171 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { if !ok { return } - t.Errorf("%s: Got signal: %v", tc.FullName(), s) + log.Warningf("%s: Got signal: %v", name, s) done := make(chan bool) - go func() { - dArgs := append(args, "-alsologtostderr=true", "debug", "--stacks", id) + dArgs := append([]string{}, args...) + dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) + go func(dArgs []string) { cmd := exec.Command(*runscPath, dArgs...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Run() done <- true - }() + }(dArgs) - timeout := time.Tick(3 * time.Second) + timeout := time.After(3 * time.Second) select { case <-timeout: - t.Logf("runsc debug --stacks is timeouted") + log.Infof("runsc debug --stacks is timeouted") case <-done: } - t.Logf("Send SIGTERM to the sandbox process") - dArgs := append(args, "debug", + log.Warningf("Send SIGTERM to the sandbox process") + dArgs = append(args, "debug", fmt.Sprintf("--signal=%d", syscall.SIGTERM), id) - cmd = exec.Command(*runscPath, dArgs...) + cmd := exec.Command(*runscPath, dArgs...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Run() }() - if err = cmd.Run(); err != nil { - t.Errorf("test %q exited with status %v, want 0", tc.FullName(), err) - } + + err = cmd.Run() + signal.Stop(sig) close(sig) + + return err +} + +// setupUDSTree updates the spec to expose a UDS tree for gofer socket testing. +func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { + socketDir, cleanup, err := uds.CreateSocketTree("/tmp") + if err != nil { + return nil, fmt.Errorf("failed to create socket tree: %v", err) + } + + // Standard access to entire tree. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets", + Source: socketDir, + Type: "bind", + }) + + // Individial attach points for each socket to test mounts that attach + // directly to the sockets. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/stream/echo", + Source: filepath.Join(socketDir, "stream/echo"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/stream/nonlistening", + Source: filepath.Join(socketDir, "stream/nonlistening"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/seqpacket/echo", + Source: filepath.Join(socketDir, "seqpacket/echo"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/seqpacket/nonlistening", + Source: filepath.Join(socketDir, "seqpacket/nonlistening"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/dgram/null", + Source: filepath.Join(socketDir, "dgram/null"), + Type: "bind", + }) + + spec.Process.Env = append(spec.Process.Env, "TEST_UDS_TREE=/tmp/sockets") + spec.Process.Env = append(spec.Process.Env, "TEST_UDS_ATTACH_TREE=/tmp/sockets-attach") + + return cleanup, nil +} + +// runsTestCaseRunsc runs the test case in runsc. +func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { + // Run a new container with the test executable and filter for the + // given test suite and name. + spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName()) + + // Mark the root as writeable, as some tests attempt to + // write to the rootfs, and expect EACCES, not EROFS. + spec.Root.Readonly = false + + // Test spec comes with pre-defined mounts that we don't want. Reset it. + spec.Mounts = nil + if *useTmpfs { + // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on + // features only available in gVisor's internal tmpfs may fail. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp", + Type: "tmpfs", + }) + } else { + // Use a gofer-backed directory as '/tmp'. + // + // Tests might be running in parallel, so make sure each has a + // unique test temp dir. + // + // Some tests (e.g., sticky) access this mount from other + // users, so make sure it is world-accessible. + tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") + if err != nil { + t.Fatalf("could not create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + if err := os.Chmod(tmpDir, 0777); err != nil { + t.Fatalf("could not chmod temp dir: %v", err) + } + + spec.Mounts = append(spec.Mounts, specs.Mount{ + Type: "bind", + Destination: "/tmp", + Source: tmpDir, + }) + } + + // Set environment variable that indicates we are + // running in gVisor and with the given platform. + platformVar := "TEST_ON_GVISOR" + env := append(os.Environ(), platformVar+"="+*platform) + + // Remove env variables that cause the gunit binary to write output + // files, since they will stomp on eachother, and on the output files + // from this go test. + env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) + + // Remove shard env variables so that the gunit binary does not try to + // intepret them. + env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) + + // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to + // be backed by tmpfs. + for i, kv := range env { + if strings.HasPrefix(kv, "TEST_TMPDIR=") { + env[i] = "TEST_TMPDIR=/tmp" + break + } + } + + spec.Process.Env = env + + if *addUDSTree { + cleanup, err := setupUDSTree(spec) + if err != nil { + t.Fatalf("error creating UDS tree: %v", err) + } + defer cleanup() + } + + if err := runRunsc(tc, spec); err != nil { + t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err) + } } // filterEnv returns an environment with the blacklisted variables removed. diff --git a/test/uds/BUILD b/test/uds/BUILD new file mode 100644 index 000000000..a3843e699 --- /dev/null +++ b/test/uds/BUILD @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +go_library( + name = "uds", + testonly = 1, + srcs = ["uds.go"], + importpath = "gvisor.dev/gvisor/test/uds", + deps = [ + "//pkg/log", + "//pkg/unet", + ], +) diff --git a/test/uds/uds.go b/test/uds/uds.go new file mode 100644 index 000000000..b714c61b0 --- /dev/null +++ b/test/uds/uds.go @@ -0,0 +1,228 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package uds contains helpers for testing external UDS functionality. +package uds + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "syscall" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/unet" +) + +// createEchoSocket creates a socket that echoes back anything received. +// +// Only works for stream, seqpacket sockets. +func createEchoSocket(path string, protocol int) (cleanup func(), err error) { + fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) + if err != nil { + return nil, fmt.Errorf("error creating echo(%d) socket: %v", protocol, err) + } + + if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { + return nil, fmt.Errorf("error binding echo(%d) socket: %v", protocol, err) + } + + if err := syscall.Listen(fd, 0); err != nil { + return nil, fmt.Errorf("error listening echo(%d) socket: %v", protocol, err) + } + + server, err := unet.NewServerSocket(fd) + if err != nil { + return nil, fmt.Errorf("error creating echo(%d) unet socket: %v", protocol, err) + } + + acceptAndEchoOne := func() error { + s, err := server.Accept() + if err != nil { + return fmt.Errorf("failed to accept: %v", err) + } + defer s.Close() + + for { + buf := make([]byte, 512) + for { + n, err := s.Read(buf) + if err == io.EOF { + return nil + } + if err != nil { + return fmt.Errorf("failed to read: %d, %v", n, err) + } + + n, err = s.Write(buf[:n]) + if err != nil { + return fmt.Errorf("failed to write: %d, %v", n, err) + } + } + } + } + + go func() { + for { + if err := acceptAndEchoOne(); err != nil { + log.Warningf("Failed to handle echo(%d) socket: %v", protocol, err) + return + } + } + }() + + cleanup = func() { + if err := server.Close(); err != nil { + log.Warningf("Failed to close echo(%d) socket: %v", protocol, err) + } + } + + return cleanup, nil +} + +// createNonListeningSocket creates a socket that is bound but not listening. +// +// Only relevant for stream, seqpacket sockets. +func createNonListeningSocket(path string, protocol int) (cleanup func(), err error) { + fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) + if err != nil { + return nil, fmt.Errorf("error creating nonlistening(%d) socket: %v", protocol, err) + } + + if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { + return nil, fmt.Errorf("error binding nonlistening(%d) socket: %v", protocol, err) + } + + cleanup = func() { + if err := syscall.Close(fd); err != nil { + log.Warningf("Failed to close nonlistening(%d) socket: %v", protocol, err) + } + } + + return cleanup, nil +} + +// createNullSocket creates a socket that reads anything received. +// +// Only works for dgram sockets. +func createNullSocket(path string, protocol int) (cleanup func(), err error) { + fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) + if err != nil { + return nil, fmt.Errorf("error creating null(%d) socket: %v", protocol, err) + } + + if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { + return nil, fmt.Errorf("error binding null(%d) socket: %v", protocol, err) + } + + s, err := unet.NewSocket(fd) + if err != nil { + return nil, fmt.Errorf("error creating null(%d) unet socket: %v", protocol, err) + } + + go func() { + buf := make([]byte, 512) + for { + n, err := s.Read(buf) + if err != nil { + log.Warningf("failed to read: %d, %v", n, err) + return + } + } + }() + + cleanup = func() { + if err := s.Close(); err != nil { + log.Warningf("Failed to close null(%d) socket: %v", protocol, err) + } + } + + return cleanup, nil +} + +type socketCreator func(path string, proto int) (cleanup func(), err error) + +// CreateSocketTree creates a local tree of unix domain sockets for use in +// testing: +// * /stream/echo +// * /stream/nonlistening +// * /seqpacket/echo +// * /seqpacket/nonlistening +// * /dgram/null +func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) { + dir, err = ioutil.TempDir(baseDir, "sockets") + if err != nil { + return "", nil, fmt.Errorf("error creating temp dir: %v", err) + } + + var protocols = []struct { + protocol int + name string + sockets map[string]socketCreator + }{ + { + protocol: syscall.SOCK_STREAM, + name: "stream", + sockets: map[string]socketCreator{ + "echo": createEchoSocket, + "nonlistening": createNonListeningSocket, + }, + }, + { + protocol: syscall.SOCK_SEQPACKET, + name: "seqpacket", + sockets: map[string]socketCreator{ + "echo": createEchoSocket, + "nonlistening": createNonListeningSocket, + }, + }, + { + protocol: syscall.SOCK_DGRAM, + name: "dgram", + sockets: map[string]socketCreator{ + "null": createNullSocket, + }, + }, + } + + var cleanups []func() + for _, proto := range protocols { + protoDir := filepath.Join(dir, proto.name) + if err := os.Mkdir(protoDir, 0755); err != nil { + return "", nil, fmt.Errorf("error creating %s dir: %v", proto.name, err) + } + + for name, fn := range proto.sockets { + path := filepath.Join(protoDir, name) + cleanup, err := fn(path, proto.protocol) + if err != nil { + return "", nil, fmt.Errorf("error creating %s %s socket: %v", proto.name, name, err) + } + + cleanups = append(cleanups, cleanup) + } + } + + cleanup = func() { + for _, c := range cleanups { + c() + } + + os.RemoveAll(dir) + } + + return dir, cleanup, nil +} diff --git a/test/util/BUILD b/test/util/BUILD index 25ed9c944..5d2a9cc2c 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -324,3 +324,14 @@ cc_library( ":test_util", ], ) + +cc_library( + name = "uid_util", + testonly = 1, + srcs = ["uid_util.cc"], + hdrs = ["uid_util.h"], + deps = [ + ":posix_error", + ":save_util", + ], +) diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index f7d231b14..88b1e7911 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -163,6 +163,26 @@ PosixError Chmod(absl::string_view path, int mode) { return NoError(); } +PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, + dev_t dev) { + int res = mknodat(dfd.get(), std::string(path).c_str(), mode, dev); + if (res < 0) { + return PosixError(errno, absl::StrCat("mknod ", path)); + } + + return NoError(); +} + +PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, + int flags) { + int res = unlinkat(dfd.get(), std::string(path).c_str(), flags); + if (res < 0) { + return PosixError(errno, absl::StrCat("unlink ", path)); + } + + return NoError(); +} + PosixError Mkdir(absl::string_view path, int mode) { int res = mkdir(std::string(path).c_str(), mode); if (res < 0) { diff --git a/test/util/fs_util.h b/test/util/fs_util.h index e5b555891..ee1b341d7 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -21,6 +21,7 @@ #include <unistd.h> #include "absl/strings/string_view.h" +#include "test/util/file_descriptor.h" #include "test/util/posix_error.h" namespace gvisor { @@ -44,6 +45,14 @@ PosixError Delete(absl::string_view path); // Changes the mode of a file or returns an error. PosixError Chmod(absl::string_view path, int mode); +// Create a special or ordinary file. +PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode, + dev_t dev); + +// Unlink the file. +PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path, + int flags); + // Truncates a file to the given length or returns an error. PosixError Truncate(absl::string_view path, int length); diff --git a/test/util/multiprocess_util.cc b/test/util/multiprocess_util.cc index 95f5f3b4f..8b676751b 100644 --- a/test/util/multiprocess_util.cc +++ b/test/util/multiprocess_util.cc @@ -14,6 +14,7 @@ #include "test/util/multiprocess_util.h" +#include <asm/unistd.h> #include <errno.h> #include <fcntl.h> #include <signal.h> @@ -30,11 +31,12 @@ namespace gvisor { namespace testing { -PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, - const ExecveArray& argv, - const ExecveArray& envv, - const std::function<void()>& fn, pid_t* child, - int* execve_errno) { +namespace { + +// exec_fn wraps a variant of the exec family, e.g. execve or execveat. +PosixErrorOr<Cleanup> ForkAndExecHelper(const std::function<void()>& exec_fn, + const std::function<void()>& fn, + pid_t* child, int* execve_errno) { int pfds[2]; int ret = pipe2(pfds, O_CLOEXEC); if (ret < 0) { @@ -76,7 +78,9 @@ PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, fn(); } - execve(filename.c_str(), argv.get(), envv.get()); + // Call variant of exec function. + exec_fn(); + int error = errno; if (WriteFd(pfds[1], &error, sizeof(error)) != sizeof(error)) { // We can't do much if the write fails, but we can at least exit with a @@ -116,6 +120,36 @@ PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, return std::move(cleanup); } +} // namespace + +PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, + const ExecveArray& argv, + const ExecveArray& envv, + const std::function<void()>& fn, pid_t* child, + int* execve_errno) { + char* const* argv_data = argv.get(); + char* const* envv_data = envv.get(); + const std::function<void()> exec_fn = [=] { + execve(filename.c_str(), argv_data, envv_data); + }; + return ForkAndExecHelper(exec_fn, fn, child, execve_errno); +} + +PosixErrorOr<Cleanup> ForkAndExecveat(const int32_t dirfd, + const std::string& pathname, + const ExecveArray& argv, + const ExecveArray& envv, const int flags, + const std::function<void()>& fn, + pid_t* child, int* execve_errno) { + char* const* argv_data = argv.get(); + char* const* envv_data = envv.get(); + const std::function<void()> exec_fn = [=] { + syscall(__NR_execveat, dirfd, pathname.c_str(), argv_data, envv_data, + flags); + }; + return ForkAndExecHelper(exec_fn, fn, child, execve_errno); +} + PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn) { pid_t pid = fork(); if (pid == 0) { diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h index 0aecd3439..61526b4e7 100644 --- a/test/util/multiprocess_util.h +++ b/test/util/multiprocess_util.h @@ -102,6 +102,22 @@ inline PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename, return ForkAndExec(filename, argv, envv, [] {}, child, execve_errno); } +// Equivalent to ForkAndExec, except using dirfd and flags with execveat. +PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, const std::string& pathname, + const ExecveArray& argv, + const ExecveArray& envv, int flags, + const std::function<void()>& fn, + pid_t* child, int* execve_errno); + +inline PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, + const std::string& pathname, + const ExecveArray& argv, + const ExecveArray& envv, int flags, + pid_t* child, int* execve_errno) { + return ForkAndExecveat( + dirfd, pathname, argv, envv, flags, [] {}, child, execve_errno); +} + // Calls fn in a forked subprocess and returns the exit status of the // subprocess. // diff --git a/test/util/proc_util.cc b/test/util/proc_util.cc index 75b24da37..34d636ba9 100644 --- a/test/util/proc_util.cc +++ b/test/util/proc_util.cc @@ -88,7 +88,7 @@ PosixErrorOr<std::vector<ProcMapsEntry>> ParseProcMaps( std::vector<ProcMapsEntry> entries; auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty()); for (const auto& l : lines) { - std::cout << "line: " << l; + std::cout << "line: " << l << std::endl; ASSIGN_OR_RETURN_ERRNO(auto entry, ParseProcMapsLine(l)); entries.push_back(entry); } diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc new file mode 100644 index 000000000..b131b4b99 --- /dev/null +++ b/test/util/uid_util.cc @@ -0,0 +1,44 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/posix_error.h" +#include "test/util/save_util.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<bool> IsRoot() { + uid_t ruid, euid, suid; + int rc = getresuid(&ruid, &euid, &suid); + MaybeSave(); + if (rc < 0) { + return PosixError(errno, "getresuid"); + } + if (ruid != 0 || euid != 0 || suid != 0) { + return false; + } + gid_t rgid, egid, sgid; + rc = getresgid(&rgid, &egid, &sgid); + MaybeSave(); + if (rc < 0) { + return PosixError(errno, "getresgid"); + } + if (rgid != 0 || egid != 0 || sgid != 0) { + return false; + } + return true; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/uid_util.h b/test/util/uid_util.h new file mode 100644 index 000000000..2cd387fb0 --- /dev/null +++ b/test/util/uid_util.h @@ -0,0 +1,29 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_ + +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +// Returns true if the caller's real/effective/saved user/group IDs are all 0. +PosixErrorOr<bool> IsRoot(); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_ diff --git a/third_party/gvsync/BUILD b/third_party/gvsync/BUILD index 8dab51daa..7d6d59c48 100644 --- a/third_party/gvsync/BUILD +++ b/third_party/gvsync/BUILD @@ -1,4 +1,5 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template") package( default_visibility = ["//:sandbox"], @@ -7,8 +8,6 @@ package( exports_files(["LICENSE"]) -load("//tools/go_generics:defs.bzl", "go_template") - go_template( name = "generic_atomicptr", srcs = ["atomicptr_unsafe.go"], diff --git a/third_party/gvsync/atomicptrtest/BUILD b/third_party/gvsync/atomicptrtest/BUILD index 6cf69ea91..447ecf96a 100644 --- a/third_party/gvsync/atomicptrtest/BUILD +++ b/third_party/gvsync/atomicptrtest/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") - go_template_instance( name = "atomicptr_int", out = "atomicptr_int_unsafe.go", diff --git a/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go b/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go index 8baec5458..3b9346843 100644 --- a/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go +++ b/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/third_party/gvsync/downgradable_rwmutex_unsafe.go b/third_party/gvsync/downgradable_rwmutex_unsafe.go index 1f6007aa1..b7862d185 100644 --- a/third_party/gvsync/downgradable_rwmutex_unsafe.go +++ b/third_party/gvsync/downgradable_rwmutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/third_party/gvsync/memmove_unsafe.go b/third_party/gvsync/memmove_unsafe.go index 84b69f215..9dd1d6142 100644 --- a/third_party/gvsync/memmove_unsafe.go +++ b/third_party/gvsync/memmove_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.12 -// +build !go1.14 +// +build !go1.15 // Check go:linkname function signatures when updating Go version. diff --git a/third_party/gvsync/seqatomictest/BUILD b/third_party/gvsync/seqatomictest/BUILD index 9e87e0bc5..c858c20c4 100644 --- a/third_party/gvsync/seqatomictest/BUILD +++ b/third_party/gvsync/seqatomictest/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template_instance") - go_template_instance( name = "seqatomic_int", out = "seqatomic_int_unsafe.go", diff --git a/tools/go_branch.sh b/tools/go_branch.sh index ddb9b6e7b..0ac16e266 100755 --- a/tools/go_branch.sh +++ b/tools/go_branch.sh @@ -17,9 +17,9 @@ set -eo pipefail # Discovery the package name from the go.mod file. -declare -r gomod="$(pwd)/go.mod" -declare -r module=$(cat "${gomod}" | grep -E "^module" | cut -d' ' -f2) -declare -r gosum="$(pwd)/go.sum" +declare -r module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2) +declare -r origpwd=$(pwd) +declare -r othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE") # Check that gopath has been built. declare -r gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}" @@ -65,10 +65,22 @@ git checkout -b go "${go_branch}" git merge --no-commit --strategy ours ${head} || \ git merge --allow-unrelated-histories --no-commit --strategy ours ${head} -# Sync the entire gopath_dir and go.mod. -rsync --recursive --verbose --delete --exclude .git --exclude README.md -L "${gopath_dir}/" . -cp "${gomod}" . -cp "${gosum}" . +# Sync the entire gopath_dir. +rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" . + +# Add additional files. +for file in "${othersrc[@]}"; do + cp "${origpwd}"/"${file}" . +done + +# Construct a new README.md. +cat > README.md <<EOF +# gVisor + +This branch is a synthetic branch, containing only Go sources, that is +compatible with standard Go tools. See the `master` branch for authoritative +sources and tests. +EOF # There are a few solitary files that can get left behind due to the way bazel # constructs the gopath target. Note that we don't find all Go files here diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD index a6f8cdd3c..9d26a88b7 100644 --- a/tools/go_generics/rules_tests/BUILD +++ b/tools/go_generics/rules_tests/BUILD @@ -1,9 +1,8 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) -load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") - go_template_instance( name = "instance", out = "instance_test.go", diff --git a/tools/make_repository.sh b/tools/make_repository.sh index 071f72b74..27ffbc9f3 100755 --- a/tools/make_repository.sh +++ b/tools/make_repository.sh @@ -17,13 +17,13 @@ # Parse arguments. We require more than two arguments, which are the private # keyring, the e-mail associated with the signer, and the list of packages. if [ "$#" -le 3 ]; then - echo "usage: $0 <private-key> <signer-email> <component> <packages...>" + echo "usage: $0 <private-key> <signer-email> <component> <root> <packages...>" exit 1 fi -declare -r private_key=$(readlink -e "$1") -declare -r signer="$2" -declare -r component="$3" -shift; shift; shift +declare -r private_key=$(readlink -e "$1"); shift +declare -r signer="$1"; shift +declare -r component="$1"; shift +declare -r root="$1"; shift # Verbose from this point. set -xeo pipefail @@ -40,7 +40,7 @@ cleanup() { trap cleanup EXIT gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" >&2 -# Copy the packages, and ensure permissions are correct. +# Copy the packages into the root. for pkg in "$@"; do name=$(basename "${pkg}" .deb) name=$(basename "${name}" .changes) @@ -48,32 +48,61 @@ for pkg in "$@"; do if [[ "${name}" == "${arch}" ]]; then continue # Not a regular package. fi - mkdir -p "${tmpdir}"/"${component}"/binary-"${arch}" - cp -a "${pkg}" "${tmpdir}"/"${component}"/binary-"${arch}" + if [[ "${pkg}" =~ ^.*\.deb$ ]]; then + # Extract from the debian file. + version=$(dpkg --info "${pkg}" | grep -E 'Version:' | cut -d':' -f2) + elif [[ "${pkg}" =~ ^.*\.changes$ ]]; then + # Extract from the changes file. + version=$(grep -E 'Version:' "${pkg}" | cut -d':' -f2) + else + # Unsupported file type. + echo "Unknown file type: ${pkg}" + exit 1 + fi + version=${version// /} # Trim whitespace. + mkdir -p "${root}"/pool/"${version}"/binary-"${arch}" + cp -a "${pkg}" "${root}"/pool/"${version}"/binary-"${arch}" done -find "${tmpdir}" -type f -exec chmod 0644 {} \; -# Ensure there are no symlinks hanging around; these may be remnants of the -# build process. They may be useful for other things, but we are going to build -# an index of the actual packages here. -find "${tmpdir}" -type l -exec rm -f {} \; +# Ensure all permissions are correct. +find "${root}"/pool -type f -exec chmod 0644 {} \; # Sign all packages. -for file in "${tmpdir}"/"${component}"/binary-*/*.deb; do +for file in "${root}"/pool/*/binary-*/*.deb; do dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}" >&2 done # Build the package list. -for dir in "${tmpdir}"/"${component}"/binary-*; do - (cd "${dir}" && apt-ftparchive packages . | gzip > Packages.gz) +declare arches=() +for dir in "${root}"/pool/*/binary-*; do + name=$(basename "${dir}") + arch=${name##binary-} + arches+=("${arch}") + repo_packages="${tmpdir}"/"${component}"/"${name}" + mkdir -p "${repo_packages}" + (cd "${root}" && apt-ftparchive --arch "${arch}" packages pool > "${repo_packages}"/Packages) + (cd "${repo_packages}" && cat Packages | gzip > Packages.gz) + (cd "${repo_packages}" && cat Packages | xz > Packages.xz) done # Build the release list. -(cd "${tmpdir}" && apt-ftparchive release . > Release) +cat > "${tmpdir}"/apt.conf <<EOF +APT { + FTPArchive { + Release { + Architectures "${arches[@]}"; + Components "${component}"; + }; + }; +}; +EOF +(cd "${tmpdir}" && apt-ftparchive -c=apt.conf release . > Release) +rm "${tmpdir}"/apt.conf # Sign the release. -(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign -o InRelease Release >&2) -(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs -o Release.gpg Release >&2) +declare -r digest_opts=("--digest-algo" "SHA512" "--cert-digest-algo" "SHA512") +(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign "${digest_opts[@]}" -o InRelease Release >&2) +(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs "${digest_opts[@]}" -o Release.gpg Release >&2) # Show the results. echo "${tmpdir}" |