From 3f7d9370909a598cf83dfa07a1e87545a66e182f Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 14 Nov 2019 10:14:07 -0800 Subject: Use PacketBuffers for outgoing packets. PiperOrigin-RevId: 280455453 --- pkg/tcpip/buffer/prependable.go | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'pkg/tcpip/buffer') diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 48a2a2713..ba21f4eca 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -77,3 +77,9 @@ func (p *Prependable) Prepend(size int) []byte { p.usedIdx -= size return p.View()[:size:size] } + +// DeepCopy copies p and the bytes backing it. +func (p Prependable) DeepCopy() Prependable { + p.buf = append(View(nil), p.buf...) + return p +} -- cgit v1.2.3 From d29e59af9fbd420e34378bcbf7ae543134070217 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Mon, 27 Jan 2020 10:04:07 -0800 Subject: Standardize on tools directory. PiperOrigin-RevId: 291745021 --- .bazelrc | 8 +- BUILD | 49 ++++++- benchmarks/defs.bzl | 18 --- benchmarks/harness/BUILD | 74 +++++----- benchmarks/harness/machine_producers/BUILD | 4 +- benchmarks/runner/BUILD | 24 ++-- benchmarks/tcp/BUILD | 3 +- benchmarks/workloads/ab/BUILD | 19 ++- benchmarks/workloads/absl/BUILD | 19 ++- benchmarks/workloads/curl/BUILD | 2 +- benchmarks/workloads/ffmpeg/BUILD | 2 +- benchmarks/workloads/fio/BUILD | 19 ++- benchmarks/workloads/httpd/BUILD | 2 +- benchmarks/workloads/iperf/BUILD | 19 ++- benchmarks/workloads/netcat/BUILD | 2 +- benchmarks/workloads/nginx/BUILD | 2 +- benchmarks/workloads/node/BUILD | 2 +- benchmarks/workloads/node_template/BUILD | 2 +- benchmarks/workloads/redis/BUILD | 2 +- benchmarks/workloads/redisbenchmark/BUILD | 19 ++- benchmarks/workloads/ruby/BUILD | 2 +- benchmarks/workloads/ruby_template/BUILD | 2 +- benchmarks/workloads/sleep/BUILD | 2 +- benchmarks/workloads/sysbench/BUILD | 19 ++- benchmarks/workloads/syscall/BUILD | 19 ++- benchmarks/workloads/tensorflow/BUILD | 2 +- benchmarks/workloads/true/BUILD | 2 +- pkg/abi/BUILD | 3 +- pkg/abi/linux/BUILD | 6 +- pkg/amutex/BUILD | 6 +- pkg/atomicbitops/BUILD | 6 +- pkg/binary/BUILD | 6 +- pkg/bits/BUILD | 6 +- pkg/bpf/BUILD | 6 +- pkg/compressio/BUILD | 6 +- pkg/control/client/BUILD | 3 +- pkg/control/server/BUILD | 3 +- pkg/cpuid/BUILD | 8 +- pkg/eventchannel/BUILD | 16 +-- pkg/fd/BUILD | 6 +- pkg/fdchannel/BUILD | 8 +- pkg/fdnotifier/BUILD | 3 +- pkg/flipcall/BUILD | 8 +- pkg/fspath/BUILD | 13 +- pkg/gate/BUILD | 4 +- pkg/goid/BUILD | 6 +- pkg/ilist/BUILD | 6 +- pkg/linewriter/BUILD | 6 +- pkg/log/BUILD | 6 +- pkg/memutil/BUILD | 3 +- pkg/metric/BUILD | 23 +-- pkg/p9/BUILD | 6 +- pkg/p9/p9test/BUILD | 6 +- pkg/procid/BUILD | 8 +- pkg/rand/BUILD | 3 +- pkg/refs/BUILD | 6 +- pkg/seccomp/BUILD | 6 +- pkg/secio/BUILD | 6 +- pkg/segment/test/BUILD | 6 +- pkg/sentry/BUILD | 2 + pkg/sentry/arch/BUILD | 20 +-- pkg/sentry/context/BUILD | 3 +- pkg/sentry/context/contexttest/BUILD | 3 +- pkg/sentry/control/BUILD | 8 +- pkg/sentry/device/BUILD | 6 +- pkg/sentry/fs/BUILD | 6 +- pkg/sentry/fs/anon/BUILD | 3 +- pkg/sentry/fs/dev/BUILD | 3 +- pkg/sentry/fs/fdpipe/BUILD | 6 +- pkg/sentry/fs/filetest/BUILD | 3 +- pkg/sentry/fs/fsutil/BUILD | 6 +- pkg/sentry/fs/gofer/BUILD | 6 +- pkg/sentry/fs/host/BUILD | 6 +- pkg/sentry/fs/lock/BUILD | 6 +- pkg/sentry/fs/proc/BUILD | 6 +- pkg/sentry/fs/proc/device/BUILD | 3 +- pkg/sentry/fs/proc/seqfile/BUILD | 6 +- pkg/sentry/fs/ramfs/BUILD | 6 +- pkg/sentry/fs/sys/BUILD | 3 +- pkg/sentry/fs/timerfd/BUILD | 3 +- pkg/sentry/fs/tmpfs/BUILD | 6 +- pkg/sentry/fs/tty/BUILD | 6 +- pkg/sentry/fsimpl/ext/BUILD | 6 +- pkg/sentry/fsimpl/ext/benchmark/BUILD | 2 +- pkg/sentry/fsimpl/ext/disklayout/BUILD | 6 +- pkg/sentry/fsimpl/kernfs/BUILD | 6 +- pkg/sentry/fsimpl/proc/BUILD | 8 +- pkg/sentry/fsimpl/sys/BUILD | 6 +- pkg/sentry/fsimpl/testutil/BUILD | 5 +- pkg/sentry/fsimpl/tmpfs/BUILD | 8 +- pkg/sentry/hostcpu/BUILD | 6 +- pkg/sentry/hostmm/BUILD | 3 +- pkg/sentry/inet/BUILD | 3 +- pkg/sentry/kernel/BUILD | 24 +--- pkg/sentry/kernel/auth/BUILD | 3 +- pkg/sentry/kernel/contexttest/BUILD | 3 +- pkg/sentry/kernel/epoll/BUILD | 6 +- pkg/sentry/kernel/eventfd/BUILD | 6 +- pkg/sentry/kernel/fasync/BUILD | 3 +- pkg/sentry/kernel/futex/BUILD | 6 +- pkg/sentry/kernel/memevent/BUILD | 20 +-- pkg/sentry/kernel/pipe/BUILD | 6 +- pkg/sentry/kernel/sched/BUILD | 6 +- pkg/sentry/kernel/semaphore/BUILD | 6 +- pkg/sentry/kernel/shm/BUILD | 3 +- pkg/sentry/kernel/signalfd/BUILD | 5 +- pkg/sentry/kernel/time/BUILD | 3 +- pkg/sentry/limits/BUILD | 6 +- pkg/sentry/loader/BUILD | 4 +- pkg/sentry/memmap/BUILD | 6 +- pkg/sentry/mm/BUILD | 6 +- pkg/sentry/pgalloc/BUILD | 6 +- pkg/sentry/platform/BUILD | 3 +- pkg/sentry/platform/interrupt/BUILD | 6 +- pkg/sentry/platform/kvm/BUILD | 6 +- pkg/sentry/platform/kvm/testutil/BUILD | 3 +- pkg/sentry/platform/ptrace/BUILD | 3 +- pkg/sentry/platform/ring0/BUILD | 3 +- pkg/sentry/platform/ring0/gen_offsets/BUILD | 2 +- pkg/sentry/platform/ring0/pagetables/BUILD | 16 +-- pkg/sentry/platform/safecopy/BUILD | 6 +- pkg/sentry/safemem/BUILD | 6 +- pkg/sentry/sighandling/BUILD | 3 +- pkg/sentry/socket/BUILD | 3 +- pkg/sentry/socket/control/BUILD | 3 +- pkg/sentry/socket/hostinet/BUILD | 3 +- pkg/sentry/socket/netfilter/BUILD | 3 +- pkg/sentry/socket/netlink/BUILD | 3 +- pkg/sentry/socket/netlink/port/BUILD | 6 +- pkg/sentry/socket/netlink/route/BUILD | 3 +- pkg/sentry/socket/netlink/uevent/BUILD | 3 +- pkg/sentry/socket/netstack/BUILD | 3 +- pkg/sentry/socket/unix/BUILD | 3 +- pkg/sentry/socket/unix/transport/BUILD | 3 +- pkg/sentry/state/BUILD | 3 +- pkg/sentry/strace/BUILD | 20 +-- pkg/sentry/syscalls/BUILD | 3 +- pkg/sentry/syscalls/linux/BUILD | 3 +- pkg/sentry/time/BUILD | 6 +- pkg/sentry/unimpl/BUILD | 21 +-- pkg/sentry/uniqueid/BUILD | 3 +- pkg/sentry/usage/BUILD | 5 +- pkg/sentry/usermem/BUILD | 7 +- pkg/sentry/vfs/BUILD | 8 +- pkg/sentry/watchdog/BUILD | 3 +- pkg/sleep/BUILD | 6 +- pkg/state/BUILD | 17 +-- pkg/state/statefile/BUILD | 6 +- pkg/sync/BUILD | 6 +- pkg/sync/atomicptrtest/BUILD | 6 +- pkg/sync/seqatomictest/BUILD | 6 +- pkg/syserr/BUILD | 3 +- pkg/syserror/BUILD | 4 +- pkg/tcpip/BUILD | 6 +- pkg/tcpip/adapters/gonet/BUILD | 6 +- pkg/tcpip/buffer/BUILD | 6 +- pkg/tcpip/checker/BUILD | 3 +- pkg/tcpip/hash/jenkins/BUILD | 6 +- pkg/tcpip/header/BUILD | 6 +- pkg/tcpip/iptables/BUILD | 3 +- pkg/tcpip/link/channel/BUILD | 3 +- pkg/tcpip/link/fdbased/BUILD | 6 +- pkg/tcpip/link/loopback/BUILD | 3 +- pkg/tcpip/link/muxed/BUILD | 6 +- pkg/tcpip/link/rawfile/BUILD | 3 +- pkg/tcpip/link/sharedmem/BUILD | 6 +- pkg/tcpip/link/sharedmem/pipe/BUILD | 6 +- pkg/tcpip/link/sharedmem/queue/BUILD | 6 +- pkg/tcpip/link/sniffer/BUILD | 3 +- pkg/tcpip/link/tun/BUILD | 3 +- pkg/tcpip/link/waitable/BUILD | 6 +- pkg/tcpip/network/BUILD | 2 +- pkg/tcpip/network/arp/BUILD | 4 +- pkg/tcpip/network/fragmentation/BUILD | 6 +- pkg/tcpip/network/hash/BUILD | 3 +- pkg/tcpip/network/ipv4/BUILD | 4 +- pkg/tcpip/network/ipv6/BUILD | 6 +- pkg/tcpip/ports/BUILD | 6 +- pkg/tcpip/sample/tun_tcp_connect/BUILD | 2 +- pkg/tcpip/sample/tun_tcp_echo/BUILD | 2 +- pkg/tcpip/seqnum/BUILD | 3 +- pkg/tcpip/stack/BUILD | 6 +- pkg/tcpip/transport/icmp/BUILD | 3 +- pkg/tcpip/transport/packet/BUILD | 3 +- pkg/tcpip/transport/raw/BUILD | 3 +- pkg/tcpip/transport/tcp/BUILD | 4 +- pkg/tcpip/transport/tcp/testing/context/BUILD | 3 +- pkg/tcpip/transport/tcpconntrack/BUILD | 4 +- pkg/tcpip/transport/udp/BUILD | 4 +- pkg/tmutex/BUILD | 6 +- pkg/unet/BUILD | 6 +- pkg/urpc/BUILD | 6 +- pkg/waiter/BUILD | 6 +- runsc/BUILD | 27 ++-- runsc/boot/BUILD | 5 +- runsc/boot/filter/BUILD | 3 +- runsc/boot/platforms/BUILD | 3 +- runsc/cgroup/BUILD | 5 +- runsc/cmd/BUILD | 5 +- runsc/console/BUILD | 3 +- runsc/container/BUILD | 5 +- runsc/container/test_app/BUILD | 4 +- runsc/criutil/BUILD | 3 +- runsc/dockerutil/BUILD | 3 +- runsc/fsgofer/BUILD | 9 +- runsc/fsgofer/filter/BUILD | 3 +- runsc/sandbox/BUILD | 3 +- runsc/specutils/BUILD | 5 +- runsc/testutil/BUILD | 3 +- runsc/version_test.sh | 2 +- scripts/common.sh | 6 +- scripts/common_bazel.sh | 99 ------------- scripts/common_build.sh | 99 +++++++++++++ test/BUILD | 45 +----- test/e2e/BUILD | 5 +- test/image/BUILD | 5 +- test/iptables/BUILD | 5 +- test/iptables/runner/BUILD | 12 +- test/root/BUILD | 5 +- test/root/testdata/BUILD | 3 +- test/runtimes/BUILD | 4 +- test/runtimes/build_defs.bzl | 5 +- test/runtimes/images/proctor/BUILD | 4 +- test/syscalls/BUILD | 2 +- test/syscalls/build_defs.bzl | 6 +- test/syscalls/gtest/BUILD | 7 +- test/syscalls/linux/BUILD | 23 ++- test/syscalls/linux/arch_prctl.cc | 2 + test/syscalls/linux/rseq/BUILD | 5 +- .../linux/udp_socket_errqueue_test_case.cc | 4 + test/uds/BUILD | 3 +- test/util/BUILD | 27 ++-- test/util/save_util_linux.cc | 4 + test/util/save_util_other.cc | 4 + test/util/test_util_runfiles.cc | 4 + tools/BUILD | 3 + tools/build/BUILD | 10 ++ tools/build/defs.bzl | 91 ++++++++++++ tools/checkunsafe/BUILD | 3 +- tools/defs.bzl | 154 +++++++++++++++++++++ tools/go_generics/BUILD | 2 +- tools/go_generics/globals/BUILD | 4 +- tools/go_generics/go_merge/BUILD | 2 +- tools/go_generics/rules_tests/BUILD | 2 +- tools/go_marshal/BUILD | 4 +- tools/go_marshal/README.md | 52 +------ tools/go_marshal/analysis/BUILD | 5 +- tools/go_marshal/defs.bzl | 112 ++------------- tools/go_marshal/gomarshal/BUILD | 6 +- tools/go_marshal/gomarshal/generator.go | 20 ++- tools/go_marshal/gomarshal/generator_tests.go | 6 +- tools/go_marshal/main.go | 11 +- tools/go_marshal/marshal/BUILD | 5 +- tools/go_marshal/test/BUILD | 7 +- tools/go_marshal/test/external/BUILD | 6 +- tools/go_stateify/BUILD | 2 +- tools/go_stateify/defs.bzl | 79 +---------- tools/images/BUILD | 2 +- tools/images/defs.bzl | 6 +- tools/issue_reviver/BUILD | 2 +- tools/issue_reviver/github/BUILD | 3 +- tools/issue_reviver/reviver/BUILD | 5 +- tools/workspace_status.sh | 2 +- vdso/BUILD | 33 ++--- 264 files changed, 1012 insertions(+), 1380 deletions(-) delete mode 100644 benchmarks/defs.bzl delete mode 100755 scripts/common_bazel.sh create mode 100755 scripts/common_build.sh create mode 100644 tools/BUILD create mode 100644 tools/build/BUILD create mode 100644 tools/build/defs.bzl create mode 100644 tools/defs.bzl (limited to 'pkg/tcpip/buffer') diff --git a/.bazelrc b/.bazelrc index 9c35c5e7b..ef214bcfa 100644 --- a/.bazelrc +++ b/.bazelrc @@ -30,10 +30,10 @@ build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" # Add a custom platform and toolchain that builds in a privileged docker # container, which is required by our syscall tests. -build:remote --host_platform=//test:rbe_ubuntu1604 -build:remote --extra_toolchains=//test:cc-toolchain-clang-x86_64-default -build:remote --extra_execution_platforms=//test:rbe_ubuntu1604 -build:remote --platforms=//test:rbe_ubuntu1604 +build:remote --host_platform=//:rbe_ubuntu1604 +build:remote --extra_toolchains=//:cc-toolchain-clang-x86_64-default +build:remote --extra_execution_platforms=//:rbe_ubuntu1604 +build:remote --platforms=//:rbe_ubuntu1604 build:remote --crosstool_top=@rbe_default//cc:toolchain build:remote --jobs=50 build:remote --remote_timeout=3600 diff --git a/BUILD b/BUILD index 76286174f..5fd929378 100644 --- a/BUILD +++ b/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) # Apache 2.0 - load("@io_bazel_rules_go//go:def.bzl", "go_path", "nogo") load("@bazel_gazelle//:def.bzl", "gazelle") +package(licenses = ["notice"]) + # The sandbox filegroup is used for sandbox-internal dependencies. package_group( name = "sandbox", @@ -49,9 +49,52 @@ gazelle(name = "gazelle") # live in the tools subdirectory (unless they are standard). nogo( name = "nogo", - config = "tools/nogo.js", + config = "//tools:nogo.js", visibility = ["//visibility:public"], deps = [ "//tools/checkunsafe", ], ) + +# We need to define a bazel platform and toolchain to specify dockerPrivileged +# and dockerRunAsRoot options, they are required to run tests on the RBE +# cluster in Kokoro. +alias( + name = "rbe_ubuntu1604", + actual = ":rbe_ubuntu1604_r346485", +) + +platform( + name = "rbe_ubuntu1604_r346485", + constraint_values = [ + "@bazel_tools//platforms:x86_64", + "@bazel_tools//platforms:linux", + "@bazel_tools//tools/cpp:clang", + "@bazel_toolchains//constraints:xenial", + "@bazel_toolchains//constraints/sanitizers:support_msan", + ], + remote_execution_properties = """ + properties: { + name: "container-image" + value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50" + } + properties: { + name: "dockerAddCapabilities" + value: "SYS_ADMIN" + } + properties: { + name: "dockerPrivileged" + value: "true" + } + """, +) + +toolchain( + name = "cc-toolchain-clang-x86_64-default", + exec_compatible_with = [ + ], + target_compatible_with = [ + ], + toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl deleted file mode 100644 index 79e6cdbc8..000000000 --- a/benchmarks/defs.bzl +++ /dev/null @@ -1,18 +0,0 @@ -"""Provides python helper functions.""" - -load("@pydeps//:requirements.bzl", _requirement = "requirement") - -def filter_deps(deps = None): - if deps == None: - deps = [] - return [dep for dep in deps if dep] - -def py_library(deps = None, **kwargs): - return native.py_library(deps = filter_deps(deps), **kwargs) - -def py_test(deps = None, **kwargs): - return native.py_test(deps = filter_deps(deps), **kwargs) - -def requirement(name, direct = True): - """ requirement returns the required dependency. """ - return _requirement(name) diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD index 081a74243..52d4e42f8 100644 --- a/benchmarks/harness/BUILD +++ b/benchmarks/harness/BUILD @@ -1,4 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "requirement") +load("//tools:defs.bzl", "py_library", "py_requirement") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -25,16 +25,16 @@ py_library( srcs = ["container.py"], deps = [ "//benchmarks/workloads", - requirement("asn1crypto", False), - requirement("chardet", False), - requirement("certifi", False), - requirement("docker", True), - requirement("docker-pycreds", False), - requirement("idna", False), - requirement("ptyprocess", False), - requirement("requests", False), - requirement("urllib3", False), - requirement("websocket-client", False), + py_requirement("asn1crypto", False), + py_requirement("chardet", False), + py_requirement("certifi", False), + py_requirement("docker", True), + py_requirement("docker-pycreds", False), + py_requirement("idna", False), + py_requirement("ptyprocess", False), + py_requirement("requests", False), + py_requirement("urllib3", False), + py_requirement("websocket-client", False), ], ) @@ -47,17 +47,17 @@ py_library( "//benchmarks/harness:ssh_connection", "//benchmarks/harness:tunnel_dispatcher", "//benchmarks/harness/machine_mocks", - requirement("asn1crypto", False), - requirement("chardet", False), - requirement("certifi", False), - requirement("docker", True), - requirement("docker-pycreds", False), - requirement("idna", False), - requirement("ptyprocess", False), - requirement("requests", False), - requirement("six", False), - requirement("urllib3", False), - requirement("websocket-client", False), + py_requirement("asn1crypto", False), + py_requirement("chardet", False), + py_requirement("certifi", False), + py_requirement("docker", True), + py_requirement("docker-pycreds", False), + py_requirement("idna", False), + py_requirement("ptyprocess", False), + py_requirement("requests", False), + py_requirement("six", False), + py_requirement("urllib3", False), + py_requirement("websocket-client", False), ], ) @@ -66,10 +66,10 @@ py_library( srcs = ["ssh_connection.py"], deps = [ "//benchmarks/harness", - requirement("bcrypt", False), - requirement("cffi", True), - requirement("paramiko", True), - requirement("cryptography", False), + py_requirement("bcrypt", False), + py_requirement("cffi", True), + py_requirement("paramiko", True), + py_requirement("cryptography", False), ], ) @@ -77,16 +77,16 @@ py_library( name = "tunnel_dispatcher", srcs = ["tunnel_dispatcher.py"], deps = [ - requirement("asn1crypto", False), - requirement("chardet", False), - requirement("certifi", False), - requirement("docker", True), - requirement("docker-pycreds", False), - requirement("idna", False), - requirement("pexpect", True), - requirement("ptyprocess", False), - requirement("requests", False), - requirement("urllib3", False), - requirement("websocket-client", False), + py_requirement("asn1crypto", False), + py_requirement("chardet", False), + py_requirement("certifi", False), + py_requirement("docker", True), + py_requirement("docker-pycreds", False), + py_requirement("idna", False), + py_requirement("pexpect", True), + py_requirement("ptyprocess", False), + py_requirement("requests", False), + py_requirement("urllib3", False), + py_requirement("websocket-client", False), ], ) diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD index c4e943882..48ea0ef39 100644 --- a/benchmarks/harness/machine_producers/BUILD +++ b/benchmarks/harness/machine_producers/BUILD @@ -1,4 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "requirement") +load("//tools:defs.bzl", "py_library", "py_requirement") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -31,7 +31,7 @@ py_library( deps = [ "//benchmarks/harness:machine", "//benchmarks/harness/machine_producers:machine_producer", - requirement("PyYAML", False), + py_requirement("PyYAML", False), ], ) diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD index e1b2ea550..fae0ca800 100644 --- a/benchmarks/runner/BUILD +++ b/benchmarks/runner/BUILD @@ -1,4 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("//tools:defs.bzl", "py_library", "py_requirement", "py_test") package(licenses = ["notice"]) @@ -28,7 +28,7 @@ py_library( "//benchmarks/suites:startup", "//benchmarks/suites:sysbench", "//benchmarks/suites:syscall", - requirement("click", True), + py_requirement("click", True), ], ) @@ -36,7 +36,7 @@ py_library( name = "commands", srcs = ["commands.py"], deps = [ - requirement("click", True), + py_requirement("click", True), ], ) @@ -50,14 +50,14 @@ py_test( ], deps = [ ":runner", - requirement("click", True), - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("click", True), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/tcp/BUILD b/benchmarks/tcp/BUILD index 735d7127f..d5e401acc 100644 --- a/benchmarks/tcp/BUILD +++ b/benchmarks/tcp/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") -load("@rules_cc//cc:defs.bzl", "cc_binary") +load("//tools:defs.bzl", "cc_binary", "go_binary") package(licenses = ["notice"]) diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD index 4fc0ab735..4dd91ceb3 100644 --- a/benchmarks/workloads/ab/BUILD +++ b/benchmarks/workloads/ab/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":ab", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD index 61e010096..55dae3baa 100644 --- a/benchmarks/workloads/absl/BUILD +++ b/benchmarks/workloads/absl/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":absl", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/curl/BUILD +++ b/benchmarks/workloads/curl/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD index be472dfb2..7c41ba631 100644 --- a/benchmarks/workloads/ffmpeg/BUILD +++ b/benchmarks/workloads/ffmpeg/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD index de257adad..7b78e8e75 100644 --- a/benchmarks/workloads/fio/BUILD +++ b/benchmarks/workloads/fio/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":fio", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/httpd/BUILD +++ b/benchmarks/workloads/httpd/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD index 8832a996c..570f40148 100644 --- a/benchmarks/workloads/iperf/BUILD +++ b/benchmarks/workloads/iperf/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":iperf", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/netcat/BUILD +++ b/benchmarks/workloads/netcat/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/nginx/BUILD +++ b/benchmarks/workloads/nginx/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD index 71cd9f519..bfcf78cf9 100644 --- a/benchmarks/workloads/node/BUILD +++ b/benchmarks/workloads/node/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD index ca996f068..e142f082a 100644 --- a/benchmarks/workloads/node_template/BUILD +++ b/benchmarks/workloads/node_template/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/redis/BUILD +++ b/benchmarks/workloads/redis/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD index f5994a815..f472a4443 100644 --- a/benchmarks/workloads/redisbenchmark/BUILD +++ b/benchmarks/workloads/redisbenchmark/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":redisbenchmark", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD index e37d77804..a3be4fe92 100644 --- a/benchmarks/workloads/ruby/BUILD +++ b/benchmarks/workloads/ruby/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD index 27f7c0c46..59443b14a 100644 --- a/benchmarks/workloads/ruby_template/BUILD +++ b/benchmarks/workloads/ruby_template/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/sleep/BUILD +++ b/benchmarks/workloads/sleep/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD index fd2f8f03d..3834af7ed 100644 --- a/benchmarks/workloads/sysbench/BUILD +++ b/benchmarks/workloads/sysbench/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":sysbench", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD index 5100cbb21..dba4bb1e7 100644 --- a/benchmarks/workloads/syscall/BUILD +++ b/benchmarks/workloads/syscall/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":syscall", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD index 026c3b316..a7b7742f4 100644 --- a/benchmarks/workloads/tensorflow/BUILD +++ b/benchmarks/workloads/tensorflow/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD index 221c4b9a7..eba23d325 100644 --- a/benchmarks/workloads/true/BUILD +++ b/benchmarks/workloads/true/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD index f5c08ea06..839f822eb 100644 --- a/pkg/abi/BUILD +++ b/pkg/abi/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,6 +9,5 @@ go_library( "abi_linux.go", "flag.go", ], - importpath = "gvisor.dev/gvisor/pkg/abi", visibility = ["//:sandbox"], ) diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 716ff22d2..1f3c0c687 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") # Package linux contains the constants and types needed to interface with a # Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix @@ -60,7 +59,6 @@ go_library( "wait.go", "xattr.go", ], - importpath = "gvisor.dev/gvisor/pkg/abi/linux", visibility = ["//visibility:public"], deps = [ "//pkg/abi", @@ -73,7 +71,7 @@ go_test( name = "linux_test", size = "small", srcs = ["netfilter_test.go"], - embed = [":linux"], + library = ":linux", deps = [ "//pkg/binary", ], diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index d99e37b40..9612f072e 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "amutex", srcs = ["amutex.go"], - importpath = "gvisor.dev/gvisor/pkg/amutex", visibility = ["//:sandbox"], ) @@ -14,6 +12,6 @@ go_test( name = "amutex_test", size = "small", srcs = ["amutex_test.go"], - embed = [":amutex"], + library = ":amutex", deps = ["//pkg/sync"], ) diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 6403c60c2..3948074ba 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "atomic_bitops_arm64.s", "atomic_bitops_common.go", ], - importpath = "gvisor.dev/gvisor/pkg/atomicbitops", visibility = ["//:sandbox"], ) @@ -19,6 +17,6 @@ go_test( name = "atomicbitops_test", size = "small", srcs = ["atomic_bitops_test.go"], - embed = [":atomicbitops"], + library = ":atomicbitops", deps = ["//pkg/sync"], ) diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD index 543fb54bf..7ca2fda90 100644 --- a/pkg/binary/BUILD +++ b/pkg/binary/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "binary", srcs = ["binary.go"], - importpath = "gvisor.dev/gvisor/pkg/binary", visibility = ["//:sandbox"], ) @@ -14,5 +12,5 @@ go_test( name = "binary_test", size = "small", srcs = ["binary_test.go"], - embed = [":binary"], + library = ":binary", ) diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 93b88a29a..63f4670d7 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -15,7 +14,6 @@ go_library( "uint64_arch_arm64_asm.s", "uint64_arch_generic.go", ], - importpath = "gvisor.dev/gvisor/pkg/bits", visibility = ["//:sandbox"], ) @@ -53,5 +51,5 @@ go_test( name = "bits_test", size = "small", srcs = ["uint64_test.go"], - embed = [":bits"], + library = ":bits", ) diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index fba5643e8..2a6977f85 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "interpreter.go", "program_builder.go", ], - importpath = "gvisor.dev/gvisor/pkg/bpf", visibility = ["//visibility:public"], deps = ["//pkg/abi/linux"], ) @@ -25,7 +23,7 @@ go_test( "interpreter_test.go", "program_builder_test.go", ], - embed = [":bpf"], + library = ":bpf", deps = [ "//pkg/abi/linux", "//pkg/binary", diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index 2bb581b18..1f75319a7 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "compressio", srcs = ["compressio.go"], - importpath = "gvisor.dev/gvisor/pkg/compressio", visibility = ["//:sandbox"], deps = [ "//pkg/binary", @@ -18,5 +16,5 @@ go_test( name = "compressio_test", size = "medium", srcs = ["compressio_test.go"], - embed = [":compressio"], + library = ":compressio", ) diff --git a/pkg/control/client/BUILD b/pkg/control/client/BUILD index 066d7b1a1..1b9e10ee7 100644 --- a/pkg/control/client/BUILD +++ b/pkg/control/client/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,6 @@ go_library( srcs = [ "client.go", ], - importpath = "gvisor.dev/gvisor/pkg/control/client", visibility = ["//:sandbox"], deps = [ "//pkg/unet", diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD index adbd1e3f8..002d2ef44 100644 --- a/pkg/control/server/BUILD +++ b/pkg/control/server/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "server", srcs = ["server.go"], - importpath = "gvisor.dev/gvisor/pkg/control/server", visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD index ed111fd2a..43a432190 100644 --- a/pkg/cpuid/BUILD +++ b/pkg/cpuid/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "cpu_amd64.s", "cpuid.go", ], - importpath = "gvisor.dev/gvisor/pkg/cpuid", visibility = ["//:sandbox"], deps = ["//pkg/log"], ) @@ -18,7 +16,7 @@ go_test( name = "cpuid_test", size = "small", srcs = ["cpuid_test.go"], - embed = [":cpuid"], + library = ":cpuid", ) go_test( @@ -27,6 +25,6 @@ go_test( srcs = [ "cpuid_parse_test.go", ], - embed = [":cpuid"], + library = ":cpuid", tags = ["manual"], ) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 9d68682c7..bee28b68d 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -1,6 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") package(licenses = ["notice"]) @@ -10,7 +8,6 @@ go_library( "event.go", "rate.go", ], - importpath = "gvisor.dev/gvisor/pkg/eventchannel", visibility = ["//:sandbox"], deps = [ ":eventchannel_go_proto", @@ -24,22 +21,15 @@ go_library( ) proto_library( - name = "eventchannel_proto", + name = "eventchannel", srcs = ["event.proto"], visibility = ["//:sandbox"], ) -go_proto_library( - name = "eventchannel_go_proto", - importpath = "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto", - proto = ":eventchannel_proto", - visibility = ["//:sandbox"], -) - go_test( name = "eventchannel_test", srcs = ["event_test.go"], - embed = [":eventchannel"], + library = ":eventchannel", deps = [ "//pkg/sync", "@com_github_golang_protobuf//proto:go_default_library", diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index afa8f7659..872361546 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "fd", srcs = ["fd.go"], - importpath = "gvisor.dev/gvisor/pkg/fd", visibility = ["//visibility:public"], ) @@ -14,5 +12,5 @@ go_test( name = "fd_test", size = "small", srcs = ["fd_test.go"], - embed = [":fd"], + library = ":fd", ) diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index b0478c672..d9104ef02 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "fdchannel", srcs = ["fdchannel_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/fdchannel", visibility = ["//visibility:public"], ) @@ -14,6 +12,6 @@ go_test( name = "fdchannel_test", size = "small", srcs = ["fdchannel_test.go"], - embed = [":fdchannel"], + library = ":fdchannel", deps = ["//pkg/sync"], ) diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD index 91a202a30..235dcc490 100644 --- a/pkg/fdnotifier/BUILD +++ b/pkg/fdnotifier/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "fdnotifier.go", "poll_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/fdnotifier", visibility = ["//:sandbox"], deps = [ "//pkg/sync", diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index 85bd83af1..9c5ad500b 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -1,7 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "flipcall", @@ -13,7 +12,6 @@ go_library( "io.go", "packet_window_allocator.go", ], - importpath = "gvisor.dev/gvisor/pkg/flipcall", visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", @@ -30,6 +28,6 @@ go_test( "flipcall_example_test.go", "flipcall_test.go", ], - embed = [":flipcall"], + library = ":flipcall", deps = ["//pkg/sync"], ) diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index ca540363c..ee84471b2 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -1,10 +1,8 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) go_library( name = "fspath", @@ -13,7 +11,6 @@ go_library( "builder_unsafe.go", "fspath.go", ], - importpath = "gvisor.dev/gvisor/pkg/fspath", ) go_test( @@ -23,5 +20,5 @@ go_test( "builder_test.go", "fspath_test.go", ], - embed = [":fspath"], + library = ":fspath", ) diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index f22bd070d..dd3141143 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "gate.go", ], - importpath = "gvisor.dev/gvisor/pkg/gate", visibility = ["//visibility:public"], ) diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD index 5d31e5366..ea8d2422c 100644 --- a/pkg/goid/BUILD +++ b/pkg/goid/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "goid_race.go", "goid_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/goid", visibility = ["//visibility:public"], ) @@ -22,5 +20,5 @@ go_test( "empty_test.go", "goid_test.go", ], - embed = [":goid"], + library = ":goid", ) diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD index 34d2673ef..3f6eb07df 100644 --- a/pkg/ilist/BUILD +++ b/pkg/ilist/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( srcs = [ "interface_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/ilist", visibility = ["//visibility:public"], ) @@ -41,7 +39,7 @@ go_test( "list_test.go", "test_list.go", ], - embed = [":ilist"], + library = ":ilist", ) go_template( diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index bcde6d308..41bf104d0 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "linewriter", srcs = ["linewriter.go"], - importpath = "gvisor.dev/gvisor/pkg/linewriter", visibility = ["//visibility:public"], deps = ["//pkg/sync"], ) @@ -14,5 +12,5 @@ go_library( go_test( name = "linewriter_test", srcs = ["linewriter_test.go"], - embed = [":linewriter"], + library = ":linewriter", ) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index 0df0f2849..935d06963 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "json_k8s.go", "log.go", ], - importpath = "gvisor.dev/gvisor/pkg/log", visibility = [ "//visibility:public", ], @@ -29,5 +27,5 @@ go_test( "json_test.go", "log_test.go", ], - embed = [":log"], + library = ":log", ) diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD index 7b50e2b28..9d07d98b4 100644 --- a/pkg/memutil/BUILD +++ b/pkg/memutil/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "memutil", srcs = ["memutil_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/memutil", visibility = ["//visibility:public"], deps = ["@org_golang_x_sys//unix:go_default_library"], ) diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 9145f3233..58305009d 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -1,14 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") package(licenses = ["notice"]) go_library( name = "metric", srcs = ["metric.go"], - importpath = "gvisor.dev/gvisor/pkg/metric", visibility = ["//:sandbox"], deps = [ ":metric_go_proto", @@ -19,28 +15,15 @@ go_library( ) proto_library( - name = "metric_proto", + name = "metric", srcs = ["metric.proto"], visibility = ["//:sandbox"], ) -cc_proto_library( - name = "metric_cc_proto", - visibility = ["//:sandbox"], - deps = [":metric_proto"], -) - -go_proto_library( - name = "metric_go_proto", - importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto", - proto = ":metric_proto", - visibility = ["//:sandbox"], -) - go_test( name = "metric_test", srcs = ["metric_test.go"], - embed = [":metric"], + library = ":metric", deps = [ ":metric_go_proto", "//pkg/eventchannel", diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index a3e05c96d..4ccc1de86 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package( default_visibility = ["//visibility:public"], @@ -23,7 +22,6 @@ go_library( "transport_flipcall.go", "version.go", ], - importpath = "gvisor.dev/gvisor/pkg/p9", deps = [ "//pkg/fd", "//pkg/fdchannel", @@ -47,7 +45,7 @@ go_test( "transport_test.go", "version_test.go", ], - embed = [":p9"], + library = ":p9", deps = [ "//pkg/fd", "//pkg/unet", diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index f4edd68b2..7ca67cb19 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_library", "go_test") package(licenses = ["notice"]) @@ -64,7 +63,6 @@ go_library( "mocks.go", "p9test.go", ], - importpath = "gvisor.dev/gvisor/pkg/p9/p9test", visibility = ["//:sandbox"], deps = [ "//pkg/fd", @@ -80,7 +78,7 @@ go_test( name = "client_test", size = "medium", srcs = ["client_test.go"], - embed = [":p9test"], + library = ":p9test", deps = [ "//pkg/fd", "//pkg/p9", diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index b506813f0..aa3e3ac0b 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "procid_amd64.s", "procid_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/procid", visibility = ["//visibility:public"], ) @@ -20,7 +18,7 @@ go_test( srcs = [ "procid_test.go", ], - embed = [":procid"], + library = ":procid", deps = ["//pkg/sync"], ) @@ -31,6 +29,6 @@ go_test( "procid_net_test.go", "procid_test.go", ], - embed = [":procid"], + library = ":procid", deps = ["//pkg/sync"], ) diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD index 9d5b4859b..80b8ceb02 100644 --- a/pkg/rand/BUILD +++ b/pkg/rand/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "rand.go", "rand_linux.go", ], - importpath = "gvisor.dev/gvisor/pkg/rand", visibility = ["//:sandbox"], deps = [ "//pkg/sync", diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 974d9af9b..74affc887 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +22,6 @@ go_library( "refcounter_state.go", "weak_ref_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/refs", visibility = ["//:sandbox"], deps = [ "//pkg/log", @@ -35,6 +33,6 @@ go_test( name = "refs_test", size = "small", srcs = ["refcounter_test.go"], - embed = [":refs"], + library = ":refs", deps = ["//pkg/sync"], ) diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index af94e944d..742c8b79b 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test") +load("//tools:defs.bzl", "go_binary", "go_embed_data", "go_library", "go_test") package(licenses = ["notice"]) @@ -27,7 +26,6 @@ go_library( "seccomp_rules.go", "seccomp_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/seccomp", visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", @@ -43,7 +41,7 @@ go_test( "seccomp_test.go", ":victim_data", ], - embed = [":seccomp"], + library = ":seccomp", deps = [ "//pkg/abi/linux", "//pkg/binary", diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD index 22abdc69f..60f63c7a6 100644 --- a/pkg/secio/BUILD +++ b/pkg/secio/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "full_reader.go", "secio.go", ], - importpath = "gvisor.dev/gvisor/pkg/secio", visibility = ["//pkg/sentry:internal"], ) @@ -17,5 +15,5 @@ go_test( name = "secio_test", size = "small", srcs = ["secio_test.go"], - embed = [":secio"], + library = ":secio", ) diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index a27c35e21..f2d8462d8 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package( @@ -38,7 +37,6 @@ go_library( "int_set.go", "set_functions.go", ], - importpath = "gvisor.dev/gvisor/pkg/segment/segment", deps = [ "//pkg/state", ], @@ -48,5 +46,5 @@ go_test( name = "segment_test", size = "small", srcs = ["segment_test.go"], - embed = [":segment"], + library = ":segment", ) diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD index 2d6379c86..e8b794179 100644 --- a/pkg/sentry/BUILD +++ b/pkg/sentry/BUILD @@ -6,6 +6,8 @@ package(licenses = ["notice"]) package_group( name = "internal", packages = [ + "//cloud/gvisor/gopkg/sentry/...", + "//cloud/gvisor/sentry/...", "//pkg/sentry/...", "//runsc/...", # Code generated by go_marshal relies on go_marshal libraries. diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 65f22af2b..51ca09b24 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -1,6 +1,4 @@ -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) @@ -27,7 +25,6 @@ go_library( "syscalls_amd64.go", "syscalls_arm64.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/arch", visibility = ["//:sandbox"], deps = [ ":registers_go_proto", @@ -44,20 +41,7 @@ go_library( ) proto_library( - name = "registers_proto", + name = "registers", srcs = ["registers.proto"], visibility = ["//visibility:public"], ) - -cc_proto_library( - name = "registers_cc_proto", - visibility = ["//visibility:public"], - deps = [":registers_proto"], -) - -go_proto_library( - name = "registers_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto", - proto = ":registers_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/context/BUILD b/pkg/sentry/context/BUILD index 8dc1a77b1..e13a9ce20 100644 --- a/pkg/sentry/context/BUILD +++ b/pkg/sentry/context/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "context", srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/context", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/amutex", diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD index 581e7aa96..f91a6d4ed 100644 --- a/pkg/sentry/context/contexttest/BUILD +++ b/pkg/sentry/context/contexttest/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "contexttest", testonly = 1, srcs = ["contexttest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/context/contexttest", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/memutil", diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 2561a6109..e69496477 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,9 +11,8 @@ go_library( "proc.go", "state.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/control", visibility = [ - "//pkg/sentry:internal", + "//:sandbox", ], deps = [ "//pkg/abi/linux", @@ -40,7 +38,7 @@ go_test( name = "control_test", size = "small", srcs = ["proc_test.go"], - embed = [":control"], + library = ":control", deps = [ "//pkg/log", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 97fa1512c..e403cbd8b 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "device", srcs = ["device.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/device", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -18,5 +16,5 @@ go_test( name = "device_test", size = "small", srcs = ["device_test.go"], - embed = [":device"], + library = ":device", ) diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 7d5d72d5a..605d61dbe 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -44,7 +43,6 @@ go_library( "splice.go", "sync.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -129,7 +127,7 @@ go_test( "mount_test.go", "path_test.go", ], - embed = [":fs"], + library = ":fs", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD index ae1c9cf76..c14e5405e 100644 --- a/pkg/sentry/fs/anon/BUILD +++ b/pkg/sentry/fs/anon/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "anon.go", "device.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/anon", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index a0d9e8496..0c7247bd7 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -13,7 +13,6 @@ go_library( "random.go", "tty.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/dev", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index cc43de69d..25ef96299 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "pipe_opener.go", "pipe_state.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fdpipe", imports = ["gvisor.dev/gvisor/pkg/sentry/fs"], visibility = ["//pkg/sentry:internal"], deps = [ @@ -36,7 +34,7 @@ go_test( "pipe_opener_test.go", "pipe_test.go", ], - embed = [":fdpipe"], + library = ":fdpipe", deps = [ "//pkg/fd", "//pkg/fdnotifier", diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD index 358dc2be3..9a7608cae 100644 --- a/pkg/sentry/fs/filetest/BUILD +++ b/pkg/sentry/fs/filetest/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "filetest", testonly = 1, srcs = ["filetest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/filetest", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 945b6270d..9142f5bdf 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -75,7 +74,6 @@ go_library( "inode.go", "inode_cached.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fsutil", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -106,7 +104,7 @@ go_test( "dirty_set_test.go", "inode_cached_test.go", ], - embed = [":fsutil"], + library = ":fsutil", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index fd870e8e1..cf48e7c03 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -22,7 +21,6 @@ go_library( "socket.go", "util.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/gofer", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -56,7 +54,7 @@ go_test( name = "gofer_test", size = "small", srcs = ["gofer_test.go"], - embed = [":gofer"], + library = ":gofer", deps = [ "//pkg/p9", "//pkg/p9/p9test", diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 2b581aa69..f586f47c1 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -25,7 +24,6 @@ go_library( "util_arm64_unsafe.go", "util_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/host", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -69,7 +67,7 @@ go_test( "socket_test.go", "wait_test.go", ], - embed = [":host"], + library = ":host", deps = [ "//pkg/fd", "//pkg/fdnotifier", diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 2c332a82a..ae3331737 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -40,7 +39,6 @@ go_library( "lock_set.go", "lock_set_functions.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/lock", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", @@ -56,5 +54,5 @@ go_test( "lock_range_test.go", "lock_test.go", ], - embed = [":lock"], + library = ":lock", ) diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index cb37c6c6b..b06bead41 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -27,7 +26,6 @@ go_library( "uptime.go", "version.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -63,7 +61,7 @@ go_test( "net_test.go", "sys_net_test.go", ], - embed = [":proc"], + library = ":proc", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/fs/proc/device/BUILD b/pkg/sentry/fs/proc/device/BUILD index 0394451d4..52c9aa93d 100644 --- a/pkg/sentry/fs/proc/device/BUILD +++ b/pkg/sentry/fs/proc/device/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "device", srcs = ["device.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/device", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/sentry/device"], ) diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 38b246dff..310d8dd52 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "seqfile", srcs = ["seqfile.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -26,7 +24,7 @@ go_test( name = "seqfile_test", size = "small", srcs = ["seqfile_test.go"], - embed = [":seqfile"], + library = ":seqfile", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 3fb7b0633..39c4b84f8 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "symlink.go", "tree.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ramfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -31,7 +29,7 @@ go_test( name = "ramfs_test", size = "small", srcs = ["tree_test.go"], - embed = [":ramfs"], + library = ":ramfs", deps = [ "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD index 25f0f124e..cc6b3bfbf 100644 --- a/pkg/sentry/fs/sys/BUILD +++ b/pkg/sentry/fs/sys/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -10,7 +10,6 @@ go_library( "fs.go", "sys.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/sys", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD index a215c1b95..092668e8d 100644 --- a/pkg/sentry/fs/timerfd/BUILD +++ b/pkg/sentry/fs/timerfd/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "timerfd", srcs = ["timerfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/timerfd", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 3400b940c..04776555f 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "inode_file.go", "tmpfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -41,7 +39,7 @@ go_test( name = "tmpfs_test", size = "small", srcs = ["file_test.go"], - embed = [":tmpfs"], + library = ":tmpfs", deps = [ "//pkg/sentry/context", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index f6f60d0cf..29f804c6c 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -14,7 +13,6 @@ go_library( "slave.go", "terminal.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tty", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -40,7 +38,7 @@ go_test( name = "tty_test", size = "small", srcs = ["tty_test.go"], - embed = [":tty"], + library = ":tty", deps = [ "//pkg/abi/linux", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 903874141..a718920d5 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -32,7 +31,6 @@ go_library( "symlink.go", "utils.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -71,7 +69,7 @@ go_test( "//pkg/sentry/fsimpl/ext:assets/tiny.ext3", "//pkg/sentry/fsimpl/ext:assets/tiny.ext4", ], - embed = [":ext"], + library = ":ext", deps = [ "//pkg/abi/linux", "//pkg/binary", diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index 4fc8296ef..12f3990c1 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index fcfaf5c3e..9bd9c76c0 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -23,7 +22,6 @@ go_library( "superblock_old.go", "test_utils.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -44,6 +42,6 @@ go_test( "inode_test.go", "superblock_test.go", ], - embed = [":disklayout"], + library = ":disklayout", deps = ["//pkg/sentry/kernel/time"], ) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 66d409785..7bf83ccba 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -1,8 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -package(licenses = ["notice"]) +licenses(["notice"]) go_template_instance( name = "slot_list", @@ -27,7 +26,6 @@ go_library( "slot_list.go", "symlink.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index c5b79fb38..3768f55b2 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,7 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "proc", @@ -15,7 +14,6 @@ go_library( "tasks_net.go", "tasks_sys.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc", deps = [ "//pkg/abi/linux", "//pkg/log", @@ -47,7 +45,7 @@ go_test( "tasks_sys_test.go", "tasks_test.go", ], - embed = [":proc"], + library = ":proc", deps = [ "//pkg/abi/linux", "//pkg/fspath", diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index ee3c842bd..beda141f1 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -1,14 +1,12 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "sys", srcs = [ "sys.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index 4e70d84a7..12053a5b6 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -1,6 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "testutil", @@ -9,7 +9,6 @@ go_library( "kernel.go", "testutil.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 691476b4f..857e98bc5 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -1,8 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -package(licenses = ["notice"]) +licenses(["notice"]) go_template_instance( name = "dentry_list", @@ -28,7 +27,6 @@ go_library( "symlink.go", "tmpfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs", deps = [ "//pkg/abi/linux", "//pkg/amutex", @@ -81,7 +79,7 @@ go_test( "regular_file_test.go", "stat_test.go", ], - embed = [":tmpfs"], + library = ":tmpfs", deps = [ "//pkg/abi/linux", "//pkg/fspath", diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD index 359468ccc..e6933aa70 100644 --- a/pkg/sentry/hostcpu/BUILD +++ b/pkg/sentry/hostcpu/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "getcpu_arm64.s", "hostcpu.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu", visibility = ["//:sandbox"], ) @@ -18,5 +16,5 @@ go_test( name = "hostcpu_test", size = "small", srcs = ["hostcpu_test.go"], - embed = [":hostcpu"], + library = ":hostcpu", ) diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD index 67831d5a1..a145a5ca3 100644 --- a/pkg/sentry/hostmm/BUILD +++ b/pkg/sentry/hostmm/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "cgroup.go", "hostmm.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/hostmm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/fd", diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 8d60ad4ad..aa621b724 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package( default_visibility = ["//:sandbox"], @@ -12,7 +12,6 @@ go_library( "inet.go", "test_stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/inet", deps = [ "//pkg/sentry/context", "//pkg/tcpip/stack", diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index ac85ba0c8..cebaccd92 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -1,8 +1,5 @@ -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -78,26 +75,12 @@ go_template_instance( ) proto_library( - name = "uncaught_signal_proto", + name = "uncaught_signal", srcs = ["uncaught_signal.proto"], visibility = ["//visibility:public"], deps = ["//pkg/sentry/arch:registers_proto"], ) -cc_proto_library( - name = "uncaught_signal_cc_proto", - visibility = ["//visibility:public"], - deps = [":uncaught_signal_proto"], -) - -go_proto_library( - name = "uncaught_signal_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto", - proto = ":uncaught_signal_proto", - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_go_proto"], -) - go_library( name = "kernel", srcs = [ @@ -156,7 +139,6 @@ go_library( "vdso.go", "version.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel", imports = [ "gvisor.dev/gvisor/pkg/bpf", "gvisor.dev/gvisor/pkg/sentry/device", @@ -227,7 +209,7 @@ go_test( "task_test.go", "timekeeper_test.go", ], - embed = [":kernel"], + library = ":kernel", deps = [ "//pkg/abi", "//pkg/sentry/arch", diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 1aa72fa47..64537c9be 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -57,7 +57,6 @@ go_library( "id_map_set.go", "user_namespace.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/auth", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD index 3a88a585c..daff608d7 100644 --- a/pkg/sentry/kernel/contexttest/BUILD +++ b/pkg/sentry/kernel/contexttest/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "contexttest", testonly = 1, srcs = ["contexttest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index c47f6b6fc..19e16ab3a 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +22,6 @@ go_library( "epoll_list.go", "epoll_state.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/epoll", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/refs", @@ -43,7 +41,7 @@ go_test( srcs = [ "epoll_test.go", ], - embed = [":epoll"], + library = ":epoll", deps = [ "//pkg/sentry/context/contexttest", "//pkg/sentry/fs/filetest", diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index c831fbab2..ee2d74864 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "eventfd", srcs = ["eventfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -26,7 +24,7 @@ go_test( name = "eventfd_test", size = "small", srcs = ["eventfd_test.go"], - embed = [":eventfd"], + library = ":eventfd", deps = [ "//pkg/sentry/context/contexttest", "//pkg/sentry/usermem", diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 6b36bc63e..b9126e946 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "fasync", srcs = ["fasync.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/fasync", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 50db443ce..f413d8ae2 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -34,7 +33,6 @@ go_library( "futex.go", "waiter_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/futex", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -51,7 +49,7 @@ go_test( name = "futex_test", size = "small", srcs = ["futex_test.go"], - embed = [":futex"], + library = ":futex", deps = [ "//pkg/sentry/usermem", "//pkg/sync", diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index 7f36252a9..4486848d2 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -1,13 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) go_library( name = "memevent", srcs = ["memory_events.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent", visibility = ["//:sandbox"], deps = [ ":memory_events_go_proto", @@ -21,20 +18,7 @@ go_library( ) proto_library( - name = "memory_events_proto", + name = "memory_events", srcs = ["memory_events.proto"], visibility = ["//visibility:public"], ) - -cc_proto_library( - name = "memory_events_cc_proto", - visibility = ["//visibility:public"], - deps = [":memory_events_proto"], -) - -go_proto_library( - name = "memory_events_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto", - proto = ":memory_events_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 5eeaeff66..2c7b6206f 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -30,7 +29,6 @@ go_library( "vfs.go", "writer.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -57,7 +55,7 @@ go_test( "node_test.go", "pipe_test.go", ], - embed = [":pipe"], + library = ":pipe", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD index 98ea7a0d8..1b82e087b 100644 --- a/pkg/sentry/kernel/sched/BUILD +++ b/pkg/sentry/kernel/sched/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "cpuset.go", "sched.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/sched", visibility = ["//pkg/sentry:internal"], ) @@ -17,5 +15,5 @@ go_test( name = "sched_test", size = "small", srcs = ["cpuset_test.go"], - embed = [":sched"], + library = ":sched", ) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 13a961594..76e19b551 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +21,6 @@ go_library( "semaphore.go", "waiter_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/semaphore", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -40,7 +38,7 @@ go_test( name = "semaphore_test", size = "small", srcs = ["semaphore_test.go"], - embed = [":semaphore"], + library = ":semaphore", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 7321b22ed..5547c5abf 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "device.go", "shm.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/shm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 89e4d84b1..5d44773d4 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "signalfd", srcs = ["signalfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 4e4de0512..d49594d9f 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "context.go", "time.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/time", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 9fa841e8b..67869757f 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "limits.go", "linux.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/limits", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", @@ -25,5 +23,5 @@ go_test( srcs = [ "limits_test.go", ], - embed = [":limits"], + library = ":limits", ) diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index 2890393bd..d4ad2bd6c 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_embed_data") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_embed_data", "go_library") package(licenses = ["notice"]) @@ -20,7 +19,6 @@ go_library( "vdso_state.go", ":vdso_bin", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/loader", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi", diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 112794e9c..f9a65f086 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -37,7 +36,6 @@ go_library( "mapping_set_impl.go", "memmap.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/memmap", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", @@ -52,6 +50,6 @@ go_test( name = "memmap_test", size = "small", srcs = ["mapping_set_test.go"], - embed = [":memmap"], + library = ":memmap", deps = ["//pkg/sentry/usermem"], ) diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 83e248431..bd6399fa2 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -96,7 +95,6 @@ go_library( "vma.go", "vma_set.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/mm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -128,7 +126,7 @@ go_test( name = "mm_test", size = "small", srcs = ["mm_test.go"], - embed = [":mm"], + library = ":mm", deps = [ "//pkg/sentry/arch", "//pkg/sentry/context", diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index a9a2642c5..02385a3ce 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -60,7 +59,6 @@ go_library( "save_restore.go", "usage_set.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/pgalloc", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", @@ -82,6 +80,6 @@ go_test( name = "pgalloc_test", size = "small", srcs = ["pgalloc_test.go"], - embed = [":pgalloc"], + library = ":pgalloc", deps = ["//pkg/sentry/usermem"], ) diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 157bffa81..006450b2d 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +22,6 @@ go_library( "mmap_min_addr.go", "platform.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index 85e882df9..83b385f14 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "interrupt.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/sync"], ) @@ -17,5 +15,5 @@ go_test( name = "interrupt_test", size = "small", srcs = ["interrupt_test.go"], - embed = [":interrupt"], + library = ":interrupt", ) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 6a358d1d4..a4532a766 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -38,7 +37,6 @@ go_library( "physical_map_arm64.go", "virtual_map.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -65,7 +63,7 @@ go_test( "kvm_test.go", "virtual_map_test.go", ], - embed = [":kvm"], + library = ":kvm", tags = [ "manual", "nogotsan", diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD index b0e45f159..f7605df8a 100644 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ b/pkg/sentry/platform/kvm/testutil/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,6 +12,5 @@ go_library( "testutil_arm64.go", "testutil_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil", visibility = ["//pkg/sentry/platform/kvm:__pkg__"], ) diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index cd13390c3..3bcc5e040 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -20,7 +20,6 @@ go_library( "subprocess_linux_unsafe.go", "subprocess_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ptrace", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 87f4552b5..6dee8fcc5 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -74,7 +74,6 @@ go_library( "lib_arm64.s", "ring0.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/cpuid", diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 42076fb04..147311ed3 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 387a7f6c3..8b5cdd6c1 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,17 +1,14 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test", "select_arch") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) -config_setting( - name = "aarch64", - constraint_values = ["@bazel_tools//platforms:aarch64"], -) - go_template( name = "generic_walker", - srcs = ["walker_amd64.go"], + srcs = select_arch( + amd64 = ["walker_amd64.go"], + arm64 = ["walker_amd64.go"], + ), opt_types = [ "Visitor", ], @@ -91,7 +88,6 @@ go_library( "walker_map.go", "walker_unmap.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables", visibility = [ "//pkg/sentry/platform/kvm:__subpackages__", "//pkg/sentry/platform/ring0:__subpackages__", @@ -111,6 +107,6 @@ go_test( "pagetables_test.go", "walker_check.go", ], - embed = [":pagetables"], + library = ":pagetables", deps = ["//pkg/sentry/usermem"], ) diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD index 6769cd0a5..b8747585b 100644 --- a/pkg/sentry/platform/safecopy/BUILD +++ b/pkg/sentry/platform/safecopy/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -17,7 +16,6 @@ go_library( "sighandler_amd64.s", "sighandler_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/syserror"], ) @@ -27,5 +25,5 @@ go_test( srcs = [ "safecopy_test.go", ], - embed = [":safecopy"], + library = ":safecopy", ) diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD index 884020f7b..3ab76da97 100644 --- a/pkg/sentry/safemem/BUILD +++ b/pkg/sentry/safemem/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "safemem.go", "seq_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/safemem", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/platform/safecopy", @@ -25,5 +23,5 @@ go_test( "io_test.go", "seq_test.go", ], - embed = [":safemem"], + library = ":safemem", ) diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sentry/sighandling/BUILD index f561670c7..6c38a3f44 100644 --- a/pkg/sentry/sighandling/BUILD +++ b/pkg/sentry/sighandling/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "sighandling.go", "sighandling_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/sighandling", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/abi/linux"], ) diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 26176b10d..8e2b97afb 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "socket", srcs = ["socket.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 357517ed4..3850f6345 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "control", srcs = ["control.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/control", imports = [ "gvisor.dev/gvisor/pkg/sentry/fs", ], diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 4c44c7c0f..42bf7be6a 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "socket_unsafe.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/hostinet", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index b70047d81..ed34a8308 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,6 @@ go_library( srcs = [ "netfilter.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netfilter", # This target depends on netstack and should only be used by epsocket, # which is allowed to depend on netstack. visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 103933144..baaac13c6 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "provider.go", "socket.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 2d9f4ba9b..3a22923d8 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "port", srcs = ["port.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/sync"], ) @@ -14,5 +12,5 @@ go_library( go_test( name = "port_test", srcs = ["port_test.go"], - embed = [":port"], + library = ":port", ) diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 1d4912753..2137c7aeb 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "route", srcs = ["protocol.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD index 0777f3baf..73fbdf1eb 100644 --- a/pkg/sentry/socket/netlink/uevent/BUILD +++ b/pkg/sentry/socket/netlink/uevent/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "uevent", srcs = ["protocol.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index f78784569..e3d1f90cb 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -11,7 +11,6 @@ go_library( "save_restore.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack", visibility = [ "//pkg/sentry:internal", ], diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 5b6a154f6..bade18686 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "io.go", "unix.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index d7ba95dff..4bdfc9208 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -25,7 +25,6 @@ go_library( "transport_message_list.go", "unix.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 88765f4d6..0ea4aab8b 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "state_metadata.go", "state_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/state", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index aa1ac720c..ff6fafa63 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -1,6 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) @@ -21,7 +19,6 @@ go_library( "strace.go", "syscalls.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/strace", visibility = ["//:sandbox"], deps = [ ":strace_go_proto", @@ -42,20 +39,7 @@ go_library( ) proto_library( - name = "strace_proto", + name = "strace", srcs = ["strace.proto"], visibility = ["//visibility:public"], ) - -cc_proto_library( - name = "strace_cc_proto", - visibility = ["//visibility:public"], - deps = [":strace_proto"], -) - -go_proto_library( - name = "strace_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto", - proto = ":strace_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD index 79d972202..b8d1bd415 100644 --- a/pkg/sentry/syscalls/BUILD +++ b/pkg/sentry/syscalls/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "epoll.go", "syscalls.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 917f74e07..7d74e0f70 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -57,7 +57,6 @@ go_library( "sys_xattr.go", "timespec.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls/linux", visibility = ["//:sandbox"], deps = [ "//pkg/abi", diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 3cde3a0be..04f81a35b 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -31,7 +30,6 @@ go_library( "tsc_amd64.s", "tsc_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/time", visibility = ["//:sandbox"], deps = [ "//pkg/log", @@ -48,5 +46,5 @@ go_test( "parameters_test.go", "sampler_test.go", ], - embed = [":time"], + library = ":time", ) diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD index fc7614fff..370fa6ec5 100644 --- a/pkg/sentry/unimpl/BUILD +++ b/pkg/sentry/unimpl/BUILD @@ -1,34 +1,17 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) proto_library( - name = "unimplemented_syscall_proto", + name = "unimplemented_syscall", srcs = ["unimplemented_syscall.proto"], visibility = ["//visibility:public"], deps = ["//pkg/sentry/arch:registers_proto"], ) -cc_proto_library( - name = "unimplemented_syscall_cc_proto", - visibility = ["//visibility:public"], - deps = [":unimplemented_syscall_proto"], -) - -go_proto_library( - name = "unimplemented_syscall_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto", - proto = ":unimplemented_syscall_proto", - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_go_proto"], -) - go_library( name = "unimpl", srcs = ["events.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl", visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD index 86a87edd4..e9c18f170 100644 --- a/pkg/sentry/uniqueid/BUILD +++ b/pkg/sentry/uniqueid/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "uniqueid", srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/uniqueid", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index 5518ac3d0..099315613 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -11,9 +11,8 @@ go_library( "memory_unsafe.go", "usage.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/usage", visibility = [ - "//pkg/sentry:internal", + "//:sandbox", ], deps = [ "//pkg/bits", diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD index 684f59a6b..c8322e29e 100644 --- a/pkg/sentry/usermem/BUILD +++ b/pkg/sentry/usermem/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -29,7 +28,6 @@ go_library( "usermem_unsafe.go", "usermem_x86.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/usermem", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/atomicbitops", @@ -38,7 +36,6 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/safemem", "//pkg/syserror", - "//pkg/tcpip/buffer", ], ) @@ -49,7 +46,7 @@ go_test( "addr_range_seq_test.go", "usermem_test.go", ], - embed = [":usermem"], + library = ":usermem", deps = [ "//pkg/sentry/context", "//pkg/sentry/safemem", diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 35c7be259..51acdc4e9 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -1,7 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "vfs", @@ -24,7 +23,6 @@ go_library( "testutil.go", "vfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/vfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -47,7 +45,7 @@ go_test( "file_description_impl_util_test.go", "mount_test.go", ], - embed = [":vfs"], + library = ":vfs", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD index 28f21f13d..1c5a1c9b6 100644 --- a/pkg/sentry/watchdog/BUILD +++ b/pkg/sentry/watchdog/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "watchdog", srcs = ["watchdog.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/watchdog", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index a23c86fb1..e131455f7 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "commit_noasm.go", "sleep_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sleep", visibility = ["//:sandbox"], ) @@ -22,5 +20,5 @@ go_test( srcs = [ "sleep_test.go", ], - embed = [":sleep"], + library = ":sleep", ) diff --git a/pkg/state/BUILD b/pkg/state/BUILD index be93750bf..921af9d63 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,6 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -49,7 +47,7 @@ go_library( "state.go", "stats.go", ], - importpath = "gvisor.dev/gvisor/pkg/state", + stateify = False, visibility = ["//:sandbox"], deps = [ ":object_go_proto", @@ -58,21 +56,14 @@ go_library( ) proto_library( - name = "object_proto", + name = "object", srcs = ["object.proto"], visibility = ["//:sandbox"], ) -go_proto_library( - name = "object_go_proto", - importpath = "gvisor.dev/gvisor/pkg/state/object_go_proto", - proto = ":object_proto", - visibility = ["//:sandbox"], -) - go_test( name = "state_test", timeout = "long", srcs = ["state_test.go"], - embed = [":state"], + library = ":state", ) diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index 8a865d229..e7581c09b 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "statefile", srcs = ["statefile.go"], - importpath = "gvisor.dev/gvisor/pkg/state/statefile", visibility = ["//:sandbox"], deps = [ "//pkg/binary", @@ -18,6 +16,6 @@ go_test( name = "statefile_test", size = "small", srcs = ["statefile_test.go"], - embed = [":statefile"], + library = ":statefile", deps = ["//pkg/compressio"], ) diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 97c4b3b1e..5340cf0d6 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template") package( @@ -40,7 +39,6 @@ go_library( "syncutil.go", "tmutex_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sync", ) go_test( @@ -51,5 +49,5 @@ go_test( "seqcount_test.go", "tmutex_test.go", ], - embed = [":sync"], + library = ":sync", ) diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD index 418eda29c..e97553254 100644 --- a/pkg/sync/atomicptrtest/BUILD +++ b/pkg/sync/atomicptrtest/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -18,12 +17,11 @@ go_template_instance( go_library( name = "atomicptr", srcs = ["atomicptr_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/sync/atomicptr", ) go_test( name = "atomicptr_test", size = "small", srcs = ["atomicptr_test.go"], - embed = [":atomicptr"], + library = ":atomicptr", ) diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD index eba21518d..5c38c783e 100644 --- a/pkg/sync/seqatomictest/BUILD +++ b/pkg/sync/seqatomictest/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -18,7 +17,6 @@ go_template_instance( go_library( name = "seqatomic", srcs = ["seqatomic_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/sync/seqatomic", deps = [ "//pkg/sync", ], @@ -28,6 +26,6 @@ go_test( name = "seqatomic_test", size = "small", srcs = ["seqatomic_test.go"], - embed = [":seqatomic"], + library = ":seqatomic", deps = ["//pkg/sync"], ) diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD index 5665ad4ee..7d760344a 100644 --- a/pkg/syserr/BUILD +++ b/pkg/syserr/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "netstack.go", "syserr.go", ], - importpath = "gvisor.dev/gvisor/pkg/syserr", visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index bd3f9fd28..b13c15d9b 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "syserror", srcs = ["syserror.go"], - importpath = "gvisor.dev/gvisor/pkg/syserror", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 23e4b09e7..26f7ba86b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "time_unsafe.go", "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -25,7 +23,7 @@ go_test( name = "tcpip_test", size = "small", srcs = ["tcpip_test.go"], - embed = [":tcpip"], + library = ":tcpip", ) go_test( diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 3df7d18d3..a984f1712 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "gonet", srcs = ["gonet.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -23,7 +21,7 @@ go_test( name = "gonet_test", size = "small", srcs = ["gonet_test.go"], - embed = [":gonet"], + library = ":gonet", deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index d6c31bfa2..563bc78ea 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "prependable.go", "view.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer", visibility = ["//visibility:public"], ) @@ -17,5 +15,5 @@ go_test( name = "buffer_test", size = "small", srcs = ["view_test.go"], - embed = [":buffer"], + library = ":buffer", ) diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index b6fa6fc37..ed434807f 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "checker", testonly = 1, srcs = ["checker.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/checker", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index e648efa71..ff2719291 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "jenkins", srcs = ["jenkins.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", visibility = ["//visibility:public"], ) @@ -16,5 +14,5 @@ go_test( srcs = [ "jenkins_test.go", ], - embed = [":jenkins"], + library = ":jenkins", ) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index cd747d100..9da0d71f8 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -24,7 +23,6 @@ go_library( "tcp.go", "udp.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/header", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -59,7 +57,7 @@ go_test( "eth_test.go", "ndp_test.go", ], - embed = [":header"], + library = ":header", deps = [ "//pkg/tcpip", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 297eaccaf..d1b73cfdf 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "targets.go", "types.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], deps = [ "//pkg/log", diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 7dbc05754..3974c464e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "channel", srcs = ["channel.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 66cc53ed4..abe725548 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -13,7 +12,6 @@ go_library( "mmap_unsafe.go", "packet_dispatchers.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -30,7 +28,7 @@ go_test( name = "fdbased_test", size = "small", srcs = ["endpoint_test.go"], - embed = [":fdbased"], + library = ":fdbased", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index f35fcdff4..6bf3805b7 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "loopback", srcs = ["loopback.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 1ac7948b6..82b441b79 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "muxed", srcs = ["injectable.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -19,7 +17,7 @@ go_test( name = "muxed_test", size = "small", srcs = ["injectable_test.go"], - embed = [":muxed"], + library = ":muxed", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index d8211e93d..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "errors.go", "rawfile_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 09165dd4c..13243ebbb 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "sharedmem_unsafe.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -30,7 +28,7 @@ go_test( srcs = [ "sharedmem_test.go", ], - embed = [":sharedmem"], + library = ":sharedmem", deps = [ "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index a0d4ad0be..87020ec08 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", visibility = ["//visibility:public"], ) @@ -20,6 +18,6 @@ go_test( srcs = [ "pipe_test.go", ], - embed = [":pipe"], + library = ":pipe", deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 8c9234d54..3ba06af73 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -22,7 +20,7 @@ go_test( srcs = [ "queue_test.go", ], - embed = [":queue"], + library = ":queue", deps = [ "//pkg/tcpip/link/sharedmem/pipe", ], diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index d6ae0368a..230a8d53a 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "pcap.go", "sniffer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", visibility = ["//visibility:public"], deps = [ "//pkg/log", diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index a71a493fc..e5096ea38 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -1,10 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "tun", srcs = ["tun_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 134837943..0956d2c65 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "waitable.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", visibility = ["//visibility:public"], deps = [ "//pkg/gate", @@ -23,7 +21,7 @@ go_test( srcs = [ "waitable_test.go", ], - embed = [":waitable"], + library = ":waitable", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9d16ff8c9..6a4839fb8 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index e7617229b..eddf7b725 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "arp", srcs = ["arp.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index ed16076fd..d1c728ccf 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -24,7 +23,6 @@ go_library( "reassembler.go", "reassembler_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -42,6 +40,6 @@ go_test( "fragmentation_test.go", "reassembler_test.go", ], - embed = [":fragmentation"], + library = ":fragmentation", deps = ["//pkg/tcpip/buffer"], ) diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD index e6db5c0b0..872165866 100644 --- a/pkg/tcpip/network/hash/BUILD +++ b/pkg/tcpip/network/hash/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "hash", srcs = ["hash.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash", visibility = ["//visibility:public"], deps = [ "//pkg/rand", diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 4e2aae9a3..0fef2b1f1 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv4.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index e4e273460..fb11874c6 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv6.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -27,7 +25,7 @@ go_test( "ipv6_test.go", "ndp_test.go", ], - embed = [":ipv6"], + library = ":ipv6", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index a6ef3bdcc..2bad05a2e 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "ports", srcs = ["ports.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -17,7 +15,7 @@ go_library( go_test( name = "ports_test", srcs = ["ports_test.go"], - embed = [":ports"], + library = ":ports", deps = [ "//pkg/tcpip", ], diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index d7496fde6..cf0a5fefe 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD index 875561566..43264b76d 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index b31ddba2f..45f503845 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -1,10 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "seqnum", srcs = ["seqnum.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 783351a69..f5b750046 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -30,7 +29,6 @@ go_library( "stack_global_state.go", "transport_demuxer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", visibility = ["//visibility:public"], deps = [ "//pkg/ilist", @@ -81,7 +79,7 @@ go_test( name = "stack_test", size = "small", srcs = ["linkaddrcache_test.go"], - embed = [":stack"], + library = ":stack", deps = [ "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 3aa23d529..ac18ec5b1 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "icmp_packet_list.go", "protocol.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 4858d150c..d22de6b26 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +22,6 @@ go_library( "endpoint_state.go", "packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 2f2131ff7..c9baf4600 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "protocol.go", "raw_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 0e3ab05ad..4acd9fb9a 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -55,7 +54,6 @@ go_library( "tcp_segment_list.go", "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD index b33ec2087..ce6a2c31d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "context", testonly = 1, srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", visibility = [ "//visibility:public", ], diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 43fcc27f0..3ad6994a7 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "tcpconntrack", srcs = ["tcp_conntrack.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 57ff123e3..adc908e24 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -25,7 +24,6 @@ go_library( "protocol.go", "udp_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 07778e4f7..2dcba84ae 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "tmutex", srcs = ["tmutex.go"], - importpath = "gvisor.dev/gvisor/pkg/tmutex", visibility = ["//:sandbox"], ) @@ -14,6 +12,6 @@ go_test( name = "tmutex_test", size = "medium", srcs = ["tmutex_test.go"], - embed = [":tmutex"], + library = ":tmutex", deps = ["//pkg/sync"], ) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index d1885ae66..a86501fa2 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "unet.go", "unet_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/unet", visibility = ["//visibility:public"], deps = [ "//pkg/gate", @@ -23,6 +21,6 @@ go_test( srcs = [ "unet_test.go", ], - embed = [":unet"], + library = ":unet", deps = ["//pkg/sync"], ) diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b8fdc3125..850c34ed0 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "urpc", srcs = ["urpc.go"], - importpath = "gvisor.dev/gvisor/pkg/urpc", visibility = ["//:sandbox"], deps = [ "//pkg/fd", @@ -20,6 +18,6 @@ go_test( name = "urpc_test", size = "small", srcs = ["urpc_test.go"], - embed = [":urpc"], + library = ":urpc", deps = ["//pkg/unet"], ) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 1c6890e52..852480a09 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +21,6 @@ go_library( "waiter.go", "waiter_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/waiter", visibility = ["//visibility:public"], deps = ["//pkg/sync"], ) @@ -33,5 +31,5 @@ go_test( srcs = [ "waiter_test.go", ], - embed = [":waiter"], + library = ":waiter", ) diff --git a/runsc/BUILD b/runsc/BUILD index e5587421d..b35b41d81 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -1,7 +1,6 @@ -package(licenses = ["notice"]) # Apache 2.0 +load("//tools:defs.bzl", "go_binary", "pkg_deb", "pkg_tar") -load("@io_bazel_rules_go//go:def.bzl", "go_binary") -load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar") +package(licenses = ["notice"]) go_binary( name = "runsc", @@ -9,7 +8,7 @@ go_binary( "main.go", "version.go", ], - pure = "on", + pure = True, visibility = [ "//visibility:public", ], @@ -26,10 +25,12 @@ go_binary( ) # The runsc-race target is a race-compatible BUILD target. This must be built -# via "bazel build --features=race //runsc:runsc-race", since the race feature -# must apply to all dependencies due a bug in gazelle file selection. The pure -# attribute must be off because the race detector requires linking with non-Go -# components, although we still require a static binary. +# via: bazel build --features=race //runsc:runsc-race +# +# This is neccessary because the race feature must apply to all dependencies +# due a bug in gazelle file selection. The pure attribute must be off because +# the race detector requires linking with non-Go components, although we still +# require a static binary. # # Note that in the future this might be convertible to a compatible target by # using the pure and static attributes within a select function, but select is @@ -42,7 +43,7 @@ go_binary( "main.go", "version.go", ], - static = "on", + static = True, visibility = [ "//visibility:public", ], @@ -82,7 +83,12 @@ genrule( # because they are assumes to be hermetic). srcs = [":runsc"], outs = ["version.txt"], - cmd = "$(location :runsc) -version | grep 'runsc version' | sed 's/^[^0-9]*//' > $@", + # Note that the little dance here is necessary because files in the $(SRCS) + # attribute are not executable by default, and we can't touch in place. + cmd = "cp $(location :runsc) $(@D)/runsc && \ + chmod a+x $(@D)/runsc && \ + $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \ + rm -f $(@D)/runsc", stamp = 1, ) @@ -109,5 +115,6 @@ sh_test( name = "version_test", size = "small", srcs = ["version_test.sh"], + args = ["$(location :runsc)"], data = [":runsc"], ) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 3e20f8f2f..f3ebc0231 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "strace.go", "user.go", ], - importpath = "gvisor.dev/gvisor/runsc/boot", visibility = [ "//runsc:__subpackages__", "//test:__subpackages__", @@ -107,7 +106,7 @@ go_test( "loader_test.go", "user_test.go", ], - embed = [":boot"], + library = ":boot", deps = [ "//pkg/control/server", "//pkg/log", diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD index 3a9dcfc04..ce30f6c53 100644 --- a/runsc/boot/filter/BUILD +++ b/runsc/boot/filter/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -13,7 +13,6 @@ go_library( "extra_filters_race.go", "filter.go", ], - importpath = "gvisor.dev/gvisor/runsc/boot/filter", visibility = [ "//runsc/boot:__subpackages__", ], diff --git a/runsc/boot/platforms/BUILD b/runsc/boot/platforms/BUILD index 03391cdca..77774f43c 100644 --- a/runsc/boot/platforms/BUILD +++ b/runsc/boot/platforms/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "platforms", srcs = ["platforms.go"], - importpath = "gvisor.dev/gvisor/runsc/boot/platforms", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD index d6165f9e5..d4c7bdfbb 100644 --- a/runsc/cgroup/BUILD +++ b/runsc/cgroup/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "cgroup", srcs = ["cgroup.go"], - importpath = "gvisor.dev/gvisor/runsc/cgroup", visibility = ["//:sandbox"], deps = [ "//pkg/log", @@ -19,6 +18,6 @@ go_test( name = "cgroup_test", size = "small", srcs = ["cgroup_test.go"], - embed = [":cgroup"], + library = ":cgroup", tags = ["local"], ) diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index b94bc4fa0..09aa46434 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -34,7 +34,6 @@ go_library( "syscalls.go", "wait.go", ], - importpath = "gvisor.dev/gvisor/runsc/cmd", visibility = [ "//runsc:__subpackages__", ], @@ -73,7 +72,7 @@ go_test( data = [ "//runsc", ], - embed = [":cmd"], + library = ":cmd", deps = [ "//pkg/abi/linux", "//pkg/log", diff --git a/runsc/console/BUILD b/runsc/console/BUILD index e623c1a0f..06924bccd 100644 --- a/runsc/console/BUILD +++ b/runsc/console/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,6 @@ go_library( srcs = [ "console.go", ], - importpath = "gvisor.dev/gvisor/runsc/console", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 6dea179e4..e21431e4c 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,6 @@ go_library( "state_file.go", "status.go", ], - importpath = "gvisor.dev/gvisor/runsc/container", visibility = [ "//runsc:__subpackages__", "//test:__subpackages__", @@ -42,7 +41,7 @@ go_test( "//runsc", "//runsc/container/test_app", ], - embed = [":container"], + library = ":container", shard_count = 5, tags = [ "requires-kvm", diff --git a/runsc/container/test_app/BUILD b/runsc/container/test_app/BUILD index bfd338bb6..e200bafd9 100644 --- a/runsc/container/test_app/BUILD +++ b/runsc/container/test_app/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) @@ -9,7 +9,7 @@ go_binary( "fds.go", "test_app.go", ], - pure = "on", + pure = True, visibility = ["//runsc/container:__pkg__"], deps = [ "//pkg/unet", diff --git a/runsc/criutil/BUILD b/runsc/criutil/BUILD index 558133a0e..8a571a000 100644 --- a/runsc/criutil/BUILD +++ b/runsc/criutil/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "criutil", testonly = 1, srcs = ["criutil.go"], - importpath = "gvisor.dev/gvisor/runsc/criutil", visibility = ["//:sandbox"], deps = ["//runsc/testutil"], ) diff --git a/runsc/dockerutil/BUILD b/runsc/dockerutil/BUILD index 0e0423504..8621af901 100644 --- a/runsc/dockerutil/BUILD +++ b/runsc/dockerutil/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "dockerutil", testonly = 1, srcs = ["dockerutil.go"], - importpath = "gvisor.dev/gvisor/runsc/dockerutil", visibility = ["//:sandbox"], deps = [ "//runsc/testutil", diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index a9582d92b..64a406ae2 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,10 +10,7 @@ go_library( "fsgofer_arm64_unsafe.go", "fsgofer_unsafe.go", ], - importpath = "gvisor.dev/gvisor/runsc/fsgofer", - visibility = [ - "//runsc:__subpackages__", - ], + visibility = ["//runsc:__subpackages__"], deps = [ "//pkg/abi/linux", "//pkg/fd", @@ -30,7 +27,7 @@ go_test( name = "fsgofer_test", size = "small", srcs = ["fsgofer_test.go"], - embed = [":fsgofer"], + library = ":fsgofer", deps = [ "//pkg/log", "//pkg/p9", diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD index bac73f89d..82b48ef32 100644 --- a/runsc/fsgofer/filter/BUILD +++ b/runsc/fsgofer/filter/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -13,7 +13,6 @@ go_library( "extra_filters_race.go", "filter.go", ], - importpath = "gvisor.dev/gvisor/runsc/fsgofer/filter", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index ddbc37456..c95d50294 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "network_unsafe.go", "sandbox.go", ], - importpath = "gvisor.dev/gvisor/runsc/sandbox", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD index 205638803..4ccd77f63 100644 --- a/runsc/specutils/BUILD +++ b/runsc/specutils/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,6 @@ go_library( "namespace.go", "specutils.go", ], - importpath = "gvisor.dev/gvisor/runsc/specutils", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", @@ -28,6 +27,6 @@ go_test( name = "specutils_test", size = "small", srcs = ["specutils_test.go"], - embed = [":specutils"], + library = ":specutils", deps = ["@com_github_opencontainers_runtime-spec//specs-go:go_default_library"], ) diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index 3c3027cb5..f845120b0 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "testutil", testonly = 1, srcs = ["testutil.go"], - importpath = "gvisor.dev/gvisor/runsc/testutil", visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/runsc/version_test.sh b/runsc/version_test.sh index cc0ca3f05..747350654 100755 --- a/runsc/version_test.sh +++ b/runsc/version_test.sh @@ -16,7 +16,7 @@ set -euf -x -o pipefail -readonly runsc="${TEST_SRCDIR}/__main__/runsc/linux_amd64_pure_stripped/runsc" +readonly runsc="$1" readonly version=$($runsc --version) # Version should should not match VERSION, which is the default and which will diff --git a/scripts/common.sh b/scripts/common.sh index fdb1aa142..cd91b9f8e 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -16,11 +16,7 @@ set -xeou pipefail -if [[ -f $(dirname $0)/common_google.sh ]]; then - source $(dirname $0)/common_google.sh -else - source $(dirname $0)/common_bazel.sh -fi +source $(dirname $0)/common_build.sh # Ensure it attempts to collect logs in all cases. trap collect_logs EXIT diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh deleted file mode 100755 index a473a88a4..000000000 --- a/scripts/common_bazel.sh +++ /dev/null @@ -1,99 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Install the latest version of Bazel and log the version. -(which use_bazel.sh && use_bazel.sh latest) || which bazel -bazel version - -# Switch into the workspace; only necessary if run with kokoro. -if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then - cd git/repo -elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then - cd github/repo -fi - -# Set the standard bazel flags. -declare -r BAZEL_FLAGS=( - "--show_timestamps" - "--test_output=errors" - "--keep_going" - "--verbose_failures=true" -) -if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then - declare -r BAZEL_RBE_AUTH_FLAGS=( - "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" - ) - declare -r BAZEL_RBE_FLAGS=("--config=remote") -fi - -# Wrap bazel. -function build() { - bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 | - tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }' -} - -function test() { - bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" -} - -function run() { - local binary=$1 - shift - bazel run "${binary}" -- "$@" -} - -function run_as_root() { - local binary=$1 - shift - bazel run --run_under="sudo" "${binary}" -- "$@" -} - -function collect_logs() { - # Zip out everything into a convenient form. - if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then - # Merge results files of all shards for each test suite. - for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do - junitparser merge `find $d -name test.xml` $d/test.xml - cat $d/shard_*_of_*/test.log > $d/test.log - ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip - done - find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf - # Move test logs to Kokoro directory. tar is used to conveniently perform - # renames while moving files. - find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | - tar --create --files-from - --transform 's/test\./sponge_log./' | - tar --extract --directory ${KOKORO_ARTIFACTS_DIR} - - # Collect sentry logs, if any. - if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then - # Check if the directory is empty or not (only the first line it needed). - local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1) - if [[ "${logs}" ]]; then - local -r archive=runsc_logs_"${RUNTIME}".tar.gz - if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then - echo "runsc logs will be uploaded to:" - echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" - echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" - fi - tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" . - fi - fi - fi -} - -function find_branch_name() { - git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename -} diff --git a/scripts/common_build.sh b/scripts/common_build.sh new file mode 100755 index 000000000..a473a88a4 --- /dev/null +++ b/scripts/common_build.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Install the latest version of Bazel and log the version. +(which use_bazel.sh && use_bazel.sh latest) || which bazel +bazel version + +# Switch into the workspace; only necessary if run with kokoro. +if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then + cd git/repo +elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then + cd github/repo +fi + +# Set the standard bazel flags. +declare -r BAZEL_FLAGS=( + "--show_timestamps" + "--test_output=errors" + "--keep_going" + "--verbose_failures=true" +) +if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then + declare -r BAZEL_RBE_AUTH_FLAGS=( + "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" + ) + declare -r BAZEL_RBE_FLAGS=("--config=remote") +fi + +# Wrap bazel. +function build() { + bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 | + tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }' +} + +function test() { + bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" +} + +function run() { + local binary=$1 + shift + bazel run "${binary}" -- "$@" +} + +function run_as_root() { + local binary=$1 + shift + bazel run --run_under="sudo" "${binary}" -- "$@" +} + +function collect_logs() { + # Zip out everything into a convenient form. + if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then + # Merge results files of all shards for each test suite. + for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do + junitparser merge `find $d -name test.xml` $d/test.xml + cat $d/shard_*_of_*/test.log > $d/test.log + ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip + done + find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf + # Move test logs to Kokoro directory. tar is used to conveniently perform + # renames while moving files. + find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | + tar --create --files-from - --transform 's/test\./sponge_log./' | + tar --extract --directory ${KOKORO_ARTIFACTS_DIR} + + # Collect sentry logs, if any. + if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then + # Check if the directory is empty or not (only the first line it needed). + local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1) + if [[ "${logs}" ]]; then + local -r archive=runsc_logs_"${RUNTIME}".tar.gz + if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then + echo "runsc logs will be uploaded to:" + echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" + echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" + fi + tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" . + fi + fi + fi +} + +function find_branch_name() { + git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename +} diff --git a/test/BUILD b/test/BUILD index bf834d994..34b950644 100644 --- a/test/BUILD +++ b/test/BUILD @@ -1,44 +1 @@ -package(licenses = ["notice"]) # Apache 2.0 - -# We need to define a bazel platform and toolchain to specify dockerPrivileged -# and dockerRunAsRoot options, they are required to run tests on the RBE -# cluster in Kokoro. -alias( - name = "rbe_ubuntu1604", - actual = ":rbe_ubuntu1604_r346485", -) - -platform( - name = "rbe_ubuntu1604_r346485", - constraint_values = [ - "@bazel_tools//platforms:x86_64", - "@bazel_tools//platforms:linux", - "@bazel_tools//tools/cpp:clang", - "@bazel_toolchains//constraints:xenial", - "@bazel_toolchains//constraints/sanitizers:support_msan", - ], - remote_execution_properties = """ - properties: { - name: "container-image" - value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50" - } - properties: { - name: "dockerAddCapabilities" - value: "SYS_ADMIN" - } - properties: { - name: "dockerPrivileged" - value: "true" - } - """, -) - -toolchain( - name = "cc-toolchain-clang-x86_64-default", - exec_compatible_with = [ - ], - target_compatible_with = [ - ], - toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", - toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", -) +package(licenses = ["notice"]) diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 4fe03a220..76e04f878 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,7 @@ go_test( "integration_test.go", "regression_test.go", ], - embed = [":integration"], + library = ":integration", tags = [ # Requires docker and runsc to be configured before the test runs. "manual", @@ -29,5 +29,4 @@ go_test( go_library( name = "integration", srcs = ["integration.go"], - importpath = "gvisor.dev/gvisor/test/integration", ) diff --git a/test/image/BUILD b/test/image/BUILD index 09b0a0ad5..7392ac54e 100644 --- a/test/image/BUILD +++ b/test/image/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -14,7 +14,7 @@ go_test( "ruby.rb", "ruby.sh", ], - embed = [":image"], + library = ":image", tags = [ # Requires docker and runsc to be configured before the test runs. "manual", @@ -30,5 +30,4 @@ go_test( go_library( name = "image", srcs = ["image.go"], - importpath = "gvisor.dev/gvisor/test/image", ) diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 22f470092..6bb3b82b5 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "iptables_util.go", "nat.go", ], - importpath = "gvisor.dev/gvisor/test/iptables", visibility = ["//test/iptables:__subpackages__"], deps = [ "//runsc/testutil", @@ -24,7 +23,7 @@ go_test( srcs = [ "iptables_test.go", ], - embed = [":iptables"], + library = ":iptables", tags = [ "local", "manual", diff --git a/test/iptables/runner/BUILD b/test/iptables/runner/BUILD index a5b6f082c..b9199387a 100644 --- a/test/iptables/runner/BUILD +++ b/test/iptables/runner/BUILD @@ -1,15 +1,21 @@ -load("@io_bazel_rules_docker//go:image.bzl", "go_image") -load("@io_bazel_rules_docker//container:container.bzl", "container_image") +load("//tools:defs.bzl", "container_image", "go_binary", "go_image") package(licenses = ["notice"]) +go_binary( + name = "runner", + testonly = 1, + srcs = ["main.go"], + deps = ["//test/iptables"], +) + container_image( name = "iptables-base", base = "@iptables-test//image", ) go_image( - name = "runner", + name = "runner-image", testonly = 1, srcs = ["main.go"], base = ":iptables-base", diff --git a/test/root/BUILD b/test/root/BUILD index d5dd9bca2..23ce2a70f 100644 --- a/test/root/BUILD +++ b/test/root/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "root", srcs = ["root.go"], - importpath = "gvisor.dev/gvisor/test/root", ) go_test( @@ -21,7 +20,7 @@ go_test( data = [ "//runsc", ], - embed = [":root"], + library = ":root", tags = [ # Requires docker and runsc to be configured before the test runs. # Also test only runs as root. diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD index 125633680..bca5f9cab 100644 --- a/test/root/testdata/BUILD +++ b/test/root/testdata/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "sandbox.go", "simple.go", ], - importpath = "gvisor.dev/gvisor/test/root/testdata", visibility = [ "//visibility:public", ], diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 367295206..2c472bf8d 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -1,6 +1,6 @@ # These packages are used to run language runtime tests inside gVisor sandboxes. -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_test") load("//test/runtimes:build_defs.bzl", "runtime_test") package(licenses = ["notice"]) @@ -49,5 +49,5 @@ go_test( name = "blacklist_test", size = "small", srcs = ["blacklist_test.go"], - embed = [":runner"], + library = ":runner", ) diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl index 6f84ca852..92e275a76 100644 --- a/test/runtimes/build_defs.bzl +++ b/test/runtimes/build_defs.bzl @@ -1,6 +1,6 @@ """Defines a rule for runtime test targets.""" -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test", "loopback") def runtime_test( name, @@ -34,6 +34,7 @@ def runtime_test( ] data = [ ":runner", + loopback, ] if blacklist_file: args += ["--blacklist_file", "test/runtimes/" + blacklist_file] @@ -61,7 +62,7 @@ def blacklist_test(name, blacklist_file): """Test that a blacklist parses correctly.""" go_test( name = name + "_blacklist_test", - embed = [":runner"], + library = ":runner", srcs = ["blacklist_test.go"], args = ["--blacklist_file", "test/runtimes/" + blacklist_file], data = [blacklist_file], diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/images/proctor/BUILD index 09dc6c42f..85e004c45 100644 --- a/test/runtimes/images/proctor/BUILD +++ b/test/runtimes/images/proctor/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_test") package(licenses = ["notice"]) @@ -19,7 +19,7 @@ go_test( name = "proctor_test", size = "small", srcs = ["proctor_test.go"], - embed = [":proctor"], + library = ":proctor", deps = [ "//runsc/testutil", ], diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 90d52e73b..40e974314 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") load("//test/syscalls:build_defs.bzl", "syscall_test") package(licenses = ["notice"]) diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index aaf77c65b..1df761dd0 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -1,5 +1,7 @@ """Defines a rule for syscall test targets.""" +load("//tools:defs.bzl", "loopback") + # syscall_test is a macro that will create targets to run the given test target # on the host (native) and runsc. def syscall_test( @@ -135,6 +137,7 @@ def _syscall_test( name = name, data = [ ":syscall_test_runner", + loopback, test, ], args = args, @@ -148,6 +151,3 @@ def sh_test(**kwargs): native.sh_test( **kwargs ) - -def select_for_linux(for_linux, for_others = []): - return for_linux diff --git a/test/syscalls/gtest/BUILD b/test/syscalls/gtest/BUILD index 9293f25cb..de4b2727c 100644 --- a/test/syscalls/gtest/BUILD +++ b/test/syscalls/gtest/BUILD @@ -1,12 +1,9 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "gtest", srcs = ["gtest.go"], - importpath = "gvisor.dev/gvisor/test/syscalls/gtest", - visibility = [ - "//test:__subpackages__", - ], + visibility = ["//:sandbox"], ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 4c7ec3f06..c2ef50c1d 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1,5 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") -load("//test/syscalls:build_defs.bzl", "select_for_linux") +load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "select_system") package( default_visibility = ["//:sandbox"], @@ -126,13 +125,11 @@ cc_library( testonly = 1, srcs = [ "socket_test_util.cc", - ] + select_for_linux( - [ - "socket_test_util_impl.cc", - ], - ), + "socket_test_util_impl.cc", + ], hdrs = ["socket_test_util.h"], - deps = [ + defines = select_system(), + deps = default_net_util() + [ "@com_google_googletest//:gtest", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -143,8 +140,7 @@ cc_library( "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", - ] + select_for_linux([ - ]), + ], ) cc_library( @@ -1443,6 +1439,7 @@ cc_binary( srcs = ["arch_prctl.cc"], linkstatic = 1, deps = [ + "//test/util:file_descriptor", "//test/util:test_main", "//test/util:test_util", "@com_google_googletest//:gtest", @@ -3383,11 +3380,11 @@ cc_library( name = "udp_socket_test_cases", testonly = 1, srcs = [ - "udp_socket_test_cases.cc", - ] + select_for_linux([ "udp_socket_errqueue_test_case.cc", - ]), + "udp_socket_test_cases.cc", + ], hdrs = ["udp_socket_test_cases.h"], + defines = select_system(), deps = [ ":socket_test_util", ":unix_domain_socket_test_util", diff --git a/test/syscalls/linux/arch_prctl.cc b/test/syscalls/linux/arch_prctl.cc index 81bf5a775..3a901faf5 100644 --- a/test/syscalls/linux/arch_prctl.cc +++ b/test/syscalls/linux/arch_prctl.cc @@ -14,8 +14,10 @@ #include #include +#include #include "gtest/gtest.h" +#include "test/util/file_descriptor.h" #include "test/util/test_util.h" // glibc does not provide a prototype for arch_prctl() so declare it here. diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD index 5cfe4e56f..ed488dbc2 100644 --- a/test/syscalls/linux/rseq/BUILD +++ b/test/syscalls/linux/rseq/BUILD @@ -1,8 +1,7 @@ # This package contains a standalone rseq test binary. This binary must not # depend on libc, which might use rseq itself. -load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", "cc_flags_supplier") -load("@rules_cc//cc:defs.bzl", "cc_library") +load("//tools:defs.bzl", "cc_flags_supplier", "cc_library", "cc_toolchain") package(licenses = ["notice"]) @@ -37,8 +36,8 @@ genrule( "$(location start.S)", ]), toolchains = [ + cc_toolchain, ":no_pie_cc_flags", - "@bazel_tools//tools/cpp:current_cc_toolchain", ], visibility = ["//:sandbox"], ) diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc index 147978f46..9a24e1df0 100644 --- a/test/syscalls/linux/udp_socket_errqueue_test_case.cc +++ b/test/syscalls/linux/udp_socket_errqueue_test_case.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __fuchsia__ + #include "test/syscalls/linux/udp_socket_test_cases.h" #include @@ -52,3 +54,5 @@ TEST_P(UdpSocketTest, ErrorQueue) { } // namespace testing } // namespace gvisor + +#endif // __fuchsia__ diff --git a/test/uds/BUILD b/test/uds/BUILD index a3843e699..51e2c7ce8 100644 --- a/test/uds/BUILD +++ b/test/uds/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package( default_visibility = ["//:sandbox"], @@ -9,7 +9,6 @@ go_library( name = "uds", testonly = 1, srcs = ["uds.go"], - importpath = "gvisor.dev/gvisor/test/uds", deps = [ "//pkg/log", "//pkg/unet", diff --git a/test/util/BUILD b/test/util/BUILD index cbc728159..3c732be62 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,5 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") -load("//test/syscalls:build_defs.bzl", "select_for_linux") +load("//tools:defs.bzl", "cc_library", "cc_test", "select_system") package( default_visibility = ["//:sandbox"], @@ -142,12 +141,13 @@ cc_library( cc_library( name = "save_util", testonly = 1, - srcs = ["save_util.cc"] + - select_for_linux( - ["save_util_linux.cc"], - ["save_util_other.cc"], - ), + srcs = [ + "save_util.cc", + "save_util_linux.cc", + "save_util_other.cc", + ], hdrs = ["save_util.h"], + defines = select_system(), ) cc_library( @@ -234,13 +234,16 @@ cc_library( testonly = 1, srcs = [ "test_util.cc", - ] + select_for_linux( - [ - "test_util_impl.cc", - "test_util_runfiles.cc", + "test_util_impl.cc", + "test_util_runfiles.cc", + ], + hdrs = ["test_util.h"], + defines = select_system( + fuchsia = [ + "__opensource__", + "__fuchsia__", ], ), - hdrs = ["test_util.h"], deps = [ ":fs_util", ":logging", diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc index cd56118c0..d0aea8e6a 100644 --- a/test/util/save_util_linux.cc +++ b/test/util/save_util_linux.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifdef __linux__ + #include #include #include @@ -43,3 +45,5 @@ void MaybeSave() { } // namespace testing } // namespace gvisor + +#endif diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc index 1aca663b7..931af2c29 100644 --- a/test/util/save_util_other.cc +++ b/test/util/save_util_other.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __linux__ + namespace gvisor { namespace testing { @@ -21,3 +23,5 @@ void MaybeSave() { } // namespace testing } // namespace gvisor + +#endif diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc index 7210094eb..694d21692 100644 --- a/test/util/test_util_runfiles.cc +++ b/test/util/test_util_runfiles.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __fuchsia__ + #include #include @@ -44,3 +46,5 @@ std::string RunfilePath(std::string path) { } // namespace testing } // namespace gvisor + +#endif // __fuchsia__ diff --git a/tools/BUILD b/tools/BUILD new file mode 100644 index 000000000..e73a9c885 --- /dev/null +++ b/tools/BUILD @@ -0,0 +1,3 @@ +package(licenses = ["notice"]) + +exports_files(["nogo.js"]) diff --git a/tools/build/BUILD b/tools/build/BUILD new file mode 100644 index 000000000..0c0ce3f4d --- /dev/null +++ b/tools/build/BUILD @@ -0,0 +1,10 @@ +package(licenses = ["notice"]) + +# In bazel, no special support is required for loopback networking. This is +# just a dummy data target that does not change the test environment. +genrule( + name = "loopback", + outs = ["loopback.txt"], + cmd = "touch $@", + visibility = ["//visibility:public"], +) diff --git a/tools/build/defs.bzl b/tools/build/defs.bzl new file mode 100644 index 000000000..d0556abd1 --- /dev/null +++ b/tools/build/defs.bzl @@ -0,0 +1,91 @@ +"""Bazel implementations of standard rules.""" + +load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") +load("@io_bazel_rules_go//go:def.bzl", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test", _go_tool_library = "go_tool_library") +load("@io_bazel_rules_go//proto:def.bzl", _go_proto_library = "go_proto_library") +load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") +load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") +load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") +load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") +load("@pydeps//:requirements.bzl", _py_requirement = "requirement") + +container_image = _container_image +cc_binary = _cc_binary +cc_library = _cc_library +cc_flags_supplier = _cc_flags_supplier +cc_proto_library = _cc_proto_library +cc_test = _cc_test +cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" +go_image = _go_image +go_embed_data = _go_embed_data +loopback = "//tools/build:loopback" +proto_library = native.proto_library +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar +py_library = native.py_library +py_binary = native.py_binary +py_test = native.py_test + +def go_binary(name, static = False, pure = False, **kwargs): + if static: + kwargs["static"] = "on" + if pure: + kwargs["pure"] = "on" + _go_binary( + name = name, + **kwargs + ) + +def go_library(name, **kwargs): + _go_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_tool_library(name, **kwargs): + _go_tool_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_proto_library(name, proto, **kwargs): + deps = kwargs.pop("deps", []) + _go_proto_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name, + proto = proto, + deps = [dep.replace("_proto", "_go_proto") for dep in deps], + **kwargs + ) + +def go_test(name, **kwargs): + library = kwargs.pop("library", None) + if library: + kwargs["embed"] = [library] + _go_test( + name = name, + **kwargs + ) + +def py_requirement(name, direct = False): + return _py_requirement(name) + +def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): + values = { + "@bazel_tools//src/conditions:linux_x86_64": amd64, + "@bazel_tools//src/conditions:linux_aarch64": arm64, + } + if default: + values["//conditions:default"] = default + return select(values, **kwargs) + +def select_system(linux = ["__linux__"], **kwargs): + return linux # Only Linux supported. + +def default_installer(): + return None + +def default_net_util(): + return [] # Nothing needed. diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD index d85c56131..92ba8ab06 100644 --- a/tools/checkunsafe/BUILD +++ b/tools/checkunsafe/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_tool_library") +load("//tools:defs.bzl", "go_tool_library") package(licenses = ["notice"]) go_tool_library( name = "checkunsafe", srcs = ["check_unsafe.go"], - importpath = "checkunsafe", visibility = ["//visibility:public"], deps = [ "@org_golang_x_tools//go/analysis:go_tool_library", diff --git a/tools/defs.bzl b/tools/defs.bzl new file mode 100644 index 000000000..819f12b0d --- /dev/null +++ b/tools/defs.bzl @@ -0,0 +1,154 @@ +"""Wrappers for common build rules. + +These wrappers apply common BUILD configurations (e.g., proto_library +automagically creating cc_ and go_ proto targets) and act as a single point of +change for Google-internal and bazel-compatible rules. +""" + +load("//tools/go_stateify:defs.bzl", "go_stateify") +load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") +load("//tools/build:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") + +# Delegate directly. +cc_binary = _cc_binary +cc_library = _cc_library +cc_test = _cc_test +cc_toolchain = _cc_toolchain +cc_flags_supplier = _cc_flags_supplier +container_image = _container_image +go_embed_data = _go_embed_data +go_image = _go_image +go_test = _go_test +go_tool_library = _go_tool_library +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar +py_library = _py_library +py_binary = _py_binary +py_test = _py_test +py_requirement = _py_requirement +select_arch = _select_arch +select_system = _select_system +loopback = _loopback +default_installer = _default_installer +default_net_util = _default_net_util + +def go_binary(name, **kwargs): + """Wraps the standard go_binary. + + Args: + name: the rule name. + **kwargs: standard go_binary arguments. + """ + _go_binary( + name = name, + **kwargs + ) + +def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): + """Wraps the standard go_library and does stateification and marshalling. + + The recommended way is to use this rule with mostly identical configuration as the native + go_library rule. + + These definitions provide additional flags (stateify, marshal) that can be used + with the generators to automatically supplement the library code. + + load("//tools:defs.bzl", "go_library") + + go_library( + name = "foo", + srcs = ["foo.go"], + ) + + Args: + name: the rule name. + srcs: the library sources. + deps: the library dependencies. + imports: imports required for stateify. + stateify: whether statify is enabled (default: true). + marshal: whether marshal is enabled (default: false). + **kwargs: standard go_library arguments. + """ + if stateify: + # Only do stateification for non-state packages without manual autogen. + go_stateify( + name = name + "_state_autogen", + srcs = [src for src in srcs if src.endswith(".go")], + imports = imports, + package = name, + arch = select_arch(), + out = name + "_state_autogen.go", + ) + all_srcs = srcs + [name + "_state_autogen.go"] + if "//pkg/state" not in deps: + all_deps = deps + ["//pkg/state"] + else: + all_deps = deps + else: + all_deps = deps + all_srcs = srcs + if marshal: + go_marshal( + name = name + "_abi_autogen", + srcs = [src for src in srcs if src.endswith(".go")], + debug = False, + imports = imports, + package = name, + ) + extra_deps = [ + dep + for dep in marshal_deps + if not dep in all_deps + ] + all_deps = all_deps + extra_deps + all_srcs = srcs + [name + "_abi_autogen_unsafe.go"] + + _go_library( + name = name, + srcs = all_srcs, + deps = all_deps, + **kwargs + ) + + if marshal: + # Ignore importpath for go_test. + kwargs.pop("importpath", None) + + _go_test( + name = name + "_abi_autogen_test", + srcs = [name + "_abi_autogen_test.go"], + library = ":" + name, + deps = marshal_test_deps, + **kwargs + ) + +def proto_library(name, srcs, **kwargs): + """Wraps the standard proto_library. + + Given a proto_library named "foo", this produces three different targets: + - foo_proto: proto_library rule. + - foo_go_proto: go_proto_library rule. + - foo_cc_proto: cc_proto_library rule. + + Args: + srcs: the proto sources. + **kwargs: standard proto_library arguments. + """ + deps = kwargs.pop("deps", []) + _proto_library( + name = name + "_proto", + srcs = srcs, + deps = deps, + **kwargs + ) + _go_proto_library( + name = name + "_go_proto", + proto = ":" + name + "_proto", + deps = deps, + **kwargs + ) + _cc_proto_library( + name = name + "_cc_proto", + deps = [":" + name + "_proto"], + **kwargs + ) diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD index 39318b877..069df3856 100644 --- a/tools/go_generics/BUILD +++ b/tools/go_generics/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/go_generics/globals/BUILD b/tools/go_generics/globals/BUILD index 74853c7d2..38caa3ce7 100644 --- a/tools/go_generics/globals/BUILD +++ b/tools/go_generics/globals/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,6 +8,6 @@ go_library( "globals_visitor.go", "scope.go", ], - importpath = "gvisor.dev/gvisor/tools/go_generics/globals", + stateify = False, visibility = ["//tools/go_generics:__pkg__"], ) diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD index 02b09120e..b7d35e272 100644 --- a/tools/go_generics/go_merge/BUILD +++ b/tools/go_generics/go_merge/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD index 9d26a88b7..8a329dfc6 100644 --- a/tools/go_generics/rules_tests/BUILD +++ b/tools/go_generics/rules_tests/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index c862b277c..80d9c0504 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -1,6 +1,6 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") -package(licenses = ["notice"]) +licenses(["notice"]) go_binary( name = "go_marshal", diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index 481575bd3..4886efddf 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -20,19 +20,7 @@ comment `// +marshal`. # Usage -See `defs.bzl`: two new rules are provided, `go_marshal` and `go_library`. - -The recommended way to generate a go library with marshalling is to use the -`go_library` with mostly identical configuration as the native go_library rule. - -``` -load("/gvisor/tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) -``` +See `defs.bzl`: a new rule is provided, `go_marshal`. Under the hood, the `go_marshal` rule is used to generate a file that will appear in a Go target; the output file should appear explicitly in a srcs list. @@ -54,11 +42,7 @@ go_library( "foo.go", "foo_abi.go", ], - deps = [ - "/gvisor/pkg/abi", - "/gvisor/pkg/sentry/safemem/safemem", - "/gvisor/pkg/sentry/usermem/usermem", - ], + ... ) ``` @@ -69,22 +53,6 @@ These tests use reflection to verify properties of the ABI struct, and should be considered part of the generated interfaces (but are too expensive to execute at runtime). Ensure these tests run at some point. -``` -$ cat BUILD -load("/gvisor/tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) -$ blaze build :foo -$ blaze query ... -:foo_abi_autogen -:foo_abi_autogen_test -$ blaze test :foo_abi_autogen_test - -``` - # Restrictions Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is @@ -131,22 +99,6 @@ for embedded structs that are not aligned. Because of this, it's generally best to avoid using `marshal:"unaligned"` and insert explicit padding fields instead. -## Debugging go_marshal - -To enable debugging output from the go marshal tool, pass the `-debug` flag to -the tool. When using the build rules from above, add a `debug = True` field to -the build rule like this: - -``` -load("/gvisor/tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], - debug = True, -) -``` - ## Modifying the `go_marshal` Tool The following are some guidelines for modifying the `go_marshal` tool: diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD index c859ced77..c2a4d45c4 100644 --- a/tools/go_marshal/analysis/BUILD +++ b/tools/go_marshal/analysis/BUILD @@ -1,12 +1,11 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "analysis", testonly = 1, srcs = ["analysis_unsafe.go"], - importpath = "gvisor.dev/gvisor/tools/go_marshal/analysis", visibility = [ "//:sandbox", ], diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index c32eb559f..2918ceffe 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -1,57 +1,14 @@ -"""Marshal is a tool for generating marshalling interfaces for Go types. - -The recommended way is to use the go_library rule defined below with mostly -identical configuration as the native go_library rule. - -load("//tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) - -Under the hood, the go_marshal rule is used to generate a file that will -appear in a Go target; the output file should appear explicitly in a srcs list. -For example (the above is still the preferred way): - -load("//tools/go_marshal:defs.bzl", "go_marshal") - -go_marshal( - name = "foo_abi", - srcs = ["foo.go"], - out = "foo_abi.go", - package = "foo", -) - -go_library( - name = "foo", - srcs = [ - "foo.go", - "foo_abi.go", - ], - deps = [ - "//tools/go_marshal:marshal", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/usermem", - ], -) -""" - -load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test") +"""Marshal is a tool for generating marshalling interfaces for Go types.""" def _go_marshal_impl(ctx): """Execute the go_marshal tool.""" output = ctx.outputs.lib output_test = ctx.outputs.test - (build_dir, _, _) = ctx.build_file_path.rpartition("/BUILD") - - decl = "/".join(["gvisor.dev/gvisor", build_dir]) # Run the marshal command. args = ["-output=%s" % output.path] args += ["-pkg=%s" % ctx.attr.package] args += ["-output_test=%s" % output_test.path] - args += ["-declarationPkg=%s" % decl] if ctx.attr.debug: args += ["-debug"] @@ -83,7 +40,6 @@ go_marshal = rule( implementation = _go_marshal_impl, attrs = { "srcs": attr.label_list(mandatory = True, allow_files = True), - "libname": attr.string(mandatory = True), "imports": attr.string_list(mandatory = False), "package": attr.string(mandatory = True), "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"), @@ -95,58 +51,14 @@ go_marshal = rule( }, ) -def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs): - """wraps the standard go_library and does mashalling interface generation. - - Args: - name: Same as native go_library. - srcs: Same as native go_library. - deps: Same as native go_library. - imports: Extra import paths to pass to the go_marshal tool. - debug: Enables debugging output from the go_marshal tool. - **kwargs: Remaining args to pass to the native go_library rule unmodified. - """ - go_marshal( - name = name + "_abi_autogen", - libname = name, - srcs = [src for src in srcs if src.endswith(".go")], - debug = debug, - imports = imports, - package = name, - ) - - extra_deps = [ - "//tools/go_marshal/marshal", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/usermem", - ] - - all_srcs = srcs + [name + "_abi_autogen_unsafe.go"] - all_deps = deps + [] # + extra_deps - - for extra in extra_deps: - if extra not in deps: - all_deps.append(extra) - - _go_library( - name = name, - srcs = all_srcs, - deps = all_deps, - **kwargs - ) - - # Don't pass importpath arg to go_test. - kwargs.pop("importpath", "") - - _go_test( - name = name + "_abi_autogen_test", - srcs = [name + "_abi_autogen_test.go"], - # Generated test has a fixed set of dependencies since we generate these - # tests. They should only depend on the library generated above, and the - # Marshallable interface. - deps = [ - ":" + name, - "//tools/go_marshal/analysis", - ], - **kwargs - ) +# marshal_deps are the dependencies requied by generated code. +marshal_deps = [ + "//tools/go_marshal/marshal", + "//pkg/sentry/platform/safecopy", + "//pkg/sentry/usermem", +] + +# marshal_test_deps are required by test targets. +marshal_test_deps = [ + "//tools/go_marshal/analysis", +] diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index a0eae6492..c92b59dd6 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -1,6 +1,6 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "gomarshal", @@ -10,7 +10,7 @@ go_library( "generator_tests.go", "util.go", ], - importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal", + stateify = False, visibility = [ "//:sandbox", ], diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 641ccd938..8392f3f6d 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -62,15 +62,12 @@ type Generator struct { outputTest *os.File // Package name for the generated file. pkg string - // Go import path for package we're processing. This package should directly - // declare the type we're generating code for. - declaration string // Set of extra packages to import in the generated file. imports *importTable } // NewGenerator creates a new code Generator. -func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) { +func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) { f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) @@ -80,12 +77,11 @@ func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) } g := Generator{ - inputs: srcs, - output: f, - outputTest: fTest, - pkg: pkg, - declaration: declaration, - imports: newImportTable(), + inputs: srcs, + output: f, + outputTest: fTest, + pkg: pkg, + imports: newImportTable(), } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as @@ -264,7 +260,7 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface // generateOneTestSuite generates a test suite for the automatically generated // implementations type t. func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator { - i := newTestGenerator(t, g.declaration) + i := newTestGenerator(t) i.emitTests() return i } @@ -359,7 +355,7 @@ func (g *Generator) Run() error { // source file. func (g *Generator) writeTests(ts []*testGenerator) error { var b sourceBuffer - b.emit("package %s_test\n\n", g.pkg) + b.emit("package %s\n\n", g.pkg) if err := b.write(g.outputTest); err != nil { return err } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index df25cb5b2..bcda17c3b 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -46,7 +46,7 @@ type testGenerator struct { decl *importStmt } -func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator { +func newTestGenerator(t *ast.TypeSpec) *testGenerator { if _, ok := t.Type.(*ast.StructType); !ok { panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) } @@ -59,14 +59,12 @@ func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator { for _, i := range standardImports { g.imports.add(i).markUsed() } - g.decl = g.imports.add(declaration) - g.decl.markUsed() return g } func (g *testGenerator) typeName() string { - return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name) + return g.t.Name.Name } func (g *testGenerator) forEachField(fn func(f *ast.Field)) { diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go index 3d12eb93c..e1a97b311 100644 --- a/tools/go_marshal/main.go +++ b/tools/go_marshal/main.go @@ -31,11 +31,10 @@ import ( ) var ( - pkg = flag.String("pkg", "", "output package") - output = flag.String("output", "", "output file") - outputTest = flag.String("output_test", "", "output file for tests") - imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") - declarationPkg = flag.String("declarationPkg", "", "import path of target declaring the types we're generating on") + pkg = flag.String("pkg", "", "output package") + output = flag.String("output", "", "output file") + outputTest = flag.String("output_test", "", "output file for tests") + imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") ) func main() { @@ -62,7 +61,7 @@ func main() { // as an import. extraImports = strings.Split(*imports, ",") } - g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, *declarationPkg, extraImports) + g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports) if err != nil { panic(err) } diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD index 47dda97a1..ad508c72f 100644 --- a/tools/go_marshal/marshal/BUILD +++ b/tools/go_marshal/marshal/BUILD @@ -1,13 +1,12 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "marshal", srcs = [ "marshal.go", ], - importpath = "gvisor.dev/gvisor/tools/go_marshal/marshal", visibility = [ "//:sandbox", ], diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index d412e1ccf..38ba49fed 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -1,7 +1,6 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_marshal:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) package_group( name = "gomarshal_test", @@ -25,6 +24,6 @@ go_library( name = "test", testonly = 1, srcs = ["test.go"], - importpath = "gvisor.dev/gvisor/tools/go_marshal/test", + marshal = True, deps = ["//tools/go_marshal/test/external"], ) diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD index 9bb89e1da..0cf6da603 100644 --- a/tools/go_marshal/test/external/BUILD +++ b/tools/go_marshal/test/external/BUILD @@ -1,11 +1,11 @@ -load("//tools/go_marshal:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "external", testonly = 1, srcs = ["external.go"], - importpath = "gvisor.dev/gvisor/tools/go_marshal/test/external", + marshal = True, visibility = ["//tools/go_marshal/test:gomarshal_test"], ) diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD index bb53f8ae9..a133d6f8b 100644 --- a/tools/go_stateify/BUILD +++ b/tools/go_stateify/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index 33267c074..0f261d89f 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -1,41 +1,4 @@ -"""Stateify is a tool for generating state wrappers for Go types. - -The recommended way is to use the go_library rule defined below with mostly -identical configuration as the native go_library rule. - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) - -Under the hood, the go_stateify rule is used to generate a file that will -appear in a Go target; the output file should appear explicitly in a srcs list. -For example (the above is still the preferred way): - -load("//tools/go_stateify:defs.bzl", "go_stateify") - -go_stateify( - name = "foo_state", - srcs = ["foo.go"], - out = "foo_state.go", - package = "foo", -) - -go_library( - name = "foo", - srcs = [ - "foo.go", - "foo_state.go", - ], - deps = [ - "//pkg/state", - ], -) -""" - -load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library") +"""Stateify is a tool for generating state wrappers for Go types.""" def _go_stateify_impl(ctx): """Implementation for the stateify tool.""" @@ -103,43 +66,3 @@ files and must be added to the srcs of the relevant go_library. "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"), }, ) - -def go_library(name, srcs, deps = [], imports = [], **kwargs): - """Standard go_library wrapped which generates state source files. - - Args: - name: the name of the go_library rule. - srcs: sources of the go_library. Each will be processed for stateify - annotations. - deps: dependencies for the go_library. - imports: an optional list of extra non-aliased, Go-style absolute import - paths required for stateified types. - **kwargs: passed to go_library. - """ - if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs: - # Only do stateification for non-state packages without manual autogen. - go_stateify( - name = name + "_state_autogen", - srcs = [src for src in srcs if src.endswith(".go")], - imports = imports, - package = name, - arch = select({ - "@bazel_tools//src/conditions:linux_aarch64": "arm64", - "//conditions:default": "amd64", - }), - out = name + "_state_autogen.go", - ) - all_srcs = srcs + [name + "_state_autogen.go"] - if "//pkg/state" not in deps: - all_deps = deps + ["//pkg/state"] - else: - all_deps = deps - else: - all_deps = deps - all_srcs = srcs - _go_library( - name = name, - srcs = all_srcs, - deps = all_deps, - **kwargs - ) diff --git a/tools/images/BUILD b/tools/images/BUILD index 2b77c2737..f1699b184 100644 --- a/tools/images/BUILD +++ b/tools/images/BUILD @@ -1,4 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_binary") +load("//tools:defs.bzl", "cc_binary") load("//tools/images:defs.bzl", "vm_image", "vm_test") package( diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl index d8e422a5d..32235813a 100644 --- a/tools/images/defs.bzl +++ b/tools/images/defs.bzl @@ -28,6 +28,8 @@ The vm_test rule can be used to execute a command remotely. For example, ) """ +load("//tools:defs.bzl", "default_installer") + def _vm_image_impl(ctx): script_paths = [] for script in ctx.files.scripts: @@ -165,8 +167,8 @@ def vm_test( targets = kwargs.pop("targets", []) if installer: targets = [installer] + targets - targets = [ - ] + targets + if default_installer(): + targets = [default_installer()] + targets _vm_test( tags = [ "local", diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD index ee7ea11fd..4ef1a3124 100644 --- a/tools/issue_reviver/BUILD +++ b/tools/issue_reviver/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD index 6da22ba1c..da4133472 100644 --- a/tools/issue_reviver/github/BUILD +++ b/tools/issue_reviver/github/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "github", srcs = ["github.go"], - importpath = "gvisor.dev/gvisor/tools/issue_reviver/github", visibility = [ "//tools/issue_reviver:__subpackages__", ], diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD index 2c3675977..d262932bd 100644 --- a/tools/issue_reviver/reviver/BUILD +++ b/tools/issue_reviver/reviver/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "reviver", srcs = ["reviver.go"], - importpath = "gvisor.dev/gvisor/tools/issue_reviver/reviver", visibility = [ "//tools/issue_reviver:__subpackages__", ], @@ -15,5 +14,5 @@ go_test( name = "reviver_test", size = "small", srcs = ["reviver_test.go"], - embed = [":reviver"], + library = ":reviver", ) diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh index fb09ff331..a22c8c9f2 100755 --- a/tools/workspace_status.sh +++ b/tools/workspace_status.sh @@ -15,4 +15,4 @@ # limitations under the License. # The STABLE_ prefix will trigger a re-link if it changes. -echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty) +echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty || echo 0.0.0) diff --git a/vdso/BUILD b/vdso/BUILD index 2b6744c26..d37d4266d 100644 --- a/vdso/BUILD +++ b/vdso/BUILD @@ -3,20 +3,10 @@ # normal system VDSO (time, gettimeofday, clock_gettimeofday) but which uses # timekeeping parameters managed by the sandbox kernel. -load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", "cc_flags_supplier") +load("//tools:defs.bzl", "cc_flags_supplier", "cc_toolchain", "select_arch") package(licenses = ["notice"]) -config_setting( - name = "x86_64", - constraint_values = ["@bazel_tools//platforms:x86_64"], -) - -config_setting( - name = "aarch64", - constraint_values = ["@bazel_tools//platforms:aarch64"], -) - genrule( name = "vdso", srcs = [ @@ -39,14 +29,15 @@ genrule( "-O2 " + "-std=c++11 " + "-fPIC " + + "-fno-sanitize=all " + # Some toolchains enable stack protector by default. Disable it, the # VDSO has no hooks to handle failures. "-fno-stack-protector " + "-fuse-ld=gold " + - select({ - ":x86_64": "-m64 ", - "//conditions:default": "", - }) + + select_arch( + amd64 = "-m64 ", + arm64 = "", + ) + "-shared " + "-nostdlib " + "-Wl,-soname=linux-vdso.so.1 " + @@ -55,12 +46,10 @@ genrule( "-Wl,-Bsymbolic " + "-Wl,-z,max-page-size=4096 " + "-Wl,-z,common-page-size=4096 " + - select( - { - ":x86_64": "-Wl,-T$(location vdso_amd64.lds) ", - ":aarch64": "-Wl,-T$(location vdso_arm64.lds) ", - }, - no_match_error = "Unsupported architecture", + select_arch( + amd64 = "-Wl,-T$(location vdso_amd64.lds) ", + arm64 = "-Wl,-T$(location vdso_arm64.lds) ", + no_match_error = "unsupported architecture", ) + "-o $(location vdso.so) " + "$(location vdso.cc) " + @@ -73,7 +62,7 @@ genrule( ], features = ["-pie"], toolchains = [ - "@bazel_tools//tools/cpp:current_cc_toolchain", + cc_toolchain, ":no_pie_cc_flags", ], visibility = ["//:sandbox"], -- cgit v1.2.3 From b8f56c79be40d9c75f4e2f279c9d821d1c1c3569 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Fri, 21 Feb 2020 15:41:56 -0800 Subject: Implement tap/tun device in vfs. PiperOrigin-RevId: 296526279 --- pkg/abi/linux/BUILD | 1 + pkg/abi/linux/ioctl.go | 26 ++ pkg/abi/linux/ioctl_tun.go | 29 ++ pkg/sentry/fs/dev/BUILD | 5 + pkg/sentry/fs/dev/dev.go | 10 +- pkg/sentry/fs/dev/net_tun.go | 170 +++++++++++ pkg/syserror/syserror.go | 1 + pkg/tcpip/buffer/view.go | 6 + pkg/tcpip/link/channel/BUILD | 1 + pkg/tcpip/link/channel/channel.go | 180 +++++++++--- pkg/tcpip/link/tun/BUILD | 18 +- pkg/tcpip/link/tun/device.go | 352 +++++++++++++++++++++++ pkg/tcpip/link/tun/protocol.go | 56 ++++ pkg/tcpip/stack/nic.go | 32 +++ pkg/tcpip/stack/stack.go | 39 +++ test/syscalls/BUILD | 2 + test/syscalls/linux/BUILD | 30 ++ test/syscalls/linux/dev.cc | 7 + test/syscalls/linux/socket_netlink_route_util.cc | 163 +++++++++++ test/syscalls/linux/socket_netlink_route_util.h | 55 ++++ test/syscalls/linux/tuntap.cc | 346 ++++++++++++++++++++++ 21 files changed, 1490 insertions(+), 39 deletions(-) create mode 100644 pkg/abi/linux/ioctl_tun.go create mode 100644 pkg/sentry/fs/dev/net_tun.go create mode 100644 pkg/tcpip/link/tun/device.go create mode 100644 pkg/tcpip/link/tun/protocol.go create mode 100644 test/syscalls/linux/socket_netlink_route_util.cc create mode 100644 test/syscalls/linux/socket_netlink_route_util.h create mode 100644 test/syscalls/linux/tuntap.cc (limited to 'pkg/tcpip/buffer') diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a89f34d4b..322d1ccc4 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -30,6 +30,7 @@ go_library( "futex.go", "inotify.go", "ioctl.go", + "ioctl_tun.go", "ip.go", "ipc.go", "limits.go", diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 0e18db9ef..2062e6a4b 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -72,3 +72,29 @@ const ( SIOCGMIIPHY = 0x8947 SIOCGMIIREG = 0x8948 ) + +// ioctl(2) directions. Used to calculate requests number. +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NONE = 0 + _IOC_WRITE = 1 + _IOC_READ = 2 +) + +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NRBITS = 8 + _IOC_TYPEBITS = 8 + _IOC_SIZEBITS = 14 + _IOC_DIRBITS = 2 + + _IOC_NRSHIFT = 0 + _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS + _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS + _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS +) + +// IOC outputs the result of _IOC macro in asm-generic/ioctl.h. +func IOC(dir, typ, nr, size uint32) uint32 { + return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT +} diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go new file mode 100644 index 000000000..c59c9c136 --- /dev/null +++ b/pkg/abi/linux/ioctl_tun.go @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// ioctl(2) request numbers from linux/if_tun.h +var ( + TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4) + TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4) +) + +// Flags from net/if_tun.h +const ( + IFF_TUN = 0x0001 + IFF_TAP = 0x0002 + IFF_NO_PI = 0x1000 + IFF_NOFILTER = 0x1000 +) diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 4c4b7d5cc..9b6bb26d0 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -9,6 +9,7 @@ go_library( "device.go", "fs.go", "full.go", + "net_tun.go", "null.go", "random.go", "tty.go", @@ -19,15 +20,19 @@ go_library( "//pkg/context", "//pkg/rand", "//pkg/safemem", + "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/pgalloc", + "//pkg/sentry/socket/netstack", "//pkg/syserror", + "//pkg/tcpip/link/tun", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index 35bd23991..7e66c29b0 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo }) } -func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode { - iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555)) +func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode { + iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), @@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { // A devpts is typically mounted at /dev/pts to provide // pseudoterminal support. Place an empty directory there for // the devpts to be mounted over. - "pts": newDirectory(ctx, msrc), + "pts": newDirectory(ctx, nil, msrc), // Similarly, applications expect a ptmx device at /dev/ptmx // connected to the terminals provided by /dev/pts/. Rather // than creating a device directly (which requires a hairy @@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "ptmx": newSymlink(ctx, "pts/ptmx", msrc), "tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor), + + "net": newDirectory(ctx, map[string]*fs.Inode{ + "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor), + }, msrc), } iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go new file mode 100644 index 000000000..755644488 --- /dev/null +++ b/pkg/sentry/fs/dev/net_tun.go @@ -0,0 +1,170 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + netTunDevMajor = 10 + netTunDevMinor = 200 +) + +// +stateify savable +type netTunInodeOperations struct { + fsutil.InodeGenericChecker `state:"nosave"` + fsutil.InodeNoExtendedAttributes `state:"nosave"` + fsutil.InodeNoopAllocate `state:"nosave"` + fsutil.InodeNoopRelease `state:"nosave"` + fsutil.InodeNoopTruncate `state:"nosave"` + fsutil.InodeNoopWriteOut `state:"nosave"` + fsutil.InodeNotDirectory `state:"nosave"` + fsutil.InodeNotMappable `state:"nosave"` + fsutil.InodeNotSocket `state:"nosave"` + fsutil.InodeNotSymlink `state:"nosave"` + fsutil.InodeVirtual `state:"nosave"` + + fsutil.InodeSimpleAttributes +} + +var _ fs.InodeOperations = (*netTunInodeOperations)(nil) + +func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations { + return &netTunInodeOperations{ + InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC), + } +} + +// GetFile implements fs.InodeOperations.GetFile. +func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil +} + +// +stateify savable +type netTunFileOperations struct { + fsutil.FileNoSeek `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + device tun.Device +} + +var _ fs.FileOperations = (*netTunFileOperations)(nil) + +// Release implements fs.FileOperations.Release. +func (fops *netTunFileOperations) Release() { + fops.device.Release() +} + +// Ioctl implements fs.FileOperations.Ioctl. +func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + request := args[1].Uint() + data := args[2].Pointer() + + switch request { + case linux.TUNSETIFF: + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + if !t.HasCapability(linux.CAP_NET_ADMIN) { + return 0, syserror.EPERM + } + stack, ok := t.NetworkContext().(*netstack.Stack) + if !ok { + return 0, syserror.EINVAL + } + + var req linux.IFReq + if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + flags := usermem.ByteOrder.Uint16(req.Data[:]) + return 0, fops.device.SetIff(stack.Stack, req.Name(), flags) + + case linux.TUNGETIFF: + var req linux.IFReq + + copy(req.IFName[:], fops.device.Name()) + + // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when + // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c. + flags := fops.device.Flags() | linux.IFF_NOFILTER + usermem.ByteOrder.PutUint16(req.Data[:], flags) + + _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + + default: + return 0, syserror.ENOTTY + } +} + +// Write implements fs.FileOperations.Write. +func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + data := make([]byte, src.NumBytes()) + if _, err := src.CopyIn(ctx, data); err != nil { + return 0, err + } + return fops.device.Write(data) +} + +// Read implements fs.FileOperations.Read. +func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + data, err := fops.device.Read() + if err != nil { + return 0, err + } + n, err := dst.CopyOut(ctx, data) + if n > 0 && n < len(data) { + // Not an error for partial copying. Packet truncated. + err = nil + } + return int64(n), err +} + +// Readiness implements watier.Waitable.Readiness. +func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return fops.device.Readiness(mask) +} + +// EventRegister implements watier.Waitable.EventRegister. +func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fops.device.EventRegister(e, mask) +} + +// EventUnregister implements watier.Waitable.EventUnregister. +func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) { + fops.device.EventUnregister(e) +} diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 2269f6237..4b5a0fca6 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -29,6 +29,7 @@ var ( EACCES = error(syscall.EACCES) EAGAIN = error(syscall.EAGAIN) EBADF = error(syscall.EBADF) + EBADFD = error(syscall.EBADFD) EBUSY = error(syscall.EBUSY) ECHILD = error(syscall.ECHILD) ECONNREFUSED = error(syscall.ECONNREFUSED) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 150310c11..17e94c562 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) { vv.views = append(vv.views, vv2.views...) vv.size += vv2.size } + +// AppendView appends the given view into this vectorised view. +func (vv *VectorisedView) AppendView(v View) { + vv.views = append(vv.views, v) + vv.size += len(v) +} diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 3974c464e..b8b93e78e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -7,6 +7,7 @@ go_library( srcs = ["channel.go"], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 78d447acd..5944ba190 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -20,6 +20,7 @@ package channel import ( "context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -33,6 +34,118 @@ type PacketInfo struct { Route stack.Route } +// Notification is the interface for receiving notification from the packet +// queue. +type Notification interface { + // WriteNotify will be called when a write happens to the queue. + WriteNotify() +} + +// NotificationHandle is an opaque handle to the registered notification target. +// It can be used to unregister the notification when no longer interested. +// +// +stateify savable +type NotificationHandle struct { + n Notification +} + +type queue struct { + // mu protects fields below. + mu sync.RWMutex + // c is the outbound packet channel. Sending to c should hold mu. + c chan PacketInfo + numWrite int + numRead int + notify []*NotificationHandle +} + +func (q *queue) Close() { + close(q.c) +} + +func (q *queue) Read() (PacketInfo, bool) { + q.mu.Lock() + defer q.mu.Unlock() + select { + case p := <-q.c: + q.numRead++ + return p, true + default: + return PacketInfo{}, false + } +} + +func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { + // We have to receive from channel without holding the lock, since it can + // block indefinitely. This will cause a window that numWrite - numRead + // produces a larger number, but won't go to negative. numWrite >= numRead + // still holds. + select { + case pkt := <-q.c: + q.mu.Lock() + defer q.mu.Unlock() + q.numRead++ + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + +func (q *queue) Write(p PacketInfo) bool { + wrote := false + + // It's important to make sure nobody can see numWrite until we increment it, + // so numWrite >= numRead holds. + q.mu.Lock() + select { + case q.c <- p: + wrote = true + q.numWrite++ + default: + } + notify := q.notify + q.mu.Unlock() + + if wrote { + // Send notification outside of lock. + for _, h := range notify { + h.n.WriteNotify() + } + } + return wrote +} + +func (q *queue) Num() int { + q.mu.RLock() + defer q.mu.RUnlock() + n := q.numWrite - q.numRead + if n < 0 { + panic("numWrite < numRead") + } + return n +} + +func (q *queue) AddNotify(notify Notification) *NotificationHandle { + q.mu.Lock() + defer q.mu.Unlock() + h := &NotificationHandle{n: notify} + q.notify = append(q.notify, h) + return h +} + +func (q *queue) RemoveNotify(handle *NotificationHandle) { + q.mu.Lock() + defer q.mu.Unlock() + // Make a copy, since we reads the array outside of lock when notifying. + notify := make([]*NotificationHandle, 0, len(q.notify)) + for _, h := range q.notify { + if h != handle { + notify = append(notify, h) + } + } + q.notify = notify +} + // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { @@ -41,14 +154,16 @@ type Endpoint struct { linkAddr tcpip.LinkAddress LinkEPCapabilities stack.LinkEndpointCapabilities - // c is where outbound packets are queued. - c chan PacketInfo + // Outbound packet queue. + q *queue } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - c: make(chan PacketInfo, size), + q: &queue{ + c: make(chan PacketInfo, size), + }, mtu: mtu, linkAddr: linkAddr, } @@ -57,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { // Close closes e. Further packet injections will panic. Reads continue to // succeed until all packets are read. func (e *Endpoint) Close() { - close(e.c) + e.q.Close() } -// Read does non-blocking read for one packet from the outbound packet queue. +// Read does non-blocking read one packet from the outbound packet queue. func (e *Endpoint) Read() (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - default: - return PacketInfo{}, false - } + return e.q.Read() } // ReadContext does blocking read for one packet from the outbound packet queue. // It can be cancelled by ctx, and in this case, it returns false. func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - case <-ctx.Done(): - return PacketInfo{}, false - } + return e.q.ReadContext(ctx) } // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { - select { - case <-e.c: - c++ - default: + if _, ok := e.Read(); !ok { return c } + c++ } } +// NumQueued returns the number of packet queued for outbound. +func (e *Endpoint) NumQueued() int { + return e.q.Num() +} + // InjectInbound injects an inbound packet. func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) @@ -155,10 +263,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne Route: route, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } @@ -171,7 +276,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac route.Release() payloadView := pkts[0].Data.ToView() n := 0 -packetLoop: for _, pkt := range pkts { off := pkt.DataOffset size := pkt.DataSize @@ -185,12 +289,10 @@ packetLoop: Route: route, } - select { - case e.c <- p: - n++ - default: - break packetLoop + if !e.q.Write(p) { + break } + n++ } return n, nil @@ -204,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { GSO: nil, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} + +// AddNotify adds a notification target for receiving event about outgoing +// packets. +func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { + return e.q.AddNotify(notify) +} + +// RemoveNotify removes handle from the list of notification targets. +func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { + e.q.RemoveNotify(handle) +} diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index e5096ea38..e0db6cf54 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -4,6 +4,22 @@ package(licenses = ["notice"]) go_library( name = "tun", - srcs = ["tun_unsafe.go"], + srcs = [ + "device.go", + "protocol.go", + "tun_unsafe.go", + ], visibility = ["//visibility:public"], + deps = [ + "//pkg/abi/linux", + "//pkg/refs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], ) diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go new file mode 100644 index 000000000..6ff47a742 --- /dev/null +++ b/pkg/tcpip/link/tun/device.go @@ -0,0 +1,352 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // drivers/net/tun.c:tun_net_init() + defaultDevMtu = 1500 + + // Queue length for outbound packet, arriving at fd side for read. Overflow + // causes packet drops. gVisor implementation-specific. + defaultDevOutQueueLen = 1024 +) + +var zeroMAC [6]byte + +// Device is an opened /dev/net/tun device. +// +// +stateify savable +type Device struct { + waiter.Queue + + mu sync.RWMutex `state:"nosave"` + endpoint *tunEndpoint + notifyHandle *channel.NotificationHandle + flags uint16 +} + +// beforeSave is invoked by stateify. +func (d *Device) beforeSave() { + d.mu.Lock() + defer d.mu.Unlock() + // TODO(b/110961832): Restore the device to stack. At this moment, the stack + // is not savable. + if d.endpoint != nil { + panic("/dev/net/tun does not support save/restore when a device is associated with it.") + } +} + +// Release implements fs.FileOperations.Release. +func (d *Device) Release() { + d.mu.Lock() + defer d.mu.Unlock() + + // Decrease refcount if there is an endpoint associated with this file. + if d.endpoint != nil { + d.endpoint.RemoveNotify(d.notifyHandle) + d.endpoint.DecRef() + d.endpoint = nil + } +} + +// SetIff services TUNSETIFF ioctl(2) request. +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.endpoint != nil { + return syserror.EINVAL + } + + // Input validations. + isTun := flags&linux.IFF_TUN != 0 + isTap := flags&linux.IFF_TAP != 0 + supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) + if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { + return syserror.EINVAL + } + + prefix := "tun" + if isTap { + prefix = "tap" + } + + endpoint, err := attachOrCreateNIC(s, name, prefix) + if err != nil { + return syserror.EINVAL + } + + d.endpoint = endpoint + d.notifyHandle = d.endpoint.AddNotify(d) + d.flags = flags + return nil +} + +func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) { + for { + // 1. Try to attach to an existing NIC. + if name != "" { + if nic, found := s.GetNICByName(name); found { + endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) + if !ok { + // Not a NIC created by tun device. + return nil, syserror.EOPNOTSUPP + } + if !endpoint.TryIncRef() { + // Race detected: NIC got deleted in between. + continue + } + return endpoint, nil + } + } + + // 2. Creating a new NIC. + id := tcpip.NICID(s.UniqueID()) + endpoint := &tunEndpoint{ + Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), + stack: s, + nicID: id, + name: name, + } + if endpoint.name == "" { + endpoint.name = fmt.Sprintf("%s%d", prefix, id) + } + err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ + Name: endpoint.name, + }) + switch err { + case nil: + return endpoint, nil + case tcpip.ErrDuplicateNICID: + // Race detected: A NIC has been created in between. + continue + default: + return nil, syserror.EINVAL + } + } +} + +// Write inject one inbound packet to the network interface. +func (d *Device) Write(data []byte) (int64, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return 0, syserror.EBADFD + } + if !endpoint.IsAttached() { + return 0, syserror.EIO + } + + dataLen := int64(len(data)) + + // Packet information. + var pktInfoHdr PacketInfoHeader + if !d.hasFlags(linux.IFF_NO_PI) { + if len(data) < PacketInfoHeaderSize { + // Ignore bad packet. + return dataLen, nil + } + pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize]) + data = data[PacketInfoHeaderSize:] + } + + // Ethernet header (TAP only). + var ethHdr header.Ethernet + if d.hasFlags(linux.IFF_TAP) { + if len(data) < header.EthernetMinimumSize { + // Ignore bad packet. + return dataLen, nil + } + ethHdr = header.Ethernet(data[:header.EthernetMinimumSize]) + data = data[header.EthernetMinimumSize:] + } + + // Try to determine network protocol number, default zero. + var protocol tcpip.NetworkProtocolNumber + switch { + case pktInfoHdr != nil: + protocol = pktInfoHdr.Protocol() + case ethHdr != nil: + protocol = ethHdr.Type() + } + + // Try to determine remote link address, default zero. + var remote tcpip.LinkAddress + switch { + case ethHdr != nil: + remote = ethHdr.SourceAddress() + default: + remote = tcpip.LinkAddress(zeroMAC[:]) + } + + pkt := tcpip.PacketBuffer{ + Data: buffer.View(data).ToVectorisedView(), + } + if ethHdr != nil { + pkt.LinkHeader = buffer.View(ethHdr) + } + endpoint.InjectLinkAddr(protocol, remote, pkt) + return dataLen, nil +} + +// Read reads one outgoing packet from the network interface. +func (d *Device) Read() ([]byte, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return nil, syserror.EBADFD + } + + for { + info, ok := endpoint.Read() + if !ok { + return nil, syserror.ErrWouldBlock + } + + v, ok := d.encodePkt(&info) + if !ok { + // Ignore unsupported packet. + continue + } + return v, nil + } +} + +// encodePkt encodes packet for fd side. +func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { + var vv buffer.VectorisedView + + // Packet information. + if !d.hasFlags(linux.IFF_NO_PI) { + hdr := make(PacketInfoHeader, PacketInfoHeaderSize) + hdr.Encode(&PacketInfoFields{ + Protocol: info.Proto, + }) + vv.AppendView(buffer.View(hdr)) + } + + // If the packet does not already have link layer header, and the route + // does not exist, we can't compute it. This is possibly a raw packet, tun + // device doesn't support this at the moment. + if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + return nil, false + } + + // Ethernet header (TAP only). + if d.hasFlags(linux.IFF_TAP) { + // Add ethernet header if not provided. + if info.Pkt.LinkHeader == nil { + hdr := &header.EthernetFields{ + SrcAddr: info.Route.LocalLinkAddress, + DstAddr: info.Route.RemoteLinkAddress, + Type: info.Proto, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = d.endpoint.LinkAddress() + } + + eth := make(header.Ethernet, header.EthernetMinimumSize) + eth.Encode(hdr) + vv.AppendView(buffer.View(eth)) + } else { + vv.AppendView(info.Pkt.LinkHeader) + } + } + + // Append upper headers. + vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + // Append data payload. + vv.Append(info.Pkt.Data) + + return vv.ToView(), true +} + +// Name returns the name of the attached network interface. Empty string if +// unattached. +func (d *Device) Name() string { + d.mu.RLock() + defer d.mu.RUnlock() + if d.endpoint != nil { + return d.endpoint.name + } + return "" +} + +// Flags returns the flags set for d. Zero value if unset. +func (d *Device) Flags() uint16 { + d.mu.RLock() + defer d.mu.RUnlock() + return d.flags +} + +func (d *Device) hasFlags(flags uint16) bool { + return d.flags&flags == flags +} + +// Readiness implements watier.Waitable.Readiness. +func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask { + if mask&waiter.EventIn != 0 { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint != nil && endpoint.NumQueued() == 0 { + mask &= ^waiter.EventIn + } + } + return mask & (waiter.EventIn | waiter.EventOut) +} + +// WriteNotify implements channel.Notification.WriteNotify. +func (d *Device) WriteNotify() { + d.Notify(waiter.EventIn) +} + +// tunEndpoint is the link endpoint for the NIC created by the tun device. +// +// It is ref-counted as multiple opening files can attach to the same NIC. +// The last owner is responsible for deleting the NIC. +type tunEndpoint struct { + *channel.Endpoint + + refs.AtomicRefCount + + stack *stack.Stack + nicID tcpip.NICID + name string +} + +// DecRef decrements refcount of e, removes NIC if refcount goes to 0. +func (e *tunEndpoint) DecRef() { + e.DecRefWithDestructor(func() { + e.stack.RemoveNIC(e.nicID) + }) +} diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go new file mode 100644 index 000000000..89d9d91a9 --- /dev/null +++ b/pkg/tcpip/link/tun/protocol.go @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // PacketInfoHeaderSize is the size of the packet information header. + PacketInfoHeaderSize = 4 + + offsetFlags = 0 + offsetProtocol = 2 +) + +// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is +// not set. +type PacketInfoFields struct { + Flags uint16 + Protocol tcpip.NetworkProtocolNumber +} + +// PacketInfoHeader is the wire representation of the packet information sent if +// IFF_NO_PI flag is not set. +type PacketInfoHeader []byte + +// Encode encodes f into h. +func (h PacketInfoHeader) Encode(f *PacketInfoFields) { + binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags) + binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol)) +} + +// Flags returns the flag field in h. +func (h PacketInfoHeader) Flags() uint16 { + return binary.BigEndian.Uint16(h[offsetFlags:]) +} + +// Protocol returns the protocol field in h. +func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber { + return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:])) +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 862954ab2..46d3a6646 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -298,6 +298,33 @@ func (n *NIC) enable() *tcpip.Error { return nil } +// remove detaches NIC from the link endpoint, and marks existing referenced +// network endpoints expired. This guarantees no packets between this NIC and +// the network stack. +func (n *NIC) remove() *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + // Detach from link endpoint, so no packet comes in. + n.linkEP.Attach(nil) + + // Remove permanent and permanentTentative addresses, so no packet goes out. + var errs []*tcpip.Error + for nid, ref := range n.mu.endpoints { + switch ref.getKind() { + case permanentTentative, permanent: + if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil { + errs = append(errs, err) + } + } + } + if len(errs) > 0 { + return errs[0] + } + + return nil +} + // becomeIPv6Router transitions n into an IPv6 router. // // When transitioning into an IPv6 router, host-only state (NDP discovered @@ -1302,6 +1329,11 @@ func (n *NIC) Stack() *Stack { return n.stack } +// LinkEndpoint returns the link endpoint of n. +func (n *NIC) LinkEndpoint() LinkEndpoint { + return n.linkEP +} + // isAddrTentative returns true if addr is tentative on n. // // Note that if addr is not associated with n, then this function will return diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f0ed76fbe..900dd46c5 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -916,6 +916,18 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } +// GetNICByName gets the NIC specified by name. +func (s *Stack) GetNICByName(name string) (*NIC, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, nic := range s.nics { + if nic.Name() == name { + return nic, true + } + } + return nil, false +} + // EnableNIC enables the given NIC so that the link-layer endpoint can start // delivering packets to it. func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { @@ -956,6 +968,33 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return nic.enabled() } +// RemoveNIC removes NIC and all related routes from the network stack. +func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + delete(s.nics, id) + + // Remove routes in-place. n tracks the number of routes written. + n := 0 + for i, r := range s.routeTable { + if r.NIC != id { + // Keep this route. + if i > n { + s.routeTable[n] = r + } + n++ + } + } + s.routeTable = s.routeTable[:n] + + return nic.remove() +} + // NICAddressRanges returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index d1977d4de..3518e862d 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -678,6 +678,8 @@ syscall_test( test = "//test/syscalls/linux:truncate_test", ) +syscall_test(test = "//test/syscalls/linux:tuntap_test") + syscall_test(test = "//test/syscalls/linux:udp_bind_test") syscall_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index aa303af84..704bae17b 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -131,6 +131,17 @@ cc_library( ], ) +cc_library( + name = "socket_netlink_route_util", + testonly = 1, + srcs = ["socket_netlink_route_util.cc"], + hdrs = ["socket_netlink_route_util.h"], + deps = [ + ":socket_netlink_util", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "socket_test_util", testonly = 1, @@ -3430,6 +3441,25 @@ cc_binary( ], ) +cc_binary( + name = "tuntap_test", + testonly = 1, + srcs = ["tuntap.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + gtest, + "//test/syscalls/linux:socket_netlink_route_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "udp_socket_test_cases", testonly = 1, diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 4dd302eed..4e473268c 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -153,6 +153,13 @@ TEST(DevTest, TTYExists) { EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); } +TEST(DevTest, NetTunExists) { + struct stat statbuf = {}; + ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds()); + // Check that it's a character device with rw-rw-rw- permissions. + EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc new file mode 100644 index 000000000..53eb3b6b2 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -0,0 +1,163 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_netlink_route_util.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr uint32_t kSeq = 12345; + +} // namespace + +PosixError DumpLinks( + const FileDescriptor& fd, uint32_t seq, + const std::function& fn) { + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + + return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); +} + +PosixErrorOr> DumpLinks() { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + std::vector links; + RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + if (hdr->nlmsg_type != RTM_NEWLINK || + hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { + return; + } + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + if (rta == nullptr) { + // Ignore links that do not have a name. + return; + } + + links.emplace_back(); + links.back().index = msg->ifi_index; + links.back().type = msg->ifi_type; + links.back().name = + std::string(reinterpret_cast(RTA_DATA(rta))); + })); + return links; +} + +PosixErrorOr> FindLoopbackLink() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.type == ARPHRD_LOOPBACK) { + return absl::optional(link); + } + } + return absl::optional(); +} + +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast( + reinterpret_cast(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char pad[NLMSG_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + req.ifinfo.ifi_flags = flags; + req.ifinfo.ifi_change = change; + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + + struct rtattr* rta = reinterpret_cast( + reinterpret_cast(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFLA_ADDRESS; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h new file mode 100644 index 000000000..2c018e487 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -0,0 +1,55 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ + +#include +#include + +#include + +#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { + +struct Link { + int index; + int16_t type; + std::string name; +}; + +PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, + const std::function& fn); + +PosixErrorOr> DumpLinks(); + +PosixErrorOr> FindLoopbackLink(); + +// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + +// LinkChangeFlags changes interface flags. E.g. IFF_UP. +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); + +// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface. +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc new file mode 100644 index 000000000..f6ac9d7b8 --- /dev/null +++ b/test/syscalls/linux/tuntap.cc @@ -0,0 +1,346 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr int kIPLen = 4; + +constexpr const char kDevNetTun[] = "/dev/net/tun"; +constexpr const char kTapName[] = "tap0"; + +constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}; +constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}; + +PosixErrorOr> DumpLinkNames() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + std::set names; + for (const auto& link : links) { + names.emplace(link.name); + } + return names; +} + +PosixErrorOr> GetLinkByName(const std::string& name) { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.name == name) { + return absl::optional(link); + } + } + return absl::optional(); +} + +struct pihdr { + uint16_t pi_flags; + uint16_t pi_protocol; +} __attribute__((packed)); + +struct ping_pkt { + pihdr pi; + struct ethhdr eth; + struct iphdr ip; + struct icmphdr icmp; + char payload[64]; +} __attribute__((packed)); + +ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + ping_pkt pkt = {}; + + pkt.pi.pi_protocol = htons(ETH_P_IP); + + memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest)); + memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source)); + pkt.eth.h_proto = htons(ETH_P_IP); + + pkt.ip.ihl = 5; + pkt.ip.version = 4; + pkt.ip.tos = 0; + pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) + + sizeof(pkt.payload)); + pkt.ip.id = 1; + pkt.ip.frag_off = 1 << 6; // Do not fragment + pkt.ip.ttl = 64; + pkt.ip.protocol = IPPROTO_ICMP; + inet_pton(AF_INET, dstip, &pkt.ip.daddr); + inet_pton(AF_INET, srcip, &pkt.ip.saddr); + pkt.ip.check = IPChecksum(pkt.ip); + + pkt.icmp.type = ICMP_ECHO; + pkt.icmp.code = 0; + pkt.icmp.checksum = 0; + pkt.icmp.un.echo.sequence = 1; + pkt.icmp.un.echo.id = 1; + + strncpy(pkt.payload, "abcd", sizeof(pkt.payload)); + pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload)); + + return pkt; +} + +struct arp_pkt { + pihdr pi; + struct ethhdr eth; + struct arphdr arp; + uint8_t arp_sha[ETH_ALEN]; + uint8_t arp_spa[kIPLen]; + uint8_t arp_tha[ETH_ALEN]; + uint8_t arp_tpa[kIPLen]; +} __attribute__((packed)); + +std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + std::string buffer; + buffer.resize(sizeof(arp_pkt)); + + arp_pkt* pkt = reinterpret_cast(&buffer[0]); + { + pkt->pi.pi_protocol = htons(ETH_P_ARP); + + memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest)); + memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source)); + pkt->eth.h_proto = htons(ETH_P_ARP); + + pkt->arp.ar_hrd = htons(ARPHRD_ETHER); + pkt->arp.ar_pro = htons(ETH_P_IP); + pkt->arp.ar_hln = ETH_ALEN; + pkt->arp.ar_pln = kIPLen; + pkt->arp.ar_op = htons(ARPOP_REPLY); + + memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha)); + inet_pton(AF_INET, srcip, pkt->arp_spa); + memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha)); + inet_pton(AF_INET, dstip, pkt->arp_tpa); + } + return buffer; +} + +} // namespace + +class TuntapTest : public ::testing::Test { + protected: + void TearDown() override { + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) { + // Bring back capability if we had dropped it in test case. + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true)); + } + } +}; + +TEST_F(TuntapTest, CreateInterfaceNoCap) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false)); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + strncpy(ifr.ifr_name, kTapName, IFNAMSIZ); + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM)); +} + +TEST_F(TuntapTest, CreateFixedNameInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_expect = ifr_set; + // See __tun_chr_ioctl() in net/drivers/tun.c. + ifr_expect.ifr_flags |= IFF_NOFILTER; + + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(kTapName))); + EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0)); +} + +TEST_F(TuntapTest, CreateInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + // Empty ifr.ifr_name. Let kernel assign. + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + std::string ifname = ifr_get.ifr_name; + EXPECT_THAT(ifname, ::testing::StartsWith("tap")); + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(ifname))); +} + +TEST_F(TuntapTest, InvalidReadWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + char buf[128] = {}; + EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); +} + +TEST_F(TuntapTest, WriteToDownDevice) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // FIXME: gVisor always creates enabled/up'd interfaces. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + // Device created should be down by default. + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + char buf[128] = {}; + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO)); +} + +// This test sets up a TAP device and pings kernel by sending ICMP echo request. +// +// It works as the following: +// * Open /dev/net/tun, and create kTapName interface. +// * Use rtnetlink to do initial setup of the interface: +// * Assign IP address 10.0.0.1/24 to kernel. +// * MAC address: kMacA +// * Bring up the interface. +// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel. +// * Loop to receive packets from TAP device/fd: +// * If packet is an ICMP echo reply, it stops and passes the test. +// * If packet is an ARP request, it responds with canned reply and resends +// the +// ICMP request packet. +TEST_F(TuntapTest, PingKernel) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Interface creation. + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + absl::optional link = + ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName)); + ASSERT_TRUE(link.has_value()); + + // Interface setup. + struct in_addr addr; + inet_pton(AF_INET, "10.0.0.1", &addr); + EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24, + &addr, sizeof(addr))); + + if (!IsRunningOnGvisor()) { + // FIXME: gVisor doesn't support setting MAC address on interfaces yet. + EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA))); + + // FIXME: gVisor always creates enabled/up'd interfaces. + EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP)); + } + + ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + + // Send ping, this would trigger an ARP request on Linux. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + + // Receive loop to process inbound packets. + struct inpkt { + union { + pihdr pi; + ping_pkt ping; + arp_pkt arp; + }; + }; + while (1) { + inpkt r = {}; + int n = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(n, SyscallSucceeds()); + + if (n < sizeof(pihdr)) { + std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol + << " len: " << n << std::endl; + continue; + } + + // Process ARP packet. + if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) { + // Respond with canned ARP reply. + EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()), + SyscallSucceedsWithValue(arp_rep.size())); + // First ping request might have been dropped due to mac address not in + // ARP cache. Send it again. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + } + + // Process ping response packet. + if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol && + r.ping.ip.protocol == ping_req.ip.protocol && + !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) && + !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) && + r.ping.icmp.type == 0 && r.ping.icmp.code == 0) { + // Ends and passes the test. + break; + } + } +} + +} // namespace testing +} // namespace gvisor -- cgit v1.2.3 From fbe80460a7eb34147b928fa1023b28a3c094c070 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 26 Mar 2020 14:04:28 -0700 Subject: Handle IPv6 Fragment & Routing extension headers Enables the reassembly of fragmented IPv6 packets and handling of the Routing extension header with a Segments Left value of 0. Atomic fragments are handled as described in RFC 6946 to not interfere with "normal" fragment traffic. No specific routing header type is supported. Note, the stack does not yet support sending ICMPv6 error messages in response to IPv6 packets that cannot be handled/parsed. That will come in a later change (Issue #2211). Test: - header_test.TestIPv6RoutingExtHdr - header_test.TestIPv6FragmentExtHdr - header_test.TestIPv6ExtHdrIterErr - header_test.TestIPv6ExtHdrIter - ipv6_test.TestReceiveIPv6ExtHdrs - ipv6_test.TestReceiveIPv6Fragments RELNOTES: n/a PiperOrigin-RevId: 303189584 --- pkg/tcpip/buffer/view.go | 20 + pkg/tcpip/header/BUILD | 3 + pkg/tcpip/header/ipv6_extension_headers.go | 344 +++++++++++ pkg/tcpip/header/ipv6_extension_headers_test.go | 515 ++++++++++++++++ pkg/tcpip/network/hash/hash.go | 4 +- pkg/tcpip/network/ipv6/BUILD | 3 + pkg/tcpip/network/ipv6/ipv6.go | 122 +++- pkg/tcpip/network/ipv6/ipv6_test.go | 768 ++++++++++++++++++++++++ 8 files changed, 1772 insertions(+), 7 deletions(-) create mode 100644 pkg/tcpip/header/ipv6_extension_headers.go create mode 100644 pkg/tcpip/header/ipv6_extension_headers_test.go (limited to 'pkg/tcpip/buffer') diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 17e94c562..8d42cd066 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -15,6 +15,10 @@ // Package buffer provides the implementation of a buffer view. package buffer +import ( + "bytes" +) + // View is a slice of a buffer, with convenience methods. type View []byte @@ -45,6 +49,13 @@ func (v *View) CapLength(length int) { *v = (*v)[:length:length] } +// Reader returns a bytes.Reader for v. +func (v *View) Reader() bytes.Reader { + var r bytes.Reader + r.Reset(*v) + return r +} + // ToVectorisedView returns a VectorisedView containing the receiver. func (v View) ToVectorisedView() VectorisedView { return NewVectorisedView(len(v), []View{v}) @@ -162,3 +173,12 @@ func (vv *VectorisedView) AppendView(v View) { vv.views = append(vv.views, v) vv.size += len(v) } + +// Readers returns a bytes.Reader for each of vv's views. +func (vv *VectorisedView) Readers() []bytes.Reader { + readers := make([]bytes.Reader, 0, len(vv.views)) + for _, v := range vv.views { + readers = append(readers, v.Reader()) + } + return readers +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 9da0d71f8..7094f3f0b 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -14,6 +14,7 @@ go_library( "interfaces.go", "ipv4.go", "ipv6.go", + "ipv6_extension_headers.go", "ipv6_fragment.go", "ndp_neighbor_advert.go", "ndp_neighbor_solicit.go", @@ -55,11 +56,13 @@ go_test( size = "small", srcs = [ "eth_test.go", + "ipv6_extension_headers_test.go", "ndp_test.go", ], library = ":header", deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "@com_github_google_go-cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go new file mode 100644 index 000000000..b8866d4d2 --- /dev/null +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -0,0 +1,344 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier. +type IPv6ExtensionHeaderIdentifier uint8 + +const ( + // IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension + // header, as per RFC 8200 section 4.4. + IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43 + + // IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment + // extension header, as per RFC 8200 section 4.5. + IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44 + + // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end + // of an IPv6 payload, as per RFC 8200 section 4.7. + IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59 +) + +const ( + // ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field + // within an IPv6RoutingExtHdr. + ipv6RoutingExtHdrSegmentsLeftIdx = 1 + + // ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the + // Fragment Offset field within an IPv6FragmentExtHdr. + ipv6FragmentExtHdrFragmentOffsetOffset = 0 + + // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to + // discard from the Fragment Offset. + ipv6FragmentExtHdrFragmentOffsetShift = 3 + + // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an + // IPv6FragmentExtHdr. + ipv6FragmentExtHdrFlagsIdx = 1 + + // ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the + // flags field of an IPv6FragmentExtHdr. + ipv6FragmentExtHdrMFlagMask = 1 + + // ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification + // field within an IPv6FragmentExtHdr. + ipv6FragmentExtHdrIdentificationOffset = 2 + + // ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length + // field. That is, given a Length field of 2, the extension header expects + // 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for + // details about the first 8 bytes' exclusion from the Length field). + ipv6ExtHdrLenBytesPerUnit = 8 + + // ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an + // extension header's Length field following the Length field. + // + // The Length field excludes the first 8 bytes, but the Next Header and Length + // field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes + // after the Length field. + // + // This ensures that every extension header is at least 8 bytes. + ipv6ExtHdrLenBytesExcluded = 6 + + // IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment + // extension header's Fragment Offset field. That is, given a Fragment Offset + // of 2, the extension header is indiciating that the fragment's payload + // starts at the 16th byte in the reassembled packet. + IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8 +) + +// IPv6PayloadHeader is implemented by the various headers that can be found +// in an IPv6 payload. +// +// These headers include IPv6 extension headers or upper layer data. +type IPv6PayloadHeader interface { + isIPv6PayloadHeader() +} + +// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator +// encounters a Next Header field it does not recognize as an IPv6 extension +// header. +type IPv6RawPayloadHeader struct { + Identifier IPv6ExtensionHeaderIdentifier + Buf buffer.VectorisedView +} + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {} + +// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific +// data as outlined in RFC 8200 section 4.4. +type IPv6RoutingExtHdr []byte + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {} + +// SegmentsLeft returns the Segments Left field. +func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 { + return b[ipv6RoutingExtHdrSegmentsLeftIdx] +} + +// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific +// data as outlined in RFC 8200 section 4.5. +// +// Note, the buffer does not include the Next Header and Reserved fields. +type IPv6FragmentExtHdr [6]byte + +// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader. +func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {} + +// FragmentOffset returns the Fragment Offset field. +// +// This value indicates where the buffer following the Fragment extension header +// starts in the target (reassembled) packet. +func (b IPv6FragmentExtHdr) FragmentOffset() uint16 { + return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift +} + +// More returns the More (M) flag. +// +// This indicates whether any fragments are expected to succeed b. +func (b IPv6FragmentExtHdr) More() bool { + return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0 +} + +// ID returns the Identification field. +// +// This value is used to uniquely identify the packet, between a +// souce and destination. +func (b IPv6FragmentExtHdr) ID() uint32 { + return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:]) +} + +// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload. +// +// The IPv6 payload may contain IPv6 extension headers before any upper layer +// data. +// +// Note, between when an IPv6PayloadIterator is obtained and last used, no +// changes to the payload may happen. Doing so may cause undefined and +// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate +// over the first few headers then modify the backing payload so long as the +// IPv6PayloadIterator obtained before modification is no longer used. +type IPv6PayloadIterator struct { + // The identifier of the next header to parse. + nextHdrIdentifier IPv6ExtensionHeaderIdentifier + + // reader is an io.Reader over payload. + reader bufio.Reader + payload buffer.VectorisedView + + // Indicates to the iterator that it should return the remaining payload as a + // raw payload on the next call to Next. + forceRaw bool +} + +// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing +// extension headers, or a raw payload if the payload cannot be parsed. +func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.VectorisedView, check bool) (IPv6PayloadIterator, error) { + readers := payload.Readers() + readerPs := make([]io.Reader, 0, len(readers)) + for i := range readers { + readerPs = append(readerPs, &readers[i]) + } + + // We need a buffer of size 1 for calls to bufio.Reader.ReadByte. + reader := *bufio.NewReaderSize(io.MultiReader(readerPs...), 1) + + it := IPv6PayloadIterator{ + nextHdrIdentifier: nextHdrIdentifier, + payload: payload.Clone(nil), + reader: reader, + } + + var err error + + if check { + for { + var done bool + if _, done, err = it.Next(); err != nil || done { + break + } + } + + // Reset it (and its underlying readers) before returning it. + for i := range readers { + readers[i].Seek(0, io.SeekStart) + } + reader.Reset(io.MultiReader(readerPs...)) + it = IPv6PayloadIterator{ + nextHdrIdentifier: nextHdrIdentifier, + payload: payload.Clone(nil), + reader: reader, + } + } + + return it, err +} + +// AsRawHeader returns the remaining payload of i as a raw header and +// completes the iterator. +// +// Calls to Next after calling AsRawHeader on i will indicate that the +// iterator is done. +func (i *IPv6PayloadIterator) AsRawHeader() IPv6RawPayloadHeader { + buf := i.payload + identifier := i.nextHdrIdentifier + + // Mark i as done. + *i = IPv6PayloadIterator{ + nextHdrIdentifier: IPv6NoNextHeaderIdentifier, + } + + return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf} +} + +// Next returns the next item in the payload. +// +// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader +// will be returned with the remaining bytes and next header identifier. +// +// The return is of the format (header, done, error). done will be true when +// Next is unable to return anything because the iterator has reached the end of +// the payload, or an error occured. +func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { + // We could be forced to return i as a raw header when the previous header was + // a fragment extension header as the data following the fragment extension + // header may not be complete. + if i.forceRaw { + return i.AsRawHeader(), false, nil + } + + // Is the header we are parsing a known extension header? + switch i.nextHdrIdentifier { + case IPv6RoutingExtHdrIdentifier: + nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil) + if err != nil { + return nil, true, err + } + + i.nextHdrIdentifier = nextHdrIdentifier + return IPv6RoutingExtHdr(bytes), false, nil + case IPv6FragmentExtHdrIdentifier: + var data [6]byte + // We ignore the returned bytes becauase we know the fragment extension + // header specific data will fit in data. + nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:]) + if err != nil { + return nil, true, err + } + + fragmentExtHdr := IPv6FragmentExtHdr(data) + + // If the packet is a fragmented packet, do not attempt to parse + // anything after the fragment extension header as the data following + // the extension header may not be complete. + if fragmentExtHdr.More() || fragmentExtHdr.FragmentOffset() != 0 { + i.forceRaw = true + } + + i.nextHdrIdentifier = nextHdrIdentifier + return fragmentExtHdr, false, nil + case IPv6NoNextHeaderIdentifier: + // This indicates the end of the IPv6 payload. + return nil, true, nil + + default: + // The header we are parsing is not a known extension header. Return the + // raw payload. + return i.AsRawHeader(), false, nil + } +} + +// nextHeaderData returns the extension header's Next Header field and raw data. +// +// fragmentHdr indicates that the extension header being parsed is the Fragment +// extension header so the Length field should be ignored as it is Reserved +// for the Fragment extension header. +// +// If bytes is not nil, extension header specific data will be read into bytes +// if it has enough capacity. If bytes is provided but does not have enough +// capacity for the data, nextHeaderData will panic. +func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, []byte, error) { + // We ignore the number of bytes read because we know we will only ever read + // at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read + // would return io.EOF to indicate that io.Reader has reached the end of the + // payload. + nextHdrIdentifier, err := i.reader.ReadByte() + i.payload.TrimFront(1) + if err != nil { + return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err) + } + + var length uint8 + length, err = i.reader.ReadByte() + i.payload.TrimFront(1) + if err != nil { + var ret error + if fragmentHdr { + ret = fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err) + } else { + ret = fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err) + } + return 0, nil, ret + } + if fragmentHdr { + length = 0 + } + + bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded + if bytes == nil { + bytes = make([]byte, bytesLen) + } else if n := len(bytes); n < bytesLen { + panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier)) + } + + n, err := io.ReadFull(&i.reader, bytes) + i.payload.TrimFront(n) + if err != nil { + return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err) + } + + return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil +} diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go new file mode 100644 index 000000000..4bfdc77c4 --- /dev/null +++ b/pkg/tcpip/header/ipv6_extension_headers_test.go @@ -0,0 +1,515 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +// Equal returns true of a and b are equivalent. +// +// Note, Equal will return true if a and b hold the same Identifier value and +// contain the same bytes in Buf, even if the bytes are split across views +// differently. +// +// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported +// fields. +func (a IPv6RawPayloadHeader) Equal(b IPv6RawPayloadHeader) bool { + return a.Identifier == b.Identifier && bytes.Equal(a.Buf.ToView(), b.Buf.ToView()) +} + +func TestIPv6RoutingExtHdr(t *testing.T) { + tests := []struct { + name string + bytes []byte + segmentsLeft uint8 + }{ + { + name: "Zeroes", + bytes: []byte{0, 0, 0, 0, 0, 0}, + segmentsLeft: 0, + }, + { + name: "Ones", + bytes: []byte{1, 1, 1, 1, 1, 1}, + segmentsLeft: 1, + }, + { + name: "Mixed", + bytes: []byte{1, 2, 3, 4, 5, 6}, + segmentsLeft: 2, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + extHdr := IPv6RoutingExtHdr(test.bytes) + if got := extHdr.SegmentsLeft(); got != test.segmentsLeft { + t.Errorf("got SegmentsLeft() = %d, want = %d", got, test.segmentsLeft) + } + }) + } +} + +func TestIPv6FragmentExtHdr(t *testing.T) { + tests := []struct { + name string + bytes [6]byte + fragmentOffset uint16 + more bool + id uint32 + }{ + { + name: "Zeroes", + bytes: [6]byte{0, 0, 0, 0, 0, 0}, + fragmentOffset: 0, + more: false, + id: 0, + }, + { + name: "Ones", + bytes: [6]byte{0, 9, 0, 0, 0, 1}, + fragmentOffset: 1, + more: true, + id: 1, + }, + { + name: "Mixed", + bytes: [6]byte{68, 9, 128, 4, 2, 1}, + fragmentOffset: 2177, + more: true, + id: 2147746305, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + extHdr := IPv6FragmentExtHdr(test.bytes) + if got := extHdr.FragmentOffset(); got != test.fragmentOffset { + t.Errorf("got FragmentOffset() = %d, want = %d", got, test.fragmentOffset) + } + if got := extHdr.More(); got != test.more { + t.Errorf("got More() = %t, want = %t", got, test.more) + } + if got := extHdr.ID(); got != test.id { + t.Errorf("got ID() = %d, want = %d", got, test.id) + } + }) + } +} + +func makeVectorisedViewFromByteBuffers(bs ...[]byte) buffer.VectorisedView { + size := 0 + var vs []buffer.View + + for _, b := range bs { + vs = append(vs, buffer.View(b)) + size += len(b) + } + + return buffer.NewVectorisedView(size, vs) +} + +func TestIPv6ExtHdrIterErr(t *testing.T) { + tests := []struct { + name string + firstNextHdr IPv6ExtensionHeaderIdentifier + payload buffer.VectorisedView + err error + }{ + { + name: "Upper layer only without data", + firstNextHdr: 255, + }, + { + name: "Upper layer only with data", + firstNextHdr: 255, + payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}), + }, + + { + name: "No next header", + firstNextHdr: IPv6NoNextHeaderIdentifier, + }, + { + name: "No next header with data", + firstNextHdr: IPv6NoNextHeaderIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}), + }, + + { + name: "Valid single fragment", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2, 1}), + }, + { + name: "Fragment too small", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2}), + err: io.ErrUnexpectedEOF, + }, + + { + name: "Valid single routing", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5, 6}), + }, + { + name: "Valid single routing across views", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2}, []byte{3, 4, 5, 6}), + }, + { + name: "Routing too small with zero length field", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5}), + err: io.ErrUnexpectedEOF, + }, + { + name: "Valid routing with non-zero length field", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8}), + }, + { + name: "Valid routing with non-zero length field across views", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7, 8}), + }, + { + name: "Routing too small with non-zero length field", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7}), + err: io.ErrUnexpectedEOF, + }, + { + name: "Routing too small with non-zero length field across views", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7}), + err: io.ErrUnexpectedEOF, + }, + + { + name: "Mixed", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, + + // Routing extension header. + 255, 0, 1, 2, 3, 4, 5, 6, + + // Upper layer data. + 1, 2, 3, 4, + }), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if _, err := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload, false); err != nil { + t.Errorf("got MakeIPv6PayloadIterator(%d, _, false) = %s, want = nil", test.firstNextHdr, err) + } + + if _, err := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload, true); !errors.Is(err, test.err) { + t.Errorf("got MakeIPv6PayloadIterator(%d, _, true) = %v, want = %v", test.firstNextHdr, err, test.err) + } + }) + } +} + +func TestIPv6ExtHdrIter(t *testing.T) { + routingExtHdrWithUpperLayerData := buffer.View([]byte{255, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4}) + upperLayerData := buffer.View([]byte{1, 2, 3, 4}) + tests := []struct { + name string + firstNextHdr IPv6ExtensionHeaderIdentifier + payload buffer.VectorisedView + expected []IPv6PayloadHeader + }{ + // With a non-atomic fragment, the payload after the fragment will not be + // parsed because the payload may not be complete. + { + name: "fragment - routing - upper", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, + + // Routing extension header. + 255, 0, 1, 2, 3, 4, 5, 6, + + // Upper layer data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}), + IPv6RawPayloadHeader{ + Identifier: IPv6RoutingExtHdrIdentifier, + Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(), + }, + }, + }, + { + name: "fragment - routing - upper (across views)", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1, + + // Routing extension header. + 255, 0, 1, 2}, []byte{3, 4, 5, 6, + + // Upper layer data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}), + IPv6RawPayloadHeader{ + Identifier: IPv6RoutingExtHdrIdentifier, + Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(), + }, + }, + }, + + // If we have an atomic fragment, the payload following the fragment + // extension header should be parsed normally. + { + name: "atomic fragment - routing - upper", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1, + + // Routing extension header. + 255, 0, 1, 2, 3, 4, 5, 6, + + // Upper layer data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6RawPayloadHeader{ + Identifier: 255, + Buf: upperLayerData.ToVectorisedView(), + }, + }, + }, + { + name: "atomic fragment - routing - upper (across views)", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1, + + // Routing extension header. + 255, 0, 1, 2}, []byte{3, 4, 5, 6, + + // Upper layer data. + 1, 2}, []byte{3, 4}), + expected: []IPv6PayloadHeader{ + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6RawPayloadHeader{ + Identifier: 255, + Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), + }, + }, + }, + { + name: "atomic fragment - no next header", + firstNextHdr: IPv6FragmentExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Fragment extension header. + // + // Res (Reserved) bits are 1 which should not affect anything. + uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1, + + // Random data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + }, + }, + { + name: "routing - atomic fragment - no next header", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Routing extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1, + + // Random data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + }, + }, + { + name: "routing - atomic fragment - no next header (across views)", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Routing extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Fragment extension header. + // + // Reserved bits are 1 which should not affect anything. + uint8(IPv6NoNextHeaderIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1, + + // Random data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}), + }, + }, + { + name: "routing - fragment - no next header", + firstNextHdr: IPv6RoutingExtHdrIdentifier, + payload: makeVectorisedViewFromByteBuffers([]byte{ + // Routing extension header. + uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6, + + // Fragment extension header. + // + // Fragment Offset = 32; Res = 6. + uint8(IPv6NoNextHeaderIdentifier), 0, 1, 6, 128, 4, 2, 1, + + // Random data. + 1, 2, 3, 4, + }), + expected: []IPv6PayloadHeader{ + IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}), + IPv6FragmentExtHdr([6]byte{1, 6, 128, 4, 2, 1}), + IPv6RawPayloadHeader{ + Identifier: IPv6NoNextHeaderIdentifier, + Buf: upperLayerData.ToVectorisedView(), + }, + }, + }, + + // Test the raw payload for common transport layer protocol numbers. + { + name: "TCP raw payload", + firstNextHdr: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber), + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber), + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "UDP raw payload", + firstNextHdr: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber), + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber), + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "ICMPv4 raw payload", + firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber), + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber), + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "ICMPv6 raw payload", + firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber), + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber), + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "Unknwon next header raw payload", + firstNextHdr: 255, + payload: makeVectorisedViewFromByteBuffers(upperLayerData), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: 255, + Buf: upperLayerData.ToVectorisedView(), + }}, + }, + { + name: "Unknwon next header raw payload (across views)", + firstNextHdr: 255, + payload: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), + expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{ + Identifier: 255, + Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]), + }}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + it, err := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload, true) + if err != nil { + t.Fatalf("MakeIPv6PayloadIterator(%d, _ true): %s", test.firstNextHdr, err) + } + + for i, e := range test.expected { + extHdr, done, err := it.Next() + if err != nil { + t.Errorf("(i=%d) Next(): %s", i, err) + } + if done { + t.Errorf("(i=%d) unexpectedly done iterating", i) + } + if diff := cmp.Diff(e, extHdr); diff != "" { + t.Errorf("(i=%d) got ext hdr mismatch (-want +got):\n%s", i, diff) + } + + if t.Failed() { + t.FailNow() + } + } + + extHdr, done, err := it.Next() + if err != nil { + t.Errorf("(last) Next(): %s", err) + } + if !done { + t.Errorf("(last) iterator unexpectedly not done") + } + if extHdr != nil { + t.Errorf("(last) got Next() = %T, want = nil", extHdr) + } + }) + } +} diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go index 6a215938b..8f65713c5 100644 --- a/pkg/tcpip/network/hash/hash.go +++ b/pkg/tcpip/network/hash/hash.go @@ -80,12 +80,12 @@ func IPv4FragmentHash(h header.IPv4) uint32 { // RFC 2640 (sec 4.5) is not very sharp on this aspect. // As a reference, also Linux ignores the protocol to compute // the hash (inet6_hash_frag). -func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 { +func IPv6FragmentHash(h header.IPv6, id uint32) uint32 { t := h.SourceAddress() y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 t = h.DestinationAddress() z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 - return Hash3Words(f.ID(), y, z, hashIV) + return Hash3Words(id, y, z, hashIV) } func rol32(v, shift uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index fb11874c6..a93a7621a 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -13,6 +13,8 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/network/fragmentation", + "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", ], ) @@ -36,5 +38,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 29e597002..a703a768c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -21,11 +21,14 @@ package ipv6 import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -49,6 +52,7 @@ type endpoint struct { linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher + fragmentation *fragmentation.Fragmentation protocol *protocol } @@ -172,6 +176,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { + r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -179,14 +184,120 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { pkt.Data.TrimFront(header.IPv6MinimumSize) pkt.Data.CapLength(int(h.PayloadLength())) - p := h.TransportProtocol() - if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, pkt) + it, err := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data, true) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() return } - r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, pkt) + for { + extHdr, done, err := it.Next() + if err != nil { + // This should never happen as MakeIPv6PayloadIterator above did not + // return an error. + panic(fmt.Sprintf("unexpected error when iterating over IPv6 payload: %s", err)) + } + if done { + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6RoutingExtHdr: + // As per RFC 8200 section 4.4, if a node encounters a routing header with + // an unrecognized routing type value, with a non-zero Segments Left + // value, the node must discard the packet and send an ICMP Parameter + // Problem, Code 0. If the Segments Left is 0, the node must ignore the + // Routing extension header and process the next header in the packet. + // + // Note, the stack does not yet handle any type of routing extension + // header, so we just make sure Segments Left is zero before processing + // the next extension header. + // + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for + // unrecognized routing types with a non-zero Segments Left value. + if extHdr.SegmentsLeft() != 0 { + return + } + + case header.IPv6FragmentExtHdr: + fragmentOffset := extHdr.FragmentOffset() + more := extHdr.More() + if !more && fragmentOffset == 0 { + // This fragment extension header indicates that this packet is an + // atomic fragment. An atomic fragment is a fragment that contains + // all the data required to reassemble a full packet. As per RFC 6946, + // atomic fragments must not interfere with "normal" fragmented traffic + // so we skip processing the fragment instead of feeding it through the + // reassembly process below. + continue + } + + rawPayload := it.AsRawHeader() + fragmentPayloadLen := rawPayload.Buf.Size() + if fragmentPayloadLen == 0 { + // Drop the packet as it's marked as a fragment but has no payload. + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + + // The packet is a fragment, let's try to reassemble it. + start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit + last := start + uint16(fragmentPayloadLen) - 1 + + // Drop the packet if the fragmentOffset is incorrect. i.e the + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. + if last < start { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + + var ready bool + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + + if ready { + // We create a new iterator with the reassembled packet because we could + // have more extension headers in the reassembled payload, as per RFC + // 8200 section 4.5. + it, err = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data, true) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + } + + case header.IPv6RawPayloadHeader: + // If the last header in the payload isn't a known IPv6 extension header, + // handle it as if it is transport layer data. + pkt.Data = extHdr.Buf + + if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { + e.handleICMP(r, headerView, pkt) + } else { + r.Stats().IP.PacketsDelivered.Increment() + // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error + // in response to unrecognized next header values. + e.dispatcher.DeliverTransportPacket(r, p, pkt) + } + + default: + // If we receive a packet for an extension header we do not yet handle, + // drop the packet for now. + // + // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error + // in response to unrecognized next header values. + r.Stats().UnknownProtocolRcvdPackets.Increment() + return + } + } } // Close cleans up resources associated with the endpoint. @@ -229,6 +340,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, + fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), protocol: p, }, nil } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index ed98ef22a..86bfda85e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -17,6 +17,7 @@ package ipv6 import ( "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -33,6 +34,12 @@ const ( // The least significant 3 bytes are the same as addr2 so both addr2 and // addr3 will have the same solicited-node address. addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" + + // Tests use the extension header identifier values as uint8 instead of + // header.IPv6ExtensionHeaderIdentifier. + routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier) + fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) + noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) ) // testReceiveICMP tests receiving an ICMP packet from src to dst. want is the @@ -268,3 +275,764 @@ func TestAddIpv6Address(t *testing.T) { }) } } + +func TestReceiveIPv6ExtHdrs(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + extHdr func(nextHdr uint8) ([]byte, uint8) + shouldAccept bool + }{ + { + name: "None", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, + shouldAccept: true, + }, + { + name: "routing with zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID }, + shouldAccept: true, + }, + { + name: "routing with non-zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID }, + shouldAccept: false, + }, + { + name: "atomic fragment with zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID }, + shouldAccept: true, + }, + { + name: "atomic fragment with non-zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + shouldAccept: true, + }, + { + name: "fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + shouldAccept: false, + }, + { + name: "routing - atomic fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Fragment extension header. + nextHdr, 0, 0, 0, 1, 2, 3, 4, + }, routingExtHdrID + }, + shouldAccept: true, + }, + { + name: "atomic fragment - routing", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Fragment extension header. + routingExtHdrID, 0, 0, 0, 1, 2, 3, 4, + + // Routing extension header. + nextHdr, 0, 1, 0, 2, 3, 4, 5, + }, fragmentExtHdrID + }, + shouldAccept: true, + }, + { + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + shouldAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8} + udpLength := header.UDPMinimumSize + len(udpPayload) + extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber)) + extHdrLen := len(extHdrBytes) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength) + + // Serialize UDP message. + u := header.UDP(hdr.Prepend(udpLength)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: uint16(udpLength), + }) + copy(u.Payload(), udpPayload) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum = header.Checksum(udpPayload, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + // Copy extension header bytes between the UDP message and the IPv6 + // fixed header. + copy(hdr.Prepend(extHdrLen), extHdrBytes) + + // Serialize IPv6 fixed header. + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: ipv6NextHdr, + HopLimit: 255, + SrcAddr: addr1, + DstAddr: addr2, + }) + + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stats := s.Stats().UDP.PacketsReceived + + if !test.shouldAccept { + if got := stats.Value(); got != 0 { + t.Errorf("got UDP Rx Packets = %d, want = 0", got) + } + + return + } + + // Expect a UDP packet. + if got := stats.Value(); got != 1 { + t.Errorf("got UDP Rx Packets = %d, want = 1", got) + } + gotPayload, _, err := ep.Read(nil) + if err != nil { + t.Fatalf("Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + + // Should not have any more UDP packets. + if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + }) + } +} + +// fragmentData holds the IPv6 payload for a fragmented IPv6 packet. +type fragmentData struct { + nextHdr uint8 + data buffer.VectorisedView +} + +func TestReceiveIPv6Fragments(t *testing.T) { + const nicID = 1 + const udpPayload1Length = 256 + const udpPayload2Length = 128 + const fragmentExtHdrLen = 8 + // Note, not all routing extension headers will be 8 bytes but this test + // uses 8 byte routing extension headers for most sub tests. + const routingExtHdrLen = 8 + + udpGen := func(payload []byte, multiplier uint8) buffer.View { + payloadLen := len(payload) + for i := 0; i < payloadLen; i++ { + payload[i] = uint8(i) * multiplier + } + + udpLength := header.UDPMinimumSize + payloadLen + + hdr := buffer.NewPrependable(udpLength) + u := header.UDP(hdr.Prepend(udpLength)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: uint16(udpLength), + }) + copy(u.Payload(), payload) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum = header.Checksum(payload, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + return hdr.View() + } + + var udpPayload1Buf [udpPayload1Length]byte + udpPayload1 := udpPayload1Buf[:] + ipv6Payload1 := udpGen(udpPayload1, 1) + + var udpPayload2Buf [udpPayload2Length]byte + udpPayload2 := udpPayload2Buf[:] + ipv6Payload2 := udpGen(udpPayload2, 2) + + tests := []struct { + name string + expectedPayload []byte + fragments []fragmentData + expectedPayloads [][]byte + }{ + { + name: "No fragmentation", + fragments: []fragmentData{ + { + nextHdr: uint8(header.UDPProtocolNumber), + data: ipv6Payload1.ToVectorisedView(), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Atomic fragment", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1), + []buffer.View{ + // Fragment extension header. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + + ipv6Payload1, + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Two fragments", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Two fragments with different IDs", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 2 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with per-fragment routing header with zero segments left", + fragments: []fragmentData{ + { + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 0. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 0. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Two fragments with per-fragment routing header with non-zero segments left", + fragments: []fragmentData{ + { + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 1. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 1. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 9, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with routing header with zero segments left", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header. + // + // Segments left = 0. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 9, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Two fragments with routing header with non-zero segments left", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header. + // + // Segments left = 1. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 9, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with routing header with zero segments left across fragments", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is fragmentExtHdrLen+8 because the + // first 8 bytes of the 16 byte routing extension header is in + // this fragment. + fragmentExtHdrLen+8, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header (part 1) + // + // Segments left = 0. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}), + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is + // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of + // the 16 byte routing extension header is in this fagment. + fragmentExtHdrLen+8+len(ipv6Payload1), + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 1, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + + // Routing extension header (part 2) + buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + + ipv6Payload1, + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Two fragments with routing header with non-zero segments left across fragments", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is fragmentExtHdrLen+8 because the + // first 8 bytes of the 16 byte routing extension header is in + // this fragment. + fragmentExtHdrLen+8, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header (part 1) + // + // Segments left = 1. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}), + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is + // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of + // the 16 byte routing extension header is in this fagment. + fragmentExtHdrLen+8+len(ipv6Payload1), + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 1, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + + // Routing extension header (part 2) + buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + + ipv6Payload1, + }, + ), + }, + }, + expectedPayloads: nil, + }, + // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal" + // fragmented traffic. + { + name: "Two fragments with atomic", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + // This fragment has the same ID as the other fragments but is an atomic + // fragment. It should not interfere with the other fragments. + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload2), + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), + + ipv6Payload2, + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload2, udpPayload1}, + }, + { + name: "Two interleaved fragmented packets", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 2 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), + + ipv6Payload2[:32], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload2)-32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 4, More = false, ID = 2 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), + + ipv6Payload2[32:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1, udpPayload2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize) + + // Serialize IPv6 fixed header. + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(f.data.Size()), + NextHeader: f.nextHdr, + HopLimit: 255, + SrcAddr: addr1, + DstAddr: addr2, + }) + + vv := hdr.View().ToVectorisedView() + vv.Append(f.data) + + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: vv, + }) + } + + if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { + t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) + } + + for i, p := range test.expectedPayloads { + gotPayload, _, err := ep.Read(nil) + if err != nil { + t.Fatalf("(i=%d) Read(nil): %s", i, err) + } + if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" { + t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) + } + } + + if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + }) + } +} -- cgit v1.2.3 From fc99a7ebf0c24b6f7b3cfd6351436373ed54548b Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Fri, 3 Apr 2020 18:34:48 -0700 Subject: Refactor software GSO code. Software GSO implementation currently has a complicated code path with implicit assumptions that all packets to WritePackets carry same Data and it does this to avoid allocations on the path etc. But this makes it hard to reuse the WritePackets API. This change breaks all such assumptions by introducing a new Vectorised View API ReadToVV which can be used to cleanly split a VV into multiple independent VVs. Further this change also makes packet buffers linkable to form an intrusive list. This allows us to get rid of the array of packet buffers that are passed in the WritePackets API call and replace it with a list of packet buffers. While this code does introduce some more allocations in the benchmarks it doesn't cause any degradation. Updates #231 PiperOrigin-RevId: 304731742 --- pkg/ilist/list.go | 13 ++- pkg/sentry/kernel/kernel.go | 22 +++-- pkg/tcpip/buffer/view.go | 53 +++++++++- pkg/tcpip/buffer/view_test.go | 137 ++++++++++++++++++++++++++ pkg/tcpip/link/channel/channel.go | 18 ++-- pkg/tcpip/link/fdbased/endpoint.go | 162 ++++++++++++++----------------- pkg/tcpip/link/loopback/loopback.go | 2 +- pkg/tcpip/link/muxed/injectable.go | 2 +- pkg/tcpip/link/sharedmem/sharedmem.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 14 +-- pkg/tcpip/link/waitable/waitable.go | 4 +- pkg/tcpip/link/waitable/waitable_test.go | 6 +- pkg/tcpip/network/arp/arp.go | 2 +- pkg/tcpip/network/ip_test.go | 2 +- pkg/tcpip/network/ipv4/ipv4.go | 37 +++++-- pkg/tcpip/network/ipv6/icmp.go | 2 +- pkg/tcpip/network/ipv6/ipv6.go | 12 +-- pkg/tcpip/stack/BUILD | 14 ++- pkg/tcpip/stack/forwarder_test.go | 8 +- pkg/tcpip/stack/iptables.go | 17 ++++ pkg/tcpip/stack/ndp_test.go | 2 +- pkg/tcpip/stack/packet_buffer.go | 14 +-- pkg/tcpip/stack/packet_buffer_state.go | 27 ------ pkg/tcpip/stack/registration.go | 4 +- pkg/tcpip/stack/route.go | 19 ++-- pkg/tcpip/stack/stack_test.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 47 +++++---- pkg/tcpip/transport/tcp/segment.go | 6 +- 28 files changed, 420 insertions(+), 230 deletions(-) delete mode 100644 pkg/tcpip/stack/packet_buffer_state.go (limited to 'pkg/tcpip/buffer') diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index 8f93e4d6d..0d07da3b1 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -86,12 +86,21 @@ func (l *List) Back() Element { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *List) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *List) PushFront(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { ElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -106,7 +115,6 @@ func (l *List) PushBack(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { ElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -127,7 +135,6 @@ func (l *List) PushBackList(m *List) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 0a448b57c..2e6f42b92 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -564,15 +564,25 @@ func (ts *TaskSet) unregisterEpollWaiters() { ts.mu.RLock() defer ts.mu.RUnlock() + + // Tasks that belong to the same process could potentially point to the + // same FDTable. So we retain a map of processed ones to avoid + // processing the same FDTable multiple times. + processed := make(map[*FDTable]struct{}) for t := range ts.Root.tids { // We can skip locking Task.mu here since the kernel is paused. - if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - if e, ok := file.FileOperations.(*epoll.EventPoll); ok { - e.UnregisterEpollWaiters() - } - }) + if t.fdTable == nil { + continue + } + if _, ok := processed[t.fdTable]; ok { + continue } + t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + if e, ok := file.FileOperations.(*epoll.EventPoll); ok { + e.UnregisterEpollWaiters() + } + }) + processed[t.fdTable] = struct{}{} } } diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8d42cd066..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -17,6 +17,7 @@ package buffer import ( "bytes" + "io" ) // View is a slice of a buffer, with convenience methods. @@ -89,6 +90,47 @@ func (vv *VectorisedView) TrimFront(count int) { } } +// Read implements io.Reader. +func (vv *VectorisedView) Read(v View) (copied int, err error) { + count := len(v) + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + copy(v[copied:], vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return copied, nil + } + count -= len(vv.views[0]) + copy(v[copied:], vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + if copied == 0 { + return 0, io.EOF + } + return copied, nil +} + +// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It +// returns the number of bytes copied. +func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) { + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + dstVV.AppendView(vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return + } + count -= len(vv.views[0]) + dstVV.AppendView(vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + return copied +} + // CapLength irreversibly reduces the length of the vectorised view. func (vv *VectorisedView) CapLength(length int) { if length < 0 { @@ -116,12 +158,12 @@ func (vv *VectorisedView) CapLength(length int) { // Clone returns a clone of this VectorisedView. // If the buffer argument is large enough to contain all the Views of this VectorisedView, // the method will avoid allocations and use the buffer to store the Views of the clone. -func (vv VectorisedView) Clone(buffer []View) VectorisedView { +func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } // First returns the first view of the vectorised view. -func (vv VectorisedView) First() View { +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { return nil } @@ -134,11 +176,12 @@ func (vv *VectorisedView) RemoveFirst() { return } vv.size -= len(vv.views[0]) + vv.views[0] = nil vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. -func (vv VectorisedView) Size() int { +func (vv *VectorisedView) Size() int { return vv.size } @@ -146,7 +189,7 @@ func (vv VectorisedView) Size() int { // // If the vectorised view contains a single view, that view will be returned // directly. -func (vv VectorisedView) ToView() View { +func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } @@ -158,7 +201,7 @@ func (vv VectorisedView) ToView() View { } // Views returns the slice containing the all views. -func (vv VectorisedView) Views() []View { +func (vv *VectorisedView) Views() []View { return vv.views } diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index ebc3a17b7..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -233,3 +233,140 @@ func TestToClone(t *testing.T) { }) } } + +func TestVVReadToVV(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + wantBytes string + leftVV VectorisedView + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + wantBytes: "0123456789", + leftVV: vv(20, "01234567890123456789"), + }, + { + comment: "largeVV, multiple views, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + wantBytes: "123345", + leftVV: vv(7, "567", "8910"), + }, + { + comment: "smallVV (multiple views), large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + wantBytes: "123", + leftVV: vv(0, ""), + }, + { + comment: "smallVV (single view), large read", + vv: vv(1, "1"), + bytesToRead: 10, + wantBytes: "1", + leftVV: vv(0, ""), + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + wantBytes: "", + leftVV: vv(0, ""), + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + var readTo VectorisedView + inSize := tc.vv.Size() + copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead) + if got, want := copied, len(tc.wantBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc) + } + if got, want := string(readTo.ToView()), tc.wantBytes; got != want { + t.Errorf("unexpected content in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { + t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want) + } + }) + } +} + +func TestVVRead(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + readBytes string + leftBytes string + wantError bool + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + readBytes: "0123456789", + leftBytes: "01234567890123456789", + }, + { + comment: "largeVV, multiple buffers, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + readBytes: "123345", + leftBytes: "5678910", + }, + { + comment: "smallVV, large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + readBytes: "123", + leftBytes: "", + }, + { + comment: "smallVV, large read", + vv: vv(1, "1"), + bytesToRead: 10, + readBytes: "1", + leftBytes: "", + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + readBytes: "", + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + readTo := NewView(tc.bytesToRead) + inSize := tc.vv.Size() + copied, err := tc.vv.Read(readTo) + if !tc.wantError && err != nil { + t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err) + } + readTo = readTo[:copied] + if got, want := copied, len(tc.readBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(readTo), tc.readBytes; got != want { + t.Errorf("unexpected data in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want { + t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want) + } + }) + } +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a8d6653ce..b4a0ae53d 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -28,7 +28,7 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Pkt stack.PacketBuffer + Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO Route stack.Route @@ -257,7 +257,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne route := r.Clone() route.Release() p := PacketInfo{ - Pkt: pkt, + Pkt: &pkt, Proto: protocol, GSO: gso, Route: route, @@ -269,21 +269,15 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() - payloadView := pkts[0].Data.ToView() n := 0 - for _, pkt := range pkts { - off := pkt.DataOffset - size := pkt.DataSize + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ - Pkt: stack.PacketBuffer{ - Header: pkt.Header, - Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), - }, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, @@ -301,7 +295,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: stack.PacketBuffer{Data: vv}, + Pkt: &stack.PacketBuffer{Data: vv}, Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 3b3b6909b..7198742b7 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -441,118 +441,106 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - var ethHdrBuf []byte - // hdr + data - iovLen := 2 - if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, - } - - // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ - } +// +// NOTE: This API uses sendmmsg to batch packets. As a result the underlying FD +// picked to write the packet out has to be the same for all packets in the +// list. In other words all packets in the batch should belong to the same +// flow. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := pkts.Len() - n := len(pkts) - - views := pkts[0].Data.Views() - /* - * Each boundary in views can add one more iovec. - * - * payload | | | | - * ----------------------------- - * packets | | | | | | | - * ----------------------------- - * iovecs | | | | | | | | | - */ - iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) mmsgHdrs := make([]rawfile.MMsgHdr, n) + i := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + var ethHdrBuf []byte + iovLen := 0 + if e.hdrSize > 0 { + // Add ethernet header if needed. + ethHdrBuf = make([]byte, header.EthernetMinimumSize) + eth := header.Ethernet(ethHdrBuf) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } - iovecIdx := 0 - viewIdx := 0 - viewOff := 0 - off := 0 - nextOff := 0 - for i := range pkts { - // TODO(b/134618279): Different packets may have different data - // in the future. We should handle this. - if !viewsEqual(pkts[i].Data.Views(), views) { - panic("All packets in pkts should have the same Data.") + // Preserve the src address if it's set in the route. + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + iovLen++ } - prevIovecIdx := iovecIdx - mmsgHdr := &mmsgHdrs[i] - mmsgHdr.Msg.Iov = &iovec[iovecIdx] - packetSize := pkts[i].DataSize - hdr := &pkts[i].Header - - off = pkts[i].DataOffset - if off != nextOff { - // We stop in a different point last time. - size := packetSize - viewIdx = 0 - viewOff = 0 - for size > 0 { - if size >= len(views[viewIdx]) { - viewIdx++ - viewOff = 0 - size -= len(views[viewIdx]) - } else { - viewOff = size - size = 0 + var vnetHdrBuf []byte + vnetHdr := virtioNetHdr{} + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if gso != nil { + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + if gso.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen + vnetHdr.csumOffset = gso.CsumOffset + } + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { + switch gso.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + } + vnetHdr.gsoSize = gso.MSS } } + vnetHdrBuf = vnetHdrToByteSlice(&vnetHdr) + iovLen++ } - nextOff = off + packetSize + iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) + mmsgHdr := &mmsgHdrs[i] + mmsgHdr.Msg.Iov = &iovecs[0] + iovecIdx := 0 + if vnetHdrBuf != nil { + v := &iovecs[iovecIdx] + v.Base = &vnetHdrBuf[0] + v.Len = uint64(len(vnetHdrBuf)) + iovecIdx++ + } if ethHdrBuf != nil { - v := &iovec[iovecIdx] + v := &iovecs[iovecIdx] v.Base = ðHdrBuf[0] v.Len = uint64(len(ethHdrBuf)) iovecIdx++ } - - v := &iovec[iovecIdx] + pktSize := uint64(0) + // Encode L3 Header + v := &iovecs[iovecIdx] + hdr := &pkt.Header hdrView := hdr.View() v.Base = &hdrView[0] v.Len = uint64(len(hdrView)) + pktSize += v.Len iovecIdx++ - for packetSize > 0 { - vec := &iovec[iovecIdx] + // Now encode the Transport Payload. + pktViews := pkt.Data.Views() + for i := range pktViews { + vec := &iovecs[iovecIdx] iovecIdx++ - - v := views[viewIdx] - vec.Base = &v[viewOff] - s := len(v) - viewOff - if s <= packetSize { - viewIdx++ - viewOff = 0 - } else { - s = packetSize - viewOff += s - } - vec.Len = uint64(s) - packetSize -= s + vec.Base = &pktViews[i][0] + vec.Len = uint64(len(pktViews[i])) + pktSize += vec.Len } - - mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) + mmsgHdr.Msg.Iovlen = uint64(iovecIdx) + i++ } packets := 0 for packets < n { - fd := e.fds[pkts[packets].Hash%uint32(len(e.fds))] + fd := e.fds[pkts.Front().Hash%uint32(len(e.fds))] sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs) if err != nil { return packets, err diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 4039753b7..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index f5973066d..a5478ce17 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,7 +87,7 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, tcpip.ErrNoRoute diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6461d0108..0796d717e 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0a6b8945c..062388f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -200,7 +200,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", protocol, pkt.Header.View(), gso) } @@ -233,20 +233,16 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket(gso, protocol, pkt) + e.dumpPacket(gso, protocol, &pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - view := pkts[0].Data.ToView() - for _, pkt := range pkts { - e.dumpPacket(gso, protocol, stack.PacketBuffer{ - Header: pkt.Header, - Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), - }) +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.dumpPacket(gso, protocol, pkt) } return e.lower.WritePackets(r, gso, pkts, protocol) } diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 52fe397bf..2b3741276 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -112,9 +112,9 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { if !e.writeGate.Enter() { - return len(pkts), nil + return pkts.Len(), nil } n, err := e.lower.WritePackets(r, gso, pkts, protocol) diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 88224e494..54eb5322b 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -71,9 +71,9 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcp } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - e.writeCount += len(pkts) - return len(pkts), nil +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += pkts.Len() + return pkts.Len(), nil } func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 255098372..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4950d69fc..4c20301c6 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a7d9a8b25..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -280,28 +280,47 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil + } + + for pkt := pkts.Front(); pkt != nil; { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + pkt = pkt.Next() } // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - for i := range pkts { - if ok := ipt.Check(stack.Output, pkts[i]); !ok { - // iptables is telling us to drop the packet. + dropped := ipt.CheckPackets(stack.Output, pkts) + if len(dropped) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + + // Slow Path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { continue } - ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) - pkts[i].NetworkHeader = buffer.View(ip) + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + return n, nil } // WriteHeaderIncludedPacket writes a packet already containing a network diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6d2d2c034..f91180aa3 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -79,7 +79,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Only the first view in vv is accounted for by h. To account for the // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. - payload := pkt.Data + payload := pkt.Data.Clone(nil) payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index b462b8604..a815b4d9b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -143,19 +143,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil } - for i := range pkts { - hdr := &pkts[i].Header - size := pkts[i].DataSize - ip := e.addIPHeader(r, hdr, size, params) - pkts[i].NetworkHeader = buffer.View(ip) + for pb := pkts.Front(); pb != nil; pb = pb.Next() { + ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) + pb.NetworkHeader = buffer.View(ip) } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 8d80e9cee..5e963a4af 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -15,6 +15,18 @@ go_template_instance( }, ) +go_template_instance( + name = "packet_buffer_list", + out = "packet_buffer_list.go", + package = "stack", + prefix = "PacketBuffer", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*PacketBuffer", + "Linker": "*PacketBuffer", + }, +) + go_library( name = "stack", srcs = [ @@ -29,7 +41,7 @@ go_library( "ndp.go", "nic.go", "packet_buffer.go", - "packet_buffer_state.go", + "packet_buffer_list.go", "rand.go", "registration.go", "route.go", diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c45c43d21..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } @@ -260,10 +260,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 - for _, pkt := range pkts { - e.WritePacket(r, gso, protocol, pkt) + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.WritePacket(r, gso, protocol, *pkt) n++ } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 37907ae24..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -209,6 +209,23 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { return true } +// CheckPackets runs pkts through the rules for hook and returns a map of packets that +// should not go forward. +// +// NOTE: unlike the Check API the returned map contains packets that should be +// dropped. +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if ok := it.Check(hook, *pkt); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + } + return drop +} + // Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 598468bdd..27dc8baf9 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -468,7 +468,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.IPv6(t, p.Pkt.Header.View(), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9367de180..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -23,9 +23,11 @@ import ( // As a PacketBuffer traverses up the stack, it may be necessary to pass it to // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. -// -// +stateify savable type PacketBuffer struct { + // PacketBufferEntry is used to build an intrusive list of + // PacketBuffers. + PacketBufferEntry + // Data holds the payload of the packet. For inbound packets, it also // holds the headers, which are consumed as the packet moves up the // stack. Headers are guaranteed not to be split across views. @@ -34,14 +36,6 @@ type PacketBuffer struct { // or otherwise modified. Data buffer.VectorisedView - // DataOffset is used for GSO output. It is the offset into the Data - // field where the payload of this packet starts. - DataOffset int - - // DataSize is used for GSO output. It is the size of this packet's - // payload. - DataSize int - // Header holds the headers of outbound packets. As a packet is passed // down the stack, each layer adds to Header. Header buffer.Prependable diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go deleted file mode 100644 index 0c6b7924c..000000000 --- a/pkg/tcpip/stack/packet_buffer_state.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import "gvisor.dev/gvisor/pkg/tcpip/buffer" - -// beforeSave is invoked by stateify. -func (pk *PacketBuffer) beforeSave() { - // Non-Data fields may be slices of the Data field. This causes - // problems for SR, so during save we make each header independent. - pk.Header = pk.Header.DeepCopy() - pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) - pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) - pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) -} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ac043b722..23ca9ee03 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -246,7 +246,7 @@ type NetworkEndpoint interface { // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. @@ -393,7 +393,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 9fbe8a411..a0e5e0300 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -168,23 +168,26 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuff return err } -// WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +// WritePackets writes a list of n packets through the given route and returns +// the number of packets written. +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - payloadSize := 0 - for i := 0; i < n; i++ { - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) - payloadSize += pkts[i].DataSize + + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Header.UsedLength() + writtenBytes += pb.Data.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b8543b71e..3f8a2a095 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -153,7 +153,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3239a5911..2ca3fb809 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -756,8 +756,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) hdr := &pkt.Header - packetSize := pkt.DataSize - off := pkt.DataOffset + packetSize := pkt.Data.Size() // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) pkt.TransportHeader = buffer.View(tcp) @@ -782,12 +781,18 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) + xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } } func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { + // We need to shallow clone the VectorisedView here as ReadToView will + // split the VectorisedView and Trim underlying views as it splits. Not + // doing the clone here will cause the underlying views of data itself + // to be altered. + data = data.Clone(nil) + optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff @@ -796,31 +801,25 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso mss := int(gso.MSS) n := (data.Size() + mss - 1) / mss - // Allocate one big slice for all the headers. - hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen - buf := make([]byte, n*hdrSize) - pkts := make([]stack.PacketBuffer, n) - for i := range pkts { - pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) - } - size := data.Size() - off := 0 + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + var pkts stack.PacketBufferList for i := 0; i < n; i++ { packetSize := mss if packetSize > size { packetSize = size } size -= packetSize - pkts[i].DataOffset = off - pkts[i].DataSize = packetSize - pkts[i].Data = data - pkts[i].Hash = tf.txHash - pkts[i].Owner = owner - buildTCPHdr(r, tf, &pkts[i], gso) - off += packetSize + var pkt stack.PacketBuffer + pkt.Header = buffer.NewPrependable(hdrSize) + pkt.Hash = tf.txHash + pkt.Owner = owner + data.ReadToVV(&pkt.Data, packetSize) + buildTCPHdr(r, tf, &pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkts.PushBack(&pkt) } + if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } @@ -845,12 +844,10 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac } pkt := stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - DataOffset: 0, - DataSize: data.Size(), - Data: data, - Hash: tf.txHash, - Owner: owner, + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + Data: data, + Hash: tf.txHash, + Owner: owner, } buildTCPHdr(r, tf, &pkt, gso) diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index e6fe7985d..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -77,9 +77,11 @@ func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.V id: id, route: r.Clone(), } - s.views[0] = v - s.data = buffer.NewVectorisedView(len(v), s.views[:1]) s.rcvdTime = time.Now() + if len(v) != 0 { + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) + } return s } -- cgit v1.2.3 From a551add5d8a5bf631cd9859c761e579fdb33ec82 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 13 Apr 2020 17:37:21 -0700 Subject: Remove View.First() and View.RemoveFirst() These methods let users eaily break the VectorisedView abstraction, and allowed netstack to slip into pseudo-enforcement of the "all headers are in the first View" invariant. Removing them and replacing with PullUp(n) breaks this reliance and will make it easier to add iptables support and rework network buffer management. The new View.PullUp(n) method is low cost in the common case, when when all the headers fit in the first View. --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++++++++---- pkg/tcpip/buffer/view_test.go | 113 +++++++++++++++++++++++++++++ pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 +++++++++++++---- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 +++-- pkg/tcpip/network/ipv4/ipv4.go | 12 ++- pkg/tcpip/network/ipv6/icmp.go | 74 ++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +++++- pkg/tcpip/stack/iptables_targets.go | 23 ++++-- pkg/tcpip/stack/nic.go | 34 +++------ pkg/tcpip/stack/packet_buffer.go | 4 +- pkg/tcpip/stack/stack_test.go | 10 ++- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++++--- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 ++- 25 files changed, 395 insertions(+), 147 deletions(-) (limited to 'pkg/tcpip/buffer') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ff1cfd8f6..55c0f04f3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(pkt.Data.First()) + tcpHeader = header.TCP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3359418c1..04d03d494 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(pkt.Data.First()) + udpHeader = header.UDP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..0799c8f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7acbfa0a8..cf73a939e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,7 +93,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 104aafbed..17202cc7a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,7 +328,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -378,7 +382,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331b0817b..486725131 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index e9c652042..c7c663498 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -473,7 +476,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +520,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +567,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +622,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..6b91159d4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..8be61f4b1 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView := newPkt.Data.First() + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + newPkt.NetworkHeader = headerView hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { return RuleDrop, 0 } - udpHeader = header.UDP(newPkt.Data.First()) + udpHeader = header.UDP(hdr) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { + if pkt.Data.Size() < header.TCPMinimumSize { return RuleDrop, 0 } - tcpHeader = header.TCP(newPkt.TransportHeader) + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 016dbe15e..0c2b1f36a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,22 +1289,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1332,12 +1318,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1375,11 +1362,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index dc125f25e..e954a8b7e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,9 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c7634ceb1..d45d2cc1f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ab1014c7f..286c66cf5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 120d3b50f4875824ec69f0cc39a09ac84fced35c Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 21 Apr 2020 07:15:25 -0700 Subject: Automated rollback of changelist 307477185 PiperOrigin-RevId: 307598974 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++---------- pkg/tcpip/buffer/view_test.go | 113 ----------------------------- pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 ++++------------- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 ++--- pkg/tcpip/network/ipv4/ipv4.go | 12 +-- pkg/tcpip/network/ipv6/icmp.go | 74 +++++++------------ pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +----- pkg/tcpip/stack/iptables_targets.go | 23 ++---- pkg/tcpip/stack/nic.go | 34 ++++++--- pkg/tcpip/stack/packet_buffer.go | 4 +- pkg/tcpip/stack/stack_test.go | 10 +-- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++----- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 +-- 25 files changed, 147 insertions(+), 395 deletions(-) (limited to 'pkg/tcpip/buffer') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 55c0f04f3..ff1cfd8f6 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,13 +121,12 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.TCPMinimumSize { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 04d03d494..3359418c1 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,13 +120,12 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index f01217c91..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,8 +77,7 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. It panics -// if count > vv.Size(). +// TrimFront removes the first "count" bytes of the vectorised view. func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -87,7 +86,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } } @@ -105,7 +104,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } if copied == 0 { return 0, io.EOF @@ -127,7 +126,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } return copied } @@ -163,37 +162,22 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// PullUp returns the first "count" bytes of the vectorised view. If those -// bytes aren't already contiguous inside the vectorised view, PullUp will -// reallocate as needed to make them contiguous. PullUp fails and returns false -// when count > vv.Size(). -func (vv *VectorisedView) PullUp(count int) (View, bool) { +// First returns the first view of the vectorised view. +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { - return nil, count == 0 - } - if count <= len(vv.views[0]) { - return vv.views[0][:count], true - } - if count > vv.size { - return nil, false + return nil } + return vv.views[0] +} - newFirst := NewView(count) - i := 0 - for offset := 0; offset < count; i++ { - copy(newFirst[offset:], vv.views[i]) - if count-offset < len(vv.views[i]) { - vv.views[i].TrimFront(count - offset) - break - } - offset += len(vv.views[i]) - vv.views[i] = nil +// RemoveFirst removes the first view of the vectorised view. +func (vv *VectorisedView) RemoveFirst() { + if len(vv.views) == 0 { + return } - // We're guaranteed that i > 0, since count is too large for the first - // view. - vv.views[i-1] = newFirst - vv.views = vv.views[i-1:] - return newFirst, true + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -241,10 +225,3 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } - -// removeFirst panics when len(vv.views) < 1. -func (vv *VectorisedView) removeFirst() { - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] -} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index c56795c7b..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,7 +16,6 @@ package buffer import ( - "bytes" "reflect" "testing" ) @@ -371,115 +370,3 @@ func TestVVRead(t *testing.T) { }) } } - -var pullUpTestCases = []struct { - comment string - in VectorisedView - count int - want []byte - result VectorisedView - ok bool -}{ - { - comment: "simple case", - in: vv(2, "12"), - count: 1, - want: []byte("1"), - result: vv(2, "12"), - ok: true, - }, - { - comment: "entire View", - in: vv(2, "1", "2"), - count: 1, - want: []byte("1"), - result: vv(2, "1", "2"), - ok: true, - }, - { - comment: "spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, - { - comment: "spanning across all Views", - in: vv(5, "1", "23", "45"), - count: 5, - want: []byte("12345"), - result: vv(5, "12345"), - ok: true, - }, - { - comment: "count = 0", - in: vv(1, "1"), - count: 0, - want: []byte{}, - result: vv(1, "1"), - ok: true, - }, - { - comment: "count = size", - in: vv(1, "1"), - count: 1, - want: []byte("1"), - result: vv(1, "1"), - ok: true, - }, - { - comment: "count too large", - in: vv(3, "1", "23"), - count: 4, - want: nil, - result: vv(3, "1", "23"), - ok: false, - }, - { - comment: "empty vv", - in: vv(0, ""), - count: 1, - want: nil, - result: vv(0, ""), - ok: false, - }, - { - comment: "empty vv, count = 0", - in: vv(0, ""), - count: 0, - want: nil, - result: vv(0, ""), - ok: true, - }, - { - comment: "empty views", - in: vv(3, "", "1", "", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, -} - -func TestPullUp(t *testing.T) { - for _, c := range pullUpTestCases { - got, ok := c.in.PullUp(c.count) - - // Is the return value right? - if ok != c.ok { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", - c.comment, c.count, c.in, ok, c.ok) - } - if bytes.Compare(got, View(c.want)) != 0 { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", - c.comment, c.count, c.in, got, c.want) - } - - // Is the underlying structure right? - if !reflect.DeepEqual(c.in, c.result) { - t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", - c.comment, c.count, c.in, c.result) - } - } -} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { return tcpip.ErrBadAddress } - linkHeader := header.Ethernet(hdr) + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..27ea3f531 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) + rcvd := []byte(c.packets[0].vv.First()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0799c8f4d..be2537a82 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,7 +171,11 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + first := pkt.Header.View() + if len(first) == 0 { + first = pkt.Data.First() + } + logPacket(prefix, protocol, first, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -234,7 +238,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -243,49 +247,28 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P size := uint16(0) var fragmentOffset uint16 var moreFragments bool - - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) - switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - ipv4 := header.IPv4(hdr) + ipv4 := header.IPv4(b) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) + b = b[ipv4.HeaderLength():] id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - ipv6 := header.IPv6(hdr) + ipv6 := header.IPv6(b) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + b = b[header.IPv6MinimumSize:] case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { - return - } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + arp := header.ARP(b) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -301,7 +284,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -314,11 +297,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) - if !ok { - break - } - icmp := header.ICMPv4(hdr) + icmp := header.ICMPv4(b) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -351,11 +330,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) - if !ok { - break - } - icmp := header.ICMPv6(hdr) + icmp := header.ICMPv6(b) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -386,11 +361,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { - break - } - udp := header.UDP(hdr) + udp := header.UDP(b) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -400,11 +371,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { - break - } - tcp := header.TCP(hdr) + tcp := header.TCP(b) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cf73a939e..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,10 +93,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..c4bf1ba5c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,11 +25,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - hdr := header.IPv4(h) + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -38,12 +34,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } - hlen := int(hdr.HeaderLength()) - if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + hlen := int(h.HeaderLength()) + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -52,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 17202cc7a..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,11 +328,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return tcpip.ErrInvalidOptionValue - } - ip := header.IPv4(h) + ip := header.IPv4(pkt.Data.First()) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -382,11 +378,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b68983d10 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,11 +28,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - hdr := header.IPv6(h) + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -40,21 +36,17 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) - if !ok { - return - } - fragHdr := header.IPv6Fragment(f) - if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { + f := header.IPv6Fragment(pkt.Data.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -63,19 +55,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } @@ -84,9 +76,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.TrimFront(len(h)) + payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -107,40 +101,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) - if !ok { + if len(v) < header.ICMPv6PacketTooBigMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) - if !ok { + if len(v) < header.ICMPv6DstUnreachableMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + switch h.Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and ToView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.ToView()) + ns := header.NDPNeighborSolicit(h.NDPPayload()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -298,16 +286,12 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.ToView()) + na := header.NDPNeighborAdvert(h.NDPPayload()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -379,15 +363,14 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) - if !ok { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, icmpHdr) + copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -401,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -423,9 +406,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + p := h.NDPPayload() + if len(p) < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -443,11 +425,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - ra := header.NDPRouterAdvert(payload.ToView()) + ra := header.NDPRouterAdvert(p) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..bd099a7f8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,8 +166,7 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, + size: header.ICMPv6NeighborSolicitMinimumSize}, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 486725131..331b0817b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,11 +171,7 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c7c663498..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,10 +70,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -476,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -520,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -567,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -622,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6b91159d4..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,11 +212,6 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. -// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -231,9 +226,7 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -278,21 +271,14 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pk.NetworkHeader is set. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. + // pkt.Data.First(). if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } + pkt.NetworkHeader = pkt.Data.First() } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8be61f4b1..7b4543caf 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,12 +96,9 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return RuleDrop, 0 - } + headerView := newPkt.Data.First() netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -120,14 +117,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.UDPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { return RuleDrop, 0 } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(newPkt.Data.First()) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -135,14 +128,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.TCPMinimumSize { + if len(pkt.Data.First()) < header.TCPMinimumSize { return RuleDrop, 0 } - hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return RuleDrop, 0 - } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(newPkt.TransportHeader) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0c2b1f36a..016dbe15e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + + src, dst := netProto.ParseAddresses(pkt.Data.First()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,8 +1289,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + + firstData := pkt.Data.First() + pkt.Data.RemoveFirst() + + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { + pkt.Header = buffer.NewPrependableFromView(firstData) + } else { + firstDataLen := len(firstData) + + // pkt.Header should have enough capacity to hold n.linkEP's headers. + pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) + + // TODO(b/151227689): avoid copying the packet when forwarding + if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) + } } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1318,13 +1332,12 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1362,12 +1375,11 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) - if !ok { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index e954a8b7e..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,9 +37,7 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. + // down the stack, each layer adds to Header. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d45d2cc1f..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,18 +95,16 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { + nb := pkt.Data.First() + if len(nb) < fakeNetHeaderLen { return } + pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,11 +642,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..feef8dca0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.Data.First()) + if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.Data.First()) + if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7712ce652..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,11 +144,7 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) + h := header.TCP(s.data.First()) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -160,16 +156,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(h[header.TCPMinimumSize:offset]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -181,19 +173,18 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() + s.csum = h.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) + xsum = h.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 286c66cf5..ab1014c7f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..edb54f0be 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..6e31a9bac 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,13 +68,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From eccae0f77d3708d591119488f427eca90de7c711 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 23 Apr 2020 17:27:24 -0700 Subject: Remove View.First() and View.RemoveFirst() These methods let users eaily break the VectorisedView abstraction, and allowed netstack to slip into pseudo-enforcement of the "all headers are in the first View" invariant. Removing them and replacing with PullUp(n) breaks this reliance and will make it easier to add iptables support and rework network buffer management. The new View.PullUp(n) method is low cost in the common case, when when all the headers fit in the first View. PiperOrigin-RevId: 308163542 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++++++++---- pkg/tcpip/buffer/view_test.go | 113 +++++++++++++++++++++++++++++ pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/rawfile/BUILD | 9 ++- pkg/tcpip/link/rawfile/rawfile_test.go | 46 ++++++++++++ pkg/tcpip/link/rawfile/rawfile_unsafe.go | 6 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 +++++++++++++---- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 +++-- pkg/tcpip/network/ipv4/ipv4.go | 12 ++- pkg/tcpip/network/ipv6/icmp.go | 74 ++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +++++- pkg/tcpip/stack/iptables_targets.go | 23 ++++-- pkg/tcpip/stack/nic.go | 34 +++------ pkg/tcpip/stack/packet_buffer.go | 8 +- pkg/tcpip/stack/stack_test.go | 10 ++- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++++--- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 ++- 28 files changed, 458 insertions(+), 149 deletions(-) create mode 100644 pkg/tcpip/link/rawfile/rawfile_test.go (limited to 'pkg/tcpip/buffer') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ff1cfd8f6..55c0f04f3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(pkt.Data.First()) + tcpHeader = header.TCP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3359418c1..04d03d494 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(pkt.Data.First()) + udpHeader = header.UDP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 14b527bc2..9cc08d0e2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -18,3 +18,10 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "rawfile_test", + size = "small", + srcs = ["rawfile_test.go"], + library = ":rawfile", +) diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go new file mode 100644 index 000000000..8f14ba761 --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_test.go @@ -0,0 +1,46 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "syscall" + "testing" +) + +func TestNonBlockingWrite3ZeroLength(t *testing.T) { + fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + t.Fatalf("failed to open /dev/null: %v", err) + } + defer syscall.Close(fd) + + if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil { + t.Fatalf("failed to write: %v", err) + } +} + +func TestNonBlockingWrite3Nil(t *testing.T) { + fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + t.Fatalf("failed to open /dev/null: %v", err) + } + defer syscall.Close(fd) + + if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil { + t.Fatalf("failed to write: %v", err) + } +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 44e25d475..92efd0bf8 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -76,9 +76,13 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { // We have two buffers. Build the iovec that represents them and issue // a writev syscall. + var base *byte + if len(b1) > 0 { + base = &b1[0] + } iovec := [3]syscall.Iovec{ { - Base: &b1[0], + Base: base, Len: uint64(len(b1)), }, { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..0799c8f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7acbfa0a8..cf73a939e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,7 +93,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 104aafbed..17202cc7a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,7 +328,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -378,7 +382,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331b0817b..486725131 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index e9c652042..c7c663498 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -473,7 +476,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +520,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +567,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +622,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..6b91159d4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..8be61f4b1 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView := newPkt.Data.First() + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + newPkt.NetworkHeader = headerView hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { return RuleDrop, 0 } - udpHeader = header.UDP(newPkt.Data.First()) + udpHeader = header.UDP(hdr) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { + if pkt.Data.Size() < header.TCPMinimumSize { return RuleDrop, 0 } - tcpHeader = header.TCP(newPkt.TransportHeader) + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 016dbe15e..0c2b1f36a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,22 +1289,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1332,12 +1318,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1375,11 +1362,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index dc125f25e..7d36f8e84 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,13 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. + // + // TODO(gvisor.dev/issue/170): Forwarded packets don't currently + // populate Header, but should. This will be doable once early parsing + // (https://github.com/google/gvisor/pull/1995) is supported. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c7634ceb1..d45d2cc1f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ab1014c7f..286c66cf5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 55f0c3316af8ea2a1fcc16511efc580f307623f6 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Mon, 27 Apr 2020 12:25:10 -0700 Subject: Automated rollback of changelist 308163542 PiperOrigin-RevId: 308674219 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++---------- pkg/tcpip/buffer/view_test.go | 113 ----------------------------- pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/rawfile/BUILD | 9 +-- pkg/tcpip/link/rawfile/rawfile_test.go | 46 ------------ pkg/tcpip/link/rawfile/rawfile_unsafe.go | 6 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 ++++------------- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 ++--- pkg/tcpip/network/ipv4/ipv4.go | 12 +-- pkg/tcpip/network/ipv6/icmp.go | 74 +++++++------------ pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +----- pkg/tcpip/stack/iptables_targets.go | 23 ++---- pkg/tcpip/stack/nic.go | 34 ++++++--- pkg/tcpip/stack/packet_buffer.go | 8 +- pkg/tcpip/stack/stack_test.go | 10 +-- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++----- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 +-- 28 files changed, 149 insertions(+), 458 deletions(-) delete mode 100644 pkg/tcpip/link/rawfile/rawfile_test.go (limited to 'pkg/tcpip/buffer') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 55c0f04f3..ff1cfd8f6 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,13 +121,12 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.TCPMinimumSize { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 04d03d494..3359418c1 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,13 +120,12 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index f01217c91..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,8 +77,7 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. It panics -// if count > vv.Size(). +// TrimFront removes the first "count" bytes of the vectorised view. func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -87,7 +86,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } } @@ -105,7 +104,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } if copied == 0 { return 0, io.EOF @@ -127,7 +126,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } return copied } @@ -163,37 +162,22 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// PullUp returns the first "count" bytes of the vectorised view. If those -// bytes aren't already contiguous inside the vectorised view, PullUp will -// reallocate as needed to make them contiguous. PullUp fails and returns false -// when count > vv.Size(). -func (vv *VectorisedView) PullUp(count int) (View, bool) { +// First returns the first view of the vectorised view. +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { - return nil, count == 0 - } - if count <= len(vv.views[0]) { - return vv.views[0][:count], true - } - if count > vv.size { - return nil, false + return nil } + return vv.views[0] +} - newFirst := NewView(count) - i := 0 - for offset := 0; offset < count; i++ { - copy(newFirst[offset:], vv.views[i]) - if count-offset < len(vv.views[i]) { - vv.views[i].TrimFront(count - offset) - break - } - offset += len(vv.views[i]) - vv.views[i] = nil +// RemoveFirst removes the first view of the vectorised view. +func (vv *VectorisedView) RemoveFirst() { + if len(vv.views) == 0 { + return } - // We're guaranteed that i > 0, since count is too large for the first - // view. - vv.views[i-1] = newFirst - vv.views = vv.views[i-1:] - return newFirst, true + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -241,10 +225,3 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } - -// removeFirst panics when len(vv.views) < 1. -func (vv *VectorisedView) removeFirst() { - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] -} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index c56795c7b..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,7 +16,6 @@ package buffer import ( - "bytes" "reflect" "testing" ) @@ -371,115 +370,3 @@ func TestVVRead(t *testing.T) { }) } } - -var pullUpTestCases = []struct { - comment string - in VectorisedView - count int - want []byte - result VectorisedView - ok bool -}{ - { - comment: "simple case", - in: vv(2, "12"), - count: 1, - want: []byte("1"), - result: vv(2, "12"), - ok: true, - }, - { - comment: "entire View", - in: vv(2, "1", "2"), - count: 1, - want: []byte("1"), - result: vv(2, "1", "2"), - ok: true, - }, - { - comment: "spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, - { - comment: "spanning across all Views", - in: vv(5, "1", "23", "45"), - count: 5, - want: []byte("12345"), - result: vv(5, "12345"), - ok: true, - }, - { - comment: "count = 0", - in: vv(1, "1"), - count: 0, - want: []byte{}, - result: vv(1, "1"), - ok: true, - }, - { - comment: "count = size", - in: vv(1, "1"), - count: 1, - want: []byte("1"), - result: vv(1, "1"), - ok: true, - }, - { - comment: "count too large", - in: vv(3, "1", "23"), - count: 4, - want: nil, - result: vv(3, "1", "23"), - ok: false, - }, - { - comment: "empty vv", - in: vv(0, ""), - count: 1, - want: nil, - result: vv(0, ""), - ok: false, - }, - { - comment: "empty vv, count = 0", - in: vv(0, ""), - count: 0, - want: nil, - result: vv(0, ""), - ok: true, - }, - { - comment: "empty views", - in: vv(3, "", "1", "", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, -} - -func TestPullUp(t *testing.T) { - for _, c := range pullUpTestCases { - got, ok := c.in.PullUp(c.count) - - // Is the return value right? - if ok != c.ok { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", - c.comment, c.count, c.in, ok, c.ok) - } - if bytes.Compare(got, View(c.want)) != 0 { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", - c.comment, c.count, c.in, got, c.want) - } - - // Is the underlying structure right? - if !reflect.DeepEqual(c.in, c.result) { - t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", - c.comment, c.count, c.in, c.result) - } - } -} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { return tcpip.ErrBadAddress } - linkHeader := header.Ethernet(hdr) + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 9cc08d0e2..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -18,10 +18,3 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) - -go_test( - name = "rawfile_test", - size = "small", - srcs = ["rawfile_test.go"], - library = ":rawfile", -) diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go deleted file mode 100644 index 8f14ba761..000000000 --- a/pkg/tcpip/link/rawfile/rawfile_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package rawfile - -import ( - "syscall" - "testing" -) - -func TestNonBlockingWrite3ZeroLength(t *testing.T) { - fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) - if err != nil { - t.Fatalf("failed to open /dev/null: %v", err) - } - defer syscall.Close(fd) - - if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil { - t.Fatalf("failed to write: %v", err) - } -} - -func TestNonBlockingWrite3Nil(t *testing.T) { - fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) - if err != nil { - t.Fatalf("failed to open /dev/null: %v", err) - } - defer syscall.Close(fd) - - if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil { - t.Fatalf("failed to write: %v", err) - } -} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 92efd0bf8..44e25d475 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -76,13 +76,9 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { // We have two buffers. Build the iovec that represents them and issue // a writev syscall. - var base *byte - if len(b1) > 0 { - base = &b1[0] - } iovec := [3]syscall.Iovec{ { - Base: base, + Base: &b1[0], Len: uint64(len(b1)), }, { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..27ea3f531 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) + rcvd := []byte(c.packets[0].vv.First()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0799c8f4d..be2537a82 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,7 +171,11 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + first := pkt.Header.View() + if len(first) == 0 { + first = pkt.Data.First() + } + logPacket(prefix, protocol, first, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -234,7 +238,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -243,49 +247,28 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P size := uint16(0) var fragmentOffset uint16 var moreFragments bool - - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) - switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - ipv4 := header.IPv4(hdr) + ipv4 := header.IPv4(b) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) + b = b[ipv4.HeaderLength():] id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - ipv6 := header.IPv6(hdr) + ipv6 := header.IPv6(b) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + b = b[header.IPv6MinimumSize:] case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { - return - } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + arp := header.ARP(b) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -301,7 +284,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -314,11 +297,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) - if !ok { - break - } - icmp := header.ICMPv4(hdr) + icmp := header.ICMPv4(b) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -351,11 +330,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) - if !ok { - break - } - icmp := header.ICMPv6(hdr) + icmp := header.ICMPv6(b) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -386,11 +361,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { - break - } - udp := header.UDP(hdr) + udp := header.UDP(b) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -400,11 +371,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { - break - } - tcp := header.TCP(hdr) + tcp := header.TCP(b) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cf73a939e..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,10 +93,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..c4bf1ba5c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,11 +25,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - hdr := header.IPv4(h) + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -38,12 +34,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } - hlen := int(hdr.HeaderLength()) - if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + hlen := int(h.HeaderLength()) + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -52,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 17202cc7a..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,11 +328,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return tcpip.ErrInvalidOptionValue - } - ip := header.IPv4(h) + ip := header.IPv4(pkt.Data.First()) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -382,11 +378,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b68983d10 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,11 +28,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - hdr := header.IPv6(h) + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -40,21 +36,17 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) - if !ok { - return - } - fragHdr := header.IPv6Fragment(f) - if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { + f := header.IPv6Fragment(pkt.Data.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -63,19 +55,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } @@ -84,9 +76,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.TrimFront(len(h)) + payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -107,40 +101,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) - if !ok { + if len(v) < header.ICMPv6PacketTooBigMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) - if !ok { + if len(v) < header.ICMPv6DstUnreachableMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + switch h.Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and ToView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.ToView()) + ns := header.NDPNeighborSolicit(h.NDPPayload()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -298,16 +286,12 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.ToView()) + na := header.NDPNeighborAdvert(h.NDPPayload()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -379,15 +363,14 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) - if !ok { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, icmpHdr) + copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -401,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -423,9 +406,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + p := h.NDPPayload() + if len(p) < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -443,11 +425,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - ra := header.NDPRouterAdvert(payload.ToView()) + ra := header.NDPRouterAdvert(p) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..bd099a7f8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,8 +166,7 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, + size: header.ICMPv6NeighborSolicitMinimumSize}, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 486725131..331b0817b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,11 +171,7 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c7c663498..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,10 +70,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -476,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -520,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -567,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -622,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6b91159d4..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,11 +212,6 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. -// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -231,9 +226,7 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -278,21 +271,14 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pk.NetworkHeader is set. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. + // pkt.Data.First(). if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } + pkt.NetworkHeader = pkt.Data.First() } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8be61f4b1..7b4543caf 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,12 +96,9 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return RuleDrop, 0 - } + headerView := newPkt.Data.First() netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -120,14 +117,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.UDPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { return RuleDrop, 0 } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(newPkt.Data.First()) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -135,14 +128,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.TCPMinimumSize { + if len(pkt.Data.First()) < header.TCPMinimumSize { return RuleDrop, 0 } - hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return RuleDrop, 0 - } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(newPkt.TransportHeader) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0c2b1f36a..016dbe15e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + + src, dst := netProto.ParseAddresses(pkt.Data.First()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,8 +1289,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + + firstData := pkt.Data.First() + pkt.Data.RemoveFirst() + + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { + pkt.Header = buffer.NewPrependableFromView(firstData) + } else { + firstDataLen := len(firstData) + + // pkt.Header should have enough capacity to hold n.linkEP's headers. + pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) + + // TODO(b/151227689): avoid copying the packet when forwarding + if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) + } } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1318,13 +1332,12 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1362,12 +1375,11 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) - if !ok { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7d36f8e84..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,13 +37,7 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. + // down the stack, each layer adds to Header. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d45d2cc1f..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,18 +95,16 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { + nb := pkt.Data.First() + if len(nb) < fakeNetHeaderLen { return } + pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,11 +642,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..feef8dca0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.Data.First()) + if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.Data.First()) + if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7712ce652..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,11 +144,7 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) + h := header.TCP(s.data.First()) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -160,16 +156,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(h[header.TCPMinimumSize:offset]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -181,19 +173,18 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() + s.csum = h.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) + xsum = h.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 286c66cf5..ab1014c7f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..edb54f0be 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..6e31a9bac 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,13 +68,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 5e1e61fbcbe8fa3cc8b104fadb8cdef3ad29c31f Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Fri, 1 May 2020 16:08:26 -0700 Subject: Automated rollback of changelist 308674219 PiperOrigin-RevId: 309491861 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++++++++---- pkg/tcpip/buffer/view_test.go | 113 +++++++++++++++++++++++++++++ pkg/tcpip/link/fdbased/endpoint.go | 3 + pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 +++++++++++++---- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 +++-- pkg/tcpip/network/ipv4/ipv4.go | 12 ++- pkg/tcpip/network/ipv6/icmp.go | 74 ++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +++++- pkg/tcpip/stack/iptables_targets.go | 23 ++++-- pkg/tcpip/stack/nic.go | 34 +++------ pkg/tcpip/stack/packet_buffer.go | 8 +- pkg/tcpip/stack/stack_test.go | 10 ++- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++++--- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 ++- 26 files changed, 402 insertions(+), 147 deletions(-) (limited to 'pkg/tcpip/buffer') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ff1cfd8f6..55c0f04f3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(pkt.Data.First()) + tcpHeader = header.TCP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3359418c1..04d03d494 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(pkt.Data.First()) + udpHeader = header.UDP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 53a9712c6..affa1bbdf 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -436,6 +436,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne if pkt.Data.Size() == 0 { return rawfile.NonBlockingWrite(fd, pkt.Header.View()) } + if pkt.Header.UsedLength() == 0 { + return rawfile.NonBlockingWrite(fd, pkt.Data.ToView()) + } return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..0799c8f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9f47b4ff2..9d0797af7 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -99,7 +99,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a9dec0c0e..1d61fddad 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -333,7 +333,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -383,7 +387,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 82928fb66..daf1fcbc6 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 9bc97b84e..8084d50bc 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -477,7 +480,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -521,7 +524,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -568,7 +571,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -623,7 +626,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..6b91159d4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..8be61f4b1 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView := newPkt.Data.First() + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + newPkt.NetworkHeader = headerView hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { return RuleDrop, 0 } - udpHeader = header.UDP(newPkt.Data.First()) + udpHeader = header.UDP(hdr) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { + if pkt.Data.Size() < header.TCPMinimumSize { return RuleDrop, 0 } - tcpHeader = header.TCP(newPkt.TransportHeader) + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 440970a21..7b54919bb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1212,12 +1212,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1298,22 +1298,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1341,12 +1327,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1384,11 +1371,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 06d312207..9ff80ab24 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,13 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. + // + // TODO(gvisor.dev/issue/170): Forwarded packets don't currently + // populate Header, but should. This will be doable once early parsing + // (https://github.com/google/gvisor/pull/1995) is supported. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 3f4e08434..1a2cf007c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 7e574859b..33e2b9a09 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3556,7 +3556,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3583,7 +3583,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 0cb9e1d021607723e008de877a056f2db2a32cef Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Mon, 11 May 2020 10:31:49 -0700 Subject: Fix view.ToVectorisedView(). view.ToVectorisedView() now just returns an empty vectorised view if the view is of zero length. Earlier it would return a VectorisedView of zero length but with 1 empty view. This has been a source of bugs as lower layers don't expect zero length views in VectorisedViews. VectorisedView.AppendView() now is a no-op if the view being appended is of zero length. Fixes #2658 PiperOrigin-RevId: 310942269 --- pkg/tcpip/buffer/view.go | 6 ++++++ pkg/tcpip/buffer/view_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) (limited to 'pkg/tcpip/buffer') diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index f01217c91..9a3c5d6c3 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -59,6 +59,9 @@ func (v *View) Reader() bytes.Reader { // ToVectorisedView returns a VectorisedView containing the receiver. func (v View) ToVectorisedView() VectorisedView { + if len(v) == 0 { + return VectorisedView{} + } return NewVectorisedView(len(v), []View{v}) } @@ -229,6 +232,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) { // AppendView appends the given view into this vectorised view. func (vv *VectorisedView) AppendView(v View) { + if len(v) == 0 { + return + } vv.views = append(vv.views, v) vv.size += len(v) } diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index c56795c7b..726e54de9 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -483,3 +483,39 @@ func TestPullUp(t *testing.T) { } } } + +func TestToVectorisedView(t *testing.T) { + testCases := []struct { + in View + want VectorisedView + }{ + {nil, VectorisedView{}}, + {View{}, VectorisedView{}}, + {View{'a'}, VectorisedView{size: 1, views: []View{{'a'}}}}, + } + for _, tc := range testCases { + if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) + } + } +} + +func TestAppendView(t *testing.T) { + testCases := []struct { + vv VectorisedView + in View + want VectorisedView + }{ + {VectorisedView{}, nil, VectorisedView{}}, + {VectorisedView{}, View{}, VectorisedView{}}, + {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, nil, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}}, + {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}}, + {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{'e'}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}, {'e'}}, 5}}, + } + for _, tc := range testCases { + tc.vv.AppendView(tc.in) + if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) + } + } +} -- cgit v1.2.3 From 47515f475167ffa23267ca0b9d1b39e7907587d6 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Thu, 13 Aug 2020 13:07:03 -0700 Subject: Migrate to PacketHeader API for PacketBuffer. Formerly, when a packet is constructed or parsed, all headers are set by the client code. This almost always involved prepending to pk.Header buffer or trimming pk.Data portion. This is known to prone to bugs, due to the complexity and number of the invariants assumed across netstack to maintain. In the new PacketHeader API, client will call Push()/Consume() method to construct/parse an outgoing/incoming packet. All invariants, such as slicing and trimming, are maintained by the API itself. NewPacketBuffer() is introduced to create new PacketBuffer. Zero value is no longer valid. PacketBuffer now assumes the packet is a concatenation of following portions: * LinkHeader * NetworkHeader * TransportHeader * Data Any of them could be empty, or zero-length. PiperOrigin-RevId: 326507688 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 4 +- pkg/sentry/socket/netfilter/udp_matcher.go | 4 +- pkg/tcpip/buffer/view.go | 10 + pkg/tcpip/link/channel/channel.go | 4 +- pkg/tcpip/link/fdbased/BUILD | 1 + pkg/tcpip/link/fdbased/endpoint.go | 14 +- pkg/tcpip/link/fdbased/endpoint_test.go | 134 ++++--- pkg/tcpip/link/fdbased/mmap.go | 16 +- pkg/tcpip/link/fdbased/packet_dispatchers.go | 45 ++- pkg/tcpip/link/loopback/loopback.go | 21 +- pkg/tcpip/link/muxed/injectable_test.go | 25 +- pkg/tcpip/link/nested/nested_test.go | 4 +- pkg/tcpip/link/sharedmem/sharedmem.go | 25 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 135 +++---- pkg/tcpip/link/sharedmem/tx.go | 14 +- pkg/tcpip/link/sniffer/sniffer.go | 19 +- pkg/tcpip/link/tun/device.go | 23 +- pkg/tcpip/link/waitable/waitable_test.go | 12 +- pkg/tcpip/network/arp/arp.go | 26 +- pkg/tcpip/network/arp/arp_test.go | 8 +- pkg/tcpip/network/ip_test.go | 80 +++-- pkg/tcpip/network/ipv4/icmp.go | 38 +- pkg/tcpip/network/ipv4/ipv4.go | 188 +++++----- pkg/tcpip/network/ipv4/ipv4_test.go | 72 ++-- pkg/tcpip/network/ipv6/icmp.go | 55 ++- pkg/tcpip/network/ipv6/icmp_test.go | 46 +-- pkg/tcpip/network/ipv6/ipv6.go | 47 +-- pkg/tcpip/network/ipv6/ipv6_test.go | 16 +- pkg/tcpip/network/ipv6/ndp_test.go | 29 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/conntrack.go | 31 +- pkg/tcpip/stack/forwarder_test.go | 65 ++-- pkg/tcpip/stack/headertype_string.go | 39 ++ pkg/tcpip/stack/iptables.go | 2 +- pkg/tcpip/stack/iptables_targets.go | 9 +- pkg/tcpip/stack/ndp.go | 32 +- pkg/tcpip/stack/ndp_test.go | 22 +- pkg/tcpip/stack/nic.go | 40 +-- pkg/tcpip/stack/nic_test.go | 4 +- pkg/tcpip/stack/packet_buffer.go | 260 ++++++++++++-- pkg/tcpip/stack/packet_buffer_test.go | 397 +++++++++++++++++++++ pkg/tcpip/stack/route.go | 5 +- pkg/tcpip/stack/stack_test.go | 61 ++-- pkg/tcpip/stack/transport_demuxer_test.go | 14 +- pkg/tcpip/stack/transport_test.go | 54 ++- .../tests/integration/multicast_broadcast_test.go | 18 +- pkg/tcpip/transport/icmp/endpoint.go | 39 +- pkg/tcpip/transport/packet/endpoint.go | 35 +- pkg/tcpip/transport/raw/endpoint.go | 30 +- pkg/tcpip/transport/tcp/connect.go | 30 +- pkg/tcpip/transport/tcp/protocol.go | 17 +- pkg/tcpip/transport/tcp/segment.go | 2 +- pkg/tcpip/transport/tcp/testing/context/context.go | 28 +- pkg/tcpip/transport/udp/endpoint.go | 26 +- pkg/tcpip/transport/udp/protocol.go | 57 ++- pkg/tcpip/transport/udp/udp_test.go | 58 ++- 56 files changed, 1568 insertions(+), 924 deletions(-) create mode 100644 pkg/tcpip/stack/headertype_string.go create mode 100644 pkg/tcpip/stack/packet_buffer_test.go (limited to 'pkg/tcpip/buffer') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 4f98ee2d5..0bfd6c1f4 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -97,7 +97,7 @@ func (*TCPMatcher) Name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) if netHeader.TransportProtocol() != header.TCPProtocolNumber { return false, false @@ -111,7 +111,7 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - tcpHeader := header.TCP(pkt.TransportHeader) + tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { // There's no valid TCP header here, so we drop the packet immediately. return false, true diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3f20fc891..7ed05461d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -94,7 +94,7 @@ func (*UDPMatcher) Name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. @@ -110,7 +110,7 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - udpHeader := header.UDP(pkt.TransportHeader) + udpHeader := header.UDP(pkt.TransportHeader().View()) if len(udpHeader) < header.UDPMinimumSize { // There's no valid UDP header here, so we drop the packet immediately. return false, true diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 9a3c5d6c3..ea0c5413d 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -65,6 +65,16 @@ func (v View) ToVectorisedView() VectorisedView { return NewVectorisedView(len(v), []View{v}) } +// IsEmpty returns whether v is of length zero. +func (v View) IsEmpty() bool { + return len(v) == 0 +} + +// Size returns the length of v. +func (v View) Size() int { + return len(v) +} + // VectorisedView is a vectorised version of View using non contiguous memory. // It supports all the convenience methods supported by View. // diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index e12a5929b..c95aef63c 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -274,7 +274,9 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: &stack.PacketBuffer{Data: vv}, + Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }), Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 507b44abc..10072eac1 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -37,5 +37,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index c18bb91fb..975309fc8 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -390,8 +390,7 @@ const ( func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ DstAddr: remote, Type: protocol, @@ -420,7 +419,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if gso != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen @@ -443,11 +442,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne builder.Add(vnetHdrBuf) } - builder.Add(pkt.Header.View()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Views() { builder.Add(v) } - return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } @@ -463,7 +460,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if pkt.GSOOptions != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen @@ -486,8 +483,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc var builder iovec.Builder builder.Add(vnetHdrBuf) - builder.Add(pkt.Header.View()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Views() { builder.Add(v) } iovecs := builder.Build() diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 7b995b85a..709f829c8 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -26,6 +26,7 @@ import ( "time" "unsafe" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,9 +44,36 @@ const ( ) type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents *stack.PacketBuffer + Raddr tcpip.LinkAddress + Proto tcpip.NetworkProtocolNumber + Contents *stack.PacketBuffer +} + +type packetContents struct { + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View + Data buffer.View +} + +func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { + t.Helper() + if diff := cmp.Diff( + want, got, + cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents { + if pk == nil { + return nil + } + return &packetContents{ + LinkHeader: pk.LinkHeader().View(), + NetworkHeader: pk.NetworkHeader().View(), + TransportHeader: pk.TransportHeader().View(), + Data: pk.Data.ToView(), + } + }), + ); diff != "" { + t.Errorf("unexpected packetInfo (-want +got):\n%s", diff) + } } type context struct { @@ -159,19 +187,28 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u RemoteLinkAddress: raddr, } - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) + // Build payload. + payload := buffer.NewView(plen) + if _, err := rand.Read(payload); err != nil { + t.Fatalf("rand.Read(payload): %s", err) } - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) + // Build packet buffer. + const netHdrLen = 100 + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen, + Data: payload.ToVectorisedView(), + }) + pkt.Hash = hash + + // Build header. + b := pkt.NetworkHeader().Push(netHdrLen) + if _, err := rand.Read(b); err != nil { + t.Fatalf("rand.Read(b): %s", err) } - want := append(hdr.View(), payload...) + + // Write. + want := append(append(buffer.View(nil), b...), payload...) var gso *stack.GSO if gsoMaxSize != 0 { gso = &stack.GSO{ @@ -183,11 +220,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - Hash: hash, - }); err != nil { + if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -296,13 +329,14 @@ func TestPreserveSrcAddress(t *testing.T) { LocalLinkAddress: baddr, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.VectorisedView{}, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength(). + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.VectorisedView{}, + }) + if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -331,24 +365,25 @@ func TestDeliverPacket(t *testing.T) { defer c.cleanup() // Build packet. - b := make([]byte, plen) - all := b - for i := range b { - b[i] = uint8(rand.Intn(256)) + all := make([]byte, plen) + if _, err := rand.Read(all); err != nil { + t.Fatalf("rand.Read(all): %s", err) } - - var hdr header.Ethernet - if !eth { - // So that it looks like an IPv4 packet. - b[0] = 0x40 - } else { - hdr = make(header.Ethernet, header.EthernetMinimumSize) + // Make it look like an IPv4 packet. + all[0] = 0x40 + + wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.NewViewFromBytes(all).ToVectorisedView(), + }) + if eth { + hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, Type: proto, }) - all = append(hdr, b...) + all = append(hdr, all...) } // Write packet via the file descriptor. @@ -360,24 +395,15 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: &stack.PacketBuffer{ - Data: buffer.View(b).ToVectorisedView(), - LinkHeader: buffer.View(hdr), - }, + Raddr: raddr, + Proto: proto, + Contents: wantPkt, } if !eth { - want.proto = header.IPv4ProtocolNumber - want.raddr = "" - } - // want.contents.Data will be a single - // view, so make pi do the same for the - // DeepEqual check. - pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() - if !reflect.DeepEqual(want, pi) { - t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) + want.Proto = header.IPv4ProtocolNumber + want.Raddr = "" } + checkPacketInfoEqual(t, pi, want) case <-time.After(10 * time.Second): t.Fatalf("Timed out waiting for packet") } @@ -572,8 +598,8 @@ func TestDispatchPacketFormat(t *testing.T) { t.Fatalf("len(sink.pkts) = %d, want %d", got, want) } pkt := sink.pkts[0] - if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want { - t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want) + if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { + t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) } if got, want := pkt.Data.Size(), 4; got != want { t.Errorf("pkt.Data.Size() = %d, want %d", got, want) diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 2dfd29aa9..c475dda20 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -18,6 +18,7 @@ package fdbased import ( "encoding/binary" + "fmt" "syscall" "golang.org/x/sys/unix" @@ -170,10 +171,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(pkt) + eth := header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -190,10 +190,14 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } } - pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ - Data: buffer.View(pkt).ToVectorisedView(), - LinkHeader: buffer.View(eth), + pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(pkt).ToVectorisedView(), }) + if d.e.hdrSize > 0 { + if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok { + panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize)) + } + } + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pbuf) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index d8f2504b3..8c3ca86d6 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -103,7 +103,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.allocateViews(BufConfig) n, err := rawfile.BlockingReadv(d.fd, d.iovecs) - if err != nil { + if n == 0 || err != nil { return false, err } if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { @@ -111,17 +111,22 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { // isn't used and it isn't in a view. n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(n, BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,13 +143,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(n, BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. @@ -268,17 +266,22 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(k, int(n), BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -295,12 +298,6 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(k, int(n), BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 781cdd317..38aa694e4 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -77,16 +77,16 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) + // Construct data as the unparsed portion for the loopback packet. + data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: data, }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt) return nil } @@ -98,18 +98,17 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) if !ok { // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } linkHeader := header.Ethernet(hdr) - vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ - Data: vv, - LinkHeader: buffer.View(linkHeader), - }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) return nil } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 0744f66d6..3e4afcdad 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -46,14 +46,14 @@ func TestInjectableEndpointRawDispatch(t *testing.T) { func TestInjectableEndpointDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) + pkt.TransportHeader().Push(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), - }) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -67,13 +67,14 @@ func TestInjectableEndpointDispatch(t *testing.T) { func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewView(0).ToVectorisedView(), + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewView(0).ToVectorisedView(), }) + pkt.TransportHeader().Push(1)[0] = 0xFA + packetRoute := stack.Route{RemoteAddress: dstIP} + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go index 7d9249c1c..c1f9d308c 100644 --- a/pkg/tcpip/link/nested/nested_test.go +++ b/pkg/tcpip/link/nested/nested_test.go @@ -87,7 +87,7 @@ func TestNestedLinkEndpoint(t *testing.T) { t.Error("After attach, nestedEP.IsAttached() = false, want = true") } - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 1 { t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) } @@ -101,7 +101,7 @@ func TestNestedLinkEndpoint(t *testing.T) { } disp.count = 0 - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 0 { t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 507c76b76..7fb8a6c49 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -186,8 +186,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // AddHeader implements stack.LinkEndpoint.AddHeader. func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { // Add ethernet header if needed. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ DstAddr: remote, Type: protocol, @@ -207,10 +206,10 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) - v := pkt.Data.ToView() + views := pkt.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(pkt.Header.View(), v) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -227,10 +226,10 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - v := vv.ToView() + views := vv.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(v, buffer.View{}) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -276,16 +275,18 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { rxb[i].Size = e.bufferSize } - if n < header.EthernetMinimumSize { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(b).ToVectorisedView(), + }) + + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { continue } + eth := header.Ethernet(hdr) // Send packet up the stack. - eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ - Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), - LinkHeader: buffer.View(eth), - }) + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 8f3cd9449..22d5c97f1 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -266,21 +266,23 @@ func TestSimpleSend(t *testing.T) { for iters := 1000; iters > 0; iters-- { func() { + hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000) + // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) + hdrBuf := buffer.NewView(hdrLen) randomFill(hdrBuf) - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) + data := buffer.NewView(dataLen) + randomFill(data) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()), + Data: data.ToVectorisedView(), + }) + copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -317,7 +319,7 @@ func TestSimpleSend(t *testing.T) { // Compare contents skipping the ethernet header added by the // endpoint. - merged := append(hdrBuf, buf...) + merged := append(hdrBuf, data...) if uint32(len(contents)) < pi.Size { t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) } @@ -344,14 +346,14 @@ func TestPreserveSrcAddressInSend(t *testing.T) { LocalLinkAddress: newLocalLinkAddress, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + ReserveHeaderBytes: header.EthernetMinimumSize, + }) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -403,12 +405,12 @@ func TestFillTxQueue(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -422,11 +424,11 @@ func TestFillTxQueue(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -450,11 +452,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -473,11 +475,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -491,11 +493,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -517,11 +519,11 @@ func TestFillTxMemory(t *testing.T) { // we fill the memory. ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -536,11 +538,11 @@ func TestFillTxMemory(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), }) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -564,11 +566,11 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Each packet is uses up one buffer, so write as many as possible // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -579,23 +581,22 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write a two-buffer packet. It must fail. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: uu, - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buffer.NewView(bufferSize).ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } // Attempt to write the one-buffer packet again. It must succeed. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index 6b8d7859d..44f421c2d 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -18,6 +18,7 @@ import ( "math" "syscall" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -76,9 +77,9 @@ func (t *tx) cleanup() { syscall.Munmap(t.data) } -// transmit sends a packet made up of up to two buffers. Returns a boolean that -// specifies whether the packet was successfully transmitted. -func (t *tx) transmit(a, b []byte) bool { +// transmit sends a packet made of bufs. Returns a boolean that specifies +// whether the packet was successfully transmitted. +func (t *tx) transmit(bufs ...buffer.View) bool { // Pull completions from the tx queue and add their buffers back to the // pool so that we can reuse them. for { @@ -93,7 +94,10 @@ func (t *tx) transmit(a, b []byte) bool { } bSize := t.bufs.entrySize - total := uint32(len(a) + len(b)) + total := uint32(0) + for _, data := range bufs { + total += uint32(len(data)) + } bufCount := (total + bSize - 1) / bSize // Allocate enough buffers to hold all the data. @@ -115,7 +119,7 @@ func (t *tx) transmit(a, b []byte) bool { // Copy data into allocated buffers. nBuf := buf var dBuf []byte - for _, data := range [][]byte{a, b} { + for _, data := range bufs { for len(data) > 0 { if len(dBuf) == 0 { dBuf = t.data[nBuf.Offset:][:nBuf.Size] diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 509076643..4fb127978 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -134,7 +134,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { - totalLength := pkt.Header.UsedLength() + pkt.Data.Size() + totalLength := pkt.Size() length := totalLength if max := int(e.maxPCAPLen); length > max { length = max @@ -155,12 +155,11 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw length -= n } } - write(pkt.Header.View()) - for _, view := range pkt.Data.Views() { + for _, v := range pkt.Views() { if length == 0 { break } - write(view) + write(v) } } } @@ -185,9 +184,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - e.dumpPacket("send", nil, 0, &stack.PacketBuffer{ + e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) return e.Endpoint.WriteRawPacket(vv) } @@ -201,12 +200,8 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var fragmentOffset uint16 var moreFragments bool - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) + // Examine the packet using a new VV. Backing storage must not be written. + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) switch protocol { case header.IPv4ProtocolNumber: diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 22b0a12bd..3b1510a33 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -215,12 +215,11 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := &stack.PacketBuffer{ - Data: buffer.View(data).ToVectorisedView(), - } - if ethHdr != nil { - pkt.LinkHeader = buffer.View(ethHdr) - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: len(ethHdr), + Data: buffer.View(data).ToVectorisedView(), + }) + copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr) endpoint.InjectLinkAddr(protocol, remote, pkt) return dataLen, nil } @@ -265,21 +264,22 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" { return nil, false } // Ethernet header (TAP only). if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. - if info.Pkt.LinkHeader == nil { + if info.Pkt.LinkHeader().View().IsEmpty() { d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } - vv.AppendView(info.Pkt.LinkHeader) + vv.AppendView(info.Pkt.LinkHeader().View()) } // Append upper headers. - vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + vv.AppendView(info.Pkt.NetworkHeader().View()) + vv.AppendView(info.Pkt.TransportHeader().View()) // Append data payload. vv.Append(info.Pkt.Data) @@ -361,8 +361,7 @@ func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip. if !e.isTap { return } - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr := &header.EthernetFields{ SrcAddr: local, DstAddr: remote, diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index c448a888f..94827fc56 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -104,21 +104,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -135,21 +135,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 31a242482..1ad788a17 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -99,7 +99,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.ARP(pkt.NetworkHeader) + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { return } @@ -110,17 +110,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } - hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - packet := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, + }) + packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -168,17 +168,17 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r.RemoteLinkAddress = header.EthernetBroadcastAddress } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) - h := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + }) + h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) copy(h.HardwareAddressSender(), linkEP.LinkAddress()) copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. @@ -210,12 +210,10 @@ func (*protocol) Wait() {} // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.ARPSize) + _, ok = pkt.NetworkHeader().Consume(header.ARPSize) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(header.ARPSize) return 0, false, true } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index a35a64a0f..c2c3e6891 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -106,9 +106,9 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { @@ -118,9 +118,9 @@ func TestDirectRequest(t *testing.T) { if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pi.Pkt.Header.View()) + rep := header.ARP(pi.Pkt.NetworkHeader().View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 615bae648..e6768258a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -156,13 +156,13 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(pkt.Header.View()) + h := header.IPv4(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(pkt.Header.View()) + h := header.IPv6(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() @@ -243,8 +243,11 @@ func TestIPv4Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -260,10 +263,7 @@ func TestIPv4Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -303,9 +303,13 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -455,17 +459,25 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag1.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag2.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -485,8 +497,11 @@ func TestIPv6Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -502,10 +517,7 @@ func TestIPv6Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -545,9 +557,13 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -673,11 +689,9 @@ func TestIPv6ReceiveControl(t *testing.T) { // becomes Data. func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { v := view[:len(view)-trunc] - if len(v) < netHdrLen { - return &stack.PacketBuffer{Data: v.ToVectorisedView()} - } - return &stack.PacketBuffer{ - NetworkHeader: v[:netHdrLen], - Data: v[netHdrLen:].ToVectorisedView(), - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + _, _ = pkt.NetworkHeader().Consume(netHdrLen) + return pkt } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 94803a359..067d770f3 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -89,12 +89,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { return } + // Make a copy of data before pkt gets sent to raw socket. + // DeliverTransportPacket will take ownership of pkt. + replyData := pkt.Data.Clone(nil) + replyData.TrimFront(header.ICMPv4MinimumSize) + // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ - Data: pkt.Data.Clone(nil), - NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), - }) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) remoteLinkAddr := r.RemoteLinkAddress @@ -116,24 +118,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // Use the remote link address from the incoming packet. r.ResolveWith(remoteLinkAddr) - vv := pkt.Data.Clone(nil) - vv.TrimFront(header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv4EchoReply) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) + // Prepare a reply packet. + icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize) + copy(icmpHdr, h) + icmpHdr.SetType(header.ICMPv4EchoReply) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0))) + dataVV := buffer.View(icmpHdr).ToVectorisedView() + dataVV.Append(replyData) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: dataVV, + }) + + // Send out the reply packet. sent := stats.ICMP.V4PacketsSent if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: vv, - TransportHeader: buffer.View(pkt), - }); err != nil { + }, replyPkt); err != nil { sent.Dropped.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 9ff27a363..3cd48ceb3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -21,7 +21,6 @@ package ipv4 import ( - "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -127,14 +126,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in pkt.Header but does not -// assume that only the IP header is in pkt.Header. It assumes that the input -// packet's stated length matches the length of the header+payload. mtu -// includes the IP header and options. This does not support the DontFragment -// IP flag. +// write. It assumes that the IP header is already present in pkt.NetworkHeader. +// pkt.TransportHeader may be set. mtu includes the IP header and options. This +// does not support the DontFragment IP flag. func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(pkt.Header.View()) + ip := header.IPv4(pkt.NetworkHeader().View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -148,88 +145,84 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := pkt.Header.AvailableLength() + + // Keep the length reserved for link-layer, we need to create fragments with + // the same reserved length. + reservedForLink := pkt.AvailableHeaderBytes() + + // Destroy the packet, pull all payloads out for fragmentation. + transHeader, data := pkt.TransportHeader().View(), pkt.Data + + // Where possible, the first fragment that is sent has the same + // number of bytes reserved for header as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. + transFitsFirst := len(transHeader) <= innerMTU + for i := 0; i < n; i++ { - // Where possible, the first fragment that is sent has the same - // pkt.Header.UsedLength() as the input packet. The link-layer - // endpoint may depend on this for looking at, eg, L4 headers. - h := ip - if i > 0 { - pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) - copy(h, ip[:ip.HeaderLength()]) + reserve := reservedForLink + int(ip.HeaderLength()) + if i == 0 && transFitsFirst { + // Reserve for transport header if it's going to be put in the first + // fragment. + reserve += len(transHeader) + } + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: reserve, + }) + fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + + // Copy data for the fragment. + avail := innerMTU + + if n := len(transHeader); n > 0 { + if n > avail { + n = avail + } + if i == 0 && transFitsFirst { + copy(fragPkt.TransportHeader().Push(n), transHeader) + } else { + fragPkt.Data.AppendView(transHeader[:n:n]) + } + transHeader = transHeader[n:] + avail -= n } + + if avail > 0 { + n := data.Size() + if n > avail { + n = avail + } + data.ReadToVV(&fragPkt.Data, n) + avail -= n + } + + copied := uint16(innerMTU - avail) + + // Set lengths in header and calculate checksum. + h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip))) + copy(h, ip) if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + copied) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) - offset += uint16(innerMTU) - if i > 0 { - newPayload := pkt.Data.Clone(nil) - newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayload.Size()) - continue - } - // Special handling for the first fragment because it comes - // from the header. - if outerMTU >= pkt.Header.UsedLength() { - // This fragment can fit all of pkt.Header and possibly - // some of pkt.Data, too. - newPayload := pkt.Data.Clone(nil) - newPayloadLength := outerMTU - pkt.Header.UsedLength() - newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayloadLength) - } else { - // The fragment is too small to fit all of pkt.Header. - startOfHdr := pkt.Header - startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) - emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: startOfHdr, - Data: emptyVV, - NetworkHeader: buffer.View(h), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of pkt.Header into the pkt.Data - // that remains to be sent. - restOfHdr := pkt.Header.View()[outerMTU:] - tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(pkt.Data) - pkt.Data = tmp + offset += copied + + // Send out the fragment. + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + return err } + r.Stats().IP.PacketsSent.Increment() } return nil } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payloadSize) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + length := uint16(pkt.Size()) // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. @@ -245,14 +238,12 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - return ip + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + e.addIPHeader(r, pkt, params) // iptables filtering. All packets that reach here are locally // generated. @@ -269,7 +260,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // only NATted packets, but removing this check short circuits broadcasts // before they are sent out to other hosts. if pkt.NatDone { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) @@ -286,7 +277,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { @@ -306,9 +297,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + e.addIPHeader(r, pkt, params) pkt = pkt.Next() } @@ -333,7 +322,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() @@ -402,17 +391,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu r.Stats().IP.PacketsSent.Increment() - ip = ip[:ip.HeaderLength()] - pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) - pkt.Data.TrimFront(int(ip.HeaderLength())) return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv4(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + h := header.IPv4(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -426,7 +412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } if h.More() || h.FragmentOffset() != 0 { - if pkt.Data.Size()+len(pkt.TransportHeader) == 0 { + if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -470,7 +456,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - pkt.NetworkHeader.CapLength(int(h.HeaderLength())) e.handleICMP(r, pkt) return } @@ -560,14 +545,19 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } ipHdr := header.IPv4(hdr) - // If there are options, pull those into hdr as well. - if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(headerLen) - if !ok { - panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen)) - } - ipHdr = header.IPv4(hdr) + // Header may have options, determine the true header length. + headerLen := int(ipHdr.HeaderLength()) + if headerLen < header.IPv4MinimumSize { + // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in + // order for the packet to be valid. Figure out if we want to reject this + // case. + headerLen = header.IPv4MinimumSize + } + hdr, ok = pkt.NetworkHeader().Consume(headerLen) + if !ok { + return 0, false, false } + ipHdr = header.IPv4(hdr) // If this is a fragment, don't bother parsing the transport header. parseTransportHeader := true @@ -576,8 +566,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) return ipHdr.TransportProtocol(), parseTransportHeader, true } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 63e2c36c2..afd3ac06d 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,6 +17,7 @@ package ipv4_test import ( "bytes" "encoding/hex" + "fmt" "math/rand" "testing" @@ -91,15 +92,11 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much +// makeRandPkt generates a randomize packet. hdrLength indicates how much // data should already be in the header before WritePacket. extraLength // indicates how much extra space should be in the header. The payload is made // from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - +func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer { var views []buffer.View totalLength := 0 for _, s := range viewSizes { @@ -108,8 +105,16 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. views = append(views, newView) totalLength += s } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLength + extraLength, + Data: buffer.NewVectorisedView(totalLength, views), + }) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt } // comparePayloads compared the contents of all the packets against the contents @@ -117,9 +122,9 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Data.ToView()...) + source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) + vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) + source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -132,8 +137,7 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI var reassembledPayload []byte for i, packet := range packets { // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Data) + allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -144,10 +148,17 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + if i == 0 { + got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size() + // sourcePacketInfo does not have NetworkHeader added, simulate one. + want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size() + // Check that it kept the transport header in packet.TransportHeader if + // it fits in the first fragment. + if want < int(mtu) && got != want { + t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + } } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { + if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { @@ -284,22 +295,14 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := &stack.PacketBuffer{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - } + pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + source := pkt.Clone() c := buildContext(t, nil, ft.mtu) err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) + }, pkt) if err != nil { t.Errorf("err got %v, want %v", err, nil) } @@ -344,16 +347,13 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) + pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) + }, pkt) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -472,9 +472,9 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), - }) + })) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { @@ -859,9 +859,9 @@ func TestReceiveFragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ded91d83a..39ae19295 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -83,7 +83,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } h := header.ICMPv6(v) - iph := header.IPv6(pkt.NetworkHeader) + iph := header.IPv6(pkt.NetworkHeader().View()) // Validate ICMPv6 checksum before processing the packet. // @@ -276,8 +276,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme optsSerializer := header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()), + }) + packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(solicited) @@ -293,9 +295,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { sent.Dropped.Increment() return } @@ -384,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) if !ok { received.Invalid.Increment() return @@ -409,16 +409,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Use the link address from the source of the original packet. r.ResolveWith(remoteLinkAddr) - pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, + Data: pkt.Data, + }) + packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: pkt.Data, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -539,17 +538,19 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - copy(pkt[icmpV6OptOffset-len(addr):], addr) - pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - length := uint16(hdr.UsedLength()) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize, + }) + icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + icmpHdr.SetType(header.ICMPv6NeighborSolicit) + copy(icmpHdr[icmpV6OptOffset-len(addr):], addr) + icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr + icmpHdr[icmpV6LengthOffset] = 1 + copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress()) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -559,9 +560,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index f86aaed1d..2a2f7de01 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -183,7 +183,11 @@ func TestICMPCounts(t *testing.T) { } handleIPv6Payload := func(icmp header.ICMPv6) { - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -191,10 +195,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: buffer.View(icmp).ToVectorisedView(), - }) + ep.HandlePacket(&r, pkt) } for _, typ := range types { @@ -323,12 +324,10 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. pi, _ := args.src.ReadContext(context.Background()) { - views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} - size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ - Data: vv, + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), }) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) } if pi.Proto != ProtocolNumber { @@ -340,7 +339,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } - ipv6 := header.IPv6(pi.Pkt.Header.View()) + // Pull the full payload since network header. Needed for header.IPv6 to + // extract its payload. + ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -558,9 +559,10 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -719,12 +721,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { icmpSize := size + payloadSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(typ) - payloadFn(pkt.Payload()) + icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) + icmpHdr.SetType(typ) + payloadFn(icmpHdr.Payload()) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -735,9 +737,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -895,14 +898,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + icmpHdr := header.ICMPv6(hdr.Prepend(size)) + icmpHdr.SetType(typ) payload := buffer.NewView(payloadSize) payloadFn(payload) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -913,9 +916,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index d7d7fc611..0ade655b2 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -99,9 +99,9 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { - length := uint16(hdr.UsedLength() + payloadSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(params.Protocol), @@ -110,26 +110,20 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - return ip + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + e.addIPHeader(r, pkt, params) if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // The inbound path expects an unparsed packet. + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) loopedR.Release() } @@ -151,9 +145,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pb := pkts.Front(); pb != nil; pb = pb.Next() { - ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) - pb.NetworkHeader = buffer.View(ip) - pb.NetworkProtocolNumber = header.IPv6ProtocolNumber + e.addIPHeader(r, pb, params) } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -171,8 +163,8 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv6(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + h := header.IPv6(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -181,8 +173,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. - vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView() - vv.AppendView(pkt.TransportHeader) + vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader().View()) vv.Append(pkt.Data) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false @@ -410,7 +402,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(len(pkt.TransportHeader)) + extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { @@ -581,17 +573,14 @@ traverseExtensions: } } - // Put the IPv6 header with extensions in pkt.NetworkHeader. - hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize) + // Put the IPv6 header with extensions in pkt.NetworkHeader(). + hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) if !ok { panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) } ipHdr = header.IPv6(hdr) - - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber return nextHdr, foundNext, true } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 3d65814de..081afb051 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -65,9 +65,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().ICMP.V6PacketsReceived @@ -123,9 +123,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stat := s.Stats().UDP.PacketsReceived @@ -637,9 +637,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { DstAddr: addr2, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().UDP.PacketsReceived @@ -1469,9 +1469,9 @@ func TestReceiveIPv6Fragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 64239ce9a..fe159b24f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -136,9 +136,9 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -380,9 +380,9 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) if test.nsInvalid { if got := invalid.Value(); got != 1 { @@ -410,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.naSrc), checker.DstAddr(test.naDst), checker.TTL(header.NDPHopLimit), @@ -497,9 +497,9 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -560,7 +560,11 @@ func TestNDPValidation(t *testing.T) { nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions))) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + Data: payload.ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, @@ -571,10 +575,7 @@ func TestNDPValidation(t *testing.T) { if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - ep.HandlePacket(r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: payload.ToVectorisedView(), - }) + ep.HandlePacket(r, pkt) } var tllData [header.NDPLinkLayerAddressSize]byte @@ -885,9 +886,9 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) if got := rxRA.Value(); got != 1 { t.Fatalf("got rxRA = %d, want = 1", got) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bfc7a0c7c..900938dd1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -57,6 +57,7 @@ go_library( "conntrack.go", "dhcpv6configurationfromndpra_string.go", "forwarder.go", + "headertype_string.go", "icmp_rate_limit.go", "iptables.go", "iptables_state.go", @@ -143,6 +144,7 @@ go_test( "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", + "packet_buffer_test.go", ], library = ":stack", deps = [ diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 470c265aa..7dd344b4f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -199,12 +199,12 @@ type bucket struct { func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader) - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol } @@ -344,8 +344,8 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -377,8 +377,8 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -396,8 +396,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -423,7 +422,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } // TODO(gvisor.dev/issue/170): Support other transport protocols. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return false } @@ -433,8 +432,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou return true } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return false } @@ -455,7 +454,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) return false } @@ -474,7 +473,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { } // We only track TCP connections. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return } @@ -486,7 +485,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) ct.insertConn(conn) } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c962693f5..944f622fd 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -75,7 +75,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -97,7 +97,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fwdTestNetHeaderLen) + b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) b[dstAddrOffset] = r.RemoteAddress[0] b[srcAddrOffset] = f.id.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) @@ -144,13 +144,11 @@ func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Add } func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = netHeader - pkt.Data.TrimFront(fwdTestNetHeaderLen) - return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true + return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { @@ -290,7 +288,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: &PacketBuffer{Data: vv}, + Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), } select { @@ -382,9 +380,9 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -419,9 +417,9 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -450,9 +448,9 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) select { case <-ep2.C: @@ -480,17 +478,17 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf = buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -500,8 +498,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -529,9 +527,9 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < 2; i++ { @@ -543,8 +541,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -575,9 +573,9 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < maxPendingPacketsPerResolution; i++ { @@ -589,13 +587,14 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) } - seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). - if !ok { - t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) } + seqNumBuf := b[fwdTestNetHeaderLen:] // The first 5 packets should not be forwarded so the sequence number should // start with 5. @@ -632,9 +631,9 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < maxPendingResolutions; i++ { @@ -648,8 +647,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. diff --git a/pkg/tcpip/stack/headertype_string.go b/pkg/tcpip/stack/headertype_string.go new file mode 100644 index 000000000..5efddfaaf --- /dev/null +++ b/pkg/tcpip/stack/headertype_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type headerType ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[linkHeader-0] + _ = x[networkHeader-1] + _ = x[transportHeader-2] + _ = x[numHeaderType-3] +} + +const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType" + +var _headerType_index = [...]uint8{0, 10, 23, 38, 51} + +func (i headerType) String() string { + if i < 0 || i >= headerType(len(_headerType_index)-1) { + return "headerType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _headerType_name[_headerType_index[i]:_headerType_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 110ba073d..c37da814f 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -394,7 +394,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index dc88033c7..5f1b2af64 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -99,7 +99,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso } // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { return RuleDrop, 0 } @@ -118,17 +118,16 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) switch protocol := netHeader.TransportProtocol(); protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader) + udpHeader := header.UDP(pkt.TransportHeader().View()) udpHeader.SetDestinationPort(rt.MinPort) // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 5174e639c..93567806b 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -746,12 +746,16 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmpData.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) ns.SetTargetAddress(addr) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, @@ -759,7 +763,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() return err @@ -1897,12 +1901,16 @@ func (ndp *ndpState) startSolicitingRouters() { } } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) - pkt := header.ICMPv6(hdr.Prepend(payloadSize)) - pkt.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(payloadSize)) + icmpData.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(icmpData.NDPPayload()) rs.Options().Serialize(optsSerializer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, @@ -1910,7 +1918,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 5d286ccbc..21bf53010 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -541,7 +541,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), @@ -550,8 +550,8 @@ func TestDADResolve(t *testing.T) { checker.NDPNSOptions(nil), )) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } }) @@ -667,9 +667,10 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(header.IPv6ProtocolNumber, pkt) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { @@ -1024,7 +1025,9 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) } // raBufWithOpts returns a valid NDP Router Advertisement with options. @@ -5134,16 +5137,15 @@ func TestRouterSolicitation(t *testing.T) { t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } - checker.IPv6(t, - p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } waitForNothing := func(timeout time.Duration) { @@ -5288,7 +5290,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index eaaf756cd..2315ea5b9 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1299,7 +1299,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader) + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1401,24 +1401,19 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this - // by having lower layers explicity write each header instead of just - // pkt.Header. - // pkt may have set its NetworkHeader and TransportHeader. If we're - // forwarding, we'll have to copy them into pkt.Header. - pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) - if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) - } - if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) - } + // pkt may have set its header and may not have enough headroom for link-layer + // header for the other link to prepend. Here we create a new packet to + // forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + // WritePacket takes ownership of fwdPkt, calculate numBytes first. + numBytes := fwdPkt.Size() - if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } @@ -1443,34 +1438,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. - if pkt.TransportHeader == nil { + if pkt.TransportHeader().View().IsEmpty() { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - pkt.TransportHeader = transHeader - pkt.Data.TrimFront(len(pkt.TransportHeader)) } else { // This is either a bad packet or was re-assembled from fragments. transProto.Parse(pkt) } } - if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { + if pkt.TransportHeader().View().Size() < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index a70792b50..0870c8d9c 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -311,7 +311,9 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ + Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), + })) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9e871f968..17b8beebb 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,16 +14,43 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) +type headerType int + +const ( + linkHeader headerType = iota + networkHeader + transportHeader + numHeaderType +) + +// PacketBufferOptions specifies options for PacketBuffer creation. +type PacketBufferOptions struct { + // ReserveHeaderBytes is the number of bytes to reserve for headers. Total + // number of bytes pushed onto the headers must not exceed this value. + ReserveHeaderBytes int + + // Data is the initial unparsed data for the new packet. If set, it will be + // owned by the new packet. + Data buffer.VectorisedView +} + // A PacketBuffer contains all the data of a network packet. // // As a PacketBuffer traverses up the stack, it may be necessary to pass it to -// multiple endpoints. Clone() should be called in such cases so that -// modifications to the Data field do not affect other copies. +// multiple endpoints. +// +// The whole packet is expected to be a series of bytes in the following order: +// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be +// empty. Use of PacketBuffer in any other order is unsupported. +// +// PacketBuffer must be created with NewPacketBuffer. type PacketBuffer struct { _ sync.NoCopy @@ -31,36 +58,27 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. For inbound packets, it also - // holds the headers, which are consumed as the packet moves up the - // stack. Headers are guaranteed not to be split across views. + // Data holds the payload of the packet. + // + // For inbound packets, Data is initially the whole packet. Then gets moved to + // headers via PacketHeader.Consume, when the packet is being parsed. // - // The bytes backing Data are immutable, but Data itself may be trimmed - // or otherwise modified. + // For outbound packets, Data is the innermost layer, defined by the protocol. + // Headers are pushed in front of it via PacketHeader.Push. + // + // The bytes backing Data are immutable, a.k.a. users shouldn't write to its + // backing storage. Data buffer.VectorisedView - // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. - Header buffer.Prependable + // headers stores metadata about each header. + headers [numHeaderType]headerInfo - // These fields are used by both inbound and outbound packets. They - // typically overlap with the Data and Header fields. + // header is the internal storage for outbound packets. Headers will be pushed + // (prepended) on this storage as the packet is being constructed. // - // The bytes backing these views are immutable. Each field may be nil - // if either it has not been set yet or no such header exists (e.g. - // packets sent via loopback may not have a link header). - // - // These fields may be Views into other slices (either Data or Header). - // SR dosen't support this, so deep copies are necessary in some cases. - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View + // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and + // data are held in the same underlying buffer storage. + header buffer.Prependable // NetworkProtocol is only valid when NetworkHeader is set. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol @@ -89,20 +107,137 @@ type PacketBuffer struct { PktType tcpip.PacketType } -// Clone makes a copy of pk. It clones the Data field, which creates a new -// VectorisedView but does not deep copy the underlying bytes. -// -// Clone also does not deep copy any of its other fields. +// NewPacketBuffer creates a new PacketBuffer with opts. +func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { + pk := &PacketBuffer{ + Data: opts.Data, + } + if opts.ReserveHeaderBytes != 0 { + pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + } + return pk +} + +// ReservedHeaderBytes returns the number of bytes initially reserved for +// headers. +func (pk *PacketBuffer) ReservedHeaderBytes() int { + return pk.header.UsedLength() + pk.header.AvailableLength() +} + +// AvailableHeaderBytes returns the number of bytes currently available for +// headers. This is relevant to PacketHeader.Push method only. +func (pk *PacketBuffer) AvailableHeaderBytes() int { + return pk.header.AvailableLength() +} + +// LinkHeader returns the handle to link-layer header. +func (pk *PacketBuffer) LinkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: linkHeader, + } +} + +// NetworkHeader returns the handle to network-layer header. +func (pk *PacketBuffer) NetworkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: networkHeader, + } +} + +// TransportHeader returns the handle to transport-layer header. +func (pk *PacketBuffer) TransportHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: transportHeader, + } +} + +// HeaderSize returns the total size of all headers in bytes. +func (pk *PacketBuffer) HeaderSize() int { + // Note for inbound packets (Consume called), headers are not stored in + // pk.header. Thus, calculation of size of each header is needed. + var size int + for i := range pk.headers { + size += len(pk.headers[i].buf) + } + return size +} + +// Size returns the size of packet in bytes. +func (pk *PacketBuffer) Size() int { + return pk.HeaderSize() + pk.Data.Size() +} + +// Views returns the underlying storage of the whole packet. +func (pk *PacketBuffer) Views() []buffer.View { + // Optimization for outbound packets that headers are in pk.header. + useHeader := true + for i := range pk.headers { + if !canUseHeader(&pk.headers[i]) { + useHeader = false + break + } + } + + dataViews := pk.Data.Views() + + var vs []buffer.View + if useHeader { + vs = make([]buffer.View, 0, 1+len(dataViews)) + vs = append(vs, pk.header.View()) + } else { + vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) + for i := range pk.headers { + if v := pk.headers[i].buf; len(v) > 0 { + vs = append(vs, v) + } + } + } + return append(vs, dataViews...) +} + +func canUseHeader(h *headerInfo) bool { + // h.offset will be negative if the header was pushed in to prependable + // portion, or doesn't matter when it's empty. + return len(h.buf) == 0 || h.offset < 0 +} + +func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + } + h.buf = buffer.View(pk.header.Prepend(size)) + h.offset = -pk.header.UsedLength() + return h.buf +} + +func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) + } + v, ok := pk.Data.PullUp(size) + if !ok { + return + } + pk.Data.TrimFront(size) + h.buf = v + return h.buf, true +} + +// Clone makes a shallow copy of pk. // -// FIXME(b/153685824): Data gets copied but not other header references. +// Clone should be called in such cases so that no modifications is done to +// underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - return &PacketBuffer{ + newPk := &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, Data: pk.Data.Clone(nil), - Header: pk.Header, - LinkHeader: pk.LinkHeader, - NetworkHeader: pk.NetworkHeader, - TransportHeader: pk.TransportHeader, + headers: pk.headers, + header: pk.header, Hash: pk.Hash, Owner: pk.Owner, EgressRoute: pk.EgressRoute, @@ -110,4 +245,55 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NetworkProtocolNumber: pk.NetworkProtocolNumber, NatDone: pk.NatDone, } + return newPk +} + +// headerInfo stores metadata about a header in a packet. +type headerInfo struct { + // buf is the memorized slice for both prepended and consumed header. + // When header is prepended, buf serves as memorized value, which is a slice + // of pk.header. When header is consumed, buf is the slice pulled out from + // pk.Data, which is the only place to hold this header. + buf buffer.View + + // offset will be a negative number denoting the offset where this header is + // from the end of pk.header, if it is prepended. Otherwise, zero. + offset int +} + +// PacketHeader is a handle object to a header in the underlying packet. +type PacketHeader struct { + pk *PacketBuffer + typ headerType +} + +// View returns the underlying storage of h. +func (h PacketHeader) View() buffer.View { + return h.pk.headers[h.typ].buf +} + +// Push pushes size bytes in the front of its residing packet, and returns the +// backing storage. Callers may only call one of Push or Consume once on each +// header in the lifetime of the underlying packet. +func (h PacketHeader) Push(size int) buffer.View { + return h.pk.push(h.typ, size) +} + +// Consume moves the first size bytes of the unparsed data portion in the packet +// to h, and returns the backing storage. In the case of data is shorter than +// size, consumed will be false, and the state of h will not be affected. +// Callers may only call one of Push or Consume once on each header in the +// lifetime of the underlying packet. +func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { + return h.pk.consume(h.typ, size) +} + +// PayloadSince returns packet payload starting from and including a particular +// header. This method isn't optimized and should be used in test only. +func PayloadSince(h PacketHeader) buffer.View { + var v buffer.View + for _, hinfo := range h.pk.headers[h.typ:] { + v = append(v, hinfo.buf...) + } + return append(v, h.pk.Data.ToView()...) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go new file mode 100644 index 000000000..c6fa8da5f --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -0,0 +1,397 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func TestPacketHeaderPush(t *testing.T) { + for _, test := range []struct { + name string + reserved int + link []byte + network []byte + transport []byte + data []byte + }{ + { + name: "construct empty packet", + }, + { + name: "construct link header only packet", + reserved: 60, + link: makeView(10), + }, + { + name: "construct link and network header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + }, + { + name: "construct header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + }, + { + name: "construct data only packet", + data: makeView(40), + }, + { + name: "construct L3 packet", + reserved: 60, + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + { + name: "construct L2 packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + allHdrSize := len(test.link) + len(test.network) + len(test.transport) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Push headers. + if v := test.transport; len(v) > 0 { + copy(pk.TransportHeader().Push(len(v)), v) + } + if v := test.network; len(v) > 0 { + copy(pk.NetworkHeader().Push(len(v)), v) + } + if v := test.link; len(v) > 0 { + copy(pk.LinkHeader().Push(len(v)), v) + } + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), allHdrSize+len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), + concatViews(test.link, test.network, test.transport, test.data)) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(test.link, test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(test.transport, test.data)) + }) + } +} + +func TestPacketHeaderConsume(t *testing.T) { + for _, test := range []struct { + name string + data []byte + link int + network int + transport int + }{ + { + name: "parse L2 packet", + data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), + link: 10, + network: 20, + transport: 30, + }, + { + name: "parse L3 packet", + data: concatViews(makeView(20), makeView(30), makeView(40)), + network: 20, + transport: 30, + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Consume headers. + if size := test.link; size > 0 { + if _, ok := pk.LinkHeader().Consume(size); !ok { + t.Fatalf("pk.LinkHeader().Consume() = false, want true") + } + } + if size := test.network; size > 0 { + if _, ok := pk.NetworkHeader().Consume(size); !ok { + t.Fatalf("pk.NetworkHeader().Consume() = false, want true") + } + } + if size := test.transport; size > 0 { + if _, ok := pk.TransportHeader().Consume(size); !ok { + t.Fatalf("pk.TransportHeader().Consume() = false, want true") + } + } + + allHdrSize := test.link + test.network + test.transport + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), 0; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), 0; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + // After state of pk. + var ( + link = test.data[:test.link] + network = test.data[test.link:][:test.network] + transport = test.data[test.link+test.network:][:test.transport] + payload = test.data[allHdrSize:] + ) + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(link, network, transport, payload)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(network, transport, payload)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(transport, payload)) + }) + } +} + +func TestPacketHeaderConsumeDataTooShort(t *testing.T) { + data := makeView(10) + + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(data).ToVectorisedView(), + }) + + // Consume should fail if pkt.Data is too short. + if _, ok := pk.LinkHeader().Consume(11); ok { + t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.NetworkHeader().Consume(11); ok { + t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.TransportHeader().Consume(11); ok { + t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") + } + + // Check packet should look the same as initial packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(data).ToVectorisedView(), + }) +} + +func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Second push should have panicked") + }) + } +} + +func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { + if _, ok := h.Consume(headerSize); !ok { + t.Fatal("First consume should succeed") + } + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Second consume should have panicked") + }) + } +} + +func TestPacketHeaderPushThenConsumePanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Consume should have panicked") + }) + } +} + +func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Consume(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Push should have panicked") + }) + } +} + +func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { + t.Helper() + reserved := opts.ReserveHeaderBytes + if got, want := pk.ReservedHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), 0; got != want { + t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) + } + data := opts.Data.ToView() + if got, want := pk.Size(), len(data); got != want { + t.Errorf("Initial pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) + // Check the initial values for each header. + checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) + checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) + checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) + // Check the initial valies for PayloadSince. + checkViewEqual(t, "Initial PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), data) +} + +func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { + t.Helper() + checkViewEqual(t, name+".View()", h.View(), want) +} + +func checkViewEqual(t *testing.T, what string, got, want buffer.View) { + t.Helper() + if !bytes.Equal(got, want) { + t.Errorf("%s = %x, want %x", what, got, want) + } +} + +func makeView(size int) buffer.View { + b := byte(size) + return bytes.Repeat([]byte{b}, size) +} + +func concatViews(views ...buffer.View) buffer.View { + var all buffer.View + for _, v := range views { + all = append(all, v...) + } + return all +} diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 9ce0a2c22..e267bebb0 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -173,7 +173,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf } // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + numBytes := pkt.Size() err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { @@ -203,8 +203,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead writtenBytes := 0 for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Header.UsedLength() - writtenBytes += pb.Data.Size() + writtenBytes += pb.Size() } r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index fe1c1b8a4..0273b3c63 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -102,7 +102,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -118,7 +118,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -143,10 +143,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. - pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) - pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] - pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] - pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) + hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + hdr[dstAddrOffset] = r.RemoteAddress[0] + hdr[srcAddrOffset] = f.id.LocalAddress[0] + hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { f.HandlePacket(r, pkt) @@ -249,12 +249,10 @@ func (*fakeNetworkProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(fakeNetHeaderLen) return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -315,9 +313,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[dstAddrOffset] = 3 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -327,9 +325,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[dstAddrOffset] = 1 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -339,9 +337,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[dstAddrOffset] = 2 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -350,9 +348,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -362,9 +360,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -383,11 +381,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro } func send(r stack.Route, payload buffer.View) *tcpip.Error { - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + })) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -442,9 +439,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -2285,9 +2282,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -2367,9 +2364,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) pkt, ok := ep2.Read() if !ok { @@ -2377,8 +2374,8 @@ func TestNICForwarding(t *testing.T) { } // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { - t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { + t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) } // Test that forwarding increments Tx stats correctly. diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 73dada928..1339edc2d 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -128,11 +128,10 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) } func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { @@ -166,11 +165,10 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) } func TestTransportDemuxerRegister(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 7e8b84867..6c6e44468 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -84,16 +84,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) - hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(v).ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, + Data: buffer.View(v).ToVectorisedView(), + }) + _ = pkt.TransportHeader().Push(fakeTransHeaderLen) + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, nil, err } @@ -328,13 +328,8 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) - if !ok { - return false - } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(fakeTransHeaderLen) - return true + _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) + return ok } func fakeTransFactory() stack.TransportProtocol { @@ -382,9 +377,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -393,9 +388,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -404,9 +399,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -459,9 +454,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -470,9 +465,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -481,9 +476,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -636,9 +631,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: req.ToVectorisedView(), - }) + })) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -655,10 +650,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.NetworkHeader[0]; dst != 3 { + nh := stack.PayloadSince(p.Pkt.NetworkHeader()) + if dst := nh[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.NetworkHeader[1]; src != 1 { + if src := nh[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 0ff3a2b89..9f0dd4d6d 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -80,9 +80,9 @@ func TestPingMulticastBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { @@ -102,9 +102,9 @@ func TestPingMulticastBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } tests := []struct { @@ -204,7 +204,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader) + src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View()) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } @@ -252,9 +252,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { @@ -280,9 +280,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } tests := []struct { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 4612be4e7..bd6f49eb8 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -430,9 +430,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + }) + pkt.Owner = owner - icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) copy(icmpv4, data) // Set the ident to the user-specified port. Sequence number should // already be set by the user. @@ -447,15 +450,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + pkt.Data = data.ToVectorisedView() + if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: data.ToVectorisedView(), - TransportHeader: buffer.View(icmpv4), - Owner: owner, - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -463,9 +463,11 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + }) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) copy(icmpv6, data) // Set the ident. Sequence number is provided by the user. icmpv6.SetIdent(ident) @@ -477,15 +479,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err dataVV := data.ToVectorisedView() icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) + pkt.Data = dataVV if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: dataVV, - TransportHeader: buffer.View(icmpv6), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } // checkV4MappedLocked determines the effective network protocol and converts @@ -748,14 +747,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.TransportHeader) + h := header.ICMPv4(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.TransportHeader) + h := header.ICMPv6(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -791,7 +794,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // ICMP socket's data includes ICMP header. - packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data = pkt.TransportHeader().View().ToVectorisedView() packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index df478115d..1b03ad6bb 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -433,9 +433,9 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet // TODO(gvisor.dev/issue/173): Return network protocol. - if len(pkt.LinkHeader) > 0 { + if !pkt.LinkHeader().View().IsEmpty() { // Get info directly from the ethernet header. - hdr := header.Ethernet(pkt.LinkHeader) + hdr := header.Ethernet(pkt.LinkHeader().View()) packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), @@ -458,9 +458,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, case tcpip.PacketHost: packet.data = pkt.Data case tcpip.PacketOutgoing: - // Strip Link Header from the Header. - pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) - combinedVV := pkt.Header.View().ToVectorisedView() + // Strip Link Header. + var combinedVV buffer.VectorisedView + if v := pkt.NetworkHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } + if v := pkt.TransportHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } combinedVV.Append(pkt.Data) packet.data = combinedVV default: @@ -471,9 +476,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - var combinedVV buffer.VectorisedView if pkt.PktType != tcpip.PacketOutgoing { - if len(pkt.LinkHeader) == 0 { + if pkt.LinkHeader().View().IsEmpty() { // We weren't provided with an actual ethernet header, // so fake one. ethFields := header.EthernetFields{ @@ -485,19 +489,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, fakeHeader.Encode(ðFields) linkHeader = buffer.View(fakeHeader) } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...) } - combinedVV = linkHeader.ToVectorisedView() - } - if pkt.PktType == tcpip.PacketOutgoing { - // For outgoing packets the Link, Network and Transport - // headers are in the pkt.Header fields normally unless - // a Raw socket is in use in which case pkt.Header could - // be nil. - combinedVV.AppendView(pkt.Header.View()) + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } else { + packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } - combinedVV.Append(pkt.Data) - packet.data = combinedVV } packet.timestampNS = ep.stack.Clock().NowNanoseconds() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index f85a68554..edc2b5b61 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -352,18 +352,23 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } if e.hdrIncluded { - if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), - }); err != nil { + }) + if err := route.WriteHeaderIncludedPacket(pkt); err != nil { return 0, nil, err } } else { - hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(payloadBytes).ToVectorisedView(), - Owner: e.owner, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(route.MaxHeaderLength()), + Data: buffer.View(payloadBytes).ToVectorisedView(), + }) + pkt.Owner = e.owner + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: e.TransProto, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, pkt); err != nil { return 0, nil, err } } @@ -691,12 +696,13 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // slice. Save/restore doesn't support overlapping slices and will fail. var combinedVV buffer.VectorisedView if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { - headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) - headers = append(headers, pkt.NetworkHeader...) - headers = append(headers, pkt.TransportHeader...) + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(network)+len(transport)) + headers = append(headers, network...) + headers = append(headers, transport...) combinedVV = headers.ToVectorisedView() } else { - combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() } combinedVV.Append(pkt.Data) packet.data = combinedVV diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 46702906b..290172ac9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -746,11 +746,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) - hdr := &pkt.Header - packetSize := pkt.Data.Size() - // Initialize the header. - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) - pkt.TransportHeader = buffer.View(tcp) + tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) tcp.Encode(&header.TCPFields{ SrcPort: tf.id.LocalPort, DstPort: tf.id.RemotePort, @@ -762,8 +758,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta }) copy(tcp[header.TCPMinimumSize:], tf.opts) - length := uint16(hdr.UsedLength() + packetSize) - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size())) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { // This is called CHECKSUM_PARTIAL in the Linux kernel. We @@ -801,17 +796,18 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso packetSize = size } size -= packetSize - var pkt stack.PacketBuffer - pkt.Header = buffer.NewPrependable(hdrSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrSize, + }) pkt.Hash = tf.txHash pkt.Owner = owner pkt.EgressRoute = r pkt.GSOOptions = gso pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() data.ReadToVV(&pkt.Data, packetSize) - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) - pkts.PushBack(&pkt) + pkts.PushBack(pkt) } if tf.ttl == 0 { @@ -837,12 +833,12 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := &stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - Data: data, - Hash: tf.txHash, - Owner: owner, - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, + Data: data, + }) + pkt.Hash = tf.txHash + pkt.Owner = owner buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 49a673b42..c5afa2680 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,7 +21,6 @@ package tcp import ( - "fmt" "runtime" "strings" "time" @@ -547,22 +546,22 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + // TCP header is variable length, peek at it first. + hdrLen := header.TCPMinimumSize + hdr, ok := pkt.Data.PullUp(hdrLen) if !ok { return false } // If the header has options, pull those up as well. if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(offset) - if !ok { - panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) - } + // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of + // packets. + hdrLen = offset } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(len(hdr)) - return true + _, ok = pkt.TransportHeader().Consume(hdrLen) + return ok } // NewProtocol returns a TCP transport protocol. diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index bb60dc29d..94307d31a 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -68,7 +68,7 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB route: r.Clone(), } s.data = pkt.Data.Clone(s.views[:]) - s.hdr = header.TCP(pkt.TransportHeader) + s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() return s } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 37e7767d6..927bc71e0 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -257,8 +257,8 @@ func (c *Context) GetPacket() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -284,8 +284,8 @@ func (c *Context) GetPacketNonBlocking() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b @@ -318,9 +318,10 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -374,26 +375,29 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: s, }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegment(payload, h), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacketWithAddrs builds and sends a TCP segment(with the provided payload // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendAck sends an ACK packet. @@ -514,9 +518,8 @@ func (c *Context) GetV6Packet() []byte { if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b @@ -566,9 +569,10 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt) } // CreateConnected creates a connected TCP endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4a2b6c03a..73608783c 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -986,13 +986,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { - // Allocate a buffer for the UDP header. - hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: data, + }) + pkt.Owner = owner - // Initialize the header. - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ SrcPort: localPort, DstPort: remotePort, @@ -1019,12 +1022,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Protocol: ProtocolNumber, TTL: ttl, TOS: tos, - }, &stack.PacketBuffer{ - Header: hdr, - Data: data, - TransportHeader: buffer.View(udp), - Owner: owner, - }); err != nil { + }, pkt); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } @@ -1372,7 +1370,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.TransportHeader) + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() @@ -1443,9 +1441,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Save any useful information from the network header to the packet. switch r.NetProto { case header.IPv4ProtocolNumber: - packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: - packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() + packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS() } // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 0e7464e3a..63d4bed7c 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -82,7 +82,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - hdr := header.UDP(pkt.TransportHeader) + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() @@ -130,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -139,22 +139,21 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // For example, a raw or packet socket may use what UDP // considers an unreachable destination. Thus we deep copy pkt // to prevent multiple ownership and SR errors. - newHeader := append(buffer.View(nil), pkt.NetworkHeader...) - newHeader = append(newHeader, pkt.TransportHeader...) + newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...) + newHeader = append(newHeader, pkt.TransportHeader().View()...) payload := newHeader.ToVectorisedView() payload.AppendView(pkt.Data.ToView()) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4DstUnreachable) - pkt.SetCode(header.ICMPv4PortUnreachable) - pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, }) + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -175,24 +174,24 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + payloadLen := len(network) + len(transport) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) + payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport}) payload.Append(pkt.Data) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) - pkt.SetType(header.ICMPv6DstUnreachable) - pkt.SetCode(header.ICMPv6PortUnreachable) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, }) + icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) } return true } @@ -215,14 +214,8 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Packet is too small - return false - } - pkt.TransportHeader = h - pkt.Data.TrimFront(header.UDPMinimumSize) - return true + _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + return ok } // NewProtocol returns a UDP transport protocol. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 1a32622ca..71776d6db 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -388,8 +388,8 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() h := flow.header4Tuple(outgoing) checkers = append( @@ -410,14 +410,14 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } else { buf := c.buildV6Packet(payload, &h) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } } @@ -804,9 +804,9 @@ func TestV4ReadSelfSource(t *testing.T) { h.srcAddr = h.dstAddr buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) @@ -1766,9 +1766,8 @@ func TestV4UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1844,9 +1843,8 @@ func TestV6UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1897,9 +1895,9 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetLength(u.Length() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { @@ -1952,9 +1950,9 @@ func TestShortHeader(t *testing.T) { copy(buf[header.IPv6MinimumSize:], udpHdr) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) @@ -1986,9 +1984,9 @@ func TestIncrementChecksumErrorsV4(t *testing.T) { } } - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2019,9 +2017,9 @@ func TestIncrementChecksumErrorsV6(t *testing.T) { u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(u.Checksum() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2049,9 +2047,9 @@ func TestPayloadModifiedV4(t *testing.T) { buf := c.buildV4Packet(payload, &h) // Modify the payload so that the checksum value in the UDP header will be incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2079,9 +2077,9 @@ func TestPayloadModifiedV6(t *testing.T) { buf := c.buildV6Packet(payload, &h) // Modify the payload so that the checksum value in the UDP header will be incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2110,9 +2108,9 @@ func TestChecksumZeroV4(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv4MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 0 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2141,9 +2139,9 @@ func TestChecksumZeroV6(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { -- cgit v1.2.3