summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--WORKSPACE8
-rw-r--r--kokoro/swgso_tests.cfg9
-rw-r--r--pkg/abi/BUILD4
-rw-r--r--pkg/abi/linux/BUILD7
-rw-r--r--pkg/abi/linux/socket.go12
-rw-r--r--pkg/bits/BUILD3
-rw-r--r--pkg/bpf/BUILD3
-rw-r--r--pkg/cpuid/BUILD3
-rw-r--r--pkg/fd/BUILD3
-rw-r--r--pkg/fd/fd.go8
-rw-r--r--pkg/refs/BUILD5
-rw-r--r--pkg/segment/BUILD4
-rw-r--r--pkg/segment/test/BUILD3
-rw-r--r--pkg/sentry/arch/BUILD3
-rw-r--r--pkg/sentry/context/contexttest/BUILD4
-rw-r--r--pkg/sentry/device/BUILD3
-rw-r--r--pkg/sentry/fs/BUILD5
-rw-r--r--pkg/sentry/fs/dev/BUILD4
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD3
-rw-r--r--pkg/sentry/fs/filetest/BUILD4
-rw-r--r--pkg/sentry/fs/fsutil/BUILD5
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go26
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go19
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go18
-rw-r--r--pkg/sentry/fs/gofer/BUILD3
-rw-r--r--pkg/sentry/fs/gofer/file.go66
-rw-r--r--pkg/sentry/fs/gofer/file_state.go2
-rw-r--r--pkg/sentry/fs/gofer/fs.go11
-rw-r--r--pkg/sentry/fs/gofer/handles.go12
-rw-r--r--pkg/sentry/fs/gofer/inode.go108
-rw-r--r--pkg/sentry/fs/gofer/session.go5
-rw-r--r--pkg/sentry/fs/host/BUILD3
-rw-r--r--pkg/sentry/fs/host/socket.go3
-rw-r--r--pkg/sentry/fs/lock/BUILD5
-rw-r--r--pkg/sentry/fs/proc/BUILD3
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD3
-rw-r--r--pkg/sentry/fs/ramfs/BUILD3
-rw-r--r--pkg/sentry/fs/sys/BUILD4
-rw-r--r--pkg/sentry/fs/timerfd/BUILD4
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD3
-rw-r--r--pkg/sentry/fs/tty/BUILD3
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD5
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD3
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go3
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/ext/symlink.go4
-rw-r--r--pkg/sentry/fsimpl/memfs/BUILD20
-rw-r--r--pkg/sentry/fsimpl/memfs/filesystem.go55
-rw-r--r--pkg/sentry/fsimpl/memfs/memfs.go2
-rw-r--r--pkg/sentry/fsimpl/memfs/named_pipe.go59
-rw-r--r--pkg/sentry/fsimpl/memfs/pipe_test.go233
-rw-r--r--pkg/sentry/inet/BUILD4
-rw-r--r--pkg/sentry/kernel/BUILD5
-rw-r--r--pkg/sentry/kernel/auth/BUILD4
-rw-r--r--pkg/sentry/kernel/contexttest/BUILD4
-rw-r--r--pkg/sentry/kernel/epoll/BUILD5
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD3
-rw-r--r--pkg/sentry/kernel/fasync/BUILD4
-rw-r--r--pkg/sentry/kernel/fd_table.go22
-rw-r--r--pkg/sentry/kernel/fd_table_test.go36
-rw-r--r--pkg/sentry/kernel/futex/BUILD5
-rw-r--r--pkg/sentry/kernel/pipe/BUILD8
-rw-r--r--pkg/sentry/kernel/pipe/node.go72
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go24
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go213
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go111
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go220
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD5
-rw-r--r--pkg/sentry/kernel/shm/BUILD4
-rw-r--r--pkg/sentry/kernel/time/BUILD4
-rw-r--r--pkg/sentry/limits/BUILD3
-rw-r--r--pkg/sentry/loader/BUILD3
-rw-r--r--pkg/sentry/memmap/BUILD5
-rw-r--r--pkg/sentry/mm/BUILD5
-rw-r--r--pkg/sentry/pgalloc/BUILD5
-rw-r--r--pkg/sentry/platform/BUILD4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go16
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go3
-rw-r--r--pkg/sentry/platform/ring0/BUILD3
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD3
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD3
-rw-r--r--pkg/sentry/socket/BUILD4
-rw-r--r--pkg/sentry/socket/control/BUILD4
-rw-r--r--pkg/sentry/socket/hostinet/BUILD4
-rw-r--r--pkg/sentry/socket/netfilter/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD3
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD4
-rw-r--r--pkg/sentry/socket/netstack/BUILD4
-rw-r--r--pkg/sentry/socket/netstack/netstack.go40
-rw-r--r--pkg/sentry/socket/netstack/provider.go53
-rw-r--r--pkg/sentry/socket/unix/BUILD4
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD4
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go5
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go16
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go9
-rw-r--r--pkg/sentry/socket/unix/unix.go3
-rw-r--r--pkg/sentry/syscalls/linux/BUILD4
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go11
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go8
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go1
-rw-r--r--pkg/sentry/time/BUILD3
-rw-r--r--pkg/sentry/usage/BUILD4
-rw-r--r--pkg/sentry/usermem/BUILD5
-rw-r--r--pkg/sentry/vfs/file_description.go4
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go4
-rw-r--r--pkg/sentry/vfs/syscalls.go22
-rw-r--r--pkg/state/BUILD3
-rw-r--r--pkg/tcpip/buffer/prependable.go5
-rw-r--r--pkg/tcpip/checker/checker.go61
-rw-r--r--pkg/tcpip/header/BUILD19
-rw-r--r--pkg/tcpip/header/checksum.go50
-rw-r--r--pkg/tcpip/header/checksum_test.go109
-rw-r--r--pkg/tcpip/header/eth.go62
-rw-r--r--pkg/tcpip/header/eth_test.go68
-rw-r--r--pkg/tcpip/header/icmpv6.go32
-rw-r--r--pkg/tcpip/header/ipv6.go12
-rw-r--r--pkg/tcpip/header/ndp_neighbor_advert.go110
-rw-r--r--pkg/tcpip/header/ndp_neighbor_solicit.go52
-rw-r--r--pkg/tcpip/header/ndp_options.go172
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go112
-rw-r--r--pkg/tcpip/header/ndp_test.go199
-rw-r--r--pkg/tcpip/link/channel/channel.go48
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go143
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go21
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go5
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go22
-rw-r--r--pkg/tcpip/link/loopback/BUILD1
-rw-r--r--pkg/tcpip/link/loopback/loopback.go27
-rw-r--r--pkg/tcpip/link/muxed/injectable.go34
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go2
-rw-r--r--pkg/tcpip/link/rawfile/BUILD5
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go11
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go22
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go16
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go85
-rw-r--r--pkg/tcpip/link/waitable/waitable.go28
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go19
-rw-r--r--pkg/tcpip/network/arp/arp.go9
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD1
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go16
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go10
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go10
-rw-r--r--pkg/tcpip/network/ip_test.go9
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go35
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go14
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go132
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go12
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go30
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go4
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/ndp.go319
-rw-r--r--pkg/tcpip/stack/ndp_test.go443
-rw-r--r--pkg/tcpip/stack/nic.go309
-rw-r--r--pkg/tcpip/stack/registration.go83
-rw-r--r--pkg/tcpip/stack/route.go48
-rw-r--r--pkg/tcpip/stack/stack.go231
-rw-r--r--pkg/tcpip/stack/stack_test.go302
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go2
-rw-r--r--pkg/tcpip/tcpip.go13
-rw-r--r--pkg/tcpip/transport/packet/BUILD46
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go363
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go72
-rw-r--r--pkg/tcpip/transport/raw/BUILD15
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go30
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go14
-rw-r--r--pkg/tcpip/transport/raw/protocol.go12
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/connect.go77
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go11
-rw-r--r--runsc/boot/config.go14
-rw-r--r--runsc/boot/controller.go4
-rw-r--r--runsc/boot/filter/config.go10
-rw-r--r--runsc/boot/fs.go16
-rw-r--r--runsc/boot/loader.go2
-rw-r--r--runsc/boot/network.go14
-rw-r--r--runsc/cmd/gofer.go29
-rw-r--r--runsc/container/container.go6
-rw-r--r--runsc/container/container_test.go37
-rw-r--r--runsc/fsgofer/filter/config.go12
-rw-r--r--runsc/fsgofer/fsgofer.go105
-rw-r--r--runsc/main.go32
-rw-r--r--runsc/sandbox/BUILD1
-rw-r--r--runsc/sandbox/network.go12
-rw-r--r--runsc/specutils/BUILD1
-rw-r--r--runsc/specutils/cri.go110
-rw-r--r--runsc/specutils/specutils.go58
-rw-r--r--runsc/testutil/BUILD1
-rw-r--r--runsc/testutil/testutil.go4
-rwxr-xr-xscripts/swgso_tests.sh21
-rw-r--r--test/README.md28
-rw-r--r--test/e2e/exec_test.go2
-rw-r--r--test/syscalls/BUILD7
-rw-r--r--test/syscalls/build_defs.bzl10
-rw-r--r--test/syscalls/linux/BUILD18
-rw-r--r--test/syscalls/linux/accept_bind.cc12
-rw-r--r--test/syscalls/linux/connect_external.cc164
-rw-r--r--test/syscalls/linux/exec.cc251
-rw-r--r--test/syscalls/linux/packet_socket.cc28
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc27
-rw-r--r--test/syscalls/linux/proc_net.cc140
-rw-r--r--test/syscalls/linux/sendfile_socket.cc8
-rw-r--r--test/syscalls/linux/socket_unix_non_stream.cc14
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc46
-rw-r--r--test/syscalls/syscall_test_runner.go266
-rw-r--r--test/uds/BUILD17
-rw-r--r--test/uds/uds.go228
-rw-r--r--test/util/fs_util.cc20
-rw-r--r--test/util/fs_util.h9
-rw-r--r--test/util/multiprocess_util.cc46
-rw-r--r--test/util/multiprocess_util.h7
-rw-r--r--third_party/gvsync/BUILD3
-rw-r--r--third_party/gvsync/atomicptrtest/BUILD3
-rw-r--r--third_party/gvsync/seqatomictest/BUILD3
-rw-r--r--tools/go_generics/rules_tests/BUILD3
220 files changed, 6850 insertions, 1148 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 8f50a3e57..57e6f3558 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -3,10 +3,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "io_bazel_rules_go",
- sha256 = "078f2a9569fa9ed846e60805fb5fb167d6f6c4ece48e6d409bf5fb2154eaf0d8",
+ sha256 = "842ec0e6b4fbfdd3de6150b61af92901eeb73681fd4d185746644c338f51d4c0",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.20.0/rules_go-v0.20.0.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/v0.20.0/rules_go-v0.20.0.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.20.1/rules_go-v0.20.1.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.20.1/rules_go-v0.20.1.tar.gz",
],
)
@@ -24,7 +24,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.13.1",
+ go_version = "1.13.3",
nogo = "@//:nogo",
)
diff --git a/kokoro/swgso_tests.cfg b/kokoro/swgso_tests.cfg
new file mode 100644
index 000000000..101a9c607
--- /dev/null
+++ b/kokoro/swgso_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/swgso_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD
index 32c601a03..f5c08ea06 100644
--- a/pkg/abi/BUILD
+++ b/pkg/abi/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "abi",
srcs = [
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index f45934466..7c17109a6 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -1,13 +1,12 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
# Package linux contains the constants and types needed to interface with a
# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix
# when the host OS may not be Linux.
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "linux",
srcs = [
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index d5b731390..2e2cc6be7 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -256,6 +256,17 @@ type SockAddrInet6 struct {
Scope_id uint32
}
+// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h.
+type SockAddrLink struct {
+ Family uint16
+ Protocol uint16
+ InterfaceIndex int32
+ ARPHardwareType uint16
+ PacketType byte
+ HardwareAddrLen byte
+ HardwareAddr [8]byte
+}
+
// UnixPathMax is the maximum length of the path in an AF_UNIX socket.
//
// From uapi/linux/un.h.
@@ -278,6 +289,7 @@ type SockAddr interface {
func (s *SockAddrInet) implementsSockAddr() {}
func (s *SockAddrInet6) implementsSockAddr() {}
+func (s *SockAddrLink) implementsSockAddr() {}
func (s *SockAddrUnix) implementsSockAddr() {}
func (s *SockAddrNetlink) implementsSockAddr() {}
diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD
index 1b5dac99a..93b88a29a 100644
--- a/pkg/bits/BUILD
+++ b/pkg/bits/BUILD
@@ -1,10 +1,9 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-
go_library(
name = "bits",
srcs = [
diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD
index 8d31e068c..fba5643e8 100644
--- a/pkg/bpf/BUILD
+++ b/pkg/bpf/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "bpf",
srcs = [
diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD
index 32422f9e2..ed111fd2a 100644
--- a/pkg/cpuid/BUILD
+++ b/pkg/cpuid/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "cpuid",
srcs = [
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index c7f549428..afa8f7659 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -8,9 +8,6 @@ go_library(
srcs = ["fd.go"],
importpath = "gvisor.dev/gvisor/pkg/fd",
visibility = ["//visibility:public"],
- deps = [
- "//pkg/unet",
- ],
)
go_test(
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
index 7691b477b..83bcfe220 100644
--- a/pkg/fd/fd.go
+++ b/pkg/fd/fd.go
@@ -22,8 +22,6 @@ import (
"runtime"
"sync/atomic"
"syscall"
-
- "gvisor.dev/gvisor/pkg/unet"
)
// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
@@ -187,12 +185,6 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
return New(f), nil
}
-// DialUnix connects to a Unix Domain Socket and return the file descriptor.
-func DialUnix(path string) (*FD, error) {
- socket, err := unet.Connect(path, false)
- return New(socket.FD()), err
-}
-
// Close closes the file descriptor contained in the FD.
//
// Close is safe to call multiple times, but will return an error after the
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 827385139..7ad59dfd7 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "weak_ref_list",
out = "weak_ref_list.go",
diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD
index 700385907..1b487b887 100644
--- a/pkg/segment/BUILD
+++ b/pkg/segment/BUILD
@@ -1,10 +1,10 @@
+load("//tools/go_generics:defs.bzl", "go_template")
+
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
)
-load("//tools/go_generics:defs.bzl", "go_template")
-
go_template(
name = "generic_range",
srcs = ["range.go"],
diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD
index 12d7c77d2..a27c35e21 100644
--- a/pkg/segment/test/BUILD
+++ b/pkg/segment/test/BUILD
@@ -1,13 +1,12 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"],
)
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "int_range",
out = "int_range.go",
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index c71cff9f3..18c73cc24 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@rules_cc//cc:defs.bzl", "cc_proto_library")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "arch",
srcs = [
diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD
index 3b6841b7e..581e7aa96 100644
--- a/pkg/sentry/context/contexttest/BUILD
+++ b/pkg/sentry/context/contexttest/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "contexttest",
testonly = 1,
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
index 0c86197f7..1098ed777 100644
--- a/pkg/sentry/device/BUILD
+++ b/pkg/sentry/device/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "device",
srcs = ["device.go"],
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index 3119a61b6..378602cc9 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "fs",
srcs = [
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 80e106e6f..a0d9e8496 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "dev",
srcs = [
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index b9bd9ed17..277ee4c31 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "fdpipe",
srcs = [
diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD
index a9d6d9301..358dc2be3 100644
--- a/pkg/sentry/fs/filetest/BUILD
+++ b/pkg/sentry/fs/filetest/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "filetest",
testonly = 1,
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index b4ac83dc4..b2e8d9c77 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "dirty_set_impl",
out = "dirty_set_impl.go",
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index e239f12a5..b06a71cc2 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -209,3 +209,29 @@ func (f *HostFileMapper) unmapAndRemoveLocked(chunkStart uint64, m mapping) {
}
delete(f.mappings, chunkStart)
}
+
+// RegenerateMappings must be called when the file description mapped by f
+// changes, to replace existing mappings of the previous file description.
+func (f *HostFileMapper) RegenerateMappings(fd int) error {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+
+ for chunkStart, m := range f.mappings {
+ prot := syscall.PROT_READ
+ if m.writable {
+ prot |= syscall.PROT_WRITE
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ m.addr,
+ chunkSize,
+ uintptr(prot),
+ syscall.MAP_SHARED|syscall.MAP_FIXED,
+ uintptr(fd),
+ uintptr(chunkStart))
+ if errno != 0 {
+ return errno
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index 693625ddc..30475f340 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -100,13 +100,30 @@ func (h *HostMappable) Translate(ctx context.Context, required, optional memmap.
}
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
-func (h *HostMappable) InvalidateUnsavable(ctx context.Context) error {
+func (h *HostMappable) InvalidateUnsavable(_ context.Context) error {
h.mu.Lock()
h.mappings.InvalidateAll(memmap.InvalidateOpts{})
h.mu.Unlock()
return nil
}
+// NotifyChangeFD must be called after the file description represented by
+// CachedFileObject.FD() changes.
+func (h *HostMappable) NotifyChangeFD() error {
+ // Update existing sentry mappings to refer to the new file description.
+ if err := h.hostFileMapper.RegenerateMappings(h.backingFile.FD()); err != nil {
+ return err
+ }
+
+ // Shoot down existing application mappings of the old file description;
+ // they will be remapped with the new file description on demand.
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ h.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ return nil
+}
+
// MapInternal implements platform.File.MapInternal.
func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write)
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index dd80757dc..798920d18 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -959,6 +959,23 @@ func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error
return nil
}
+// NotifyChangeFD must be called after the file description represented by
+// CachedFileObject.FD() changes.
+func (c *CachingInodeOperations) NotifyChangeFD() error {
+ // Update existing sentry mappings to refer to the new file description.
+ if err := c.hostFileMapper.RegenerateMappings(c.backingFile.FD()); err != nil {
+ return err
+ }
+
+ // Shoot down existing application mappings of the old file description;
+ // they will be remapped with the new file description on demand.
+ c.mapsMu.Lock()
+ defer c.mapsMu.Unlock()
+
+ c.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ return nil
+}
+
// Evict implements pgalloc.EvictableMemoryUser.Evict.
func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.EvictableRange) {
c.mapsMu.Lock()
@@ -1027,7 +1044,6 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
}
c.refs.MergeAdjacent(fr)
c.dataMu.Unlock()
-
}
// MapInternal implements platform.File.MapInternal. This is used when we
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 2b71ca0e1..4a005c605 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "gofer",
srcs = [
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index 9e2e412cd..7960b9c7b 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -214,28 +214,64 @@ func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr
return entries, nil
}
+// maybeSync will call FSync on the file if either the cache policy or file
+// flags require it.
+func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n int64) error {
+ if n == 0 {
+ // Nothing to sync.
+ return nil
+ }
+
+ if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) {
+ // Call WriteOut directly, as some "writethrough" filesystems
+ // do not support sync.
+ return f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode)
+ }
+
+ flags := file.Flags()
+ var syncType fs.SyncType
+ switch {
+ case flags.Direct || flags.Sync:
+ syncType = fs.SyncAll
+ case flags.DSync:
+ syncType = fs.SyncData
+ default:
+ // No need to sync.
+ return nil
+ }
+
+ return f.Fsync(ctx, file, offset, offset+n, syncType)
+}
+
// Write implements fs.FileOperations.Write.
func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
if fs.IsDir(file.Dirent.Inode.StableAttr) {
// Not all remote file systems enforce this so this client does.
return 0, syserror.EISDIR
}
- cp := f.inodeOperations.session().cachePolicy
- if cp.useCachingInodeOps(file.Dirent.Inode) {
- n, err := f.inodeOperations.cachingInodeOps.Write(ctx, src, offset)
- if err != nil {
- return n, err
- }
- if cp.writeThrough(file.Dirent.Inode) {
- // Write out the file.
- err = f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode)
- }
- return n, err
+
+ var (
+ n int64
+ err error
+ )
+ // The write is handled in different ways depending on the cache policy
+ // and availability of a host-mappable FD.
+ if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
+ n, err = f.inodeOperations.cachingInodeOps.Write(ctx, src, offset)
+ } else if f.inodeOperations.fileState.hostMappable != nil {
+ n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset)
+ } else {
+ n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
}
- if f.inodeOperations.fileState.hostMappable != nil {
- return f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset)
+
+ // We may need to sync the written bytes.
+ if syncErr := f.maybeSync(ctx, file, offset, n); syncErr != nil {
+ // Sync failed. Report 0 bytes written, since none of them are
+ // guaranteed to have been synced.
+ return 0, syncErr
}
- return src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
+
+ return n, err
}
// incrementReadCounters increments the read counters for the read starting at the given time. We
@@ -273,7 +309,7 @@ func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO
}
// Fsync implements fs.FileOperations.Fsync.
-func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start int64, end int64, syncType fs.SyncType) error {
+func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start, end int64, syncType fs.SyncType) error {
switch syncType {
case fs.SyncAll, fs.SyncData:
if err := file.Dirent.Inode.WriteOut(ctx); err != nil {
diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go
index 9aa68a70e..c2fbb4be9 100644
--- a/pkg/sentry/fs/gofer/file_state.go
+++ b/pkg/sentry/fs/gofer/file_state.go
@@ -29,7 +29,7 @@ func (f *fileOperations) afterLoad() {
// Manually load the open handles.
var err error
// TODO(b/38173783): Context is not plumbed to save/restore.
- f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags)
+ f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags, f.inodeOperations.cachingInodeOps)
if err != nil {
return fmt.Errorf("failed to re-open handle: %v", err)
}
diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go
index 8f8ab5d29..cf96dd9fa 100644
--- a/pkg/sentry/fs/gofer/fs.go
+++ b/pkg/sentry/fs/gofer/fs.go
@@ -58,6 +58,11 @@ const (
// If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to
// true.
limitHostFDTranslationKey = "limit_host_fd_translation"
+
+ // overlayfsStaleRead if present closes cached readonly file after the first
+ // write. This is done to workaround a limitation of overlayfs in kernels
+ // before 4.19 where open FDs are not updated after the file is copied up.
+ overlayfsStaleRead = "overlayfs_stale_read"
)
// defaultAname is the default attach name.
@@ -145,6 +150,7 @@ type opts struct {
version string
privateunixsocket bool
limitHostFDTranslation bool
+ overlayfsStaleRead bool
}
// options parses mount(2) data into structured options.
@@ -247,6 +253,11 @@ func options(data string) (opts, error) {
delete(options, limitHostFDTranslationKey)
}
+ if _, ok := options[overlayfsStaleRead]; ok {
+ o.overlayfsStaleRead = true
+ delete(options, overlayfsStaleRead)
+ }
+
// Fail to attach if the caller wanted us to do something that we
// don't support.
if len(options) > 0 {
diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go
index 27eeae3d9..39c8ec33d 100644
--- a/pkg/sentry/fs/gofer/handles.go
+++ b/pkg/sentry/fs/gofer/handles.go
@@ -39,14 +39,22 @@ type handles struct {
// Host is an *fd.FD handle. May be nil.
Host *fd.FD
+
+ // isHostBorrowed tells whether 'Host' is owned or borrowed. If owned, it's
+ // closed on destruction, otherwise it's released.
+ isHostBorrowed bool
}
// DecRef drops a reference on handles.
func (h *handles) DecRef() {
h.DecRefWithDestructor(func() {
if h.Host != nil {
- if err := h.Host.Close(); err != nil {
- log.Warningf("error closing host file: %v", err)
+ if h.isHostBorrowed {
+ h.Host.Release()
+ } else {
+ if err := h.Host.Close(); err != nil {
+ log.Warningf("error closing host file: %v", err)
+ }
}
}
// FIXME(b/38173783): Context is not plumbed here.
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index d918d6620..99910388f 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -100,7 +100,7 @@ type inodeFileState struct {
// true.
//
// Once readHandles becomes non-nil, it can't be changed until
- // inodeFileState.Release(), because of a defect in the
+ // inodeFileState.Release()*, because of a defect in the
// fsutil.CachedFileObject interface: there's no way for the caller of
// fsutil.CachedFileObject.FD() to keep the returned FD open, so if we
// racily replace readHandles after inodeFileState.FD() has returned
@@ -108,6 +108,9 @@ type inodeFileState struct {
// FD. writeHandles can be changed if writeHandlesRW is false, since
// inodeFileState.FD() can't return a write-only FD, but can't be changed
// if writeHandlesRW is true for the same reason.
+ //
+ // * There is one notable exception in recreateReadHandles(), where it dup's
+ // the FD and invalidates the page cache.
readHandles *handles `state:"nosave"`
writeHandles *handles `state:"nosave"`
writeHandlesRW bool `state:"nosave"`
@@ -175,43 +178,124 @@ func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles)
// getHandles returns a set of handles for a new file using i opened with the
// given flags.
-func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags) (*handles, error) {
+func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cache *fsutil.CachingInodeOperations) (*handles, error) {
if !i.canShareHandles() {
return newHandles(ctx, i.file, flags)
}
+
i.handlesMu.Lock()
- defer i.handlesMu.Unlock()
+ h, invalidate, err := i.getHandlesLocked(ctx, flags)
+ i.handlesMu.Unlock()
+
+ if invalidate {
+ cache.NotifyChangeFD()
+ if i.hostMappable != nil {
+ i.hostMappable.NotifyChangeFD()
+ }
+ }
+
+ return h, err
+}
+
+// getHandlesLocked returns a pointer to cached handles and a boolean indicating
+// whether previously open read handle was recreated. Host mappings must be
+// invalidated if so.
+func (i *inodeFileState) getHandlesLocked(ctx context.Context, flags fs.FileFlags) (*handles, bool, error) {
// Do we already have usable shared handles?
if flags.Write {
if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) {
i.writeHandles.IncRef()
- return i.writeHandles, nil
+ return i.writeHandles, false, nil
}
} else if i.readHandles != nil {
i.readHandles.IncRef()
- return i.readHandles, nil
+ return i.readHandles, false, nil
}
+
// No; get new handles and cache them for future sharing.
h, err := newHandles(ctx, i.file, flags)
if err != nil {
- return nil, err
+ return nil, false, err
+ }
+
+ // Read handles invalidation is needed if:
+ // - Mount option 'overlayfs_stale_read' is set
+ // - Read handle is open: nothing to invalidate otherwise
+ // - Write handle is not open: file was not open for write and is being open
+ // for write now (will trigger copy up in overlayfs).
+ invalidate := false
+ if i.s.overlayfsStaleRead && i.readHandles != nil && i.writeHandles == nil && flags.Write {
+ if err := i.recreateReadHandles(ctx, h, flags); err != nil {
+ return nil, false, err
+ }
+ invalidate = true
}
i.setSharedHandlesLocked(flags, h)
- return h, nil
+ return h, invalidate, nil
+}
+
+func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handles, flags fs.FileFlags) error {
+ h := writer
+ if !flags.Read {
+ // Writer can't be used for read, must create a new handle.
+ var err error
+ h, err = newHandles(ctx, i.file, fs.FileFlags{Read: true})
+ if err != nil {
+ return err
+ }
+ defer h.DecRef()
+ }
+
+ if i.readHandles.Host == nil {
+ // If current readHandles doesn't have a host FD, it can simply be replaced.
+ i.readHandles.DecRef()
+
+ h.IncRef()
+ i.readHandles = h
+ return nil
+ }
+
+ if h.Host == nil {
+ // Current read handle has a host FD and can't be replaced with one that
+ // doesn't, because it breaks fsutil.CachedFileObject.FD() contract.
+ log.Warningf("Read handle can't be invalidated, reads may return stale data")
+ return nil
+ }
+
+ // Due to a defect in the fsutil.CachedFileObject interface,
+ // readHandles.Host.FD() may be used outside locks, making it impossible to
+ // reliably close it. To workaround it, we dup the new FD into the old one, so
+ // operations on the old will see the new data. Then, make the new handle take
+ // ownereship of the old FD and mark the old readHandle to not close the FD
+ // when done.
+ if err := syscall.Dup2(h.Host.FD(), i.readHandles.Host.FD()); err != nil {
+ return err
+ }
+
+ h.Host.Close()
+ h.Host = fd.New(i.readHandles.Host.FD())
+ i.readHandles.isHostBorrowed = true
+ i.readHandles.DecRef()
+
+ h.IncRef()
+ i.readHandles = h
+ return nil
}
// ReadToBlocksAt implements fsutil.CachedFileObject.ReadToBlocksAt.
func (i *inodeFileState) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
i.handlesMu.RLock()
- defer i.handlesMu.RUnlock()
- return i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts)
+ n, err := i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts)
+ i.handlesMu.RUnlock()
+ return n, err
}
// WriteFromBlocksAt implements fsutil.CachedFileObject.WriteFromBlocksAt.
func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
i.handlesMu.RLock()
- defer i.handlesMu.RUnlock()
- return i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs)
+ n, err := i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs)
+ i.handlesMu.RUnlock()
+ return n, err
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
@@ -449,7 +533,7 @@ func (i *inodeOperations) NonBlockingOpen(ctx context.Context, p fs.PermMask) (*
}
func (i *inodeOperations) getFileDefault(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
- h, err := i.fileState.getHandles(ctx, flags)
+ h, err := i.fileState.getHandles(ctx, flags, i.cachingInodeOps)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 50da865c1..0da608548 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -122,6 +122,10 @@ type session struct {
// CachingInodeOperations created by the session.
limitHostFDTranslation bool
+ // overlayfsStaleRead when set causes the readonly handle to be invalidated
+ // after file is open for write.
+ overlayfsStaleRead bool
+
// connID is a unique identifier for the session connection.
connID string `state:"wait"`
@@ -257,6 +261,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
aname: o.aname,
superBlockFlags: superBlockFlags,
limitHostFDTranslation: o.limitHostFDTranslation,
+ overlayfsStaleRead: o.overlayfsStaleRead,
mounter: mounter,
}
s.EnableLeakCheck("gofer.session")
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 3e532332e..1cbed07ae 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "host",
srcs = [
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 2392787cb..107336a3e 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -385,3 +385,6 @@ func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
func (c *ConnectedEndpoint) Release() {
c.ref.DecRefWithDestructor(c.close)
}
+
+// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
+func (c *ConnectedEndpoint) CloseUnread() {}
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index 5a7a5b8cd..8d62642e7 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "lock_range",
out = "lock_range.go",
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index c307603a6..75cbb0622 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "proc",
srcs = [
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index 76433c7d0..fe7067be1 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "seqfile",
srcs = ["seqfile.go"],
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index d0f351e5a..012cb3e44 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "ramfs",
srcs = [
diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD
index 70fa3af89..25f0f124e 100644
--- a/pkg/sentry/fs/sys/BUILD
+++ b/pkg/sentry/fs/sys/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "sys",
srcs = [
diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD
index 1d80daeaf..a215c1b95 100644
--- a/pkg/sentry/fs/timerfd/BUILD
+++ b/pkg/sentry/fs/timerfd/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "timerfd",
srcs = ["timerfd.go"],
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 11b680929..59ce400c2 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "tmpfs",
srcs = [
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 25811f668..95ad98cb0 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "tty",
srcs = [
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index b0c286b7a..7ccff8b0d 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
+package(licenses = ["notice"])
+
go_template_instance(
name = "dirent_list",
out = "dirent_list.go",
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 0b471d121..91802dc1e 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -301,8 +301,8 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
return offset, nil
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
-func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
// mmap(2) specifies that EACCESS should be returned for non-regular file fds.
return syserror.EACCES
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 2d50e30aa..fcfaf5c3e 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "disklayout",
srcs = [
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
index a0065343b..4d18b28cb 100644
--- a/pkg/sentry/fsimpl/ext/file_description.go
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -43,9 +43,6 @@ func (fd *fileDescription) inode() *inode {
return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode
}
-// OnClose implements vfs.FileDescriptionImpl.OnClose.
-func (fd *fileDescription) OnClose() error { return nil }
-
// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
return fd.flags, nil
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index ffc76ba5b..aec33e00a 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -152,8 +152,8 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
return offset, nil
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
-func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
// TODO(b/134676337): Implement mmap(2).
return syserror.ENODEV
}
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
index e06548a98..bdf8705c1 100644
--- a/pkg/sentry/fsimpl/ext/symlink.go
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -105,7 +105,7 @@ func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int6
return 0, syserror.EBADF
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
-func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
return syserror.EBADF
}
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD
index 7e364c5fd..04d667273 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/memfs/BUILD
@@ -24,14 +24,18 @@ go_library(
"directory.go",
"filesystem.go",
"memfs.go",
+ "named_pipe.go",
"regular_file.go",
"symlink.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs",
deps = [
"//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/sentry/arch",
"//pkg/sentry/context",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
@@ -54,3 +58,19 @@ go_test(
"//pkg/syserror",
],
)
+
+go_test(
+ name = "memfs_test",
+ size = "small",
+ srcs = ["pipe_test.go"],
+ embed = [":memfs"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/context",
+ "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go
index f79e2d9c8..f006c40cd 100644
--- a/pkg/sentry/fsimpl/memfs/filesystem.go
+++ b/pkg/sentry/fsimpl/memfs/filesystem.go
@@ -233,7 +233,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if err != nil {
return err
}
- _, err = checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -241,8 +241,40 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
defer rp.Mount().EndWrite()
- // TODO: actually implement mknod
- return syserror.EPERM
+
+ switch opts.Mode.FileType() {
+ case 0:
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ fallthrough
+ case linux.ModeRegular:
+ // TODO(b/138862511): Implement.
+ return syserror.EINVAL
+
+ case linux.ModeNamedPipe:
+ child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode))
+ parentVFSD.InsertChild(&child.vfsd, pc)
+ parentInode.impl.(*directory).childList.PushBack(child)
+ return nil
+
+ case linux.ModeSocket:
+ // TODO(b/138862511): Implement.
+ return syserror.EINVAL
+
+ case linux.ModeCharacterDevice:
+ fallthrough
+ case linux.ModeBlockDevice:
+ // TODO(b/72101894): We don't support creating block or character
+ // devices at the moment.
+ //
+ // When we start supporting block and character devices, we'll
+ // need to check for CAP_MKNOD here.
+ return syserror.EPERM
+
+ default:
+ // "EINVAL - mode requested creation of something other than a
+ // regular file, device special file, FIFO or socket." - mknod(2)
+ return syserror.EINVAL
+ }
}
// OpenAt implements vfs.FilesystemImpl.OpenAt.
@@ -250,8 +282,9 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// Filter out flags that are not supported by memfs. O_DIRECTORY and
// O_NOFOLLOW have no effect here (they're handled by VFS by setting
// appropriate bits in rp), but are returned by
- // FileDescriptionImpl.StatusFlags().
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW
+ // FileDescriptionImpl.StatusFlags(). O_NONBLOCK is supported only by
+ // pipes.
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK
if opts.Flags&linux.O_CREAT == 0 {
fs.mu.RLock()
@@ -260,7 +293,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return nil, err
}
- return inode.open(rp, vfsd, opts.Flags, false)
+ return inode.open(ctx, rp, vfsd, opts.Flags, false)
}
mustCreate := opts.Flags&linux.O_EXCL != 0
@@ -275,7 +308,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
- return inode.open(rp, vfsd, opts.Flags, false)
+ return inode.open(ctx, rp, vfsd, opts.Flags, false)
}
afterTrailingSymlink:
// Walk to the parent directory of the last path component.
@@ -320,7 +353,7 @@ afterTrailingSymlink:
child := fs.newDentry(childInode)
vfsd.InsertChild(&child.vfsd, pc)
inode.impl.(*directory).childList.PushBack(child)
- return childInode.open(rp, &child.vfsd, opts.Flags, true)
+ return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true)
}
// Open existing file or follow symlink.
if mustCreate {
@@ -336,10 +369,10 @@ afterTrailingSymlink:
// symlink target.
goto afterTrailingSymlink
}
- return childInode.open(rp, childVFSD, opts.Flags, false)
+ return childInode.open(ctx, rp, childVFSD, opts.Flags, false)
}
-func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
+func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(flags)
if !afterCreate {
if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil {
@@ -378,6 +411,8 @@ func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte
case *symlink:
// Can't open symlinks without O_PATH (which is unimplemented).
return nil, syserror.ELOOP
+ case *namedPipe:
+ return newNamedPipeFD(ctx, impl, rp, vfsd, flags)
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go
index b78471c0f..64c851c1a 100644
--- a/pkg/sentry/fsimpl/memfs/memfs.go
+++ b/pkg/sentry/fsimpl/memfs/memfs.go
@@ -227,6 +227,8 @@ func (i *inode) statTo(stat *linux.Statx) {
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
stat.Size = uint64(len(impl.target))
stat.Blocks = allocatedBlocksForSize(stat.Size)
+ case *namedPipe:
+ stat.Mode |= linux.S_IFIFO
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/memfs/named_pipe.go
new file mode 100644
index 000000000..732ed7c58
--- /dev/null
+++ b/pkg/sentry/fsimpl/memfs/named_pipe.go
@@ -0,0 +1,59 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+type namedPipe struct {
+ inode inode
+
+ pipe *pipe.VFSPipe
+}
+
+// Preconditions:
+// * fs.mu must be locked.
+// * rp.Mount().CheckBeginWrite() has been called successfully.
+func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode {
+ file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)}
+ file.inode.init(file, fs, creds, mode)
+ file.inode.nlink = 1 // Only the parent has a link.
+ return &file.inode
+}
+
+// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented
+// entirely via struct embedding.
+type namedPipeFD struct {
+ fileDescription
+
+ *pipe.VFSPipeFD
+}
+
+func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ var err error
+ var fd namedPipeFD
+ fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, rp, vfsd, &fd.vfsfd, flags)
+ if err != nil {
+ return nil, err
+ }
+ fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ return &fd.vfsfd, nil
+}
diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/memfs/pipe_test.go
new file mode 100644
index 000000000..0674b81a3
--- /dev/null
+++ b/pkg/sentry/fsimpl/memfs/pipe_test.go
@@ -0,0 +1,233 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memfs
+
+import (
+ "bytes"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const fileName = "mypipe"
+
+func TestSeparateFDs(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the read side. This is done in a concurrently because opening
+ // One end the pipe blocks until the other end is opened.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ rfdchan := make(chan *vfs.FileDescription)
+ go func() {
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY}
+ rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ rfdchan <- rfd
+ }()
+
+ // Open the write side.
+ openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY}
+ wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer wfd.DecRef()
+
+ rfd, ok := <-rfdchan
+ if !ok {
+ t.Fatalf("failed to open pipe for reading %q", fileName)
+ }
+ defer rfd.DecRef()
+
+ const msg = "vamos azul"
+ checkEmpty(ctx, t, rfd)
+ checkWrite(ctx, t, wfd, msg)
+ checkRead(ctx, t, rfd, msg)
+}
+
+func TestNonblockingRead(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the read side as nonblocking.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK}
+ rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for reading %q: %v", fileName, err)
+ }
+ defer rfd.DecRef()
+
+ // Open the write side.
+ openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY}
+ wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer wfd.DecRef()
+
+ const msg = "geh blau"
+ checkEmpty(ctx, t, rfd)
+ checkWrite(ctx, t, wfd, msg)
+ checkRead(ctx, t, rfd, msg)
+}
+
+func TestNonblockingWriteError(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the write side as nonblocking, which should return ENXIO.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK}
+ _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != syserror.ENXIO {
+ t.Fatalf("expected ENXIO, but got error: %v", err)
+ }
+}
+
+func TestSingleFD(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the pipe as readable and writable.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDWR}
+ fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer fd.DecRef()
+
+ const msg = "forza blu"
+ checkEmpty(ctx, t, fd)
+ checkWrite(ctx, t, fd, msg)
+ checkRead(ctx, t, fd, msg)
+}
+
+// setup creates a VFS with a pipe in the root directory at path fileName. The
+// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal
+// upon failure.
+func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+
+ // Create the pipe.
+ root := mntns.Root()
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644}
+ if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil {
+ t.Fatalf("failed to create file %q: %v", fileName, err)
+ }
+
+ // Sanity check: the file pipe exists and has the correct mode.
+ stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat(%q) failed: %v", fileName, err)
+ }
+ if stat.Mode&^linux.S_IFMT != 0644 {
+ t.Errorf("got wrong permissions (%0o)", stat.Mode)
+ }
+ if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe {
+ t.Errorf("got wrong file type (%0o)", stat.Mode)
+ }
+
+ return ctx, creds, vfsObj, root
+}
+
+// checkEmpty calls t.Fatal if the pipe in fd is not empty.
+func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
+ readData := make([]byte, 1)
+ dst := usermem.BytesIOSequence(readData)
+ bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{})
+ if err != syserror.ErrWouldBlock {
+ t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err)
+ }
+ if bytesRead != 0 {
+ t.Fatalf("expected to read 0 bytes, but got %d", bytesRead)
+ }
+}
+
+// checkWrite calls t.Fatal if it fails to write all of msg to fd.
+func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
+ writeData := []byte(msg)
+ src := usermem.BytesIOSequence(writeData)
+ bytesWritten, err := fd.Impl().Write(ctx, src, vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("error writing to pipe %q: %v", fileName, err)
+ }
+ if bytesWritten != int64(len(writeData)) {
+ t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten)
+ }
+}
+
+// checkRead calls t.Fatal if it fails to read msg from fd.
+func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
+ readData := make([]byte, len(msg))
+ dst := usermem.BytesIOSequence(readData)
+ bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{})
+ if err != nil {
+ t.Fatalf("error reading from pipe %q: %v", fileName, err)
+ }
+ if bytesRead != int64(len(msg)) {
+ t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead)
+ }
+ if !bytes.Equal(readData, []byte(msg)) {
+ t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData))
+ }
+}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 184b566d9..d5284f0d9 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -1,10 +1,10 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
)
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "inet",
srcs = [
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index aba2414d4..e041c51b3 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -1,12 +1,11 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("@rules_cc//cc:defs.bzl", "cc_proto_library")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "pending_signals_list",
out = "pending_signals_list.go",
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 1d00a6310..51de4568a 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "atomicptr_credentials",
out = "atomicptr_credentials_unsafe.go",
diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD
index bec13a3d9..3a88a585c 100644
--- a/pkg/sentry/kernel/contexttest/BUILD
+++ b/pkg/sentry/kernel/contexttest/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "contexttest",
testonly = 1,
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index 65427b112..3361e8b7d 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "epoll_list",
out = "epoll_list.go",
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index 983ca67ed..e65b961e8 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "eventfd",
srcs = ["eventfd.go"],
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index 5eddca115..49d81b712 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "fasync",
srcs = ["fasync.go"],
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index cc3f43a45..11f613a11 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -81,6 +81,9 @@ type FDTable struct {
// mu protects below.
mu sync.Mutex `state:"nosave"`
+ // next is start position to find fd.
+ next int32
+
// used contains the number of non-nil entries. It must be accessed
// atomically. It may be read atomically without holding mu (but not
// written).
@@ -226,6 +229,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
f.mu.Lock()
defer f.mu.Unlock()
+ // From f.next to find available fd.
+ if fd < f.next {
+ fd = f.next
+ }
+
// Install all entries.
for i := fd; i < end && len(fds) < len(files); i++ {
if d, _, _ := f.get(i); d == nil {
@@ -242,6 +250,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return nil, syscall.EMFILE
}
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fds[len(fds)-1] + 1
+ }
+
return fds, nil
}
@@ -361,6 +374,11 @@ func (f *FDTable) Remove(fd int32) *fs.File {
f.mu.Lock()
defer f.mu.Unlock()
+ // Update current available position.
+ if fd < f.next {
+ f.next = fd
+ }
+
orig, _, _ := f.get(fd)
if orig != nil {
orig.IncRef() // Reference for caller.
@@ -377,6 +395,10 @@ func (f *FDTable) RemoveIf(cond func(*fs.File, FDFlags) bool) {
f.forEach(func(fd int32, file *fs.File, flags FDFlags) {
if cond(file, flags) {
f.set(fd, nil, FDFlags{}) // Clear from table.
+ // Update current available position.
+ if fd < f.next {
+ f.next = fd
+ }
}
})
}
diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go
index 2413788e7..2bcb6216a 100644
--- a/pkg/sentry/kernel/fd_table_test.go
+++ b/pkg/sentry/kernel/fd_table_test.go
@@ -70,6 +70,42 @@ func TestFDTableMany(t *testing.T) {
if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil {
t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err)
}
+
+ i := int32(2)
+ fdTable.Remove(i)
+ if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i {
+ t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err)
+ }
+ })
+}
+
+func TestFDTableOverLimit(t *testing.T) {
+ runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) {
+ if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil {
+ t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error")
+ }
+
+ if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil {
+ t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error")
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil {
+ t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err)
+ } else {
+ for _, fd := range fds {
+ fdTable.Remove(fd)
+ }
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 {
+ t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err)
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil {
+ t.Fatalf("Adding an FD to a resized map: got %v, want nil", err)
+ } else if len(fds) != 1 || fds[0] != 0 {
+ t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds)
+ }
})
}
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 41f44999c..34286c7a8 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "atomicptr_bucket",
out = "atomicptr_bucket_unsafe.go",
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 2ce8952e2..9d34f6d4d 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "buffer_list",
out = "buffer_list.go",
@@ -25,8 +24,10 @@ go_library(
"device.go",
"node.go",
"pipe.go",
+ "pipe_util.go",
"reader.go",
"reader_writer.go",
+ "vfs.go",
"writer.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe",
@@ -41,6 +42,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/safemem",
"//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index a2dc72204..4a19ab7ce 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -18,7 +18,6 @@ import (
"sync"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -91,10 +90,10 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
switch {
case flags.Read && !flags.Write: // O_RDONLY.
r := i.p.Open(ctx, d, flags)
- i.newHandleLocked(&i.rWakeup)
+ newHandleLocked(&i.rWakeup)
if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() {
- if !i.waitFor(&i.wWakeup, ctx) {
+ if !waitFor(&i.mu, &i.wWakeup, ctx) {
r.DecRef()
return nil, syserror.ErrInterrupted
}
@@ -107,7 +106,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
case flags.Write && !flags.Read: // O_WRONLY.
w := i.p.Open(ctx, d, flags)
- i.newHandleLocked(&i.wWakeup)
+ newHandleLocked(&i.wWakeup)
if i.p.isNamed && !i.p.HasReaders() {
// On a nonblocking, write-only open, the open fails with ENXIO if the
@@ -117,7 +116,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
return nil, syserror.ENXIO
}
- if !i.waitFor(&i.rWakeup, ctx) {
+ if !waitFor(&i.mu, &i.rWakeup, ctx) {
w.DecRef()
return nil, syserror.ErrInterrupted
}
@@ -127,8 +126,8 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
case flags.Read && flags.Write: // O_RDWR.
// Pipes opened for read-write always succeeds without blocking.
rw := i.p.Open(ctx, d, flags)
- i.newHandleLocked(&i.rWakeup)
- i.newHandleLocked(&i.wWakeup)
+ newHandleLocked(&i.rWakeup)
+ newHandleLocked(&i.wWakeup)
return rw, nil
default:
@@ -136,65 +135,6 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
}
}
-// waitFor blocks until the underlying pipe has at least one reader/writer is
-// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this
-// function will block for either readers or writers, depending on where
-// 'wakeupChan' points.
-//
-// f.mu must be held by the caller. waitFor returns with f.mu held, but it will
-// drop f.mu before blocking for any reader/writers.
-func (i *inodeOperations) waitFor(wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool {
- // Ideally this function would simply use a condition variable. However, the
- // wait needs to be interruptible via 'sleeper', so we must sychronize via a
- // channel. The synchronization below relies on the fact that closing a
- // channel unblocks all receives on the channel.
-
- // Does an appropriate wakeup channel already exist? If not, create a new
- // one. This is all done under f.mu to avoid races.
- if *wakeupChan == nil {
- *wakeupChan = make(chan struct{})
- }
-
- // Grab a local reference to the wakeup channel since it may disappear as
- // soon as we drop f.mu.
- wakeup := *wakeupChan
-
- // Drop the lock and prepare to sleep.
- i.mu.Unlock()
- cancel := sleeper.SleepStart()
-
- // Wait for either a new reader/write to be signalled via 'wakeup', or
- // for the sleep to be cancelled.
- select {
- case <-wakeup:
- sleeper.SleepFinish(true)
- case <-cancel:
- sleeper.SleepFinish(false)
- }
-
- // Take the lock and check if we were woken. If we were woken and
- // interrupted, the former takes priority.
- i.mu.Lock()
- select {
- case <-wakeup:
- return true
- default:
- return false
- }
-}
-
-// newHandleLocked signals a new pipe reader or writer depending on where
-// 'wakeupChan' points. This unblocks any corresponding reader or writer
-// waiting for the other end of the channel to be opened, see Fifo.waitFor.
-//
-// i.mu must be held.
-func (*inodeOperations) newHandleLocked(wakeupChan *chan struct{}) {
- if *wakeupChan != nil {
- close(*wakeupChan)
- *wakeupChan = nil
- }
-}
-
func (*inodeOperations) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
return syserror.EPIPE
}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 8e4e8e82e..1a1b38f83 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -111,11 +111,27 @@ func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
if atomicIOBytes > sizeBytes {
atomicIOBytes = sizeBytes
}
- return &Pipe{
- isNamed: isNamed,
- max: sizeBytes,
- atomicIOBytes: atomicIOBytes,
+ var p Pipe
+ initPipe(&p, isNamed, sizeBytes, atomicIOBytes)
+ return &p
+}
+
+func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) {
+ if sizeBytes < MinimumPipeSize {
+ sizeBytes = MinimumPipeSize
+ }
+ if sizeBytes > MaximumPipeSize {
+ sizeBytes = MaximumPipeSize
+ }
+ if atomicIOBytes <= 0 {
+ atomicIOBytes = 1
+ }
+ if atomicIOBytes > sizeBytes {
+ atomicIOBytes = sizeBytes
}
+ pipe.isNamed = isNamed
+ pipe.max = sizeBytes
+ pipe.atomicIOBytes = atomicIOBytes
}
// NewConnectedPipe initializes a pipe and returns a pair of objects
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
new file mode 100644
index 000000000..ef9641e6a
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -0,0 +1,213 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "io"
+ "math"
+ "sync"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// This file contains Pipe file functionality that is tied to neither VFS nor
+// the old fs architecture.
+
+// Release cleans up the pipe's state.
+func (p *Pipe) Release() {
+ p.rClose()
+ p.wClose()
+
+ // Wake up readers and writers.
+ p.Notify(waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads from the Pipe into dst.
+func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ n, err := p.read(ctx, readOps{
+ left: func() int64 {
+ return dst.NumBytes()
+ },
+ limit: func(l int64) {
+ dst = dst.TakeFirst64(l)
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, buf)
+ dst = dst.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// WriteTo writes to w from the Pipe.
+func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) {
+ ops := readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := buf.ReadToWriter(w, count, dup)
+ count -= n
+ return n, err
+ },
+ }
+ if dup {
+ // There is no notification for dup operations.
+ return p.dup(ctx, ops)
+ }
+ n, err := p.read(ctx, ops)
+ if n > 0 {
+ p.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// Write writes to the Pipe from src.
+func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ n, err := p.write(ctx, writeOps{
+ left: func() int64 {
+ return src.NumBytes()
+ },
+ limit: func(l int64) {
+ src = src.TakeFirst64(l)
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := src.CopyInTo(ctx, buf)
+ src = src.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// ReadFrom reads from r to the Pipe.
+func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) {
+ n, err := p.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := buf.WriteFromReader(r, count)
+ count -= n
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// Readiness returns the ready events in the underlying pipe.
+func (p *Pipe) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return p.rwReadiness() & mask
+}
+
+// Ioctl implements ioctls on the Pipe.
+func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Switch on ioctl request.
+ switch int(args[1].Int()) {
+ case linux.FIONREAD:
+ v := p.queued()
+ if v > math.MaxInt32 {
+ v = math.MaxInt32 // Silently truncate.
+ }
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ default:
+ return 0, syscall.ENOTTY
+ }
+}
+
+// waitFor blocks until the underlying pipe has at least one reader/writer is
+// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this
+// function will block for either readers or writers, depending on where
+// 'wakeupChan' points.
+//
+// mu must be held by the caller. waitFor returns with mu held, but it will
+// drop mu before blocking for any reader/writers.
+func waitFor(mu *sync.Mutex, wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool {
+ // Ideally this function would simply use a condition variable. However, the
+ // wait needs to be interruptible via 'sleeper', so we must sychronize via a
+ // channel. The synchronization below relies on the fact that closing a
+ // channel unblocks all receives on the channel.
+
+ // Does an appropriate wakeup channel already exist? If not, create a new
+ // one. This is all done under f.mu to avoid races.
+ if *wakeupChan == nil {
+ *wakeupChan = make(chan struct{})
+ }
+
+ // Grab a local reference to the wakeup channel since it may disappear as
+ // soon as we drop f.mu.
+ wakeup := *wakeupChan
+
+ // Drop the lock and prepare to sleep.
+ mu.Unlock()
+ cancel := sleeper.SleepStart()
+
+ // Wait for either a new reader/write to be signalled via 'wakeup', or
+ // for the sleep to be cancelled.
+ select {
+ case <-wakeup:
+ sleeper.SleepFinish(true)
+ case <-cancel:
+ sleeper.SleepFinish(false)
+ }
+
+ // Take the lock and check if we were woken. If we were woken and
+ // interrupted, the former takes priority.
+ mu.Lock()
+ select {
+ case <-wakeup:
+ return true
+ default:
+ return false
+ }
+}
+
+// newHandleLocked signals a new pipe reader or writer depending on where
+// 'wakeupChan' points. This unblocks any corresponding reader or writer
+// waiting for the other end of the channel to be opened, see Fifo.waitFor.
+//
+// Precondition: the mutex protecting wakeupChan must be held.
+func newHandleLocked(wakeupChan *chan struct{}) {
+ if *wakeupChan != nil {
+ close(*wakeupChan)
+ *wakeupChan = nil
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
index 7c307f013..b4d29fc77 100644
--- a/pkg/sentry/kernel/pipe/reader_writer.go
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -16,16 +16,12 @@ package pipe
import (
"io"
- "math"
- "syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
)
// ReaderWriter satisfies the FileOperations interface and services both
@@ -45,124 +41,27 @@ type ReaderWriter struct {
*Pipe
}
-// Release implements fs.FileOperations.Release.
-func (rw *ReaderWriter) Release() {
- rw.Pipe.rClose()
- rw.Pipe.wClose()
-
- // Wake up readers and writers.
- rw.Pipe.Notify(waiter.EventIn | waiter.EventOut)
-}
-
// Read implements fs.FileOperations.Read.
func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.read(ctx, readOps{
- left: func() int64 {
- return dst.NumBytes()
- },
- limit: func(l int64) {
- dst = dst.TakeFirst64(l)
- },
- read: func(buf *buffer) (int64, error) {
- n, err := dst.CopyOutFrom(ctx, buf)
- dst = dst.DropFirst64(n)
- return n, err
- },
- })
- if n > 0 {
- rw.Pipe.Notify(waiter.EventOut)
- }
- return n, err
+ return rw.Pipe.Read(ctx, dst)
}
// WriteTo implements fs.FileOperations.WriteTo.
func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) {
- ops := readOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- read: func(buf *buffer) (int64, error) {
- n, err := buf.ReadToWriter(w, count, dup)
- count -= n
- return n, err
- },
- }
- if dup {
- // There is no notification for dup operations.
- return rw.Pipe.dup(ctx, ops)
- }
- n, err := rw.Pipe.read(ctx, ops)
- if n > 0 {
- rw.Pipe.Notify(waiter.EventOut)
- }
- return n, err
+ return rw.Pipe.WriteTo(ctx, w, count, dup)
}
// Write implements fs.FileOperations.Write.
func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, writeOps{
- left: func() int64 {
- return src.NumBytes()
- },
- limit: func(l int64) {
- src = src.TakeFirst64(l)
- },
- write: func(buf *buffer) (int64, error) {
- n, err := src.CopyInTo(ctx, buf)
- src = src.DropFirst64(n)
- return n, err
- },
- })
- if n > 0 {
- rw.Pipe.Notify(waiter.EventIn)
- }
- return n, err
+ return rw.Pipe.Write(ctx, src)
}
// ReadFrom implements fs.FileOperations.WriteTo.
func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, writeOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- write: func(buf *buffer) (int64, error) {
- n, err := buf.WriteFromReader(r, count)
- count -= n
- return n, err
- },
- })
- if n > 0 {
- rw.Pipe.Notify(waiter.EventIn)
- }
- return n, err
-}
-
-// Readiness returns the ready events in the underlying pipe.
-func (rw *ReaderWriter) Readiness(mask waiter.EventMask) waiter.EventMask {
- return rw.Pipe.rwReadiness() & mask
+ return rw.Pipe.ReadFrom(ctx, r, count)
}
// Ioctl implements fs.FileOperations.Ioctl.
func (rw *ReaderWriter) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- // Switch on ioctl request.
- switch int(args[1].Int()) {
- case linux.FIONREAD:
- v := rw.queued()
- if v > math.MaxInt32 {
- v = math.MaxInt32 // Silently truncate.
- }
- // Copy result to user-space.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
- return 0, err
- default:
- return 0, syscall.ENOTTY
- }
+ return rw.Pipe.Ioctl(ctx, io, args)
}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
new file mode 100644
index 000000000..6416e0dd8
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -0,0 +1,220 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// This file contains types enabling the pipe package to be used with the vfs
+// package.
+
+// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should
+// not be copied.
+type VFSPipe struct {
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // pipe is the underlying pipe.
+ pipe Pipe
+
+ // Channels for synchronizing the creation of new readers and writers
+ // of this fifo. See waitFor and newHandleLocked.
+ //
+ // These are not saved/restored because all waiters are unblocked on
+ // save, and either automatically restart (via ERESTARTSYS) or return
+ // EINTR on resume. On restarts via ERESTARTSYS, the appropriate
+ // channel will be recreated.
+ rWakeup chan struct{} `state:"nosave"`
+ wWakeup chan struct{} `state:"nosave"`
+}
+
+// NewVFSPipe returns an initialized VFSPipe.
+func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe {
+ var vp VFSPipe
+ initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes)
+ return &vp
+}
+
+// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics
+// during open:
+//
+// "Normally, opening the FIFO blocks until the other end is opened also. A
+// process can open a FIFO in nonblocking mode. In this case, opening for
+// read-only will succeed even if no-one has opened on the write side yet,
+// opening for write-only will fail with ENXIO (no such device or address)
+// unless the other end has already been opened. Under Linux, opening a FIFO
+// for read and write will succeed both in blocking and nonblocking mode. POSIX
+// leaves this behavior undefined. This can be used to open a FIFO for writing
+// while there are no readers available." - fifo(7)
+func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
+ vp.mu.Lock()
+ defer vp.mu.Unlock()
+
+ readable := vfs.MayReadFileWithOpenFlags(flags)
+ writable := vfs.MayWriteFileWithOpenFlags(flags)
+ if !readable && !writable {
+ return nil, syserror.EINVAL
+ }
+
+ vfd, err := vp.open(rp, vfsd, vfsfd, flags)
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case readable && writable:
+ // Pipes opened for read-write always succeed without blocking.
+ newHandleLocked(&vp.rWakeup)
+ newHandleLocked(&vp.wWakeup)
+
+ case readable:
+ newHandleLocked(&vp.rWakeup)
+ // If this pipe is being opened as nonblocking and there's no
+ // writer, we have to wait for a writer to open the other end.
+ if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ return nil, syserror.EINTR
+ }
+
+ case writable:
+ newHandleLocked(&vp.wWakeup)
+
+ if !vp.pipe.HasReaders() {
+ // Nonblocking, write-only opens fail with ENXIO when
+ // the read side isn't open yet.
+ if flags&linux.O_NONBLOCK != 0 {
+ return nil, syserror.ENXIO
+ }
+ // Wait for a reader to open the other end.
+ if !waitFor(&vp.mu, &vp.rWakeup, ctx) {
+ return nil, syserror.EINTR
+ }
+ }
+
+ default:
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ return vfd, nil
+}
+
+// Preconditions: vp.mu must be held.
+func (vp *VFSPipe) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
+ var fd VFSPipeFD
+ fd.flags = flags
+ fd.readable = vfs.MayReadFileWithOpenFlags(flags)
+ fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
+ fd.vfsfd = vfsfd
+ fd.pipe = &vp.pipe
+ if fd.writable {
+ // The corresponding Mount.EndWrite() is in VFSPipe.Release().
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ }
+
+ switch {
+ case fd.readable && fd.writable:
+ vp.pipe.rOpen()
+ vp.pipe.wOpen()
+ case fd.readable:
+ vp.pipe.rOpen()
+ case fd.writable:
+ vp.pipe.wOpen()
+ default:
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ return &fd, nil
+}
+
+// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is
+// expected that filesystesm will use this in a struct implementing
+// vfs.FileDescriptionImpl.
+type VFSPipeFD struct {
+ pipe *Pipe
+ flags uint32
+ readable bool
+ writable bool
+ vfsfd *vfs.FileDescription
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *VFSPipeFD) Release() {
+ var event waiter.EventMask
+ if fd.readable {
+ fd.pipe.rClose()
+ event |= waiter.EventIn
+ }
+ if fd.writable {
+ fd.pipe.wClose()
+ event |= waiter.EventOut
+ }
+ if event == 0 {
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ if fd.writable {
+ fd.vfsfd.VirtualDentry().Mount().EndWrite()
+ }
+
+ fd.pipe.Notify(event)
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *VFSPipeFD) OnClose(_ context.Context) error {
+ return nil
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ if !fd.readable {
+ return 0, syserror.EINVAL
+ }
+
+ return fd.pipe.Read(ctx, dst)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ if !fd.writable {
+ return 0, syserror.EINVAL
+ }
+
+ return fd.pipe.Write(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return fd.pipe.Ioctl(ctx, uio, args)
+}
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index 80e5e5da3..f4c00cd86 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "waiter_list",
out = "waiter_list.go",
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index aa7471eb6..cd48945e6 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "shm",
srcs = [
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 9beae4b31..31847e1df 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "time",
srcs = [
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
index 59649c770..156e67bf8 100644
--- a/pkg/sentry/limits/BUILD
+++ b/pkg/sentry/limits/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "limits",
srcs = [
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index 3b322f5f3..2890393bd 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_embed_data")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_embed_data(
name = "vdso_bin",
src = "//vdso:vdso.so",
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index 9687e7e76..3ef84245b 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "mappable_range",
out = "mappable_range.go",
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index b35c8c673..a804b8b5c 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "file_refcount_set",
out = "file_refcount_set.go",
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 3fd904c67..f404107af 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "evictable_range",
out = "evictable_range.go",
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index 9aa6ec507..157bffa81 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "file_range",
out = "file_range.go",
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 9f0ecfbe4..b699b057d 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -327,6 +327,19 @@ func (t *thread) dumpAndPanic(message string) {
panic(message)
}
+func (t *thread) unexpectedStubExit() {
+ msg, err := t.getEventMessage()
+ status := syscall.WaitStatus(msg)
+ if status.Signaled() && status.Signal() == syscall.SIGKILL {
+ // SIGKILL can be only sent by an user or OOM-killer. In both
+ // these cases, we don't need to panic. There is no reasons to
+ // think that something wrong in gVisor.
+ log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid)
+ syscall.Kill(os.Getpid(), syscall.SIGKILL)
+ }
+ t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
+}
+
// wait waits for a stop event.
//
// Precondition: outcome is a valid waitOutcome.
@@ -355,8 +368,7 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
}
if stopSig == syscall.SIGTRAP {
if status.TrapCause() == syscall.PTRACE_EVENT_EXIT {
- msg, err := t.getEventMessage()
- t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
+ t.unexpectedStubExit()
}
// Re-encode the trap cause the way it's expected.
return stopSig | syscall.Signal(status.TrapCause()<<8)
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index c075b5f91..3782d4332 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -129,6 +129,9 @@ func createStub() (*thread, error) {
// transitively) will be killed as well. It's simply not possible to
// safely handle a single stub getting killed: the exact state of
// execution is unknown and not recoverable.
+ //
+ // In addition, we set the PTRACE_O_TRACEEXIT option to log more
+ // information about a stub process when it receives a fatal signal.
return attachedThread(uintptr(syscall.SIGKILL)|syscall.CLONE_FILES, defaultAction)
}
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 8ed6c7652..48b0ceaec 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -1,9 +1,8 @@
load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-
go_template(
name = "defs",
srcs = [
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index d7029d5a9..780bf9a66 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "defs_impl",
out = "defs_impl.go",
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index ea090b686..934a90378 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -1,10 +1,9 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-
go_template(
name = "generic_walker",
srcs = [
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 3300f9a6b..26176b10d 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "socket",
srcs = ["socket.go"],
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 81dbd7309..4a6e83a8b 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "control",
srcs = ["control.go"],
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index a951f1bb0..4d174dda4 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "hostinet",
srcs = [
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 354a0d6ee..5eb06bbf4 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "netfilter",
srcs = [
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 7da68384e..f95803f91 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "netlink",
srcs = [
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 445080aa4..463544c1a 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "port",
srcs = ["port.go"],
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
index 5dc8533ec..1d4912753 100644
--- a/pkg/sentry/socket/netlink/route/BUILD
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "route",
srcs = ["protocol.go"],
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 60523f79a..e414d8055 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "netstack",
srcs = [
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 6fd43fcbd..69dbfd197 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -53,6 +53,7 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -298,6 +299,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue
var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
+var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{}))
// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
// netstack representation taking any addresses into account.
@@ -309,12 +311,12 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
}
// AddressAndFamily reads an sockaddr struct from the given address and
-// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and
-// AF_INET6 addresses.
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
//
// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
//
-// AddressAndFamily returns an address, its family.
+// AddressAndFamily returns an address and its family.
func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
@@ -373,6 +375,22 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress,
}
return out, family, nil
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(b/129292371): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
case linux.AF_UNSPEC:
return tcpip.FullAddress{}, family, nil
@@ -1953,12 +1971,14 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
return &out, uint32(2 + l)
}
return &out, uint32(3 + l)
+
case linux.AF_INET:
var out linux.SockAddrInet
copy(out.Addr[:], addr.Addr)
out.Family = linux.AF_INET
out.Port = htons(addr.Port)
- return &out, uint32(binary.Size(out))
+ return &out, uint32(sockAddrInetSize)
+
case linux.AF_INET6:
var out linux.SockAddrInet6
if len(addr.Addr) == 4 {
@@ -1974,7 +1994,17 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
if isLinkLocal(addr.Addr) {
out.Scope_id = uint32(addr.NIC)
}
- return &out, uint32(binary.Size(out))
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(b/129292371): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
default:
return nil, 0
}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 357a664cc..2d2c1ba2a 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -62,6 +62,10 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
}
case linux.SOCK_RAW:
+ // TODO(b/142504697): "In order to create a raw socket, a
+ // process must have the CAP_NET_RAW capability in the user
+ // namespace that governs its network namespace." - raw(7)
+
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
@@ -85,7 +89,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
return 0, true, syserr.ErrProtocolNotSupported
}
-// Socket creates a new socket object for the AF_INET or AF_INET6 family.
+// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET
+// family.
func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Fail right away if we don't have a stack.
stack := t.NetworkContext()
@@ -99,6 +104,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
return nil, nil
}
+ // Packet sockets are handled separately, since they are neither INET
+ // nor INET6 specific.
+ if p.family == linux.AF_PACKET {
+ return packetSocket(t, eps, stype, protocol)
+ }
+
// Figure out the transport protocol.
transProto, associated, err := getTransportProtocol(t, stype, protocol)
if err != nil {
@@ -121,12 +132,47 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
return New(t, p.family, stype, int(transProto), wq, ep)
}
+func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // TODO(b/142504697): "In order to create a packet socket, a process
+ // must have the CAP_NET_RAW capability in the user namespace that
+ // governs its network namespace." - packet(7)
+
+ // Packet sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(t)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return nil, syserr.ErrNotPermitted
+ }
+
+ // "cooked" packets don't contain link layer information.
+ var cooked bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ cooked = true
+ case linux.SOCK_RAW:
+ cooked = false
+ default:
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // protocol is passed in network byte order, but netstack wants it in
+ // host order.
+ netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+
+ wq := &waiter.Queue{}
+ ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return New(t, linux.AF_PACKET, stype, protocol, wq, ep)
+}
+
// Pair just returns nil sockets (not supported).
func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
return nil, nil, nil
}
-// init registers socket providers for AF_INET and AF_INET6.
+// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET.
func init() {
// Providers backed by netstack.
p := []provider{
@@ -138,6 +184,9 @@ func init() {
family: linux.AF_INET6,
netProto: ipv6.ProtocolNumber,
},
+ {
+ family: linux.AF_PACKET,
+ },
}
for i := range p {
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 830f4da10..5b6a154f6 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "unix",
srcs = [
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 0b0240336..788ad70d2 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
+package(licenses = ["notice"])
+
go_template_instance(
name = "transport_message_list",
out = "transport_message_list.go",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 4bd15808a..dea11e253 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -220,6 +220,11 @@ func (e *connectionedEndpoint) Close() {
case e.Connected():
e.connected.CloseSend()
e.receiver.CloseRecv()
+ // Still have unread data? If yes, we set this into the write
+ // end so that the peer can get ECONNRESET) when it does read.
+ if e.receiver.RecvQueuedSize() > 0 {
+ e.connected.CloseUnread()
+ }
c = e.connected
r = e.receiver
e.connected = nil
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index 0415fae9a..e27b1c714 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -33,6 +33,7 @@ type queue struct {
mu sync.Mutex `state:"nosave"`
closed bool
+ unread bool
used int64
limit int64
dataList messageList
@@ -161,6 +162,9 @@ func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) {
err := syserr.ErrWouldBlock
if q.closed {
err = syserr.ErrClosedForReceive
+ if q.unread {
+ err = syserr.ErrConnectionReset
+ }
}
q.mu.Unlock()
@@ -188,7 +192,9 @@ func (q *queue) Peek() (*message, *syserr.Error) {
if q.dataList.Front() == nil {
err := syserr.ErrWouldBlock
if q.closed {
- err = syserr.ErrClosedForReceive
+ if err = syserr.ErrClosedForReceive; q.unread {
+ err = syserr.ErrConnectionReset
+ }
}
return nil, err
}
@@ -208,3 +214,11 @@ func (q *queue) QueuedSize() int64 {
func (q *queue) MaxQueueSize() int64 {
return q.limit
}
+
+// CloseUnread sets flag to indicate that the peer is closed (not shutdown)
+// with unread data. So if read on this queue shall return ECONNRESET error.
+func (q *queue) CloseUnread() {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.unread = true
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 1867b3a5c..529a7a7a9 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -608,6 +608,10 @@ type ConnectedEndpoint interface {
// Release releases any resources owned by the ConnectedEndpoint. It should
// be called before droping all references to a ConnectedEndpoint.
Release()
+
+ // CloseUnread sets the fact that this end is closed with unread data to
+ // the peer socket.
+ CloseUnread()
}
// +stateify savable
@@ -711,6 +715,11 @@ func (e *connectedEndpoint) Release() {
e.writeQueue.DecRef()
}
+// CloseUnread implements ConnectedEndpoint.CloseUnread.
+func (e *connectedEndpoint) CloseUnread() {
+ e.writeQueue.CloseUnread()
+}
+
// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
// unix domain socket Endpoint implementations.
//
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 50c308134..1aaae8487 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -595,7 +595,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
total += n
}
- if err != nil || !waitAll || isPacket || n >= dst.NumBytes() {
+ streamPeerClosed := s.stype == linux.SOCK_STREAM && n == 0 && err == nil
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() || streamPeerClosed {
if total > 0 {
err = nil
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index cf2a56bed..fb2c1777f 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "linux",
srcs = [
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index b64c49ff5..68589a377 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -16,7 +16,12 @@
package linux
const (
- _LINUX_SYSNAME = "Linux"
- _LINUX_RELEASE = "4.4"
- _LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
+ // LinuxSysname is the OS name advertised by gVisor.
+ LinuxSysname = "Linux"
+
+ // LinuxRelease is the Linux release version number advertised by gVisor.
+ LinuxRelease = "4.4.0"
+
+ // LinuxVersion is the version info advertised by gVisor.
+ LinuxVersion = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
)
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
index e215ac049..aedb6d774 100644
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -34,9 +34,9 @@ var AMD64 = &kernel.SyscallTable{
// guides the interface provided by this syscall table. The build
// version is that for a clean build with default kernel config, at 5
// minutes after v4.4 was tagged.
- Sysname: _LINUX_SYSNAME,
- Release: _LINUX_RELEASE,
- Version: _LINUX_VERSION,
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
},
AuditNumber: linux.AUDIT_ARCH_X86_64,
Table: map[uintptr]kernel.Syscall{
@@ -362,7 +362,7 @@ var AMD64 = &kernel.SyscallTable{
319: syscalls.Supported("memfd_create", MemfdCreate),
320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
- 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836)
+ 322: syscalls.PartiallySupported("execveat", Execveat, "No support for AT_EMPTY_PATH, AT_SYMLINK_FOLLOW.", nil),
323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
index 1d3b63020..4cf7f836a 100644
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -30,9 +30,9 @@ var ARM64 = &kernel.SyscallTable{
OS: abi.Linux,
Arch: arch.ARM64,
Version: kernel.Version{
- Sysname: _LINUX_SYSNAME,
- Release: _LINUX_RELEASE,
- Version: _LINUX_VERSION,
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
},
AuditNumber: linux.AUDIT_ARCH_AARCH64,
Table: map[uintptr]kernel.Syscall{
@@ -107,10 +107,10 @@ var ARM64 = &kernel.SyscallTable{
71: syscalls.Supported("sendfile", Sendfile),
72: syscalls.Supported("pselect", Pselect),
73: syscalls.Supported("ppoll", Ppoll),
- 74: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426)
+ 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 77: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 77: syscalls.Supported("tee", Tee),
78: syscalls.Supported("readlinkat", Readlinkat),
80: syscalls.Supported("fstat", Fstat),
81: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
@@ -245,7 +245,7 @@ var ARM64 = &kernel.SyscallTable{
210: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
211: syscalls.Supported("sendmsg", SendMsg),
212: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
- 213: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341)
+ 213: syscalls.Supported("readahead", Readahead),
214: syscalls.Supported("brk", Brk),
215: syscalls.Supported("munmap", Munmap),
216: syscalls.Supported("mremap", Mremap),
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 8ab7ffa25..6e425f1ec 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -15,10 +15,12 @@
package linux
import (
+ "path"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/usermem"
@@ -67,8 +69,22 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
argvAddr := args[1].Pointer()
envvAddr := args[2].Pointer()
- // Extract our arguments.
- filename, err := t.CopyInString(filenameAddr, linux.PATH_MAX)
+ return execveat(t, linux.AT_FDCWD, filenameAddr, argvAddr, envvAddr, 0)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+
+ return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
if err != nil {
return 0, nil, err
}
@@ -89,14 +105,38 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
}
+ if flags != 0 {
+ // TODO(b/128449944): Handle AT_EMPTY_PATH and AT_SYMLINK_NOFOLLOW.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
root := t.FSContext().RootDirectory()
defer root.DecRef()
- wd := t.FSContext().WorkingDirectory()
+
+ var wd *fs.Dirent
+ if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) {
+ // If pathname is absolute, LoadTaskImage() will ignore the wd.
+ wd = t.FSContext().WorkingDirectory()
+ } else {
+ // Need to extract the given FD.
+ f := t.GetFile(dirFD)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ wd = f.Dirent
+ wd.IncRef()
+ if !fs.IsDir(wd.Inode.StableAttr) {
+ return 0, nil, syserror.ENOTDIR
+ }
+ }
defer wd.DecRef()
// Load the new TaskContext.
maxTraversals := uint(linux.MaxSymlinkTraversals)
- tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, nil, argv, envv, t.Arch().FeatureSet())
+ tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, pathname, nil, argv, envv, t.Arch().FeatureSet())
if se != nil {
return 0, nil, se.ToError()
}
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index 27cd2c336..ad4b67806 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -191,7 +191,6 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
// Pwritev2 implements linux syscall pwritev2(2).
-// TODO(b/120161091): Implement O_SYNC and D_SYNC functionality.
func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
// While the syscall is
// pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index beb43ba13..d3a4cd943 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -1,10 +1,9 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "seqatomic_parameters",
out = "seqatomic_parameters_unsafe.go",
diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD
index a34c39540..c32fe3241 100644
--- a/pkg/sentry/usage/BUILD
+++ b/pkg/sentry/usage/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "usage",
srcs = [
diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD
index cc5d25762..684f59a6b 100644
--- a/pkg/sentry/usermem/BUILD
+++ b/pkg/sentry/usermem/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "addr_range",
out = "addr_range.go",
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 7eb2b2821..3a9665800 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -102,7 +102,7 @@ type FileDescriptionImpl interface {
// OnClose is called when a file descriptor representing the
// FileDescription is closed. Note that returning a non-nil error does not
// prevent the file descriptor from being closed.
- OnClose() error
+ OnClose(ctx context.Context) error
// StatusFlags returns file description status flags, as for
// fcntl(F_GETFL).
@@ -180,7 +180,7 @@ type FileDescriptionImpl interface {
// ConfigureMMap mutates opts to implement mmap(2) for the file. Most
// implementations that support memory mapping can call
// GenericConfigureMMap with the appropriate memmap.Mappable.
- ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error
+ ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error
// Ioctl implements the ioctl(2) syscall.
Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index ba230da72..4fbad7840 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -45,7 +45,7 @@ type FileDescriptionDefaultImpl struct{}
// OnClose implements FileDescriptionImpl.OnClose analogously to
// file_operations::flush == NULL in Linux.
-func (FileDescriptionDefaultImpl) OnClose() error {
+func (FileDescriptionDefaultImpl) OnClose(ctx context.Context) error {
return nil
}
@@ -117,7 +117,7 @@ func (FileDescriptionDefaultImpl) Sync(ctx context.Context) error {
// ConfigureMMap implements FileDescriptionImpl.ConfigureMMap analogously to
// file_operations::mmap == NULL in Linux.
-func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
return syserror.ENODEV
}
diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go
index 23f2b9e08..abde0feaa 100644
--- a/pkg/sentry/vfs/syscalls.go
+++ b/pkg/sentry/vfs/syscalls.go
@@ -96,6 +96,26 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
}
}
+// MknodAt creates a file of the given mode at the given path. It returns an
+// error from the syserror package.
+func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return nil
+ }
+ for {
+ if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ // Handle mount traversals.
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
// OpenAt returns a FileDescription providing access to the file at the given
// path. A reference is taken on the returned FileDescription.
func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
@@ -198,8 +218,6 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) err
//
// - VFS.LinkAt()
//
-// - VFS.MknodAt()
-//
// - VFS.ReadlinkAt()
//
// - VFS.RenameAt()
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index 329904457..be93750bf 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -1,11 +1,10 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "addr_range",
out = "addr_range.go",
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
index 4287464f3..48a2a2713 100644
--- a/pkg/tcpip/buffer/prependable.go
+++ b/pkg/tcpip/buffer/prependable.go
@@ -41,6 +41,11 @@ func NewPrependableFromView(v View) Prependable {
return Prependable{buf: v, usedIdx: 0}
}
+// NewEmptyPrependableFromView creates a new prependable buffer from a View.
+func NewEmptyPrependableFromView(v View) Prependable {
+ return Prependable{buf: v, usedIdx: len(v)}
+}
+
// View returns a View of the backing buffer that contains all prepended
// data so far.
func (p Prependable) View() View {
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 096ad71ab..02137e1c9 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -686,3 +686,64 @@ func ICMPv6Code(want byte) TransportChecker {
}
}
}
+
+// NDP creates a checker that checks that the packet contains a valid NDP
+// message for type of ty, with potentially additional checks specified by
+// checkers.
+//
+// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDP message as far as the size of the message (minSize) is concerned. The
+// values within the message are up to checkers to validate.
+func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // Check normal ICMPv6 first.
+ ICMPv6(
+ ICMPv6Type(msgType),
+ ICMPv6Code(0))(t, h)
+
+ last := h[len(h)-1]
+
+ icmp := header.ICMPv6(last.Payload())
+ if got := len(icmp.NDPPayload()); got < minSize {
+ t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
+ }
+
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// NDPNS creates a checker that checks that the packet contains a valid NDP
+// Neighbor Solicitation message (as per the raw wire format), with potentially
+// additional checks specified by checkers.
+//
+// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNS message as far as the size of the messages concerned. The values within
+// the message are up to checkers to validate.
+func NDPNS(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
+}
+
+// NDPNSTargetAddress creates a checker that checks the Target Address field of
+// a header.NDPNeighborSolicit.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNS message as far as the size is concerned.
+func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+
+ if got := ns.TargetAddress(); got != want {
+ t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index a255231a3..a3485b35c 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -16,6 +16,10 @@ go_library(
"ipv4.go",
"ipv6.go",
"ipv6_fragment.go",
+ "ndp_neighbor_advert.go",
+ "ndp_neighbor_solicit.go",
+ "ndp_options.go",
+ "ndp_router_advert.go",
"tcp.go",
"udp.go",
],
@@ -30,13 +34,26 @@ go_library(
)
go_test(
- name = "header_test",
+ name = "header_x_test",
size = "small",
srcs = [
+ "checksum_test.go",
"ipversion_test.go",
"tcp_test.go",
],
deps = [
":header",
+ "//pkg/tcpip/buffer",
+ ],
+)
+
+go_test(
+ name = "header_test",
+ size = "small",
+ srcs = [
+ "eth_test.go",
+ "ndp_test.go",
],
+ embed = [":header"],
+ deps = ["//pkg/tcpip"],
)
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 39a4d69be..9749c7f4d 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -23,11 +23,17 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
-func calculateChecksum(buf []byte, initial uint32) uint16 {
+func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
v := initial
+ if odd {
+ v += uint32(buf[0])
+ buf = buf[1:]
+ }
+
l := len(buf)
- if l&1 != 0 {
+ odd = l&1 != 0
+ if odd {
l--
v += uint32(buf[l]) << 8
}
@@ -36,7 +42,7 @@ func calculateChecksum(buf []byte, initial uint32) uint16 {
v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
}
- return ChecksumCombine(uint16(v), uint16(v>>16))
+ return ChecksumCombine(uint16(v), uint16(v>>16)), odd
}
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
@@ -44,7 +50,8 @@ func calculateChecksum(buf []byte, initial uint32) uint16 {
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
- return calculateChecksum(buf, uint32(initial))
+ s, _ := calculateChecksum(buf, false, uint32(initial))
+ return s
}
// ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in
@@ -52,19 +59,40 @@ func Checksum(buf []byte, initial uint16) uint16 {
//
// The initial checksum must have been computed on an even number of bytes.
func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
- var odd bool
+ return ChecksumVVWithOffset(vv, initial, 0, vv.Size())
+}
+
+// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the
+// bytes in the given VectorizedView.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 {
+ odd := false
sum := initial
for _, v := range vv.Views() {
if len(v) == 0 {
continue
}
- s := uint32(sum)
- if odd {
- s += uint32(v[0])
- v = v[1:]
+
+ if off >= len(v) {
+ off -= len(v)
+ continue
+ }
+ v = v[off:]
+
+ l := len(v)
+ if l > size {
+ l = size
+ }
+ v = v[:l]
+
+ sum, odd = calculateChecksum(v, odd, uint32(sum))
+
+ size -= len(v)
+ if size == 0 {
+ break
}
- odd = len(v)&1 != 0
- sum = calculateChecksum(v, s)
+ off = 0
}
return sum
}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
new file mode 100644
index 000000000..86b466c1c
--- /dev/null
+++ b/pkg/tcpip/header/checksum_test.go
@@ -0,0 +1,109 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package header provides the implementation of the encoding and decoding of
+// network protocol headers.
+package header_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestChecksumVVWithOffset(t *testing.T) {
+ testCases := []struct {
+ name string
+ vv buffer.VectorisedView
+ off, size int
+ initial uint16
+ want uint16
+ }{
+ {
+ name: "empty",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ }),
+ off: 0,
+ size: 0,
+ want: 0,
+ },
+ {
+ name: "OneView",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ }),
+ off: 0,
+ size: 5,
+ want: 1294,
+ },
+ {
+ name: "TwoViews",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 0,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "TwoViewsWithOffset",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 1,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "ThreeViewsWithOffset",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 7,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "ThreeViewsWithInitial",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}),
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}),
+ }),
+ initial: 77,
+ off: 7,
+ size: 11,
+ want: 33896,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want {
+ t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want)
+ }
+ v := tc.vv.ToView()
+ v.TrimFront(tc.off)
+ v.CapLength(tc.size)
+ if got, want := header.Checksum(v, tc.initial), tc.want; got != want {
+ t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index 4c3d3311f..f5d2c127f 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -48,8 +48,48 @@ const (
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
+
+ // unspecifiedEthernetAddress is the unspecified ethernet address
+ // (all bits set to 0).
+ unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+
+ // unicastMulticastFlagMask is the mask of the least significant bit in
+ // the first octet (in network byte order) of an ethernet address that
+ // determines whether the ethernet address is a unicast or multicast. If
+ // the masked bit is a 1, then the address is a multicast, unicast
+ // otherwise.
+ //
+ // See the IEEE Std 802-2001 document for more details. Specifically,
+ // section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf:
+ // "A 48-bit universal address consists of two parts. The first 24 bits
+ // correspond to the OUI as assigned by the IEEE, expect that the
+ // assignee may set the LSB of the first octet to 1 for group addresses
+ // or set it to 0 for individual addresses."
+ unicastMulticastFlagMask = 1
+
+ // unicastMulticastFlagByteIdx is the byte that holds the
+ // unicast/multicast flag. See unicastMulticastFlagMask.
+ unicastMulticastFlagByteIdx = 0
+)
+
+const (
+ // EthernetProtocolAll is a catch-all for all protocols carried inside
+ // an ethernet frame. It is mainly used to create packet sockets that
+ // capture all traffic.
+ EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003
+
+ // EthernetProtocolPUP is the PARC Universial Packet protocol ethertype.
+ EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200
)
+// Ethertypes holds the protocol numbers describing the payload of an ethernet
+// frame. These types aren't necessarily supported by netstack, but can be used
+// to catch all traffic of a type via packet endpoints.
+var Ethertypes = []tcpip.NetworkProtocolNumber{
+ EthernetProtocolAll,
+ EthernetProtocolPUP,
+}
+
// SourceAddress returns the "MAC source" field of the ethernet frame header.
func (b Ethernet) SourceAddress() tcpip.LinkAddress {
return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
@@ -72,3 +112,25 @@ func (b Ethernet) Encode(e *EthernetFields) {
copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}
+
+// IsValidUnicastEthernetAddress returns true if addr is a valid unicast
+// ethernet address.
+func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
+ // Must be of the right length.
+ if len(addr) != EthernetAddressSize {
+ return false
+ }
+
+ // Must not be unspecified.
+ if addr == unspecifiedEthernetAddress {
+ return false
+ }
+
+ // Must not be a multicast.
+ if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
+ return false
+ }
+
+ // addr is a valid unicast ethernet address.
+ return true
+}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
new file mode 100644
index 000000000..6634c90f5
--- /dev/null
+++ b/pkg/tcpip/header/eth_test.go
@@ -0,0 +1,68 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestIsValidUnicastEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.LinkAddress
+ expected bool
+ }{
+ {
+ "Nil",
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty",
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "InvalidLength",
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Unspecified",
+ unspecifiedEthernetAddress,
+ false,
+ },
+ {
+ "Multicast",
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ false,
+ },
+ {
+ "Valid",
+ tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
+ true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected {
+ t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 1125a7d14..c2bfd8c79 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -25,6 +25,12 @@ import (
type ICMPv6 []byte
const (
+ // ICMPv6HeaderSize is the size of the ICMPv6 header. That is, the
+ // sum of the size of the ICMPv6 Type, Code and Checksum fields, as
+ // per RFC 4443 section 2.1. After the ICMPv6 header, the ICMPv6
+ // message body begins.
+ ICMPv6HeaderSize = 4
+
// ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
ICMPv6MinimumSize = 8
@@ -37,10 +43,16 @@ const (
// ICMPv6NeighborSolicitMinimumSize is the minimum size of a
// neighbor solicitation packet.
- ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6HeaderSize + NDPNSMinimumSize
+
+ // ICMPv6NeighborAdvertMinimumSize is the minimum size of a
+ // neighbor advertisement packet.
+ ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize
- // ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
- ICMPv6NeighborAdvertSize = 32
+ // ICMPv6NeighborAdvertSize is size of a neighbor advertisement
+ // including the NDP Target Link Layer option for an Ethernet
+ // address.
+ ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpTargetEthernetLinkLayerAddressSize
// ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
ICMPv6EchoMinimumSize = 8
@@ -68,6 +80,13 @@ const (
// icmpv6SequenceOffset is the offset of the sequence field
// in a ICMPv6 Echo Request/Reply message.
icmpv6SequenceOffset = 6
+
+ // NDPHopLimit is the expected IP hop limit value of 255 for received
+ // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
+ // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
+ // drop the NDP packet. All outgoing NDP packets must use this value for
+ // its IP hop limit field.
+ NDPHopLimit = 255
)
// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
@@ -166,6 +185,13 @@ func (b ICMPv6) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
+// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6
+// packet's message body as defined by RFC 4443 section 2.1; the portion of the
+// ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
+func (b ICMPv6) NDPPayload() []byte {
+ return b[ICMPv6HeaderSize:]
+}
+
// Payload implements Transport.Payload.
func (b ICMPv6) Payload() []byte {
return b[ICMPv6PayloadOffset:]
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 9d3abc0e4..f1e60911b 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -87,7 +87,8 @@ const (
// section 5.
IPv6MinimumMTU = 1280
- // IPv6Any is the non-routable IPv6 "any" meta address.
+ // IPv6Any is the non-routable IPv6 "any" meta address. It is also
+ // known as the unspecified address.
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
)
@@ -100,6 +101,15 @@ var IPv6EmptySubnet = func() tcpip.Subnet {
return subnet
}()
+// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
+// by RFC 4291 section 2.5.6.
+//
+// The prefix is fe80::/64
+var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{
+ Address: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 64,
+}
+
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go
new file mode 100644
index 000000000..505c92668
--- /dev/null
+++ b/pkg/tcpip/header/ndp_neighbor_advert.go
@@ -0,0 +1,110 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will
+// only contain the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.4 for more details.
+type NDPNeighborAdvert []byte
+
+const (
+ // NDPNAMinimumSize is the minimum size of a valid NDP Neighbor
+ // Advertisement message (body of an ICMPv6 packet).
+ NDPNAMinimumSize = 20
+
+ // ndpNATargetAddressOffset is the start of the Target Address
+ // field within an NDPNeighborAdvert.
+ ndpNATargetAddressOffset = 4
+
+ // ndpNAOptionsOffset is the start of the NDP options in an
+ // NDPNeighborAdvert.
+ ndpNAOptionsOffset = ndpNATargetAddressOffset + IPv6AddressSize
+
+ // ndpNAFlagsOffset is the offset of the flags within an
+ // NDPNeighborAdvert
+ ndpNAFlagsOffset = 0
+
+ // ndpNARouterFlagMask is the mask of the Router Flag field in
+ // the flags byte within in an NDPNeighborAdvert.
+ ndpNARouterFlagMask = (1 << 7)
+
+ // ndpNASolicitedFlagMask is the mask of the Solicited Flag field in
+ // the flags byte within in an NDPNeighborAdvert.
+ ndpNASolicitedFlagMask = (1 << 6)
+
+ // ndpNAOverrideFlagMask is the mask of the Override Flag field in
+ // the flags byte within in an NDPNeighborAdvert.
+ ndpNAOverrideFlagMask = (1 << 5)
+)
+
+// TargetAddress returns the value within the Target Address field.
+func (b NDPNeighborAdvert) TargetAddress() tcpip.Address {
+ return tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize])
+}
+
+// SetTargetAddress sets the value within the Target Address field.
+func (b NDPNeighborAdvert) SetTargetAddress(addr tcpip.Address) {
+ copy(b[ndpNATargetAddressOffset:][:IPv6AddressSize], addr)
+}
+
+// RouterFlag returns the value of the Router Flag field.
+func (b NDPNeighborAdvert) RouterFlag() bool {
+ return b[ndpNAFlagsOffset]&ndpNARouterFlagMask != 0
+}
+
+// SetRouterFlag sets the value in the Router Flag field.
+func (b NDPNeighborAdvert) SetRouterFlag(f bool) {
+ if f {
+ b[ndpNAFlagsOffset] |= ndpNARouterFlagMask
+ } else {
+ b[ndpNAFlagsOffset] &^= ndpNARouterFlagMask
+ }
+}
+
+// SolicitedFlag returns the value of the Solicited Flag field.
+func (b NDPNeighborAdvert) SolicitedFlag() bool {
+ return b[ndpNAFlagsOffset]&ndpNASolicitedFlagMask != 0
+}
+
+// SetSolicitedFlag sets the value in the Solicited Flag field.
+func (b NDPNeighborAdvert) SetSolicitedFlag(f bool) {
+ if f {
+ b[ndpNAFlagsOffset] |= ndpNASolicitedFlagMask
+ } else {
+ b[ndpNAFlagsOffset] &^= ndpNASolicitedFlagMask
+ }
+}
+
+// OverrideFlag returns the value of the Override Flag field.
+func (b NDPNeighborAdvert) OverrideFlag() bool {
+ return b[ndpNAFlagsOffset]&ndpNAOverrideFlagMask != 0
+}
+
+// SetOverrideFlag sets the value in the Override Flag field.
+func (b NDPNeighborAdvert) SetOverrideFlag(f bool) {
+ if f {
+ b[ndpNAFlagsOffset] |= ndpNAOverrideFlagMask
+ } else {
+ b[ndpNAFlagsOffset] &^= ndpNAOverrideFlagMask
+ }
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPNeighborAdvert) Options() NDPOptions {
+ return NDPOptions(b[ndpNAOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go
new file mode 100644
index 000000000..3a1b8e139
--- /dev/null
+++ b/pkg/tcpip/header/ndp_neighbor_solicit.go
@@ -0,0 +1,52 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only
+// contain the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.3 for more details.
+type NDPNeighborSolicit []byte
+
+const (
+ // NDPNSMinimumSize is the minimum size of a valid NDP Neighbor
+ // Solicitation message (body of an ICMPv6 packet).
+ NDPNSMinimumSize = 20
+
+ // ndpNSTargetAddessOffset is the start of the Target Address
+ // field within an NDPNeighborSolicit.
+ ndpNSTargetAddessOffset = 4
+
+ // ndpNSOptionsOffset is the start of the NDP options in an
+ // NDPNeighborSolicit.
+ ndpNSOptionsOffset = ndpNSTargetAddessOffset + IPv6AddressSize
+)
+
+// TargetAddress returns the value within the Target Address field.
+func (b NDPNeighborSolicit) TargetAddress() tcpip.Address {
+ return tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize])
+}
+
+// SetTargetAddress sets the value within the Target Address field.
+func (b NDPNeighborSolicit) SetTargetAddress(addr tcpip.Address) {
+ copy(b[ndpNSTargetAddessOffset:][:IPv6AddressSize], addr)
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPNeighborSolicit) Options() NDPOptions {
+ return NDPOptions(b[ndpNSOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
new file mode 100644
index 000000000..b28bde15b
--- /dev/null
+++ b/pkg/tcpip/header/ndp_options.go
@@ -0,0 +1,172 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // NDPTargetLinkLayerAddressOptionType is the type of the Target
+ // Link-Layer Address option, as per RFC 4861 section 4.6.1.
+ NDPTargetLinkLayerAddressOptionType = 2
+
+ // ndpTargetEthernetLinkLayerAddressSize is the size of a Target
+ // Link Layer Option for an Ethernet address.
+ ndpTargetEthernetLinkLayerAddressSize = 8
+
+ // lengthByteUnits is the multiplier factor for the Length field of an
+ // NDP option. That is, the length field for NDP options is in units of
+ // 8 octets, as per RFC 4861 section 4.6.
+ lengthByteUnits = 8
+)
+
+// NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6.
+type NDPOptions []byte
+
+// Serialize serializes the provided list of NDP options into o.
+//
+// Note, b must be of sufficient size to hold all the options in s. See
+// NDPOptionsSerializer.Length for details on the getting the total size
+// of a serialized NDPOptionsSerializer.
+//
+// Serialize may panic if b is not of sufficient size to hold all the options
+// in s.
+func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
+ done := 0
+
+ for _, o := range s {
+ l := paddedLength(o)
+
+ if l == 0 {
+ continue
+ }
+
+ b[0] = o.Type()
+
+ // We know this safe because paddedLength would have returned
+ // 0 if o had an invalid length (> 255 * lengthByteUnits).
+ b[1] = uint8(l / lengthByteUnits)
+
+ // Serialize NDP option body.
+ used := o.serializeInto(b[2:])
+
+ // Zero out remaining (padding) bytes, if any exists.
+ for i := used + 2; i < l; i++ {
+ b[i] = 0
+ }
+
+ b = b[l:]
+ done += l
+ }
+
+ return done
+}
+
+// ndpOption is the set of functions to be implemented by all NDP option types.
+type ndpOption interface {
+ // Type returns the type of this ndpOption.
+ Type() uint8
+
+ // Length returns the length of the body of this ndpOption, in bytes.
+ Length() int
+
+ // serializeInto serializes this ndpOption into the provided byte
+ // buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // Length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize this ndpOption. Implementers must only use the number of
+ // bytes required to serialize this ndpOption. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto([]byte) int
+}
+
+// paddedLength returns the length of o, in bytes, with any padding bytes, if
+// required.
+func paddedLength(o ndpOption) int {
+ l := o.Length()
+
+ if l == 0 {
+ return 0
+ }
+
+ // Length excludes the 2 Type and Length bytes.
+ l += 2
+
+ // Add extra bytes if needed to make sure the option is
+ // lengthByteUnits-byte aligned. We do this by adding lengthByteUnits-1
+ // to l and then stripping off the last few LSBits from l. This will
+ // make sure that l is rounded up to the nearest unit of
+ // lengthByteUnits. This works since lengthByteUnits is a power of 2
+ // (= 8).
+ mask := lengthByteUnits - 1
+ l += mask
+ l &^= mask
+
+ if l/lengthByteUnits > 255 {
+ // Should never happen because an option can only have a max
+ // value of 255 for its Length field, so just return 0 so this
+ // option does not get serialized.
+ //
+ // Returning 0 here will make sure that this option does not get
+ // serialized when NDPOptions.Serialize is called with the
+ // NDPOptionsSerializer that holds this option, effectively
+ // skipping this option during serialization. Also note that
+ // a value of zero for the Length field in an NDP option is
+ // invalid so this is another sign to the caller that this NDP
+ // option is malformed, as per RFC 4861 section 4.6.
+ return 0
+ }
+
+ return l
+}
+
+// NDPOptionsSerializer is a serializer for NDP options.
+type NDPOptionsSerializer []ndpOption
+
+// Length returns the total number of bytes required to serialize.
+func (b NDPOptionsSerializer) Length() int {
+ l := 0
+
+ for _, o := range b {
+ l += paddedLength(o)
+ }
+
+ return l
+}
+
+// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option
+// as defined by RFC 4861 section 4.6.1.
+type NDPTargetLinkLayerAddressOption tcpip.LinkAddress
+
+// Type implements ndpOption.Type.
+func (o NDPTargetLinkLayerAddressOption) Type() uint8 {
+ return NDPTargetLinkLayerAddressOptionType
+}
+
+// Length implements ndpOption.Length.
+func (o NDPTargetLinkLayerAddressOption) Length() int {
+ return len(o)
+}
+
+// serializeInto implements ndpOption.serializeInto.
+func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
+ return copy(b, o)
+}
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
new file mode 100644
index 000000000..bf7610863
--- /dev/null
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -0,0 +1,112 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "time"
+)
+
+// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
+// the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.2 for more details.
+type NDPRouterAdvert []byte
+
+const (
+ // NDPRAMinimumSize is the minimum size of a valid NDP Router
+ // Advertisement message (body of an ICMPv6 packet).
+ NDPRAMinimumSize = 12
+
+ // ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field
+ // within an NDPRouterAdvert.
+ ndpRACurrHopLimitOffset = 0
+
+ // ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags
+ // within an NDPRouterAdvert.
+ ndpRAFlagsOffset = 1
+
+ // ndpRAManagedAddrConfFlagMask is the mask of the Managed Address
+ // Configuration flag within the bit-field/flags byte of an
+ // NDPRouterAdvert.
+ ndpRAManagedAddrConfFlagMask = (1 << 7)
+
+ // ndpRAOtherConfFlagMask is the mask of the Other Configuration flag
+ // within the bit-field/flags byte of an NDPRouterAdvert.
+ ndpRAOtherConfFlagMask = (1 << 6)
+
+ // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
+ // field within an NDPRouterAdvert.
+ ndpRARouterLifetimeOffset = 2
+
+ // ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time
+ // field within an NDPRouterAdvert.
+ ndpRAReachableTimeOffset = 4
+
+ // ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer
+ // field within an NDPRouterAdvert.
+ ndpRARetransTimerOffset = 8
+
+ // ndpRAOptionsOffset is the start of the NDP options in an
+ // NDPRouterAdvert.
+ ndpRAOptionsOffset = 12
+)
+
+// CurrHopLimit returns the value of the Curr Hop Limit field.
+func (b NDPRouterAdvert) CurrHopLimit() uint8 {
+ return b[ndpRACurrHopLimitOffset]
+}
+
+// ManagedAddrConfFlag returns the value of the Managed Address Configuration
+// flag.
+func (b NDPRouterAdvert) ManagedAddrConfFlag() bool {
+ return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0
+}
+
+// OtherConfFlag returns the value of the Other Configuration flag.
+func (b NDPRouterAdvert) OtherConfFlag() bool {
+ return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
+}
+
+// RouterLifetime returns the lifetime associated with the default router. A
+// value of 0 means the source of the Router Advertisement is not a default
+// router and SHOULD NOT appear on the default router list. Note, a value of 0
+// only means that the router should not be used as a default router, it does
+// not apply to other information contained in the Router Advertisement.
+func (b NDPRouterAdvert) RouterLifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 4861 section 4.2.
+ return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:]))
+}
+
+// ReachableTime returns the time that a node assumes a neighbor is reachable
+// after having received a reachability confirmation. A value of 0 means
+// that it is unspecified by the source of the Router Advertisement message.
+func (b NDPRouterAdvert) ReachableTime() time.Duration {
+ // The field is the time in milliseconds, as per RFC 4861 section 4.2.
+ return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:]))
+}
+
+// RetransTimer returns the time between retransmitted Neighbor Solicitation
+// messages. A value of 0 means that it is unspecified by the source of the
+// Router Advertisement message.
+func (b NDPRouterAdvert) RetransTimer() time.Duration {
+ // The field is the time in milliseconds, as per RFC 4861 section 4.2.
+ return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:]))
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPRouterAdvert) Options() NDPOptions {
+ return NDPOptions(b[ndpRAOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
new file mode 100644
index 000000000..0aac14f43
--- /dev/null
+++ b/pkg/tcpip/header/ndp_test.go
@@ -0,0 +1,199 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bytes"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit.
+func TestNDPNeighborSolicit(t *testing.T) {
+ b := []byte{
+ 0, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ }
+
+ // Test getting the Target Address.
+ ns := NDPNeighborSolicit(b)
+ addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ if got := ns.TargetAddress(); got != addr {
+ t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr)
+ }
+
+ // Test updating the Target Address.
+ addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ ns.SetTargetAddress(addr2)
+ if got := ns.TargetAddress(); got != addr2 {
+ t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr2)
+ }
+ // Make sure the address got updated in the backing buffer.
+ if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 {
+ t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2)
+ }
+}
+
+// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert.
+func TestNDPNeighborAdvert(t *testing.T) {
+ b := []byte{
+ 160, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ }
+
+ // Test getting the Target Address.
+ na := NDPNeighborAdvert(b)
+ addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ if got := na.TargetAddress(); got != addr {
+ t.Fatalf("got TargetAddress = %s, want %s", got, addr)
+ }
+
+ // Test getting the Router Flag.
+ if got := na.RouterFlag(); !got {
+ t.Fatalf("got RouterFlag = false, want = true")
+ }
+
+ // Test getting the Solicited Flag.
+ if got := na.SolicitedFlag(); got {
+ t.Fatalf("got SolicitedFlag = true, want = false")
+ }
+
+ // Test getting the Override Flag.
+ if got := na.OverrideFlag(); !got {
+ t.Fatalf("got OverrideFlag = false, want = true")
+ }
+
+ // Test updating the Target Address.
+ addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ na.SetTargetAddress(addr2)
+ if got := na.TargetAddress(); got != addr2 {
+ t.Fatalf("got TargetAddress = %s, want %s", got, addr2)
+ }
+ // Make sure the address got updated in the backing buffer.
+ if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 {
+ t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2)
+ }
+
+ // Test updating the Router Flag.
+ na.SetRouterFlag(false)
+ if got := na.RouterFlag(); got {
+ t.Fatalf("got RouterFlag = true, want = false")
+ }
+
+ // Test updating the Solicited Flag.
+ na.SetSolicitedFlag(true)
+ if got := na.SolicitedFlag(); !got {
+ t.Fatalf("got SolicitedFlag = false, want = true")
+ }
+
+ // Test updating the Override Flag.
+ na.SetOverrideFlag(false)
+ if got := na.OverrideFlag(); got {
+ t.Fatalf("got OverrideFlag = true, want = false")
+ }
+
+ // Make sure flags got updated in the backing buffer.
+ if got := b[ndpNAFlagsOffset]; got != 64 {
+ t.Fatalf("got flags byte = %d, want = 64")
+ }
+}
+
+func TestNDPRouterAdvert(t *testing.T) {
+ b := []byte{
+ 64, 128, 1, 2,
+ 3, 4, 5, 6,
+ 7, 8, 9, 10,
+ }
+
+ ra := NDPRouterAdvert(b)
+
+ if got := ra.CurrHopLimit(); got != 64 {
+ t.Fatalf("got ra.CurrHopLimit = %d, want = 64", got)
+ }
+
+ if got := ra.ManagedAddrConfFlag(); !got {
+ t.Fatalf("got ManagedAddrConfFlag = false, want = true")
+ }
+
+ if got := ra.OtherConfFlag(); got {
+ t.Fatalf("got OtherConfFlag = true, want = false")
+ }
+
+ if got, want := ra.RouterLifetime(), time.Second*258; got != want {
+ t.Fatalf("got ra.RouterLifetime = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
+ t.Fatalf("got ra.ReachableTime = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
+ t.Fatalf("got ra.RetransTimer = %d, want = %d", got, want)
+ }
+}
+
+// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a
+// NDPTargetLinkLayerAddressOption.
+func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expectedBuf []byte
+ addr tcpip.LinkAddress
+ }{
+ {
+ "Ethernet",
+ make([]byte, 8),
+ []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ "\x01\x02\x03\x04\x05\x06",
+ },
+ {
+ "Padding",
+ []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ "\x01\x02\x03\x04\x05\x06\x07\x08",
+ },
+ {
+ "Empty",
+ []byte{},
+ []byte{},
+ "",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ serializer := NDPOptionsSerializer{
+ NDPTargetLinkLayerAddressOption(test.addr),
+ }
+ if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
+ t.Fatalf("got Length = %d, want = %d", got, want)
+ }
+ opts.Serialize(serializer)
+ if !bytes.Equal(test.buf, test.expectedBuf) {
+ t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 18adb2085..14f197a77 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -72,7 +72,7 @@ func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.Vector
// InjectLinkAddr injects an inbound packet with a remote link address.
func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) {
- e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil))
+ e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil), nil /* linkHeader */)
}
// Attach saves the stack network-layer dispatcher for use later when packets
@@ -96,7 +96,7 @@ func (e *Endpoint) MTU() uint32 {
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
caps := stack.LinkEndpointCapabilities(0)
if e.GSO {
- caps |= stack.CapabilityGSO
+ caps |= stack.CapabilityHardwareGSO
}
return caps
}
@@ -134,5 +134,49 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+// WritePackets stores outbound packets into the channel.
+func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ payloadView := payload.ToView()
+ n := 0
+packetLoop:
+ for i := range hdrs {
+ hdr := &hdrs[i].Hdr
+ off := hdrs[i].Off
+ size := hdrs[i].Size
+ p := PacketInfo{
+ Header: hdr.View(),
+ Proto: protocol,
+ Payload: buffer.NewViewFromBytes(payloadView[off : off+size]),
+ GSO: gso,
+ }
+
+ select {
+ case e.C <- p:
+ n++
+ default:
+ break packetLoop
+ }
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ p := PacketInfo{
+ Header: packet.ToView(),
+ Proto: 0,
+ Payload: buffer.View{},
+ GSO: nil,
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index f80ac3435..ae4858529 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -165,6 +165,9 @@ type Options struct {
// disabled.
GSOMaxSize uint32
+ // SoftwareGSOEnabled indicates whether software GSO is enabled or not.
+ SoftwareGSOEnabled bool
+
// PacketDispatchMode specifies the type of inbound dispatcher to be
// used for this endpoint.
PacketDispatchMode PacketDispatchMode
@@ -242,7 +245,11 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
}
if isSocket {
if opts.GSOMaxSize != 0 {
- e.caps |= stack.CapabilityGSO
+ if opts.SoftwareGSOEnabled {
+ e.caps |= stack.CapabilitySoftwareGSO
+ } else {
+ e.caps |= stack.CapabilityHardwareGSO
+ }
e.gsoMaxSize = opts.GSOMaxSize
}
}
@@ -397,7 +404,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
eth.Encode(ethHdr)
}
- if e.Capabilities()&stack.CapabilityGSO != 0 {
+ if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
vnetHdr := virtioNetHdr{}
vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
if gso != nil {
@@ -430,8 +437,130 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil)
}
-// WriteRawPacket writes a raw packet directly to the file descriptor.
-func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
+// WritePackets writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ var ethHdrBuf []byte
+ // hdr + data
+ iovLen := 2
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ ethHdrBuf = make([]byte, header.EthernetMinimumSize)
+ eth := header.Ethernet(ethHdrBuf)
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ iovLen++
+ }
+
+ n := len(hdrs)
+
+ views := payload.Views()
+ /*
+ * Each bondary in views can add one more iovec.
+ *
+ * payload | | | |
+ * -----------------------------
+ * packets | | | | | | |
+ * -----------------------------
+ * iovecs | | | | | | | | |
+ */
+ iovec := make([]syscall.Iovec, n*iovLen+len(views)-1)
+ mmsgHdrs := make([]rawfile.MMsgHdr, n)
+
+ iovecIdx := 0
+ viewIdx := 0
+ viewOff := 0
+ off := 0
+ nextOff := 0
+ for i := range hdrs {
+ prevIovecIdx := iovecIdx
+ mmsgHdr := &mmsgHdrs[i]
+ mmsgHdr.Msg.Iov = &iovec[iovecIdx]
+ packetSize := hdrs[i].Size
+ hdr := &hdrs[i].Hdr
+
+ off = hdrs[i].Off
+ if off != nextOff {
+ // We stop in a different point last time.
+ size := packetSize
+ viewIdx = 0
+ viewOff = 0
+ for size > 0 {
+ if size >= len(views[viewIdx]) {
+ viewIdx++
+ viewOff = 0
+ size -= len(views[viewIdx])
+ } else {
+ viewOff = size
+ size = 0
+ }
+ }
+ }
+ nextOff = off + packetSize
+
+ if ethHdrBuf != nil {
+ v := &iovec[iovecIdx]
+ v.Base = &ethHdrBuf[0]
+ v.Len = uint64(len(ethHdrBuf))
+ iovecIdx++
+ }
+
+ v := &iovec[iovecIdx]
+ hdrView := hdr.View()
+ v.Base = &hdrView[0]
+ v.Len = uint64(len(hdrView))
+ iovecIdx++
+
+ for packetSize > 0 {
+ vec := &iovec[iovecIdx]
+ iovecIdx++
+
+ v := views[viewIdx]
+ vec.Base = &v[viewOff]
+ s := len(v) - viewOff
+ if s <= packetSize {
+ viewIdx++
+ viewOff = 0
+ } else {
+ s = packetSize
+ viewOff += s
+ }
+ vec.Len = uint64(s)
+ packetSize -= s
+ }
+
+ mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx)
+ }
+
+ packets := 0
+ for packets < n {
+ sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs)
+ if err != nil {
+ return packets, err
+ }
+ packets += sent
+ mmsgHdrs = mmsgHdrs[sent:]
+ }
+ return packets, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ return rawfile.NonBlockingWrite(e.fds[0], packet.ToView())
+}
+
+// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
+func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
return rawfile.NonBlockingWrite(e.fds[0], packet)
}
@@ -468,9 +597,9 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
-// Inject injects an inbound packet.
-func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+// InjectInbound injects an inbound packet.
+func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */)
}
// NewInjectable creates a new fd-based InjectableEndpoint.
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 04406bc9a..59378b96c 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -43,9 +43,10 @@ const (
)
type packetInfo struct {
- raddr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- contents buffer.View
+ raddr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ contents buffer.View
+ linkHeader buffer.View
}
type context struct {
@@ -92,8 +93,8 @@ func (c *context) cleanup() {
syscall.Close(c.fds[1])
}
-func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- c.ch <- packetInfo{remote, protocol, vv.ToView()}
+func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
+ c.ch <- packetInfo{remote, protocol, vv.ToView(), linkHeader}
}
func TestNoEthernetProperties(t *testing.T) {
@@ -293,11 +294,12 @@ func TestDeliverPacket(t *testing.T) {
b[i] = uint8(rand.Intn(256))
}
+ var hdr header.Ethernet
if !eth {
// So that it looks like an IPv4 packet.
b[0] = 0x40
} else {
- hdr := make(header.Ethernet, header.EthernetMinimumSize)
+ hdr = make(header.Ethernet, header.EthernetMinimumSize)
hdr.Encode(&header.EthernetFields{
SrcAddr: raddr,
DstAddr: laddr,
@@ -315,9 +317,10 @@ func TestDeliverPacket(t *testing.T) {
select {
case pi := <-c.ch:
want := packetInfo{
- raddr: raddr,
- proto: proto,
- contents: b,
+ raddr: raddr,
+ proto: proto,
+ contents: b,
+ linkHeader: buffer.View(hdr),
}
if !eth {
want.proto = header.IPv4ProtocolNumber
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 8bfeb97e4..554d45715 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -169,9 +169,10 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
+ eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth := header.Ethernet(pkt)
+ eth = header.Ethernet(pkt)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -189,6 +190,6 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}), buffer.View(eth))
return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 7ca217e5b..12168a1dc 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -53,7 +53,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
d := &readVDispatcher{fd: fd, e: e}
d.views = make([]buffer.View, len(BufConfig))
iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
iovLen++
}
d.iovecs = make([]syscall.Iovec, iovLen)
@@ -63,7 +63,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
func (d *readVDispatcher) allocateViews(bufConfig []int) {
var vnetHdr [virtioNetHdrSize]byte
vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// The kernel adds virtioNetHdr before each packet, but
// we don't use it, so so we allocate a buffer for it,
// add it in iovecs but don't add it in a view.
@@ -106,7 +106,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
if err != nil {
return false, err
}
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// Skip virtioNetHdr which is added before each packet, it
// isn't used and it isn't in a view.
n -= virtioNetHdrSize
@@ -118,9 +118,10 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
+ eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth := header.Ethernet(d.views[0])
+ eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize])
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -141,7 +142,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
vv := buffer.NewVectorisedView(n, d.views[:used])
vv.TrimFront(d.e.hdrSize)
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth))
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
@@ -194,7 +195,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
}
d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv)
iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// virtioNetHdr is prepended before each packet.
iovLen++
}
@@ -225,7 +226,7 @@ func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) {
for k := 0; k < len(d.views); k++ {
var vnetHdr [virtioNetHdrSize]byte
vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// The kernel adds virtioNetHdr before each packet, but
// we don't use it, so so we allocate a buffer for it,
// add it in iovecs but don't add it in a view.
@@ -261,7 +262,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
// Process each of received packets.
for k := 0; k < nMsgs; k++ {
n := int(d.msgHdrs[k].Len)
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
n -= virtioNetHdrSize
}
if n <= d.e.hdrSize {
@@ -271,9 +272,10 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
+ eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth := header.Ethernet(d.views[k][0])
+ eth = header.Ethernet(d.views[k][0])
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -293,7 +295,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
used := d.capViews(k, int(n), BufConfig)
vv := buffer.NewVectorisedView(int(n), d.views[k][:used])
vv.TrimFront(d.e.hdrSize)
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth))
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
index 47a54845c..23e4d1418 100644
--- a/pkg/tcpip/link/loopback/BUILD
+++ b/pkg/tcpip/link/loopback/BUILD
@@ -10,6 +10,7 @@ go_library(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index b36629d2c..a3b48fa73 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -23,6 +23,7 @@ package loopback
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,6 +71,9 @@ func (*endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
@@ -81,10 +85,27 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa
// Because we're immediately turning around and writing the packet back to the
// rx path, we intentionally don't preserve the remote and local link
// addresses from the stack.Route we're passed.
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */)
return nil
}
-// Wait implements stack.LinkEndpoint.Wait.
-func (*endpoint) Wait() {}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ // Reject the packet if it's shorter than an ethernet header.
+ if packet.Size() < header.EthernetMinimumSize {
+ return tcpip.ErrBadAddress
+ }
+
+ // There should be an ethernet header at the beginning of packet.
+ linkHeader := header.Ethernet(packet.First()[:header.EthernetMinimumSize])
+ packet.TrimFront(len(linkHeader))
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), packet, buffer.View(linkHeader))
+
+ return nil
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 7c946101d..682b60291 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -79,29 +79,47 @@ func (m *InjectableEndpoint) IsAttached() bool {
return m.dispatcher != nil
}
-// Inject implements stack.InjectableLinkEndpoint.
-func (m *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv)
+// InjectInbound implements stack.InjectableLinkEndpoint.
+func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */)
+}
+
+// WritePackets writes outbound packets to the appropriate
+// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
+// r.RemoteAddress has a route registered in this endpoint.
+func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ endpoint, ok := m.routes[r.RemoteAddress]
+ if !ok {
+ return 0, tcpip.ErrNoRoute
+ }
+ return endpoint.WritePackets(r, gso, hdrs, payload, protocol)
}
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
- return endpoint.WritePacket(r, nil /* gso */, hdr, payload, protocol)
+ return endpoint.WritePacket(r, gso, hdr, payload, protocol)
}
return tcpip.ErrNoRoute
}
-// WriteRawPacket writes outbound packets to the appropriate
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (m *InjectableEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ // WriteRawPacket doesn't get a route or network address, so there's
+ // nowhere to write this.
+ return tcpip.ErrNoRoute
+}
+
+// InjectOutbound writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the dest address.
-func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
+func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
endpoint, ok := m.routes[dest]
if !ok {
return tcpip.ErrNoRoute
}
- return endpoint.WriteRawPacket(dest, packet)
+ return endpoint.InjectOutbound(dest, packet)
}
// Wait implements stack.LinkEndpoint.Wait.
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 3086fec00..9cd300af8 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -31,7 +31,7 @@ import (
func TestInjectableEndpointRawDispatch(t *testing.T) {
endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
- endpoint.WriteRawPacket(dstIP, []byte{0xFA})
+ endpoint.InjectOutbound(dstIP, []byte{0xFA})
buf := make([]byte, ipv4.MaxTotalSize)
bytesRead, err := sock.Read(buf)
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 2e8bc772a..05c7b8024 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -16,5 +16,8 @@ go_library(
visibility = [
"//visibility:public",
],
- deps = ["//pkg/tcpip"],
+ deps = [
+ "//pkg/tcpip",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
)
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 7e286a3a6..44e25d475 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -22,6 +22,7 @@ import (
"syscall"
"unsafe"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -101,6 +102,16 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
return nil
}
+// NonBlockingSendMMsg sends multiple messages on a socket.
+func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+ n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
+ if e != 0 {
+ return 0, TranslateErrno(e)
+ }
+
+ return int(n), nil
+}
+
// PollEvent represents the pollfd structure passed to a poll() system call.
type PollEvent struct {
FD int32
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 9e71d4edf..279e2b457 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -212,6 +212,26 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa
return nil
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ v := packet.ToView()
+ // Transmit the packet.
+ e.mu.Lock()
+ ok := e.tx.transmit(v, buffer.View{})
+ e.mu.Unlock()
+
+ if !ok {
+ return tcpip.ErrWouldBlock
+ }
+
+ return nil
+}
+
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
// to the network stack.
func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
@@ -254,7 +274,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
// Send packet up the stack.
eth := header.Ethernet(b)
- d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView())
+ d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), buffer.View(eth))
}
// Clean state.
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 0e9ba0846..f3e9705c9 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -78,9 +78,10 @@ func (q *queueBuffers) cleanup() {
}
type packetInfo struct {
- addr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- vv buffer.VectorisedView
+ addr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ vv buffer.VectorisedView
+ linkHeader buffer.View
}
type testContext struct {
@@ -130,12 +131,13 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
return c
}
-func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
c.mu.Lock()
c.packets = append(c.packets, packetInfo{
- addr: remoteLinkAddr,
- proto: proto,
- vv: vv.Clone(nil),
+ addr: remoteLinkAddr,
+ proto: proto,
+ vv: vv.Clone(nil),
+ linkHeader: linkHeader,
})
c.mu.Unlock()
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index e401dce44..39757ea2a 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -116,7 +116,7 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
-func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
logPacket("recv", protocol, vv.First(), nil)
}
@@ -147,7 +147,7 @@ func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local
panic(err)
}
}
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv)
+ e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader)
}
// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
@@ -193,10 +193,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-// WritePacket implements the stack.LinkEndpoint interface. It is called by
-// higher-level protocols to write packets; it just logs the packet and forwards
-// the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
logPacket("send", protocol, hdr.View(), gso)
}
@@ -218,28 +215,74 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
panic(err)
}
length -= len(hdrBuf)
- if length > 0 {
- for _, v := range payload.Views() {
- if len(v) > length {
- v = v[:length]
- }
- n, err := buf.Write(v)
- if err != nil {
- panic(err)
- }
- length -= n
- if length == 0 {
- break
- }
- }
- }
+ logVectorisedView(payload, length, buf)
if _, err := e.file.Write(buf.Bytes()); err != nil {
panic(err)
}
}
+}
+
+// WritePacket implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and
+// forwards the request to the lower endpoint.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ e.dumpPacket(gso, hdr, payload, protocol)
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
+// WritePackets implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and
+// forwards the request to the lower endpoint.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ view := payload.ToView()
+ for _, d := range hdrs {
+ e.dumpPacket(gso, d.Hdr, buffer.NewVectorisedView(d.Size, []buffer.View{view[d.Off:][:d.Size]}), protocol)
+ }
+ return e.lower.WritePackets(r, gso, hdrs, payload, protocol)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
+ logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */)
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ length := packet.Size()
+ if length > int(e.maxPCAPLen) {
+ length = int(e.maxPCAPLen)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(packet.Size()))); err != nil {
+ panic(err)
+ }
+ logVectorisedView(packet, length, buf)
+ if _, err := e.file.Write(buf.Bytes()); err != nil {
+ panic(err)
+ }
+ }
+ return e.lower.WriteRawPacket(packet)
+}
+
+func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) {
+ if length <= 0 {
+ return
+ }
+ for _, v := range vv.Views() {
+ if len(v) > length {
+ v = v[:length]
+ }
+ n, err := buf.Write(v)
+ if err != nil {
+ panic(err)
+ }
+ length -= n
+ if length == 0 {
+ return
+ }
+ }
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 5a1791cb5..a04fc1062 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint {
// It is called by the link-layer endpoint being wrapped when a packet arrives,
// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
// been called.
-func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
if !e.dispatchGate.Enter() {
return
}
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv)
+ e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader)
e.dispatchGate.Leave()
}
@@ -109,6 +109,30 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return err
}
+// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
+// higher-level protocols to write packets. It only forwards packets to the
+// lower endpoint if Wait or WaitWrite haven't been called.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if !e.writeGate.Enter() {
+ return len(hdrs), nil
+ }
+
+ n, err := e.lower.WritePackets(r, gso, hdrs, payload, protocol)
+ e.writeGate.Leave()
+ return n, err
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ if !e.writeGate.Enter() {
+ return nil
+ }
+
+ err := e.lower.WriteRawPacket(packet)
+ e.writeGate.Leave()
+ return err
+}
+
// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
// and waits for inflight ones to finish before returning.
func (e *Endpoint) WaitWrite() {
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index ae23c96b7..5f0f8fa2d 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -35,7 +35,7 @@ type countedEndpoint struct {
dispatcher stack.NetworkDispatcher
}
-func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
e.dispatchCount++
}
@@ -70,6 +70,17 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P
return nil
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ e.writeCount += len(hdrs)
+ return len(hdrs), nil
+}
+
+func (e *countedEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ e.writeCount++
+ return nil
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
@@ -109,21 +120,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 922181ac0..46178459e 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -83,6 +83,11 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buf
return tcpip.ErrNotSupported
}
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketDescriptor, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -109,7 +114,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ fallthrough // also fill the cache from requests
case header.ARPReply:
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 825ff3392..2cad0a0b6 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -28,6 +28,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/tcpip",
"//pkg/tcpip/buffer",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 1628a82be..6da5238ec 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -17,6 +17,7 @@
package fragmentation
import (
+ "fmt"
"log"
"sync"
"time"
@@ -82,7 +83,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
// Process processes an incoming fragment belonging to an ID
// and returns a complete packet when all the packets belonging to that ID have been received.
-func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
f.mu.Lock()
r, ok := f.reassemblers[id]
if ok && r.tooOld(f.timeout) {
@@ -97,8 +98,15 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
}
f.mu.Unlock()
- res, done, consumed := r.process(first, last, more, vv)
-
+ res, done, consumed, err := r.process(first, last, more, vv)
+ if err != nil {
+ // We probably got an invalid sequence of fragments. Just
+ // discard the reassembler and move on.
+ f.mu.Lock()
+ f.release(r)
+ f.mu.Unlock()
+ return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err)
+ }
f.mu.Lock()
f.size += consumed
if done {
@@ -114,7 +122,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
}
}
f.mu.Unlock()
- return res, done
+ return res, done, nil
}
func (f *Fragmentation) release(r *reassembler) {
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 799798544..72c0f53be 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -83,7 +83,10 @@ func TestFragmentationProcess(t *testing.T) {
t.Run(c.comment, func(t *testing.T) {
f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
for i, in := range c.in {
- vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ if err != nil {
+ t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
+ }
if !reflect.DeepEqual(vv, c.out[i].vv) {
t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
}
@@ -114,7 +117,10 @@ func TestReassemblingTimeout(t *testing.T) {
time.Sleep(2 * timeout)
// Send another fragment that completes a packet.
// However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done := f.Process(0, 1, 1, false, vv(1, "1"))
+ _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
+ if err != nil {
+ t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
+ }
if done {
t.Errorf("Fragmentation does not respect the reassembling timeout.")
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 8037f734b..9e002e396 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
return used
}
-func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
+func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -86,7 +86,7 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, false, consumed
+ return buffer.VectorisedView{}, false, consumed, nil
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
@@ -96,13 +96,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
}
// Check if all the holes have been deleted and we are ready to reassamble.
if r.deleted < len(r.holes) {
- return buffer.VectorisedView{}, false, consumed
+ return buffer.VectorisedView{}, false, consumed, nil
}
res, err := r.heap.reassemble()
if err != nil {
- panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
+ return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err)
}
- return res, true, consumed
+ return res, true, consumed, nil
}
func (r *reassembler) tooOld(timeout time.Duration) bool {
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index f644a8b08..8d74497ba 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -171,6 +171,15 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen
return nil
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, hdr []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index df1a08113..1339f8474 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -198,10 +198,9 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff
return nil
}
-// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- length := uint16(hdr.UsedLength() + payload.Size())
+ length := uint16(hdr.UsedLength() + payloadSize)
id := uint32(0)
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
@@ -219,6 +218,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
+ e.addIPHeader(r, &hdr, payload.Size(), params)
if loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(payload.Views()))
@@ -242,6 +246,23 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ if loop&stack.PacketLoop != 0 {
+ panic("multiple packets in local loop")
+ }
+ if loop&stack.PacketOut == 0 {
+ return len(hdrs), nil
+ }
+
+ for i := range hdrs {
+ e.addIPHeader(r, &hdrs[i].Hdr, hdrs[i].Size, params)
+ }
+ n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+}
+
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
@@ -326,7 +347,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
return
}
var ready bool
- vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ var err error
+ vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
if !ready {
return
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 560638ce8..99f84acd7 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -47,10 +47,6 @@ func TestExcludeBroadcast(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
NIC: 1,
@@ -441,6 +437,16 @@ func TestInvalidFragments(t *testing.T) {
1,
1,
},
+ {
+ "multiple_fragments_with_more_fragments_set_to_false",
+ [][]byte{
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ },
+ 1,
+ 1,
+ },
}
for _, tc := range testCases {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index b5df85455..b289e902f 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -21,15 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-const (
- // ndpHopLimit is the expected IP hop limit value of 255 for received
- // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
- // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
- // drop the NDP packet. All outgoing NDP packets must use this value for
- // its IP hop limit field.
- ndpHopLimit = 255
-)
-
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
@@ -79,6 +70,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
return
}
h := header.ICMPv6(v)
+ iph := header.IPv6(netHeader)
// As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
// 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
@@ -89,7 +81,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
header.ICMPv6RouterSolicit,
header.ICMPv6RouterAdvert,
header.ICMPv6RedirectMsg:
- if header.IPv6(netHeader).HopLimit() != ndpHopLimit {
+ if iph.HopLimit() != header.NDPHopLimit {
received.Invalid.Increment()
return
}
@@ -121,25 +113,72 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
-
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
+
+ ns := header.NDPNeighborSolicit(h.NDPPayload())
+ targetAddr := ns.TargetAddress()
+ s := r.Stack()
+ rxNICID := r.NICID()
+
+ isTentative, err := s.IsAddrTentative(rxNICID, targetAddr)
+ if err != nil {
+ // We will only get an error if rxNICID is unrecognized,
+ // which should not happen. For now short-circuit this
+ // packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ }
+
+ if isTentative {
+ // If the target address is tentative and the source
+ // of the packet is a unicast (specified) address, then
+ // the source of the packet is attempting to perform
+ // address resolution on the target. In this case, the
+ // solicitation is silently ignored, as per RFC 4862
+ // section 5.4.3.
+ //
+ // If the target address is tentative and the source of
+ // the packet is the unspecified address (::), then we
+ // know another node is also performing DAD for the
+ // same address (since targetAddr is tentative for us,
+ // we know we are also performing DAD on it). In this
+ // case we let the stack know so it can handle such a
+ // scenario and do nothing further with the NDP NS.
+ if iph.SourceAddress() == header.IPv6Any {
+ s.DupTentativeAddrDetected(rxNICID, targetAddr)
+ }
+
+ // Do not handle neighbor solicitations targeted
+ // to an address that is tentative on the received
+ // NIC any further.
+ return
+ }
+
+ // At this point we know that targetAddr is not tentative on
+ // rxNICID so the packet is processed as defined in RFC 4861,
+ // as per RFC 4862 section 5.4.3.
+
if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize)
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]),
+ }
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag
- copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr)
- pkt[icmpV6OptOffset] = ndpOptDstLinkAddr
- pkt[icmpV6LengthOffset] = 1
- copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:])
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(targetAddr)
+ opts := na.Options()
+ opts.Serialize(optsSerializer)
// ICMPv6 Neighbor Solicit messages are always sent to
// specially crafted IPv6 multicast addresses. As a result, the
@@ -154,7 +193,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
r.LocalAddress = targetAddr
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
+ // TODO(tamird/ghanan): there exists an explicit NDP option that is
+ // used to update the neighbor table with link addresses for a
+ // neighbor from an NS (see the Source Link Layer option RFC
+ // 4861 section 4.6.1 and section 7.2.3).
+ //
+ // Furthermore, the entirety of NDP handling here seems to be
+ // contradicted by RFC 4861.
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+
+ // RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
+ //
+ // 7.1.2. Validation of Neighbor Advertisements
+ //
+ // The IP Hop Limit field has a value of 255, i.e., the packet
+ // could not possibly have been forwarded by a router.
+ if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}); err != nil {
sent.Dropped.Increment()
return
}
@@ -166,7 +220,42 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
+
+ na := header.NDPNeighborAdvert(h.NDPPayload())
+ targetAddr := na.TargetAddress()
+ stack := r.Stack()
+ rxNICID := r.NICID()
+
+ isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr)
+ if err != nil {
+ // We will only get an error if rxNICID is unrecognized,
+ // which should not happen. For now short-circuit this
+ // packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ }
+
+ if isTentative {
+ // We just got an NA from a node that owns an address we
+ // are performing DAD on, implying the address is not
+ // unique. In this case we let the stack know so it can
+ // handle such a scenario and do nothing furthur with
+ // the NDP NA.
+ stack.DupTentativeAddrDetected(rxNICID, targetAddr)
+ return
+ }
+
+ // At this point we know that the targetAddress is not tentative
+ // on rxNICID. However, targetAddr may still be assigned to
+ // rxNICID but not tentative (it could be permanent). Such a
+ // scenario is beyond the scope of RFC 4862. As such, we simply
+ // ignore such a scenario for now and proceed as normal.
+ //
+ // TODO(b/143147598): Handle the scenario described above. Also
+ // inform the netstack integration that a duplicate address was
+ // detected outside of DAD.
+
e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
@@ -178,7 +267,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
-
vv.TrimFront(header.ICMPv6EchoMinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
@@ -262,7 +350,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: ndpHopLimit,
+ HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 501be208e..7c11dde55 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,7 +15,6 @@
package ipv6
import (
- "fmt"
"reflect"
"strings"
"testing"
@@ -144,7 +143,7 @@ func TestICMPCounts(t *testing.T) {
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: ndpHopLimit,
+ HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
@@ -179,13 +178,10 @@ func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
t := v.Type()
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
- switch v.Kind() {
- case reflect.Ptr:
- f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter))
- case reflect.Struct:
+ if s, ok := v.Interface().(*tcpip.StatCounter); ok {
+ f(t.Field(i).Name, s)
+ } else {
visitStats(v, f)
- default:
- panic(fmt.Sprintf("unexpected type %s", v.Type()))
}
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index cd1e34085..5898f8f9e 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -97,9 +97,8 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
- length := uint16(hdr.UsedLength() + payload.Size())
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) {
+ length := uint16(hdr.UsedLength() + payloadSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
@@ -109,6 +108,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
+ e.addIPHeader(r, &hdr, payload.Size(), params)
if loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(payload.Views()))
@@ -127,6 +131,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ if loop&stack.PacketLoop != 0 {
+ panic("not implemented")
+ }
+ if loop&stack.PacketOut == 0 {
+ return len(hdrs), nil
+ }
+
+ for i := range hdrs {
+ hdr := &hdrs[i].Hdr
+ size := hdrs[i].Size
+ e.addIPHeader(r, hdr, size, params)
+ }
+
+ n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+}
+
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index e30791fe3..c32716f2e 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -150,7 +150,7 @@ func TestHopLimitValidation(t *testing.T) {
// Receive the NDP packet with an invalid hop limit
// value.
- handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r)
+ handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r)
// Invalid count should have increased.
if got := invalid.Value(); got != 1 {
@@ -164,7 +164,7 @@ func TestHopLimitValidation(t *testing.T) {
}
// Receive the NDP packet with a valid hop limit value.
- handleIPv6Payload(hdr, ndpHopLimit, ep, &r)
+ handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r)
// Rx count of NDP packet of type typ.typ should have
// increased.
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 6a78432c9..460db3cf8 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -22,6 +22,7 @@ go_library(
"icmp_rate_limit.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
+ "ndp.go",
"nic.go",
"registration.go",
"route.go",
@@ -53,6 +54,7 @@ go_test(
name = "stack_x_test",
size = "small",
srcs = [
+ "ndp_test.go",
"stack_test.go",
"transport_demuxer_test.go",
"transport_test.go",
@@ -61,14 +63,17 @@ go_test(
":stack",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
new file mode 100644
index 000000000..ea2dbed2e
--- /dev/null
+++ b/pkg/tcpip/stack/ndp.go
@@ -0,0 +1,319 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "log"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ // defaultDupAddrDetectTransmits is the default number of NDP Neighbor
+ // Solicitation messages to send when doing Duplicate Address Detection
+ // for a tentative address.
+ //
+ // Default = 1 (from RFC 4862 section 5.1)
+ defaultDupAddrDetectTransmits = 1
+
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending NDP Neighbor solicitation messages.
+ //
+ // Default = 1s (from RFC 4861 section 10).
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending NDP Neighbor solicitation messages. Note, RFC 4861 does
+ // not impose a minimum Retransmit Timer, but we do here to make sure
+ // the messages are not sent all at once. We also come to this value
+ // because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1.
+ // Note, the unit of the RetransmitTimer field in the Router
+ // Advertisement is milliseconds.
+ //
+ // Min = 1ms.
+ minimumRetransmitTimer = time.Millisecond
+)
+
+// NDPDispatcher is the interface integrators of netstack must implement to
+// receive and handle NDP related events.
+type NDPDispatcher interface {
+ // OnDuplicateAddressDetectionStatus will be called when the DAD process
+ // for an address (addr) on a NIC (with ID nicid) completes. resolved
+ // will be set to true if DAD completed successfully (no duplicate addr
+ // detected); false otherwise (addr was detected to be a duplicate on
+ // the link the NIC is a part of, or it was stopped for some other
+ // reason, such as the address being removed). If an error occured
+ // during DAD, err will be set and resolved must be ignored.
+ //
+ // This function is permitted to block indefinitely without interfering
+ // with the stack's operation.
+ OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+}
+
+// NDPConfigurations is the NDP configurations for the netstack.
+type NDPConfigurations struct {
+ // The number of Neighbor Solicitation messages to send when doing
+ // Duplicate Address Detection for a tentative address.
+ //
+ // Note, a value of zero effectively disables DAD.
+ DupAddrDetectTransmits uint8
+
+ // The amount of time to wait between sending Neighbor solicitation
+ // messages.
+ //
+ // Must be greater than 0.5s.
+ RetransmitTimer time.Duration
+}
+
+// DefaultNDPConfigurations returns an NDPConfigurations populated with
+// default values.
+func DefaultNDPConfigurations() NDPConfigurations {
+ return NDPConfigurations{
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ }
+}
+
+// validate modifies an NDPConfigurations with valid values. If invalid values
+// are present in c, the corresponding default values will be used instead.
+//
+// If RetransmitTimer is less than minimumRetransmitTimer, then a value of
+// defaultRetransmitTimer will be used.
+func (c *NDPConfigurations) validate() {
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+}
+
+// ndpState is the per-interface NDP state.
+type ndpState struct {
+ // The NIC this ndpState is for.
+ nic *NIC
+
+ // The DAD state to send the next NS message, or resolve the address.
+ dad map[tcpip.Address]dadState
+}
+
+// dadState holds the Duplicate Address Detection timer and channel to signal
+// to the DAD goroutine that DAD should stop.
+type dadState struct {
+ // The DAD timer to send the next NS message, or resolve the address.
+ timer *time.Timer
+
+ // Used to let the DAD timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this dadState is associated with.
+ done *bool
+}
+
+// startDuplicateAddressDetection performs Duplicate Address Detection.
+//
+// This function must only be called by IPv6 addresses that are currently
+// tentative.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+ // addr must be a valid unicast IPv6 address.
+ if !header.IsV6UnicastAddress(addr) {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
+ // Should not attempt to perform DAD on an address that is currently in
+ // the DAD process.
+ if _, ok := ndp.dad[addr]; ok {
+ // Should never happen because we should only ever call this
+ // function for newly created addresses. If we attemped to
+ // "add" an address that already existed, we would returned an
+ // error since we attempted to add a duplicate address, or its
+ // reference count would have been increased without doing the
+ // work that would have been done for an address that was brand
+ // new. See NIC.addPermanentAddressLocked.
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ remaining := ndp.nic.stack.ndpConfigs.DupAddrDetectTransmits
+
+ {
+ done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref)
+ if err != nil {
+ return err
+ }
+ if done {
+ return nil
+ }
+ }
+
+ remaining--
+
+ var done bool
+ var timer *time.Timer
+ timer = time.AfterFunc(ndp.nic.stack.ndpConfigs.RetransmitTimer, func() {
+ var d bool
+ var err *tcpip.Error
+
+ // doDadIteration does a single iteration of the DAD loop.
+ //
+ // Returns true if the integrator needs to be informed of DAD
+ // completing.
+ doDadIteration := func() bool {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ if done {
+ // If we reach this point, it means that the DAD
+ // timer fired after another goroutine already
+ // obtained the NIC lock and stopped DAD before
+ // this function obtained the NIC lock. Simply
+ // return here and do nothing further.
+ return false
+ }
+
+ ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ // This should never happen.
+ // We should have an endpoint for addr since we
+ // are still performing DAD on it. If the
+ // endpoint does not exist, but we are doing DAD
+ // on it, then we started DAD at some point, but
+ // forgot to stop it when the endpoint was
+ // deleted.
+ panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref)
+ if err != nil || d {
+ delete(ndp.dad, addr)
+
+ if err != nil {
+ log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
+ }
+
+ // Let the integrator know DAD has completed.
+ return true
+ }
+
+ remaining--
+ timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
+ return false
+ }
+
+ if doDadIteration() && ndp.nic.stack.ndpDisp != nil {
+ ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err)
+ }
+
+ })
+
+ ndp.dad[addr] = dadState{
+ timer: timer,
+ done: &done,
+ }
+
+ return nil
+}
+
+// doDuplicateAddressDetection is called on every iteration of the timer, and
+// when DAD starts.
+//
+// It handles resolving the address (if there are no more NS to send), or
+// sending the next NS if there are more NS to send.
+//
+// This function must only be called by IPv6 addresses that are currently
+// tentative.
+//
+// The NIC that ndp belongs to (n) MUST be locked.
+//
+// Returns true if DAD has resolved; false if DAD is still ongoing.
+func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
+ if ref.getKind() != permanentTentative {
+ // The endpoint should still be marked as tentative
+ // since we are still performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ if remaining == 0 {
+ // DAD has resolved.
+ ref.setKind(permanent)
+ return true, nil
+ }
+
+ // Send a new NS.
+ snmc := header.SolicitedNodeAddr(addr)
+ snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}]
+ if !ok {
+ // This should never happen as if we have the
+ // address, we should have the solicited-node
+ // address.
+ panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
+ }
+
+ // Use the unspecified address as the source address when performing
+ // DAD.
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(addr)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}); err != nil {
+ sent.Dropped.Increment()
+ return false, err
+ }
+ sent.NeighborSolicit.Increment()
+
+ return false, nil
+}
+
+// stopDuplicateAddressDetection ends a running Duplicate Address Detection
+// process. Note, this may leave the DAD process for a tentative address in
+// such a state forever, unless some other external event resolves the DAD
+// process (receiving an NA from the true owner of addr, or an NS for addr
+// (implying another node is attempting to use addr)). It is up to the caller
+// of this function to handle such a scenario. Normally, addr will be removed
+// from n right after this function returns or the address successfully
+// resolved.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
+ dad, ok := ndp.dad[addr]
+ if !ok {
+ // Not currently performing DAD on addr, just return.
+ return
+ }
+
+ if dad.timer != nil {
+ dad.timer.Stop()
+ dad.timer = nil
+
+ *dad.done = true
+ dad.done = nil
+ }
+
+ delete(ndp.dad, addr)
+
+ // Let the integrator know DAD did not resolve.
+ if ndp.nic.stack.ndpDisp != nil {
+ go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ }
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
new file mode 100644
index 000000000..b089ce2ae
--- /dev/null
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -0,0 +1,443 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = "\x02\x02\x03\x04\x05\x06"
+)
+
+// TestDADDisabled tests that an address successfully resolves immediately
+// when DAD is not enabled (the default for an empty stack.Options).
+func TestDADDisabled(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Should get the address immediately since we should not have performed
+ // DAD on it.
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+ }
+
+ // We should not have sent any NDP NS messages.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ t.Fatalf("got NeighborSolicit = %d, want = 0", got)
+ }
+}
+
+// ndpDADEvent is a set of parameters that was passed to
+// ndpDispatcher.OnDuplicateAddressDetectionStatus.
+type ndpDADEvent struct {
+ nicid tcpip.NICID
+ addr tcpip.Address
+ resolved bool
+ err *tcpip.Error
+}
+
+var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
+
+// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
+// related events happen for test purposes.
+type ndpDispatcher struct {
+ dadC chan ndpDADEvent
+}
+
+// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
+//
+// If the DAD event matches what we are expecting, send signal on n.dadC.
+func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
+ n.dadC <- ndpDADEvent{
+ nicid,
+ addr,
+ resolved,
+ err,
+ }
+}
+
+// TestDADResolve tests that an address successfully resolves after performing
+// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
+// Included in the subtests is a test to make sure that an invalid
+// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+func TestDADResolve(t *testing.T) {
+ tests := []struct {
+ name string
+ dupAddrDetectTransmits uint8
+ retransTimer time.Duration
+ expectedRetransmitTimer time.Duration
+ }{
+ {"1:1s:1s", 1, time.Second, time.Second},
+ {"2:1s:1s", 2, time.Second, time.Second},
+ {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
+ // 0s is an invalid RetransmitTimer timer and will be fixed to
+ // the default RetransmitTimer value of 1s.
+ {"1:0s:1s", 1, 0, time.Second},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ }
+ opts.NDPConfigs.RetransmitTimer = test.retransTimer
+ opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit
+
+ // Should have sent an NDP NS immediately.
+ if got := stat.Value(); got != 1 {
+ t.Fatalf("got NeighborSolicit = %d, want = 1", got)
+
+ }
+
+ // Address should not be considered bound to the NIC yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Wait for the remaining time - some delta (500ms), to
+ // make sure the address is still not resolved.
+ const delta = 500 * time.Millisecond
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(2 * delta):
+ // We should get a resolution event after 500ms
+ // (delta) since we wait for 500ms less than the
+ // expected resolution time above to make sure
+ // that the address did not yet resolve. Waiting
+ // for 1s (2x delta) without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+ }
+
+ // Should not have sent any more NS messages.
+ if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
+ }
+
+ // Validate the sent Neighbor Solicitation messages.
+ for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
+ p := <-e.C
+
+ // Make sure its an IPv6 packet.
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+
+ // Check NDP packet.
+ checker.IPv6(t, p.Header.ToVectorisedView().First(),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr1)))
+ }
+ })
+ }
+
+}
+
+// TestDADFail tests to make sure that the DAD process fails if another node is
+// detected to be performing DAD on the same address (receive an NS message from
+// a node doing DAD for the same address), or if another node is detected to own
+// the address already (receive an NA message for the tentative address).
+func TestDADFail(t *testing.T) {
+ tests := []struct {
+ name string
+ makeBuf func(tgt tcpip.Address) buffer.Prependable
+ getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ "RxSolicit",
+ func(tgt tcpip.Address) buffer.Prependable {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(tgt)
+ snmc := header.SolicitedNodeAddr(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
+ })
+
+ return hdr
+
+ },
+ func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return s.NeighborSolicit
+ },
+ },
+ {
+ "RxAdvert",
+ func(tgt tcpip.Address) buffer.Prependable {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ return hdr
+
+ },
+ func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return s.NeighborAdvert
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ }
+ opts.NDPConfigs.RetransmitTimer = time.Second * 2
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Address should not be considered bound to the NIC yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Receive a packet to simulate multiple nodes owning or
+ // attempting to own the same address.
+ hdr := test.makeBuf(addr1)
+ e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ if got := stat.Value(); got != 1 {
+ t.Fatalf("got stat = %d, want = 1", got)
+ }
+
+ // Wait for DAD to fail and make sure the address did
+ // not get resolved.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the
+ // expected resolution time + extra 1s buffer,
+ // something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ })
+ }
+}
+
+// TestDADStop tests to make sure that the DAD process stops when an address is
+// removed.
+func TestDADStop(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Remove the address. This should stop DAD.
+ if err := s.RemoveAddress(1, addr1); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err)
+ }
+
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 5993fe582..2d29fa88e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -19,7 +19,6 @@ import (
"sync"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/ilist"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -37,12 +36,17 @@ type NIC struct {
mu sync.RWMutex
spoofing bool
promiscuous bool
- primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
addressRanges []tcpip.Subnet
mcastJoins map[NetworkEndpointID]int32
+ // packetEPs is protected by mu, but the contained PacketEndpoint
+ // values are not.
+ packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
stats NICStats
+
+ ndp ndpState
}
// NICStats includes transmitted and received stats.
@@ -77,15 +81,19 @@ const (
)
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
- return &NIC{
+ // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
+ // example, make sure that the link address it provides is a valid
+ // unicast ethernet address.
+ nic := &NIC{
stack: stack,
id: id,
name: name,
linkEP: ep,
loopback: loopback,
- primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
+ packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint),
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -96,7 +104,21 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
Bytes: &tcpip.StatCounter{},
},
},
+ ndp: ndpState{
+ dad: make(map[tcpip.Address]dadState),
+ },
+ }
+ nic.ndp.nic = nic
+
+ // Register supported packet endpoint protocols.
+ for _, netProto := range header.Ethertypes {
+ nic.packetEPs[netProto] = []PacketEndpoint{}
}
+ for _, netProto := range stack.networkProtocols {
+ nic.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ }
+
+ return nic
}
// enable enables the NIC. enable will attach the link to its LinkEndpoint and
@@ -121,11 +143,50 @@ func (n *NIC) enable() *tcpip.Error {
// when we perform Duplicate Address Detection, or Router Advertisement
// when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
// section 4.2 for more information.
- if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
- return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress)
+ //
+ // Also auto-generate an IPv6 link-local address based on the NIC's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+ _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]
+ if !ok {
+ return nil
}
- return nil
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ if !n.stack.autoGenIPv6LinkLocal {
+ return nil
+ }
+
+ l2addr := n.linkEP.LinkAddress()
+
+ // Only attempt to generate the link-local address if we have a
+ // valid MAC address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address
+ // (provided by LinkEndpoint.LinkAddress) before reaching this
+ // point.
+ if !header.IsValidUnicastEthernetAddress(l2addr) {
+ return nil
+ }
+
+ addr := header.LinkLocalAddr(l2addr)
+
+ _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ },
+ }, CanBePrimaryEndpoint)
+
+ return err
}
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
@@ -161,18 +222,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
n.mu.RLock()
defer n.mu.RUnlock()
- list := n.primary[protocol]
- if list == nil {
- return nil
- }
-
- for e := list.Front(); e != nil; e = e.Next() {
- r := e.(*referencedNetworkEndpoint)
- // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set.
- switch r.ep.ID().LocalAddress {
- case header.IPv4Broadcast, header.IPv4Any:
- continue
- }
+ for _, r := range n.primary[protocol] {
if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
@@ -282,13 +332,35 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
if ref, ok := n.endpoints[id]; ok {
switch ref.getKind() {
- case permanent:
+ case permanentTentative, permanent:
// The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
case permanentExpired, temporary:
- // Promote the endpoint to become permanent.
+ // Promote the endpoint to become permanent and respect
+ // the new peb.
if ref.tryIncRef() {
ref.setKind(permanent)
+
+ refs := n.primary[ref.protocol]
+ for i, r := range refs {
+ if r == ref {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return ref, nil
+ case FirstPrimaryEndpoint:
+ if i == 0 {
+ return ref, nil
+ }
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ case NeverPrimaryEndpoint:
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ return ref, nil
+ }
+ }
+ }
+
+ n.insertPrimaryEndpointLocked(ref, peb)
+
return ref, nil
}
// tryIncRef failing means the endpoint is scheduled to be removed once
@@ -298,6 +370,7 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
n.removeEndpointLocked(ref)
}
}
+
return n.addAddressLocked(protocolAddress, peb, permanent)
}
@@ -321,6 +394,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
if err != nil {
return nil, err
}
+
+ isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
+
+ // If the address is an IPv6 address and it is a permanent address,
+ // mark it as tentative so it goes through the DAD process.
+ if isIPv6Unicast && kind == permanent {
+ kind = permanentTentative
+ }
+
ref := &referencedNetworkEndpoint{
refs: 1,
ep: ep,
@@ -338,7 +420,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
// If we are adding an IPv6 unicast address, join the solicited-node
// multicast address.
- if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) {
+ if isIPv6Unicast {
snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
return nil, err
@@ -347,17 +429,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
n.endpoints[id] = ref
- l, ok := n.primary[protocolAddress.Protocol]
- if !ok {
- l = &ilist.List{}
- n.primary[protocolAddress.Protocol] = l
- }
+ n.insertPrimaryEndpointLocked(ref, peb)
- switch peb {
- case CanBePrimaryEndpoint:
- l.PushBack(ref)
- case FirstPrimaryEndpoint:
- l.PushFront(ref)
+ // If we are adding a tentative IPv6 address, start DAD.
+ if isIPv6Unicast && kind == permanentTentative {
+ if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
+ return nil, err
+ }
}
return ref, nil
@@ -382,10 +460,12 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
- // Don't include expired or temporary endpoints to avoid confusion and
- // prevent the caller from using those.
+ // Don't include tentative, expired or temporary endpoints to
+ // avoid confusion and prevent the caller from using those.
switch ref.getKind() {
- case permanentExpired, temporary:
+ case permanentTentative, permanentExpired, temporary:
+ // TODO(b/140898488): Should tentative addresses be
+ // returned?
continue
}
addrs = append(addrs, tcpip.ProtocolAddress{
@@ -406,12 +486,12 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for proto, list := range n.primary {
- for e := list.Front(); e != nil; e = e.Next() {
- ref := e.(*referencedNetworkEndpoint)
- // Don't include expired or tempory endpoints to avoid confusion and
- // prevent the caller from using those.
+ for _, ref := range list {
+ // Don't include tentative, expired or tempory endpoints
+ // to avoid confusion and prevent the caller from using
+ // those.
switch ref.getKind() {
- case permanentExpired, temporary:
+ case permanentTentative, permanentExpired, temporary:
continue
}
@@ -471,6 +551,19 @@ func (n *NIC) AddressRanges() []tcpip.Subnet {
return append(sns, n.addressRanges...)
}
+// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
+// by peb.
+//
+// n MUST be locked.
+func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ n.primary[r.protocol] = append(n.primary[r.protocol], r)
+ case FirstPrimaryEndpoint:
+ n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...)
+ }
+}
+
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
@@ -488,9 +581,12 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
}
delete(n.endpoints, id)
- wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front()
- if wasInList {
- n.primary[r.protocol].Remove(r)
+ refs := n.primary[r.protocol]
+ for i, ref := range refs {
+ if ref == r {
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ break
+ }
}
r.ep.Close()
@@ -504,10 +600,22 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
r, ok := n.endpoints[NetworkEndpointID{addr}]
- if !ok || r.getKind() != permanent {
+ if !ok {
return tcpip.ErrBadLocalAddress
}
+ kind := r.getKind()
+ if kind != permanent && kind != permanentTentative {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
+
+ // If we are removing a tentative IPv6 unicast address, stop DAD.
+ if isIPv6Unicast && kind == permanentTentative {
+ n.ndp.stopDuplicateAddressDetection(addr)
+ }
+
r.setKind(permanentExpired)
if !r.decRefLocked() {
// The endpoint still has references to it.
@@ -518,7 +626,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing an IPv6 unicast address, leave the solicited-node
// multicast address.
- if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) {
+ if isIPv6Unicast {
snmc := header.SolicitedNodeAddr(addr)
if err := n.leaveGroupLocked(snmc); err != nil {
return err
@@ -548,6 +656,11 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
// exists yet. Otherwise it just increments its count. n MUST be locked before
// joinGroupLocked is called.
func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ // TODO(b/143102137): When implementing MLD, make sure MLD packets are
+ // not sent unless a valid link-local address is available for use on n
+ // as an MLD packet's source address must be a link-local address as
+ // outlined in RFC 3810 section 5.
+
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
if joins == 0 {
@@ -611,7 +724,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
n.stats.Rx.Packets.Increment()
n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size()))
@@ -621,6 +734,26 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
return
}
+ // If no local link layer address is provided, assume it was sent
+ // directly to this NIC.
+ if local == "" {
+ local = n.linkEP.LinkAddress()
+ }
+
+ // Are any packet sockets listening for this network protocol?
+ n.mu.RLock()
+ packetEPs := n.packetEPs[protocol]
+ // Check whether there are packet sockets listening for every protocol.
+ // If we received a packet with protocol EthernetProtocolAll, then the
+ // previous for loop will have handled it.
+ if protocol != header.EthernetProtocolAll {
+ packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...)
+ }
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ ep.HandlePacket(n.id, local, protocol, vv, linkHeader)
+ }
+
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
@@ -632,8 +765,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
src, dst := netProto.ParseAddresses(vv.First())
- n.stack.AddLinkAddress(n.id, src, remote)
-
if ref := n.getRef(protocol, dst); ref != nil {
handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
return
@@ -682,7 +813,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
return
}
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ // If a packet socket handled the packet, don't treat it as invalid.
+ if len(packetEPs) == 0 {
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ }
}
// DeliverTransportPacket delivers the packets to the appropriate transport
@@ -769,14 +903,58 @@ func (n *NIC) Stack() *Stack {
return n.stack
}
+// isAddrTentative returns true if addr is tentative on n.
+//
+// Note that if addr is not associated with n, then this function will return
+// false. It will only return true if the address is associated with the NIC
+// AND it is tentative.
+func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
+ ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ return false
+ }
+
+ return ref.getKind() == permanentTentative
+}
+
+// dupTentativeAddrDetected attempts to inform n that a tentative addr
+// is a duplicate on a link.
+//
+// dupTentativeAddrDetected will delete the tentative address if it exists.
+func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ return tcpip.ErrBadAddress
+ }
+
+ if ref.getKind() != permanentTentative {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ return n.removePermanentAddressLocked(addr)
+}
+
type networkEndpointKind int32
const (
+ // A permanentTentative endpoint is a permanent address that is not yet
+ // considered to be fully bound to an interface in the traditional
+ // sense. That is, the address is associated with a NIC, but packets
+ // destined to the address MUST NOT be accepted and MUST be silently
+ // dropped, and the address MUST NOT be used as a source address for
+ // outgoing packets. For IPv6, addresses will be of this kind until
+ // NDP's Duplicate Address Detection has resolved, or be deleted if
+ // the process results in detecting a duplicate address.
+ permanentTentative networkEndpointKind = iota
+
// A permanent endpoint is created by adding a permanent address (vs. a
// temporary one) to the NIC. Its reference count is biased by 1 to avoid
// removal when no route holds a reference to it. It is removed by explicitly
// removing the permanent address from the NIC.
- permanent networkEndpointKind = iota
+ permanent
// An expired permanent endoint is a permanent endoint that had its address
// removed from the NIC, and it is waiting to be removed once no more routes
@@ -794,8 +972,37 @@ const (
temporary
)
+func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.packetEPs[netProto]
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+ n.packetEPs[netProto] = append(eps, ep)
+
+ return nil
+}
+
+func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.packetEPs[netProto]
+ if !ok {
+ return
+ }
+
+ for i, epOther := range eps {
+ if epOther == ep {
+ n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
+ return
+ }
+ }
+}
+
type referencedNetworkEndpoint struct {
- ilist.Entry
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 9d6157f22..0869fb084 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -71,8 +71,8 @@ type TransportEndpoint interface {
// RawTransportEndpoint is the interface that needs to be implemented by raw
// transport protocol endpoints. RawTransportEndpoints receive the entire
-// packet - including the link, network, and transport headers - as delivered
-// to netstack.
+// packet - including the network and transport headers - as delivered to
+// netstack.
type RawTransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint. The packet contains all data from the link
@@ -80,6 +80,22 @@ type RawTransportEndpoint interface {
HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
}
+// PacketEndpoint is the interface that needs to be implemented by packet
+// transport protocol endpoints. These endpoints receive link layer headers in
+// addition to whatever they contain (usually network and transport layer
+// headers and a payload).
+type PacketEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive that
+ // match the endpoint.
+ //
+ // Implementers should treat packet as immutable and should copy it
+ // before before modification.
+ //
+ // linkHeader may have a length of 0, in which case the PacketEndpoint
+ // should construct its own ethernet header for applications.
+ HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View)
+}
+
// TransportProtocol is the interface that needs to be implemented by transport
// protocols (e.g., tcp, udp) that want to be part of the networking stack.
type TransportProtocol interface {
@@ -185,6 +201,10 @@ type NetworkEndpoint interface {
// protocol.
WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error
+ // WritePackets writes packets to the given destination address and
+ // protocol.
+ WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error)
+
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error
@@ -242,9 +262,10 @@ type NetworkProtocol interface {
// packets to the appropriate network endpoint after it has been handled by
// the data link layer.
type NetworkDispatcher interface {
- // DeliverNetworkPacket finds the appropriate network protocol
- // endpoint and hands the packet over for further processing.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+ // DeliverNetworkPacket finds the appropriate network protocol endpoint
+ // and hands the packet over for further processing. linkHeader may have
+ // length 0 when the caller does not have ethernet data.
+ DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -266,7 +287,11 @@ const (
CapabilitySaveRestore
CapabilityDisconnectOk
CapabilityLoopback
- CapabilityGSO
+ CapabilityHardwareGSO
+
+ // CapabilitySoftwareGSO indicates the link endpoint supports of sending
+ // multiple packets using a single call (LinkEndpoint.WritePackets).
+ CapabilitySoftwareGSO
)
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
@@ -301,6 +326,18 @@ type LinkEndpoint interface {
// r.LocalLinkAddress if it is provided.
WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+ // WritePackets writes packets with the given protocol through the
+ // given route.
+ //
+ // Right now, WritePackets is used only when the software segmentation
+ // offload is enabled. If it will be used for something else, it may
+ // require to change syscall filters.
+ WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link. The packet
+ // should already have an ethernet header.
+ WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error
+
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
Attach(dispatcher NetworkDispatcher)
@@ -324,13 +361,14 @@ type LinkEndpoint interface {
type InjectableLinkEndpoint interface {
LinkEndpoint
- // Inject injects an inbound packet.
- Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+ // InjectInbound injects an inbound packet.
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
- // WriteRawPacket writes a fully formed outbound packet directly to the link.
+ // InjectOutbound writes a fully formed outbound packet directly to the
+ // link.
//
// dest is used by endpoints with multiple raw destinations.
- WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error
+ InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error
}
// A LinkAddressResolver is an extension to a NetworkProtocol that
@@ -379,11 +417,16 @@ type LinkAddressCache interface {
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// UnassociatedEndpointFactory produces endpoints for writing packets not
-// associated with a particular transport protocol. Such endpoints can be used
-// to write arbitrary packets that include the IP header.
-type UnassociatedEndpointFactory interface {
- NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+// RawFactory produces endpoints for writing various types of raw packets.
+type RawFactory interface {
+ // NewUnassociatedEndpoint produces endpoints for writing packets not
+ // associated with a particular transport protocol. Such endpoints can
+ // be used to write arbitrary packets that include the network header.
+ NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // NewPacketEndpoint produces endpoints for reading and writing packets
+ // that include network and (when cooked is false) link layer headers.
+ NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
// GSOType is the type of GSO segments.
@@ -394,8 +437,14 @@ type GSOType int
// Types of gso segments.
const (
GSONone GSOType = iota
+
+ // Hardware GSO types:
GSOTCPv4
GSOTCPv6
+
+ // GSOSW is used for software GSO segments which have to be sent by
+ // endpoint.WritePackets.
+ GSOSW
)
// GSO contains generic segmentation offload properties.
@@ -423,3 +472,7 @@ type GSOEndpoint interface {
// GSOMaxSize returns the maximum GSO packet size.
GSOMaxSize() uint32
}
+
+// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
+// This isn't a hard limit, because it is never set into packet headers.
+const SoftwareGSOMaxSize = (1 << 16)
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index e72373964..1a0a51b57 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -47,8 +47,8 @@ type Route struct {
// starts.
ref *referencedNetworkEndpoint
- // loop controls where WritePacket should send packets.
- loop PacketLooping
+ // Loop controls where WritePacket should send packets.
+ Loop PacketLooping
}
// makeRoute initializes a new route. It takes ownership of the provided
@@ -69,7 +69,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
LocalLinkAddress: localLinkAddr,
RemoteAddress: remoteAddr,
ref: ref,
- loop: loop,
+ Loop: loop,
}
}
@@ -159,7 +159,7 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
return tcpip.ErrInvalidEndpointState
}
- err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.loop)
+ err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.Loop)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
@@ -169,6 +169,44 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
return err
}
+// PacketDescriptor is a packet descriptor which contains a packet header and
+// offset and size of packet data in a payload view.
+type PacketDescriptor struct {
+ Hdr buffer.Prependable
+ Off int
+ Size int
+}
+
+// NewPacketDescriptors allocates a set of packet descriptors.
+func NewPacketDescriptors(n int, hdrSize int) []PacketDescriptor {
+ buf := make([]byte, n*hdrSize)
+ hdrs := make([]PacketDescriptor, n)
+ for i := range hdrs {
+ hdrs[i].Hdr = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize])
+ }
+ return hdrs
+}
+
+// WritePackets writes the set of packets through the given route.
+func (r *Route) WritePackets(gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams) (int, *tcpip.Error) {
+ if !r.ref.isValidForOutgoing() {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ n, err := r.ref.ep.WritePackets(r, gso, hdrs, payload, params, r.Loop)
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(hdrs) - n))
+ }
+ r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
+ payloadSize := 0
+ for i := 0; i < n; i++ {
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdrs[i].Hdr.UsedLength()))
+ payloadSize += hdrs[i].Size
+ }
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize))
+ return n, err
+}
+
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
@@ -176,7 +214,7 @@ func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.
return tcpip.ErrInvalidEndpointState
}
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.Loop); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index f67975525..5ea432a24 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -351,10 +351,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
- // unassociatedFactory creates unassociated endpoints. If nil, raw
- // endpoints are disabled. It is set during Stack creation and is
- // immutable.
- unassociatedFactory UnassociatedEndpointFactory
+ // rawFactory creates raw endpoints. If nil, raw endpoints are
+ // disabled. It is set during Stack creation and is immutable.
+ rawFactory RawFactory
demux *transportDemuxer
@@ -399,6 +398,18 @@ type Stack struct {
//
// TODO(gvisor.dev/issue/940): S/R this field.
portSeed uint32
+
+ // ndpConfigs is the NDP configurations used by interfaces.
+ ndpConfigs NDPConfigurations
+
+ // autoGenIPv6LinkLocal determines whether or not the stack will attempt
+ // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
}
// Options contains optional Stack configuration.
@@ -422,9 +433,32 @@ type Options struct {
// stack (false).
HandleLocal bool
- // UnassociatedFactory produces unassociated endpoints raw endpoints.
- // Raw endpoints are enabled only if this is non-nil.
- UnassociatedFactory UnassociatedEndpointFactory
+ // NDPConfigs is the NDP configurations used by interfaces.
+ //
+ // By default, NDPConfigs will have a zero value for its
+ // DupAddrDetectTransmits field, implying that DAD will not be performed
+ // before assigning an address to a NIC.
+ NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determins whether or not the stack will attempt
+ // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // Note, setting this to true does not mean that a link-local address
+ // will be assigned right away, or at all. If Duplicate Address
+ // Detection is enabled, an address will only be assigned if it
+ // successfully resolves. If it fails, no further attempt will be made
+ // to auto-generate an IPv6 link-local address.
+ //
+ // The generated link-local address will follow RFC 4291 Appendix A
+ // guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // RawFactory produces raw endpoints. Raw endpoints are enabled only if
+ // this is non-nil.
+ RawFactory RawFactory
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -458,6 +492,9 @@ func (*TransportEndpointInfo) IsEndpointInfo() {}
// New allocates a new networking stack with only the requested networking and
// transport protocols configured with default options.
//
+// Note, NDPConfigurations will be fixed before being used by the Stack. That
+// is, if an invalid value was provided, it will be reset to the default value.
+//
// Protocol options can be changed by calling the
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
@@ -468,18 +505,24 @@ func New(opts Options) *Stack {
clock = &tcpip.StdClock{}
}
+ // Make sure opts.NDPConfigs contains valid values only.
+ opts.NDPConfigs.validate()
+
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
- nics: make(map[tcpip.NICID]*NIC),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- icmpRateLimiter: NewICMPRateLimiter(),
- portSeed: generateRandUint32(),
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ portSeed: generateRandUint32(),
+ ndpConfigs: opts.NDPConfigs,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ ndpDisp: opts.NDPDisp,
}
// Add specified network protocols.
@@ -497,8 +540,8 @@ func New(opts Options) *Stack {
}
}
- // Add the factory for unassociated endpoints, if present.
- s.unassociatedFactory = opts.UnassociatedFactory
+ // Add the factory for raw endpoints, if present.
+ s.rawFactory = opts.RawFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -633,12 +676,12 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if s.unassociatedFactory == nil {
+ if s.rawFactory == nil {
return nil, tcpip.ErrNotPermitted
}
if !associated {
- return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue)
+ return s.rawFactory.NewUnassociatedEndpoint(s, network, transport, waiterQueue)
}
t, ok := s.transportProtocols[transport]
@@ -649,6 +692,16 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
return t.proto.NewRawEndpoint(s, network, waiterQueue)
}
+// NewPacketEndpoint creates a new packet endpoint listening for the given
+// netProto.
+func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if s.rawFactory == nil {
+ return nil, tcpip.ErrNotPermitted
+ }
+
+ return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
+}
+
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
@@ -1118,6 +1171,109 @@ func (s *Stack) Resume() {
}
}
+// RegisterPacketEndpoint registers ep with the stack, causing it to receive
+// all traffic of the specified netProto on the given NIC. If nicID is 0, it
+// receives traffic from every NIC.
+func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If no NIC is specified, capture on all devices.
+ if nicID == 0 {
+ // Register with each NIC.
+ for _, nic := range s.nics {
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ s.unregisterPacketEndpointLocked(0, netProto, ep)
+ return err
+ }
+ }
+ return nil
+ }
+
+ // Capture on a specific device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnregisterPacketEndpoint unregisters ep for packets of the specified
+// netProto from the specified NIC. If nicID is 0, ep is unregistered from all
+// NICs.
+func (s *Stack) UnregisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.unregisterPacketEndpointLocked(nicID, netProto, ep)
+}
+
+func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ // If no NIC is specified, unregister on all devices.
+ if nicID == 0 {
+ // Unregister with each NIC.
+ for _, nic := range s.nics {
+ nic.unregisterPacketEndpoint(netProto, ep)
+ }
+ return
+ }
+
+ // Unregister in a single device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return
+ }
+ nic.unregisterPacketEndpoint(netProto, ep)
+}
+
+// WritePacket writes data directly to the specified NIC. It adds an ethernet
+// header based on the arguments.
+func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicid]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ // Add our own fake ethernet header.
+ ethFields := header.EthernetFields{
+ SrcAddr: nic.linkEP.LinkAddress(),
+ DstAddr: dst,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ ethHeader := buffer.View(fakeHeader).ToVectorisedView()
+ ethHeader.Append(payload)
+
+ if err := nic.linkEP.WriteRawPacket(ethHeader); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WriteRawPacket writes data directly to the specified NIC without adding any
+// headers.
+func (s *Stack) WriteRawPacket(nicid tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicid]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ return err
+ }
+
+ return nil
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1238,6 +1394,37 @@ func (s *Stack) AllowICMPMessage() bool {
return s.icmpRateLimiter.Allow()
}
+// IsAddrTentative returns true if addr is tentative on the NIC with ID id.
+//
+// Note that if addr is not associated with a NIC with id ID, then this
+// function will return false. It will only return true if the address is
+// associated with the NIC AND it is tentative.
+func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return false, tcpip.ErrUnknownNICID
+ }
+
+ return nic.isAddrTentative(addr), nil
+}
+
+// DupTentativeAddrDetected attempts to inform the NIC with ID id that a
+// tentative addr on it is a duplicate on a link.
+func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.dupTentativeAddrDetected(addr)
+}
+
// PortSeed returns a 32 bit value that can be used as a seed value for port
// picking.
//
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 10fd1065f..9dae853d0 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -24,11 +24,14 @@ import (
"sort"
"strings"
"testing"
+ "time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -144,6 +147,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber)
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -1864,3 +1872,297 @@ func TestNICForwarding(t *testing.T) {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
+
+// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses
+// (or lack there-of if disabled (default)). Note, DAD will be disabled in
+// these tests.
+func TestNICAutoGenAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ shouldGen bool
+ }{
+ {
+ "Disabled",
+ false,
+ linkAddr1,
+ false,
+ },
+ {
+ "Enabled",
+ true,
+ linkAddr1,
+ true,
+ },
+ {
+ "Nil MAC",
+ true,
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty MAC",
+ true,
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "Invalid MAC",
+ true,
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Multicast MAC",
+ true,
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ false,
+ },
+ {
+ "Unspecified MAC",
+ true,
+ tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ }
+
+ if test.autoGen {
+ // Only set opts.AutoGenIPv6LinkLocal when
+ // test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by
+ // default.
+ opts.AutoGenIPv6LinkLocal = true
+ }
+
+ e := channel.New(10, 1280, test.linkAddr)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+
+ if test.shouldGen {
+ // Should have auto-generated an address and
+ // resolved immediately (DAD is disabled).
+ if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+ } else {
+ // Should not have auto-generated an address.
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ }
+ })
+ }
+}
+
+// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
+// link-local addresses will only be assigned after the DAD process resolves.
+func TestNICAutoGenAddrDoesDAD(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Address should not be considered bound to the
+ // NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ linkLocalAddr := header.LinkLocalAddr(linkAddr1)
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // We should get a resolution event after 1s (default time to
+ // resolve as per default NDP configurations). Waiting for that
+ // resolution time + an extra 1s without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != linkLocalAddr {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+}
+
+// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected
+// when an address's kind gets "promoted" to permanent from permanentExpired.
+func TestNewPEBOnPromotionToPermanent(t *testing.T) {
+ pebs := []stack.PrimaryEndpointBehavior{
+ stack.NeverPrimaryEndpoint,
+ stack.CanBePrimaryEndpoint,
+ stack.FirstPrimaryEndpoint,
+ }
+
+ for _, pi := range pebs {
+ for _, ps := range pebs {
+ t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ // Add a permanent address with initial
+ // PrimaryEndpointBehavior (peb), pi. If pi is
+ // NeverPrimaryEndpoint, the address should not
+ // be returned by a call to GetMainNICAddress;
+ // else, it should.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ addr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("s.GetMainNICAddress failed:", err)
+ }
+ if pi == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatalf("NewSubnet failed:", err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // Take a route through the address so its ref
+ // count gets incremented and does not actually
+ // get deleted when RemoveAddress is called
+ // below. This is because we want to test that a
+ // new peb is respected when an address gets
+ // "promoted" to permanent from a
+ // permanentExpired kind.
+ r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(1, "\x01"); err != nil {
+ t.Fatalf("RemoveAddress failed:", err)
+ }
+
+ //
+ // At this point, the address should still be
+ // known by the NIC, but have its
+ // kind = permanentExpired.
+ //
+
+ // Add some other address with peb set to
+ // FirstPrimaryEndpoint.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+
+ }
+
+ // Add back the address we removed earlier and
+ // make sure the new peb was respected.
+ // (The address should just be promoted now).
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ var primaryAddrs []tcpip.Address
+ for _, pa := range s.NICInfo()[1].ProtocolAddresses {
+ primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address)
+ }
+ var expectedList []tcpip.Address
+ switch ps {
+ case stack.FirstPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x01",
+ "\x03",
+ }
+ case stack.CanBePrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ "\x01",
+ }
+ case stack.NeverPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ }
+ }
+ if !cmp.Equal(primaryAddrs, expectedList) {
+ t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList)
+ }
+
+ // Once we remove the other address, if the new
+ // peb, ps, was NeverPrimaryEndpoint, no address
+ // should be returned by a call to
+ // GetMainNICAddress; else, our original address
+ // should be returned.
+ if err := s.RemoveAddress(1, "\x03"); err != nil {
+ t.Fatalf("RemoveAddress failed:", err)
+ }
+ addr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("s.GetMainNICAddress failed:", err)
+ }
+ if ps == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else {
+ if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 92267ce4d..97a1aec4b 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -465,7 +465,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
- return nil
+ return tcpip.ErrNotSupported
}
eps.mu.Lock()
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 26f338d8d..353ecd49b 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -255,7 +255,7 @@ type FullAddress struct {
// This may not be used by all endpoint types.
NIC NICID
- // Addr is the network address.
+ // Addr is the network or link layer address.
Addr Address
// Port is the transport port.
@@ -1100,15 +1100,12 @@ func (*TransportEndpointStats) IsEndpointStats() {}
func fillIn(v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
- switch v.Kind() {
- case reflect.Ptr:
- if s := v.Addr().Interface().(**StatCounter); *s == nil {
- *s = &StatCounter{}
+ if s, ok := v.Addr().Interface().(**StatCounter); ok {
+ if *s == nil {
+ *s = new(StatCounter)
}
- case reflect.Struct:
+ } else {
fillIn(v)
- default:
- panic(fmt.Sprintf("unexpected type %s", v.Type()))
}
}
}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
new file mode 100644
index 000000000..8ea2e6ee5
--- /dev/null
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -0,0 +1,46 @@
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "packet_list",
+ out = "packet_list.go",
+ package = "packet",
+ prefix = "packet",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*packet",
+ "Linker": "*packet",
+ },
+)
+
+go_library(
+ name = "packet",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "packet_list.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet",
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sleep",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "packet_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
new file mode 100644
index 000000000..73cdaa265
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -0,0 +1,363 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packet provides the implementation of packet sockets (see
+// packet(7)). Packet sockets allow applications to:
+//
+// * manually write and inspect link, network, and transport headers
+// * receive all traffic of a given network protocol, or all protocols
+//
+// Packet sockets are similar to raw sockets, but provide even more power to
+// users, letting them effectively talk directly to the network device.
+//
+// Packet sockets skip the input and output iptables chains.
+package packet
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type packet struct {
+ packetEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is pre-allocated space to back data. As long as the packet is
+ // made up of fewer than 8 buffer.Views, no extra allocation is
+ // necessary to store packet data.
+ views [8]buffer.View `state:"nosave"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
+// to have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ stack.TransportEndpointInfo
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+ cooked bool
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+}
+
+// NewEndpoint returns a new packet endpoint.
+func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ },
+ cooked: cooked,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+
+ if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
+ return nil, err
+ }
+ return ep, nil
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (ep *endpoint) Close() {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return
+ }
+
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+
+ // Clear the receive list.
+ ep.rcvClosed = true
+ ep.rcvBufSize = 0
+ for !ep.rcvList.Empty() {
+ ep.rcvList.Remove(ep.rcvList.Front())
+ }
+
+ ep.closed = true
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (ep *endpoint) ModerateRecvBuf(copied int) {}
+
+// IPTables implements tcpip.Endpoint.IPTables.
+func (ep *endpoint) IPTables() (iptables.IPTables, error) {
+ return ep.stack.IPTables(), nil
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ ep.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if ep.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if ep.rcvClosed {
+ ep.stats.ReadErrors.ReadClosed.Increment()
+ err = tcpip.ErrClosedForReceive
+ }
+ ep.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ packet := ep.rcvList.Front()
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = packet.senderAddr
+ }
+
+ return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+}
+
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // TODO(b/129292371): Implement.
+ return 0, nil, tcpip.ErrInvalidOptionValue
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
+// disconnected, and this function always returns tpcip.ErrNotSupported.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
+// connected, and this function always returnes tcpip.ErrNotSupported.
+func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
+// with Shutdown, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
+// Listen, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
+// Accept, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ // TODO(gvisor.dev/issue/173): Add Bind support.
+
+ // "By default, all packets of the specified protocol type are passed
+ // to a packet socket. To get packets only from a specific interface
+ // use bind(2) specifying an address in a struct sockaddr_ll to bind
+ // the packet socket to an interface. Fields used for binding are
+ // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
+ // - packet(7).
+
+ return tcpip.ErrNotSupported
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() || ep.rcvClosed {
+ result |= waiter.EventIn
+ }
+ ep.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
+// used with SetSockOpt, and this function always returns
+// tcpip.ErrNotSupported.
+func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// HandlePacket implements stack.PacketEndpoint.HandlePacket.
+func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, ethHeader buffer.View) {
+ ep.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if ep.rcvClosed {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if ep.rcvBufSize >= ep.rcvBufSizeMax {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ wasEmpty := ep.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ var packet packet
+ // TODO(b/129292371): Return network protocol.
+ if len(ethHeader) > 0 {
+ // Get info directly from the ethernet header.
+ hdr := header.Ethernet(ethHeader)
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicid,
+ Addr: tcpip.Address(hdr.SourceAddress()),
+ }
+ } else {
+ // Guess the would-be ethernet header.
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicid,
+ Addr: tcpip.Address(localAddr),
+ }
+ }
+
+ if ep.cooked {
+ // Cooked packets can simply be queued.
+ packet.data = vv.Clone(packet.views[:])
+ } else {
+ // Raw packets need their ethernet headers prepended before
+ // queueing.
+ if len(ethHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ ethHeader = buffer.View(fakeHeader)
+ }
+ combinedVV := buffer.View(ethHeader).ToVectorisedView()
+ combinedVV.Append(vv)
+ packet.data = combinedVV.Clone(packet.views[:])
+ }
+ packet.timestampNS = ep.stack.NowNanoseconds()
+
+ ep.rcvList.PushBack(&packet)
+ ep.rcvBufSize += packet.data.Size()
+
+ ep.rcvMu.Unlock()
+ ep.stats.PacketsReceived.Increment()
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// State implements socket.Socket.State.
+func (ep *endpoint) State() uint32 {
+ return 0
+}
+
+// Info returns a copy of the endpoint info.
+func (ep *endpoint) Info() tcpip.EndpointInfo {
+ ep.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := ep.TransportEndpointInfo
+ ep.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (ep *endpoint) Stats() tcpip.EndpointStats {
+ return &ep.stats
+}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
new file mode 100644
index 000000000..9b88f17e4
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package packet
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves packet.data field.
+func (p *packet) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads packet.data field.
+func (p *packet) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ // StackFromEnv is a stack used specifically for save/restore.
+ ep.stack = stack.StackFromEnv
+
+ // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
+ if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index fba598d51..4af49218c 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -4,14 +4,14 @@ load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
go_template_instance(
- name = "packet_list",
- out = "packet_list.go",
+ name = "raw_packet_list",
+ out = "raw_packet_list.go",
package = "raw",
- prefix = "packet",
+ prefix = "rawPacket",
template = "//pkg/ilist:generic_list",
types = {
- "Element": "*packet",
- "Linker": "*packet",
+ "Element": "*rawPacket",
+ "Linker": "*rawPacket",
},
)
@@ -20,8 +20,8 @@ go_library(
srcs = [
"endpoint.go",
"endpoint_state.go",
- "packet_list.go",
"protocol.go",
+ "raw_packet_list.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
@@ -34,6 +34,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
)
@@ -41,7 +42,7 @@ go_library(
filegroup(
name = "autogen",
srcs = [
- "packet_list.go",
+ "raw_packet_list.go",
],
visibility = ["//:sandbox"],
)
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b4c660859..308f10d24 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -17,8 +17,7 @@
//
// * manually write and inspect transport layer headers and payloads
// * receive all traffic of a given transport protocol (e.g. ICMP or UDP)
-// * optionally write and inspect network layer and link layer headers for
-// packets
+// * optionally write and inspect network layer headers of packets
//
// Raw sockets don't have any notion of ports, and incoming packets are
// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
@@ -38,8 +37,8 @@ import (
)
// +stateify savable
-type packet struct {
- packetEntry
+type rawPacket struct {
+ rawPacketEntry
// data holds the actual packet data, including any headers and
// payload.
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
@@ -72,7 +71,7 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ rcvList rawPacketList
rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
rcvClosed bool
@@ -90,7 +89,6 @@ type endpoint struct {
}
// NewEndpoint returns a raw endpoint for the given protocols.
-// TODO(b/129292371): IP_HDRINCL and AF_PACKET.
func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
}
@@ -187,17 +185,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
return buffer.View{}, tcpip.ControlMessages{}, err
}
- packet := e.rcvList.Front()
- e.rcvList.Remove(packet)
- e.rcvBufSize -= packet.data.Size()
+ pkt := e.rcvList.Front()
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
e.rcvMu.Unlock()
if addr != nil {
- *addr = packet.senderAddr
+ *addr = pkt.senderAddr
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+ return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -602,7 +600,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- packet := &packet{
+ pkt := &rawPacket{
senderAddr: tcpip.FullAddress{
NIC: route.NICID(),
Addr: route.RemoteAddress,
@@ -611,11 +609,11 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu
combinedVV := netHeader.ToVectorisedView()
combinedVV.Append(vv)
- packet.data = combinedVV.Clone(packet.views[:])
- packet.timestampNS = e.stack.NowNanoseconds()
+ pkt.data = combinedVV.Clone(pkt.views[:])
+ pkt.timestampNS = e.stack.NowNanoseconds()
- e.rcvList.PushBack(packet)
- e.rcvBufSize += packet.data.Size()
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += pkt.data.Size()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index a6c7cc43a..33bfb56cd 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -20,15 +20,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// saveData saves packet.data field.
-func (p *packet) saveData() buffer.VectorisedView {
+// saveData saves rawPacket.data field.
+func (p *rawPacket) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
// which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
-// loadData loads packet.data field.
-func (p *packet) loadData(data buffer.VectorisedView) {
+// loadData loads rawPacket.data field.
+func (p *rawPacket) loadData(data buffer.VectorisedView) {
// NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
// here because data.views is not guaranteed to be loaded by now. Plus,
// data.views will be allocated anyway so there really is little point
@@ -86,7 +86,9 @@ func (ep *endpoint) Resume(s *stack.Stack) {
}
}
- if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
- panic(err)
+ if ep.associated {
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
+ panic(err)
+ }
}
}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index a2512d666..f30aa2a4a 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -17,13 +17,19 @@ package raw
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/packet"
"gvisor.dev/gvisor/pkg/waiter"
)
-// EndpointFactory implements stack.UnassociatedEndpointFactory.
+// EndpointFactory implements stack.RawFactory.
type EndpointFactory struct{}
-// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint.
+func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
+
+// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint.
+func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return packet.NewEndpoint(stack, cooked, netProto, waiterQueue)
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index aed70e06f..f1dbc6f91 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -44,6 +44,7 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index b724d02bb..8db1cc028 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -607,17 +607,11 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu
return nil
}
-// sendTCP sends a TCP segment with the provided options via the provided
-// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDescriptor, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) {
optLen := len(opts)
- // Allocate a buffer for the TCP header.
- hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
-
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
- }
-
+ hdr := &d.Hdr
+ packetSize := d.Size
+ off := d.Off
// Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
tcp.Encode(&header.TCPFields{
@@ -631,7 +625,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
})
copy(tcp[header.TCPMinimumSize:], opts)
- length := uint16(hdr.UsedLength() + data.Size())
+ length := uint16(hdr.UsedLength() + packetSize)
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
// Only calculate the checksum if offloading isn't supported.
if gso != nil && gso.NeedsCsum {
@@ -641,14 +635,71 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
} else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVV(data, xsum)
+ xsum = header.ChecksumVVWithOffset(data, xsum, off, packetSize)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
+}
+
+func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ mss := int(gso.MSS)
+ n := (data.Size() + mss - 1) / mss
+
+ hdrs := stack.NewPacketDescriptors(n, header.TCPMinimumSize+int(r.MaxHeaderLength())+optLen)
+
+ size := data.Size()
+ off := 0
+ for i := 0; i < n; i++ {
+ packetSize := mss
+ if packetSize > size {
+ packetSize = size
+ }
+ size -= packetSize
+ hdrs[i].Off = off
+ hdrs[i].Size = packetSize
+ buildTCPHdr(r, id, &hdrs[i], data, flags, seq, ack, rcvWnd, opts, gso)
+ off += packetSize
+ seq = seq.Add(seqnum.Size(packetSize))
+ }
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ sent, err := r.WritePackets(gso, hdrs, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos})
+ if err != nil {
+ r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
+ }
+ r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent))
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
+ return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso)
+ }
+
+ d := &stack.PacketDescriptor{
+ Hdr: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
+ Off: 0,
+ Size: data.Size(),
+ }
+ buildTCPHdr(r, id, d, data, flags, seq, ack, rcvWnd, opts, gso)
+
if ttl == 0 {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ if err := r.WritePacket(gso, d.Hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 31a22c1eb..c6bc5528c 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2328,11 +2328,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
return s
}
-func (e *endpoint) initGSO() {
- if e.route.Capabilities()&stack.CapabilityGSO == 0 {
- return
- }
-
+func (e *endpoint) initHardwareGSO() {
gso := &stack.GSO{}
switch e.route.NetProto {
case header.IPv4ProtocolNumber:
@@ -2350,6 +2346,18 @@ func (e *endpoint) initGSO() {
e.gso = gso
}
+func (e *endpoint) initGSO() {
+ if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ e.initHardwareGSO()
+ } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ e.gso = &stack.GSO{
+ MaxSize: e.route.GSOMaxSize(),
+ Type: stack.GSOSW,
+ NeedsCsum: false,
+ }
+ }
+}
+
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 6e87245b7..91c8487f3 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -140,7 +140,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
- TransProto: header.TCPProtocolNumber,
+ TransProto: header.UDPProtocolNumber,
},
waiterQueue: waiterQueue,
// RFC 1075 section 5.4 recommends a TTL of 1 for membership
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index de026880f..5c3358a5e 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -121,8 +121,15 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
- payload.Append(vv)
+ // The buffers used by vv and netHeader may be used elsewhere
+ // in the system. For example, a raw or packet socket may use
+ // what UDP considers an unreachable destination. Thus we deep
+ // copy vv and netHeader to prevent multiple ownership and SR
+ // errors.
+ newNetHeader := make(buffer.View, len(netHeader))
+ copy(newNetHeader, netHeader)
+ payload := buffer.NewVectorisedView(len(newNetHeader), []buffer.View{newNetHeader})
+ payload.Append(vv.ToView().ToVectorisedView())
payload.CapLength(payloadLen)
hdr := buffer.NewPrependable(headerLen)
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
index 38278d0a2..72a33534f 100644
--- a/runsc/boot/config.go
+++ b/runsc/boot/config.go
@@ -178,8 +178,11 @@ type Config struct {
// capabilities.
EnableRaw bool
- // GSO indicates that generic segmentation offload is enabled.
- GSO bool
+ // HardwareGSO indicates that hardware segmentation offload is enabled.
+ HardwareGSO bool
+
+ // SoftwareGSO indicates that software segmentation offload is enabled.
+ SoftwareGSO bool
// LogPackets indicates that all network packets should be logged.
LogPackets bool
@@ -231,6 +234,10 @@ type Config struct {
// ReferenceLeakMode sets reference leak check mode
ReferenceLeakMode refs.LeakMode
+ // OverlayfsStaleRead causes cached FDs to reopen after a file is opened for
+ // write to workaround overlayfs limitation on kernels before 4.19.
+ OverlayfsStaleRead bool
+
// TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
// tests. It allows runsc to start the sandbox process as the current
// user, and without chrooting the sandbox process. This can be
@@ -271,6 +278,9 @@ func (c *Config) ToFlags() []string {
"--rootless=" + strconv.FormatBool(c.Rootless),
"--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr),
"--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode),
+ "--gso=" + strconv.FormatBool(c.HardwareGSO),
+ "--software-gso=" + strconv.FormatBool(c.SoftwareGSO),
+ "--overlayfs-stale-read=" + strconv.FormatBool(c.OverlayfsStaleRead),
}
// Only include these if set since it is never to be used by users.
if c.TestOnlyAllowRunAsCurrentUserWithoutChroot {
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index a73c593ea..5f644b57e 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/runsc/specutils"
)
const (
@@ -237,6 +238,9 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error {
return fmt.Errorf("start arguments must contain stdin, stderr, and stdout followed by at least one file for the container root gofer")
}
+ // All validation passed, logs the spec for debugging.
+ specutils.LogSpec(args.Spec)
+
err := cm.l.startContainer(args.Spec, args.Conf, args.CID, args.FilePayload.Files)
if err != nil {
log.Debugf("containerManager.Start failed %q: %+v: %v", args.CID, args, err)
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index a2ecc6bcb..5ad108261 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -44,6 +44,7 @@ var allowedSyscalls = seccomp.SyscallRules{
},
syscall.SYS_CLOSE: {},
syscall.SYS_DUP: {},
+ syscall.SYS_DUP2: {},
syscall.SYS_EPOLL_CREATE1: {},
syscall.SYS_EPOLL_CTL: {},
syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{
@@ -242,6 +243,15 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowValue(0),
},
},
+ unix.SYS_SENDMMSG: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.MSG_DONTWAIT),
+ seccomp.AllowValue(0),
+ },
+ },
syscall.SYS_RESTART_SYSCALL: {},
syscall.SYS_RT_SIGACTION: {},
syscall.SYS_RT_SIGPROCMASK: {},
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 393c2a88b..76036c147 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -703,6 +703,14 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (*
log.Infof("Mounting root over 9P, ioFD: %d", fd)
p9FS := mustFindFilesystem("9p")
opts := p9MountOptions(fd, conf.FileAccess)
+
+ if conf.OverlayfsStaleRead {
+ // We can't check for overlayfs here because sandbox is chroot'ed and gofer
+ // can only send mount options for specs.Mounts (specs.Root is missing
+ // Options field). So assume root is always on top of overlayfs.
+ opts = append(opts, "overlayfs_stale_read")
+ }
+
rootInode, err := p9FS.Mount(ctx, rootDevice, mf, strings.Join(opts, ","), nil)
if err != nil {
return nil, fmt.Errorf("creating root mount point: %v", err)
@@ -737,7 +745,6 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (
fsName string
opts []string
useOverlay bool
- err error
)
switch m.Type {
@@ -747,7 +754,12 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (
fsName = sysfs
case tmpfs:
fsName = m.Type
+
+ var err error
opts, err = parseAndFilterOptions(m.Options, tmpfsAllowedOptions...)
+ if err != nil {
+ return "", nil, false, err
+ }
case bind:
fd := c.fds.remove()
@@ -763,7 +775,7 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (
// for now.
log.Warningf("ignoring unknown filesystem type %q", m.Type)
}
- return fsName, opts, useOverlay, err
+ return fsName, opts, useOverlay, nil
}
// mountSubmount mounts volumes inside the container's root. Because mounts may
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index c8e5e86ee..0c0eba99e 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -922,7 +922,7 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
HandleLocal: true,
// Enable raw sockets for users with sufficient
// privileges.
- UnassociatedFactory: raw.EndpointFactory{},
+ RawFactory: raw.EndpointFactory{},
})}
// Enable SACK Recovery.
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index 32cba5ac1..f98c5fd36 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -50,12 +50,13 @@ type DefaultRoute struct {
// FDBasedLink configures an fd-based link.
type FDBasedLink struct {
- Name string
- MTU int
- Addresses []net.IP
- Routes []Route
- GSOMaxSize uint32
- LinkAddress net.HardwareAddr
+ Name string
+ MTU int
+ Addresses []net.IP
+ Routes []Route
+ GSOMaxSize uint32
+ SoftwareGSOEnabled bool
+ LinkAddress net.HardwareAddr
// NumChannels controls how many underlying FD's are to be used to
// create this endpoint.
@@ -163,6 +164,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
Address: mac,
PacketDispatchMode: fdbased.RecvMMsg,
GSOMaxSize: link.GSOMaxSize,
+ SoftwareGSOEnabled: link.SoftwareGSOEnabled,
RXChecksumOffload: true,
})
if err != nil {
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 4c2fb80bf..4831210c0 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -27,6 +27,7 @@ import (
"flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/unet"
@@ -135,7 +136,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
//
// Note that all mount points have been mounted in the proper location in
// setupRootFS().
- cleanMounts, err := resolveMounts(spec.Mounts, root)
+ cleanMounts, err := resolveMounts(conf, spec.Mounts, root)
if err != nil {
Fatalf("Failure to resolve mounts: %v", err)
}
@@ -380,7 +381,7 @@ func setupMounts(mounts []specs.Mount, root string) error {
// Otherwise, it may follow symlinks to locations that would be overwritten
// with another mount point and return the wrong location. In short, make sure
// setupMounts() has been called before.
-func resolveMounts(mounts []specs.Mount, root string) ([]specs.Mount, error) {
+func resolveMounts(conf *boot.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) {
cleanMounts := make([]specs.Mount, 0, len(mounts))
for _, m := range mounts {
if m.Type != "bind" || !specutils.IsSupportedDevMount(m) {
@@ -395,8 +396,15 @@ func resolveMounts(mounts []specs.Mount, root string) ([]specs.Mount, error) {
if err != nil {
panic(fmt.Sprintf("%q could not be made relative to %q: %v", dst, root, err))
}
+
+ opts, err := adjustMountOptions(conf, filepath.Join(root, relDst), m.Options)
+ if err != nil {
+ return nil, err
+ }
+
cpy := m
cpy.Destination = filepath.Join("/", relDst)
+ cpy.Options = opts
cleanMounts = append(cleanMounts, cpy)
}
return cleanMounts, nil
@@ -453,3 +461,20 @@ func resolveSymlinksImpl(root, base, rel string, followCount uint) (string, erro
}
return base, nil
}
+
+// adjustMountOptions adds 'overlayfs_stale_read' if mounting over overlayfs.
+func adjustMountOptions(conf *boot.Config, path string, opts []string) ([]string, error) {
+ rv := make([]string, len(opts))
+ copy(rv, opts)
+
+ if conf.OverlayfsStaleRead {
+ statfs := syscall.Statfs_t{}
+ if err := syscall.Statfs(path, &statfs); err != nil {
+ return nil, err
+ }
+ if statfs.Type == unix.OVERLAYFS_SUPER_MAGIC {
+ rv = append(rv, "overlayfs_stale_read")
+ }
+ }
+ return rv, nil
+}
diff --git a/runsc/container/container.go b/runsc/container/container.go
index a721c1c31..32510d427 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -1149,7 +1149,7 @@ func maybeLockRootContainer(spec *specs.Spec, rootDir string) (func() error, err
}
func isRoot(spec *specs.Spec) bool {
- return specutils.ShouldCreateSandbox(spec)
+ return specutils.SpecContainerType(spec) != specutils.ContainerTypeContainer
}
// runInCgroup executes fn inside the specified cgroup. If cg is nil, execute
@@ -1198,7 +1198,7 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool)
// Get the lowest score for all containers.
var lowScore int
scoreFound := false
- if len(containers) == 1 && len(containers[0].Spec.Annotations[specutils.ContainerdContainerTypeAnnotation]) == 0 {
+ if len(containers) == 1 && specutils.SpecContainerType(containers[0].Spec) == specutils.ContainerTypeUnspecified {
// This is a single-container sandbox. Set the oom_score_adj to
// the value specified in the OCI bundle.
if containers[0].Spec.Process.OOMScoreAdj != nil {
@@ -1214,7 +1214,7 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool)
//
// We will use OOMScoreAdj in the single-container case where the
// containerd container-type annotation is not present.
- if container.Spec.Annotations[specutils.ContainerdContainerTypeAnnotation] == specutils.ContainerdContainerTypeSandbox {
+ if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox {
continue
}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 519f5ed9b..c4c56b2e0 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -2074,6 +2074,43 @@ func TestNetRaw(t *testing.T) {
}
}
+// TestOverlayfsStaleRead most basic test that '--overlayfs-stale-read' works.
+func TestOverlayfsStaleRead(t *testing.T) {
+ conf := testutil.TestConfig()
+ conf.OverlayfsStaleRead = true
+
+ in, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.in")
+ if err != nil {
+ t.Fatalf("ioutil.TempFile() failed: %v", err)
+ }
+ defer in.Close()
+ if _, err := in.WriteString("stale data"); err != nil {
+ t.Fatalf("in.Write() failed: %v", err)
+ }
+
+ out, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.out")
+ if err != nil {
+ t.Fatalf("ioutil.TempFile() failed: %v", err)
+ }
+ defer out.Close()
+
+ const want = "foobar"
+ cmd := fmt.Sprintf("cat %q && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name())
+ spec := testutil.NewSpecWithArgs("/bin/bash", "-c", cmd)
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("Error running container: %v", err)
+ }
+
+ gotBytes, err := ioutil.ReadAll(out)
+ if err != nil {
+ t.Fatalf("out.Read() failed: %v", err)
+ }
+ got := strings.TrimSpace(string(gotBytes))
+ if want != got {
+ t.Errorf("Wrong content in out file, got: %q. want: %q", got, want)
+ }
+}
+
// executeSync synchronously executes a new process.
func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
pid, err := cont.Execute(args)
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index 0bf7507b7..2ea95f8fb 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -220,6 +220,18 @@ var udsSyscalls = seccomp.SyscallRules{
syscall.SYS_SOCKET: []seccomp.Rule{
{
seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_STREAM),
+ seccomp.AllowValue(0),
+ },
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_DGRAM),
+ seccomp.AllowValue(0),
+ },
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_SEQPACKET),
+ seccomp.AllowValue(0),
},
},
syscall.SYS_CONNECT: []seccomp.Rule{
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 29a82138e..3fceecb3d 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -21,7 +21,6 @@
package fsgofer
import (
- "errors"
"fmt"
"io"
"math"
@@ -126,13 +125,6 @@ func NewAttachPoint(prefix string, c Config) (p9.Attacher, error) {
// Attach implements p9.Attacher.
func (a *attachPoint) Attach() (p9.File, error) {
- // dirFD (1st argument) is ignored because 'prefix' is always absolute.
- stat, err := statAt(-1, a.prefix)
- if err != nil {
- return nil, fmt.Errorf("stat file %q, err: %v", a.prefix, err)
- }
-
- // Acquire the attach point lock.
a.attachedMu.Lock()
defer a.attachedMu.Unlock()
@@ -140,47 +132,24 @@ func (a *attachPoint) Attach() (p9.File, error) {
return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix)
}
- // Hold the file descriptor we are converting into a p9.File.
- var f *fd.FD
-
- // Apply the S_IFMT bitmask so we can detect file type appropriately.
- switch fmtStat := stat.Mode & syscall.S_IFMT; fmtStat {
- case syscall.S_IFSOCK:
- // Check to see if the CLI option has been set to allow the UDS mount.
- if !a.conf.HostUDS {
- return nil, errors.New("host UDS support is disabled")
- }
-
- // Attempt to open a connection. Bubble up the failures.
- f, err = fd.DialUnix(a.prefix)
- if err != nil {
- return nil, err
- }
-
- default:
- // Default to Read/Write permissions.
- mode := syscall.O_RDWR
-
- // If the configuration is Read Only or the mount point is a directory,
- // set the mode to Read Only.
- if a.conf.ROMount || fmtStat == syscall.S_IFDIR {
- mode = syscall.O_RDONLY
- }
+ f, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) {
+ return fd.Open(a.prefix, openFlags|mode, 0)
+ })
+ if err != nil {
+ return nil, fmt.Errorf("unable to open %q: %v", a.prefix, err)
+ }
- // Open the mount point & capture the FD.
- f, err = fd.Open(a.prefix, openFlags|mode, 0)
- if err != nil {
- return nil, fmt.Errorf("unable to open file %q, err: %v", a.prefix, err)
- }
+ stat, err := stat(f.FD())
+ if err != nil {
+ return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err)
}
- // Return a localFile object to the caller with the UDS FD included.
- rv, err := newLocalFile(a, f, a.prefix, stat)
+ lf, err := newLocalFile(a, f, a.prefix, stat)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err)
}
a.attached = true
- return rv, nil
+ return lf, nil
}
// makeQID returns a unique QID for the given stat buffer.
@@ -296,10 +265,10 @@ func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, erro
// actual file open and is customizable by the caller.
func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error) {
// Attempt to open file in the following mode in order:
- // 1. RDONLY | NONBLOCK: for all files, works for directories and ro mounts too.
- // Use non-blocking to prevent getting stuck inside open(2) for FIFOs. This option
- // has no effect on regular files.
- // 2. PATH: for symlinks
+ // 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs.
+ // Use non-blocking to prevent getting stuck inside open(2) for
+ // FIFOs. This option has no effect on regular files.
+ // 2. PATH: for symlinks, sockets.
modes := []int{syscall.O_RDONLY | syscall.O_NONBLOCK, unix.O_PATH}
var err error
@@ -1063,12 +1032,48 @@ func (l *localFile) Flush() error {
}
// Connect implements p9.File.
-func (l *localFile) Connect(p9.ConnectFlags) (*fd.FD, error) {
- // Check to see if the CLI option has been set to allow the UDS mount.
+func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) {
if !l.attachPoint.conf.HostUDS {
return nil, syscall.ECONNREFUSED
}
- return fd.DialUnix(l.hostPath)
+
+ // TODO(gvisor.dev/issue/1003): Due to different app vs replacement
+ // mappings, the app path may have fit in the sockaddr, but we can't
+ // fit f.path in our sockaddr. We'd need to redirect through a shorter
+ // path in order to actually connect to this socket.
+ if len(l.hostPath) > linux.UnixPathMax {
+ return nil, syscall.ECONNREFUSED
+ }
+
+ var stype int
+ switch flags {
+ case p9.StreamSocket:
+ stype = syscall.SOCK_STREAM
+ case p9.DgramSocket:
+ stype = syscall.SOCK_DGRAM
+ case p9.SeqpacketSocket:
+ stype = syscall.SOCK_SEQPACKET
+ default:
+ return nil, syscall.ENXIO
+ }
+
+ f, err := syscall.Socket(syscall.AF_UNIX, stype, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := syscall.SetNonblock(f, true); err != nil {
+ syscall.Close(f)
+ return nil, err
+ }
+
+ sa := syscall.SockaddrUnix{Name: l.hostPath}
+ if err := syscall.Connect(f, &sa); err != nil {
+ syscall.Close(f)
+ return nil, err
+ }
+
+ return fd.New(f), nil
}
// Close implements p9.File.
diff --git a/runsc/main.go b/runsc/main.go
index 7dce9dc00..ae906c661 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -41,35 +41,37 @@ import (
var (
// Although these flags are not part of the OCI spec, they are used by
// Docker, and thus should not be changed.
- rootDir = flag.String("root", "", "root directory for storage of container state")
- logFilename = flag.String("log", "", "file path where internal debug information is written, default is stdout")
- logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s")
- debug = flag.Bool("debug", false, "enable debug logging")
- showVersion = flag.Bool("version", false, "show version and exit")
+ rootDir = flag.String("root", "", "root directory for storage of container state.")
+ logFilename = flag.String("log", "", "file path where internal debug information is written, default is stdout.")
+ logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s.")
+ debug = flag.Bool("debug", false, "enable debug logging.")
+ showVersion = flag.Bool("version", false, "show version and exit.")
// These flags are unique to runsc, and are used to configure parts of the
// system that are not covered by the runtime spec.
// Debugging flags.
debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.")
- logPackets = flag.Bool("log-packets", false, "enable network packet logging")
+ logPackets = flag.Bool("log-packets", false, "enable network packet logging.")
logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.")
debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.")
- debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s")
- alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr")
+ debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.")
+ alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.")
// Debugging flags: strace related
- strace = flag.Bool("strace", false, "enable strace")
+ strace = flag.Bool("strace", false, "enable strace.")
straceSyscalls = flag.String("strace-syscalls", "", "comma-separated list of syscalls to trace. If --strace is true and this list is empty, then all syscalls will be traced.")
- straceLogSize = flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs")
+ straceLogSize = flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs.")
// Flags that control sandbox runtime behavior.
- platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm")
+ platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.")
network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.")
- gso = flag.Bool("gso", true, "enable generic segmenation offload")
+ hardwareGSO = flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.")
+ softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware ofload can't be enabled.")
fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.")
- fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "Allow the gofer to mount Unix Domain Sockets.")
+ fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.")
overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.")
+ overlayfsStaleRead = flag.Bool("overlayfs-stale-read", false, "reopen cached FDs after a file is opened for write to workaround overlayfs limitation on kernels before 4.19.")
watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.")
panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.")
profile = flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).")
@@ -199,7 +201,8 @@ func main() {
FSGoferHostUDS: *fsGoferHostUDS,
Overlay: *overlay,
Network: netType,
- GSO: *gso,
+ HardwareGSO: *hardwareGSO,
+ SoftwareGSO: *softwareGSO,
LogPackets: *logPackets,
Platform: platformType,
Strace: *strace,
@@ -212,6 +215,7 @@ func main() {
Rootless: *rootless,
AlsoLogToStderr: *alsoLogToStderr,
ReferenceLeakMode: refsLeakMode,
+ OverlayfsStaleRead: *overlayfsStaleRead,
TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot,
TestOnlyTestNameEnv: *testOnlyTestNameEnv,
diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD
index 7fdceaab6..27459e6d1 100644
--- a/runsc/sandbox/BUILD
+++ b/runsc/sandbox/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/platform",
+ "//pkg/tcpip/stack",
"//pkg/urpc",
"//runsc/boot",
"//runsc/boot/platforms",
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index 5634f0707..d42de0176 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -28,6 +28,7 @@ import (
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/urpc"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
@@ -61,7 +62,7 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi
// Build the path to the net namespace of the sandbox process.
// This is what we will copy.
nsPath := filepath.Join("/proc", strconv.Itoa(pid), "ns/net")
- if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.GSO, conf.NumNetworkChannels); err != nil {
+ if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.NumNetworkChannels); err != nil {
return fmt.Errorf("creating interfaces from net namespace %q: %v", nsPath, err)
}
case boot.NetworkHost:
@@ -136,7 +137,7 @@ func isRootNS() (bool, error) {
// createInterfacesAndRoutesFromNS scrapes the interface and routes from the
// net namespace with the given path, creates them in the sandbox, and removes
// them from the host.
-func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO bool, numNetworkChannels int) error {
+func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, numNetworkChannels int) error {
// Join the network namespace that we will be copying.
restore, err := joinNetNS(nsPath)
if err != nil {
@@ -232,7 +233,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO
// Create the socket for the device.
for i := 0; i < link.NumChannels; i++ {
log.Debugf("Creating Channel %d", i)
- socketEntry, err := createSocket(iface, ifaceLink, enableGSO)
+ socketEntry, err := createSocket(iface, ifaceLink, hardwareGSO)
if err != nil {
return fmt.Errorf("failed to createSocket for %s : %v", iface.Name, err)
}
@@ -246,6 +247,11 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, enableGSO
}
args.FilePayload.Files = append(args.FilePayload.Files, socketEntry.deviceFile)
}
+ if link.GSOMaxSize == 0 && softwareGSO {
+ // Hardware GSO is disabled. Let's enable software GSO.
+ link.GSOMaxSize = stack.SoftwareGSOMaxSize
+ link.SoftwareGSOEnabled = true
+ }
// Collect the addresses for the interface, enable forwarding,
// and remove them from the host.
diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD
index fa58313a0..205638803 100644
--- a/runsc/specutils/BUILD
+++ b/runsc/specutils/BUILD
@@ -5,6 +5,7 @@ package(licenses = ["notice"])
go_library(
name = "specutils",
srcs = [
+ "cri.go",
"fs.go",
"namespace.go",
"specutils.go",
diff --git a/runsc/specutils/cri.go b/runsc/specutils/cri.go
new file mode 100644
index 000000000..9c5877cd5
--- /dev/null
+++ b/runsc/specutils/cri.go
@@ -0,0 +1,110 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package specutils
+
+import (
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+const (
+ // ContainerdContainerTypeAnnotation is the OCI annotation set by
+ // containerd to indicate whether the container to create should have
+ // its own sandbox or a container within an existing sandbox.
+ ContainerdContainerTypeAnnotation = "io.kubernetes.cri.container-type"
+ // ContainerdContainerTypeContainer is the container type value
+ // indicating the container should be created in an existing sandbox.
+ ContainerdContainerTypeContainer = "container"
+ // ContainerdContainerTypeSandbox is the container type value
+ // indicating the container should be created in a new sandbox.
+ ContainerdContainerTypeSandbox = "sandbox"
+
+ // ContainerdSandboxIDAnnotation is the OCI annotation set to indicate
+ // which sandbox the container should be created in when the container
+ // is not the first container in the sandbox.
+ ContainerdSandboxIDAnnotation = "io.kubernetes.cri.sandbox-id"
+
+ // CRIOContainerTypeAnnotation is the OCI annotation set by
+ // CRI-O to indicate whether the container to create should have
+ // its own sandbox or a container within an existing sandbox.
+ CRIOContainerTypeAnnotation = "io.kubernetes.cri-o.ContainerType"
+
+ // CRIOContainerTypeContainer is the container type value
+ // indicating the container should be created in an existing sandbox.
+ CRIOContainerTypeContainer = "container"
+ // CRIOContainerTypeSandbox is the container type value
+ // indicating the container should be created in a new sandbox.
+ CRIOContainerTypeSandbox = "sandbox"
+
+ // CRIOSandboxIDAnnotation is the OCI annotation set to indicate
+ // which sandbox the container should be created in when the container
+ // is not the first container in the sandbox.
+ CRIOSandboxIDAnnotation = "io.kubernetes.cri-o.SandboxID"
+)
+
+// ContainerType represents the type of container requested by the calling container manager.
+type ContainerType int
+
+const (
+ // ContainerTypeUnspecified indicates that no known container type
+ // annotation was found in the spec.
+ ContainerTypeUnspecified ContainerType = iota
+ // ContainerTypeUnknown indicates that a container type was specified
+ // but is unknown to us.
+ ContainerTypeUnknown
+ // ContainerTypeSandbox indicates that the container should be run in a
+ // new sandbox.
+ ContainerTypeSandbox
+ // ContainerTypeContainer indicates that the container should be run in
+ // an existing sandbox.
+ ContainerTypeContainer
+)
+
+// SpecContainerType tries to determine the type of container specified by the
+// container manager using well-known container annotations.
+func SpecContainerType(spec *specs.Spec) ContainerType {
+ if t, ok := spec.Annotations[ContainerdContainerTypeAnnotation]; ok {
+ switch t {
+ case ContainerdContainerTypeSandbox:
+ return ContainerTypeSandbox
+ case ContainerdContainerTypeContainer:
+ return ContainerTypeContainer
+ default:
+ return ContainerTypeUnknown
+ }
+ }
+ if t, ok := spec.Annotations[CRIOContainerTypeAnnotation]; ok {
+ switch t {
+ case CRIOContainerTypeSandbox:
+ return ContainerTypeSandbox
+ case CRIOContainerTypeContainer:
+ return ContainerTypeContainer
+ default:
+ return ContainerTypeUnknown
+ }
+ }
+ return ContainerTypeUnspecified
+}
+
+// SandboxID returns the ID of the sandbox to join and whether an ID was found
+// in the spec.
+func SandboxID(spec *specs.Spec) (string, bool) {
+ if id, ok := spec.Annotations[ContainerdSandboxIDAnnotation]; ok {
+ return id, true
+ }
+ if id, ok := spec.Annotations[CRIOSandboxIDAnnotation]; ok {
+ return id, true
+ }
+ return "", false
+}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 3d9ced1b6..d3c2e4e78 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -108,23 +108,18 @@ func ValidateSpec(spec *specs.Spec) error {
}
}
- // Two annotations are use by containerd to support multi-container pods.
- // "io.kubernetes.cri.container-type"
- // "io.kubernetes.cri.sandbox-id"
- containerType, hasContainerType := spec.Annotations[ContainerdContainerTypeAnnotation]
- _, hasSandboxID := spec.Annotations[ContainerdSandboxIDAnnotation]
- switch {
- // Non-containerd use won't set a container type.
- case !hasContainerType:
- case containerType == ContainerdContainerTypeSandbox:
- // When starting a container in an existing sandbox, the sandbox ID
- // must be set.
- case containerType == ContainerdContainerTypeContainer:
- if !hasSandboxID {
- return fmt.Errorf("spec has container-type of %s, but no sandbox ID set", containerType)
+ // CRI specifies whether a container should start a new sandbox, or run
+ // another container in an existing sandbox.
+ switch SpecContainerType(spec) {
+ case ContainerTypeContainer:
+ // When starting a container in an existing sandbox, the
+ // sandbox ID must be set.
+ if _, ok := SandboxID(spec); !ok {
+ return fmt.Errorf("spec has container-type of container, but no sandbox ID set")
}
+ case ContainerTypeUnknown:
+ return fmt.Errorf("unknown container-type")
default:
- return fmt.Errorf("unknown container-type: %s", containerType)
}
return nil
@@ -338,39 +333,6 @@ func IsSupportedDevMount(m specs.Mount) bool {
return true
}
-const (
- // ContainerdContainerTypeAnnotation is the OCI annotation set by
- // containerd to indicate whether the container to create should have
- // its own sandbox or a container within an existing sandbox.
- ContainerdContainerTypeAnnotation = "io.kubernetes.cri.container-type"
- // ContainerdContainerTypeContainer is the container type value
- // indicating the container should be created in an existing sandbox.
- ContainerdContainerTypeContainer = "container"
- // ContainerdContainerTypeSandbox is the container type value
- // indicating the container should be created in a new sandbox.
- ContainerdContainerTypeSandbox = "sandbox"
-
- // ContainerdSandboxIDAnnotation is the OCI annotation set to indicate
- // which sandbox the container should be created in when the container
- // is not the first container in the sandbox.
- ContainerdSandboxIDAnnotation = "io.kubernetes.cri.sandbox-id"
-)
-
-// ShouldCreateSandbox returns true if the spec indicates that a new sandbox
-// should be created for the container. If false, the container should be
-// started in an existing sandbox.
-func ShouldCreateSandbox(spec *specs.Spec) bool {
- t, ok := spec.Annotations[ContainerdContainerTypeAnnotation]
- return !ok || t == ContainerdContainerTypeSandbox
-}
-
-// SandboxID returns the ID of the sandbox to join and whether an ID was found
-// in the spec.
-func SandboxID(spec *specs.Spec) (string, bool) {
- id, ok := spec.Annotations[ContainerdSandboxIDAnnotation]
- return id, ok
-}
-
// WaitForReady waits for a process to become ready. The process is ready when
// the 'ready' function returns true. It continues to wait if 'ready' returns
// false. It returns error on timeout, if the process stops or if 'ready' fails.
diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD
index d44ebc906..c96ca2eb6 100644
--- a/runsc/testutil/BUILD
+++ b/runsc/testutil/BUILD
@@ -9,6 +9,7 @@ go_library(
importpath = "gvisor.dev/gvisor/runsc/testutil",
visibility = ["//:sandbox"],
deps = [
+ "//pkg/log",
"//runsc/boot",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go
index edf8b126c..26467bdc7 100644
--- a/runsc/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -25,7 +25,6 @@ import (
"fmt"
"io"
"io/ioutil"
- "log"
"math"
"math/rand"
"net/http"
@@ -42,6 +41,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -286,7 +286,7 @@ func WaitForHTTP(port int, timeout time.Duration) error {
url := fmt.Sprintf("http://localhost:%d/", port)
resp, err := c.Get(url)
if err != nil {
- log.Printf("Waiting %s: %v", url, err)
+ log.Infof("Waiting %s: %v", url, err)
return err
}
resp.Body.Close()
diff --git a/scripts/swgso_tests.sh b/scripts/swgso_tests.sh
new file mode 100755
index 000000000..0de2df1d2
--- /dev/null
+++ b/scripts/swgso_tests.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Install the runtime and perform basic tests.
+install_runsc_for_test swgso --software-gso=true --gso=false
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/test/README.md b/test/README.md
index 09c36b461..97fe7ea04 100644
--- a/test/README.md
+++ b/test/README.md
@@ -10,9 +10,31 @@ they may need extra setup in the test machine and extra configuration to run.
functionality.
- **image:** basic end to end test for popular images. These require the same
setup as integration tests.
-- **root:** tests that require to be run as root.
+- **root:** tests that require to be run as root. These require the same setup
+ as integration tests.
- **util:** utilities library to support the tests.
For the above noted cases, the relevant runtime must be installed via `runsc
-install` before running. This is handled automatically by the test scripts in
-the `kokoro` directory.
+install` before running. Just note that they require specific configuration to
+work. This is handled automatically by the test scripts in the `scripts`
+directory and they can be used to run tests locally on your machine. They are
+also used to run these tests in `kokoro`.
+
+**Example:**
+
+To run image and integration tests, run:
+
+`./scripts/docker_test.sh`
+
+To run root tests, run:
+
+`./scripts/root_test.sh`
+
+There are a few other interesting variations for image and integration tests:
+
+* overlay: sets writable overlay inside the sentry
+* hostnet: configures host network pass-thru, instead of netstack
+* kvm: runsc the test using the KVM platform, instead of ptrace
+
+The test will build runsc, configure it with your local docker, restart
+`dockerd`, and run tests. The location for runsc logs is printed to the output.
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
index c962a3159..4074d2285 100644
--- a/test/e2e/exec_test.go
+++ b/test/e2e/exec_test.go
@@ -109,7 +109,7 @@ func TestExecPrivileged(t *testing.T) {
t.Logf("Exec CapEff: %v", got)
want := fmt.Sprintf("CapEff:\t%016x\n", specutils.AllCapabilitiesUint64()&^bits.MaskOf64(int(linux.CAP_NET_RAW)))
if got != want {
- t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
+ t.Errorf("Wrong capabilities, got: %q, want: %q. Make sure runsc is not using '--net-raw'", got, want)
}
}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 87ef87e07..a53a23afd 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -79,6 +79,12 @@ syscall_test(test = "//test/syscalls/linux:clock_nanosleep_test")
syscall_test(test = "//test/syscalls/linux:concurrency_test")
syscall_test(
+ add_uds_tree = True,
+ test = "//test/syscalls/linux:connect_external_test",
+ use_tmpfs = True,
+)
+
+syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:creat_test",
)
@@ -716,6 +722,7 @@ go_binary(
"//runsc/specutils",
"//runsc/testutil",
"//test/syscalls/gtest",
+ "//test/uds",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl
index e94ef5602..dcf5b73ed 100644
--- a/test/syscalls/build_defs.bzl
+++ b/test/syscalls/build_defs.bzl
@@ -8,6 +8,7 @@ def syscall_test(
size = "small",
use_tmpfs = False,
add_overlay = False,
+ add_uds_tree = False,
tags = None):
_syscall_test(
test = test,
@@ -15,6 +16,7 @@ def syscall_test(
size = size,
platform = "native",
use_tmpfs = False,
+ add_uds_tree = add_uds_tree,
tags = tags,
)
@@ -24,6 +26,7 @@ def syscall_test(
size = size,
platform = "kvm",
use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
tags = tags,
)
@@ -33,6 +36,7 @@ def syscall_test(
size = size,
platform = "ptrace",
use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
tags = tags,
)
@@ -43,6 +47,7 @@ def syscall_test(
size = size,
platform = "ptrace",
use_tmpfs = False, # overlay is adding a writable tmpfs on top of root.
+ add_uds_tree = add_uds_tree,
tags = tags,
overlay = True,
)
@@ -55,6 +60,7 @@ def syscall_test(
size = size,
platform = "ptrace",
use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
tags = tags,
file_access = "shared",
)
@@ -67,7 +73,8 @@ def _syscall_test(
use_tmpfs,
tags,
file_access = "exclusive",
- overlay = False):
+ overlay = False,
+ add_uds_tree = False):
test_name = test.split(":")[1]
# Prepend "runsc" to non-native platform names.
@@ -103,6 +110,7 @@ def _syscall_test(
"--use-tmpfs=" + str(use_tmpfs),
"--file-access=" + file_access,
"--overlay=" + str(overlay),
+ "--add-uds-tree=" + str(add_uds_tree),
]
sh_test(
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index d243be9e4..833fbaa09 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -480,6 +480,21 @@ cc_binary(
)
cc_binary(
+ name = "connect_external_test",
+ testonly = 1,
+ srcs = ["connect_external.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "creat_test",
testonly = 1,
srcs = ["creat.cc"],
@@ -655,6 +670,7 @@ cc_binary(
"//test/util:thread_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:optional",
"@com_google_googletest//:gtest",
],
)
@@ -1558,6 +1574,8 @@ cc_binary(
"//test/util:fs_util",
"//test/util:test_main",
"//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc
index 1122ea240..328192a05 100644
--- a/test/syscalls/linux/accept_bind.cc
+++ b/test/syscalls/linux/accept_bind.cc
@@ -140,6 +140,18 @@ TEST_P(AllSocketPairTest, Connect) {
SyscallSucceeds());
}
+TEST_P(AllSocketPairTest, ConnectNonListening) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
TEST_P(AllSocketPairTest, ConnectToFilePath) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/connect_external.cc b/test/syscalls/linux/connect_external.cc
new file mode 100644
index 000000000..98032ac19
--- /dev/null
+++ b/test/syscalls/linux/connect_external.cc
@@ -0,0 +1,164 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <stdlib.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <string>
+#include <tuple>
+
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+// This file contains tests specific to connecting to host UDS managed outside
+// the sandbox / test.
+//
+// A set of ultity sockets will be created externally in $TEST_UDS_TREE and
+// $TEST_UDS_ATTACH_TREE for these tests to interact with.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+struct ProtocolSocket {
+ int protocol;
+ std::string name;
+};
+
+// Parameter is (socket root dir, ProtocolSocket).
+using GoferStreamSeqpacketTest =
+ ::testing::TestWithParam<std::tuple<std::string, ProtocolSocket>>;
+
+// Connect to a socket and verify that write/read work.
+//
+// An "echo" socket doesn't work for dgram sockets because our socket is
+// unnamed. The server thus has no way to reply to us.
+TEST_P(GoferStreamSeqpacketTest, Echo) {
+ std::string env;
+ ProtocolSocket proto;
+ std::tie(env, proto) = GetParam();
+
+ char *val = getenv(env.c_str());
+ ASSERT_NE(val, nullptr);
+ std::string root(val);
+
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0));
+
+ std::string socket_path = JoinPath(root, proto.name, "echo");
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ memcpy(addr.sun_path, socket_path.c_str(), socket_path.length());
+
+ ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ constexpr int kBufferSize = 64;
+ char send_buffer[kBufferSize];
+ memset(send_buffer, 'a', sizeof(send_buffer));
+
+ ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)),
+ SyscallSucceedsWithValue(sizeof(send_buffer)));
+
+ char recv_buffer[kBufferSize];
+ ASSERT_THAT(ReadFd(sock.get(), recv_buffer, sizeof(recv_buffer)),
+ SyscallSucceedsWithValue(sizeof(recv_buffer)));
+ ASSERT_EQ(0, memcmp(send_buffer, recv_buffer, sizeof(send_buffer)));
+}
+
+// It is not possible to connect to a bound but non-listening socket.
+TEST_P(GoferStreamSeqpacketTest, NonListening) {
+ std::string env;
+ ProtocolSocket proto;
+ std::tie(env, proto) = GetParam();
+
+ char *val = getenv(env.c_str());
+ ASSERT_NE(val, nullptr);
+ std::string root(val);
+
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, proto.protocol, 0));
+
+ std::string socket_path = JoinPath(root, proto.name, "nonlistening");
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ memcpy(addr.sun_path, socket_path.c_str(), socket_path.length());
+
+ ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ StreamSeqpacket, GoferStreamSeqpacketTest,
+ ::testing::Combine(
+ // Test access via standard path and attach point.
+ ::testing::Values("TEST_UDS_TREE", "TEST_UDS_ATTACH_TREE"),
+ ::testing::Values(ProtocolSocket{SOCK_STREAM, "stream"},
+ ProtocolSocket{SOCK_SEQPACKET, "seqpacket"})));
+
+// Parameter is socket root dir.
+using GoferDgramTest = ::testing::TestWithParam<std::string>;
+
+// Connect to a socket and verify that write works.
+//
+// An "echo" socket doesn't work for dgram sockets because our socket is
+// unnamed. The server thus has no way to reply to us.
+TEST_P(GoferDgramTest, Null) {
+ std::string env = GetParam();
+ char *val = getenv(env.c_str());
+ ASSERT_NE(val, nullptr);
+ std::string root(val);
+
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_DGRAM, 0));
+
+ std::string socket_path = JoinPath(root, "dgram/null");
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ memcpy(addr.sun_path, socket_path.c_str(), socket_path.length());
+
+ ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ constexpr int kBufferSize = 64;
+ char send_buffer[kBufferSize];
+ memset(send_buffer, 'a', sizeof(send_buffer));
+
+ ASSERT_THAT(WriteFd(sock.get(), send_buffer, sizeof(send_buffer)),
+ SyscallSucceedsWithValue(sizeof(send_buffer)));
+}
+
+INSTANTIATE_TEST_SUITE_P(Dgram, GoferDgramTest,
+ // Test access via standard path and attach point.
+ ::testing::Values("TEST_UDS_TREE",
+ "TEST_UDS_ATTACH_TREE"));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc
index 4c7c95321..85734c290 100644
--- a/test/syscalls/linux/exec.cc
+++ b/test/syscalls/linux/exec.cc
@@ -33,6 +33,7 @@
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
+#include "absl/types/optional.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
#include "test/util/multiprocess_util.h"
@@ -68,11 +69,12 @@ constexpr char kExit42[] = "--exec_exit_42";
constexpr char kExecWithThread[] = "--exec_exec_with_thread";
constexpr char kExecFromThread[] = "--exec_exec_from_thread";
-// Runs filename with argv and checks that the exit status is expect_status and
-// that stderr contains expect_stderr.
-void CheckOutput(const std::string& filename, const ExecveArray& argv,
- const ExecveArray& envv, int expect_status,
- const std::string& expect_stderr) {
+// Runs file specified by dirfd and pathname with argv and checks that the exit
+// status is expect_status and that stderr contains expect_stderr.
+void CheckExecHelper(const absl::optional<int32_t> dirfd,
+ const std::string& pathname, const ExecveArray& argv,
+ const ExecveArray& envv, const int flags,
+ int expect_status, const std::string& expect_stderr) {
int pipe_fds[2];
ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds());
@@ -110,8 +112,15 @@ void CheckOutput(const std::string& filename, const ExecveArray& argv,
// CloexecEventfd depend on that not happening.
};
- auto kill = ASSERT_NO_ERRNO_AND_VALUE(
- ForkAndExec(filename, argv, envv, remap_stderr, &child, &execve_errno));
+ Cleanup kill;
+ if (dirfd.has_value()) {
+ kill = ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(*dirfd, pathname, argv,
+ envv, flags, remap_stderr,
+ &child, &execve_errno));
+ } else {
+ kill = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(pathname, argv, envv, remap_stderr, &child, &execve_errno));
+ }
ASSERT_EQ(0, execve_errno);
@@ -140,57 +149,71 @@ void CheckOutput(const std::string& filename, const ExecveArray& argv,
EXPECT_TRUE(absl::StrContains(output, expect_stderr)) << output;
}
-TEST(ExecDeathTest, EmptyPath) {
+void CheckExec(const std::string& filename, const ExecveArray& argv,
+ const ExecveArray& envv, int expect_status,
+ const std::string& expect_stderr) {
+ CheckExecHelper(/*dirfd=*/absl::optional<int32_t>(), filename, argv, envv,
+ /*flags=*/0, expect_status, expect_stderr);
+}
+
+void CheckExecveat(const int32_t dirfd, const std::string& pathname,
+ const ExecveArray& argv, const ExecveArray& envv,
+ const int flags, int expect_status,
+ const std::string& expect_stderr) {
+ CheckExecHelper(absl::optional<int32_t>(dirfd), pathname, argv, envv, flags,
+ expect_status, expect_stderr);
+}
+
+TEST(ExecTest, EmptyPath) {
int execve_errno;
ASSERT_NO_ERRNO_AND_VALUE(ForkAndExec("", {}, {}, nullptr, &execve_errno));
EXPECT_EQ(execve_errno, ENOENT);
}
-TEST(ExecDeathTest, Basic) {
- CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {},
- ArgEnvExitStatus(0, 0),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n"));
+TEST(ExecTest, Basic) {
+ CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {},
+ ArgEnvExitStatus(0, 0),
+ absl::StrCat(WorkloadPath(kBasicWorkload), "\n"));
}
-TEST(ExecDeathTest, OneArg) {
- CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "1"},
- {}, ArgEnvExitStatus(1, 0),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n"));
+TEST(ExecTest, OneArg) {
+ CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "1"},
+ {}, ArgEnvExitStatus(1, 0),
+ absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n"));
}
-TEST(ExecDeathTest, FiveArg) {
- CheckOutput(WorkloadPath(kBasicWorkload),
- {WorkloadPath(kBasicWorkload), "1", "2", "3", "4", "5"}, {},
- ArgEnvExitStatus(5, 0),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
+TEST(ExecTest, FiveArg) {
+ CheckExec(WorkloadPath(kBasicWorkload),
+ {WorkloadPath(kBasicWorkload), "1", "2", "3", "4", "5"}, {},
+ ArgEnvExitStatus(5, 0),
+ absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
}
-TEST(ExecDeathTest, OneEnv) {
- CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)},
- {"1"}, ArgEnvExitStatus(0, 1),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n"));
+TEST(ExecTest, OneEnv) {
+ CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {"1"},
+ ArgEnvExitStatus(0, 1),
+ absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n"));
}
-TEST(ExecDeathTest, FiveEnv) {
- CheckOutput(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)},
- {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
+TEST(ExecTest, FiveEnv) {
+ CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)},
+ {"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5),
+ absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
}
-TEST(ExecDeathTest, OneArgOneEnv) {
- CheckOutput(WorkloadPath(kBasicWorkload),
- {WorkloadPath(kBasicWorkload), "arg"}, {"env"},
- ArgEnvExitStatus(1, 1),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\narg\nenv\n"));
+TEST(ExecTest, OneArgOneEnv) {
+ CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "arg"},
+ {"env"}, ArgEnvExitStatus(1, 1),
+ absl::StrCat(WorkloadPath(kBasicWorkload), "\narg\nenv\n"));
}
-TEST(ExecDeathTest, InterpreterScript) {
- CheckOutput(WorkloadPath(kExitScript), {WorkloadPath(kExitScript), "25"}, {},
- ArgEnvExitStatus(25, 0), "");
+TEST(ExecTest, InterpreterScript) {
+ CheckExec(WorkloadPath(kExitScript), {WorkloadPath(kExitScript), "25"}, {},
+ ArgEnvExitStatus(25, 0), "");
}
// Everything after the path in the interpreter script is a single argument.
-TEST(ExecDeathTest, InterpreterScriptArgSplit) {
+TEST(ExecTest, InterpreterScriptArgSplit) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -199,12 +222,12 @@ TEST(ExecDeathTest, InterpreterScriptArgSplit) {
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo bar"),
0755));
- CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0),
- absl::StrCat(link.path(), "\nfoo bar\n", script.path(), "\n"));
+ CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0),
+ absl::StrCat(link.path(), "\nfoo bar\n", script.path(), "\n"));
}
// Original argv[0] is replaced with the script path.
-TEST(ExecDeathTest, InterpreterScriptArgvZero) {
+TEST(ExecTest, InterpreterScriptArgvZero) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -212,13 +235,13 @@ TEST(ExecDeathTest, InterpreterScriptArgvZero) {
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755));
- CheckOutput(script.path(), {"REPLACED"}, {}, ArgEnvExitStatus(1, 0),
- absl::StrCat(link.path(), "\n", script.path(), "\n"));
+ CheckExec(script.path(), {"REPLACED"}, {}, ArgEnvExitStatus(1, 0),
+ absl::StrCat(link.path(), "\n", script.path(), "\n"));
}
// Original argv[0] is replaced with the script path, exactly as passed to
// execve.
-TEST(ExecDeathTest, InterpreterScriptArgvZeroRelative) {
+TEST(ExecTest, InterpreterScriptArgvZeroRelative) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -230,12 +253,12 @@ TEST(ExecDeathTest, InterpreterScriptArgvZeroRelative) {
auto script_relative =
ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path()));
- CheckOutput(script_relative, {"REPLACED"}, {}, ArgEnvExitStatus(1, 0),
- absl::StrCat(link.path(), "\n", script_relative, "\n"));
+ CheckExec(script_relative, {"REPLACED"}, {}, ArgEnvExitStatus(1, 0),
+ absl::StrCat(link.path(), "\n", script_relative, "\n"));
}
// argv[0] is added as the script path, even if there was none.
-TEST(ExecDeathTest, InterpreterScriptArgvZeroAdded) {
+TEST(ExecTest, InterpreterScriptArgvZeroAdded) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -243,12 +266,12 @@ TEST(ExecDeathTest, InterpreterScriptArgvZeroAdded) {
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755));
- CheckOutput(script.path(), {}, {}, ArgEnvExitStatus(1, 0),
- absl::StrCat(link.path(), "\n", script.path(), "\n"));
+ CheckExec(script.path(), {}, {}, ArgEnvExitStatus(1, 0),
+ absl::StrCat(link.path(), "\n", script.path(), "\n"));
}
// A NUL byte in the script line ends parsing.
-TEST(ExecDeathTest, InterpreterScriptArgNUL) {
+TEST(ExecTest, InterpreterScriptArgNUL) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -258,12 +281,12 @@ TEST(ExecDeathTest, InterpreterScriptArgNUL) {
absl::StrCat("#!", link.path(), " foo", std::string(1, '\0'), "bar"),
0755));
- CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0),
- absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n"));
+ CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0),
+ absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n"));
}
// Trailing whitespace following interpreter path is ignored.
-TEST(ExecDeathTest, InterpreterScriptTrailingWhitespace) {
+TEST(ExecTest, InterpreterScriptTrailingWhitespace) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -271,12 +294,12 @@ TEST(ExecDeathTest, InterpreterScriptTrailingWhitespace) {
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " "), 0755));
- CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(1, 0),
- absl::StrCat(link.path(), "\n", script.path(), "\n"));
+ CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(1, 0),
+ absl::StrCat(link.path(), "\n", script.path(), "\n"));
}
// Multiple whitespace characters between interpreter and arg allowed.
-TEST(ExecDeathTest, InterpreterScriptArgWhitespace) {
+TEST(ExecTest, InterpreterScriptArgWhitespace) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
@@ -284,11 +307,11 @@ TEST(ExecDeathTest, InterpreterScriptArgWhitespace) {
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo"), 0755));
- CheckOutput(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0),
- absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n"));
+ CheckExec(script.path(), {script.path()}, {}, ArgEnvExitStatus(2, 0),
+ absl::StrCat(link.path(), "\nfoo\n", script.path(), "\n"));
}
-TEST(ExecDeathTest, InterpreterScriptNoPath) {
+TEST(ExecTest, InterpreterScriptNoPath) {
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "#!", 0755));
@@ -299,7 +322,7 @@ TEST(ExecDeathTest, InterpreterScriptNoPath) {
}
// AT_EXECFN is the path passed to execve.
-TEST(ExecDeathTest, ExecFn) {
+TEST(ExecTest, ExecFn) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload)));
@@ -314,18 +337,18 @@ TEST(ExecDeathTest, ExecFn) {
auto script_relative =
ASSERT_NO_ERRNO_AND_VALUE(GetRelativePath(cwd, script.path()));
- CheckOutput(script_relative, {script_relative}, {}, ArgEnvExitStatus(0, 0),
- absl::StrCat(script_relative, "\n"));
+ CheckExec(script_relative, {script_relative}, {}, ArgEnvExitStatus(0, 0),
+ absl::StrCat(script_relative, "\n"));
}
-TEST(ExecDeathTest, ExecName) {
+TEST(ExecTest, ExecName) {
std::string path = WorkloadPath(kStateWorkload);
- CheckOutput(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0),
- absl::StrCat(Basename(path).substr(0, 15), "\n"));
+ CheckExec(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0),
+ absl::StrCat(Basename(path).substr(0, 15), "\n"));
}
-TEST(ExecDeathTest, ExecNameScript) {
+TEST(ExecTest, ExecNameScript) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload)));
@@ -336,21 +359,21 @@ TEST(ExecDeathTest, ExecNameScript) {
std::string script_path = script.path();
- CheckOutput(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0),
- absl::StrCat(Basename(script_path).substr(0, 15), "\n"));
+ CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0),
+ absl::StrCat(Basename(script_path).substr(0, 15), "\n"));
}
// execve may be called by a multithreaded process.
-TEST(ExecDeathTest, WithSiblingThread) {
- CheckOutput("/proc/self/exe", {"/proc/self/exe", kExecWithThread}, {},
- W_EXITCODE(42, 0), "");
+TEST(ExecTest, WithSiblingThread) {
+ CheckExec("/proc/self/exe", {"/proc/self/exe", kExecWithThread}, {},
+ W_EXITCODE(42, 0), "");
}
// execve may be called from a thread other than the leader of a multithreaded
// process.
-TEST(ExecDeathTest, FromSiblingThread) {
- CheckOutput("/proc/self/exe", {"/proc/self/exe", kExecFromThread}, {},
- W_EXITCODE(42, 0), "");
+TEST(ExecTest, FromSiblingThread) {
+ CheckExec("/proc/self/exe", {"/proc/self/exe", kExecFromThread}, {},
+ W_EXITCODE(42, 0), "");
}
TEST(ExecTest, NotFound) {
@@ -376,7 +399,7 @@ void SignalHandler(int signo) {
// Signal handlers are reset on execve(2), unless they have default or ignored
// disposition.
-TEST(ExecStateDeathTest, HandlerReset) {
+TEST(ExecStateTest, HandlerReset) {
struct sigaction sa;
sa.sa_handler = SignalHandler;
ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
@@ -388,11 +411,11 @@ TEST(ExecStateDeathTest, HandlerReset) {
absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_DFL))),
};
- CheckOutput(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
+ CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
}
// Ignored signal dispositions are not reset.
-TEST(ExecStateDeathTest, IgnorePreserved) {
+TEST(ExecStateTest, IgnorePreserved) {
struct sigaction sa;
sa.sa_handler = SIG_IGN;
ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
@@ -404,11 +427,11 @@ TEST(ExecStateDeathTest, IgnorePreserved) {
absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_IGN))),
};
- CheckOutput(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
+ CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
}
// Signal masks are not reset on exec
-TEST(ExecStateDeathTest, SignalMask) {
+TEST(ExecStateTest, SignalMask) {
sigset_t s;
sigemptyset(&s);
sigaddset(&s, SIGUSR1);
@@ -420,12 +443,12 @@ TEST(ExecStateDeathTest, SignalMask) {
absl::StrCat(SIGUSR1),
};
- CheckOutput(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
+ CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
}
// itimers persist across execve.
// N.B. Timers created with timer_create(2) should not be preserved!
-TEST(ExecStateDeathTest, ItimerPreserved) {
+TEST(ExecStateTest, ItimerPreserved) {
// The fork in ForkAndExec clears itimers, so only set them up after fork.
auto setup_itimer = [] {
// Ignore SIGALRM, as we don't actually care about timer
@@ -472,10 +495,10 @@ TEST(ExecStateDeathTest, ItimerPreserved) {
TEST(ProcSelfExe, ChangesAcrossExecve) {
// See exec_proc_exe_workload for more details. We simply
// assert that the /proc/self/exe link changes across execve.
- CheckOutput(WorkloadPath(kProcExeWorkload),
- {WorkloadPath(kProcExeWorkload),
- ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))},
- {}, W_EXITCODE(0, 0), "");
+ CheckExec(WorkloadPath(kProcExeWorkload),
+ {WorkloadPath(kProcExeWorkload),
+ ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))},
+ {}, W_EXITCODE(0, 0), "");
}
TEST(ExecTest, CloexecNormalFile) {
@@ -484,20 +507,20 @@ TEST(ExecTest, CloexecNormalFile) {
const FileDescriptor fd_closed_on_exec =
ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC));
- CheckOutput(WorkloadPath(kAssertClosedWorkload),
- {WorkloadPath(kAssertClosedWorkload),
- absl::StrCat(fd_closed_on_exec.get())},
- {}, W_EXITCODE(0, 0), "");
+ CheckExec(WorkloadPath(kAssertClosedWorkload),
+ {WorkloadPath(kAssertClosedWorkload),
+ absl::StrCat(fd_closed_on_exec.get())},
+ {}, W_EXITCODE(0, 0), "");
// The assert closed workload exits with code 2 if the file still exists. We
// can use this to do a negative test.
const FileDescriptor fd_open_on_exec =
ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY));
- CheckOutput(WorkloadPath(kAssertClosedWorkload),
- {WorkloadPath(kAssertClosedWorkload),
- absl::StrCat(fd_open_on_exec.get())},
- {}, W_EXITCODE(2, 0), "");
+ CheckExec(WorkloadPath(kAssertClosedWorkload),
+ {WorkloadPath(kAssertClosedWorkload),
+ absl::StrCat(fd_open_on_exec.get())},
+ {}, W_EXITCODE(2, 0), "");
}
TEST(ExecTest, CloexecEventfd) {
@@ -505,9 +528,40 @@ TEST(ExecTest, CloexecEventfd) {
ASSERT_THAT(efd = eventfd(0, EFD_CLOEXEC), SyscallSucceeds());
FileDescriptor fd(efd);
- CheckOutput(WorkloadPath(kAssertClosedWorkload),
- {WorkloadPath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {},
- W_EXITCODE(0, 0), "");
+ CheckExec(WorkloadPath(kAssertClosedWorkload),
+ {WorkloadPath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {},
+ W_EXITCODE(0, 0), "");
+}
+
+TEST(ExecveatTest, BasicWithFDCWD) {
+ std::string path = WorkloadPath(kBasicWorkload);
+ CheckExecveat(AT_FDCWD, path, {path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0),
+ absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, Basic) {
+ std::string absolute_path = WorkloadPath(kBasicWorkload);
+ std::string parent_dir = std::string(Dirname(absolute_path));
+ std::string relative_path = std::string(Basename(absolute_path));
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY));
+
+ CheckExecveat(dirfd.get(), relative_path, {absolute_path}, {}, /*flags=*/0,
+ ArgEnvExitStatus(0, 0), absl::StrCat(absolute_path, "\n"));
+}
+
+TEST(ExecveatTest, AbsolutePathWithFDCWD) {
+ std::string path = WorkloadPath(kBasicWorkload);
+ CheckExecveat(AT_FDCWD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0,
+ absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, AbsolutePath) {
+ std::string path = WorkloadPath(kBasicWorkload);
+ // File descriptor should be ignored when an absolute path is given.
+ const int32_t badFD = -1;
+ CheckExecveat(badFD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0,
+ absl::StrCat(path, "\n"));
}
// Priority consistent across calls to execve()
@@ -522,9 +576,8 @@ TEST(GetpriorityTest, ExecveMaintainsPriority) {
// Program run (priority_execve) will exit(X) where
// X=getpriority(PRIO_PROCESS,0). Check that this exit value is prio.
- CheckOutput(WorkloadPath(kPriorityWorkload),
- {WorkloadPath(kPriorityWorkload)}, {},
- W_EXITCODE(expected_exit_code, 0), "");
+ CheckExec(WorkloadPath(kPriorityWorkload), {WorkloadPath(kPriorityWorkload)},
+ {}, W_EXITCODE(expected_exit_code, 0), "");
}
void ExecWithThread() {
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 37b4e6575..fcf64ee59 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -61,6 +61,9 @@ namespace testing {
namespace {
+using ::testing::AnyOf;
+using ::testing::Eq;
+
constexpr char kMessage[] = "soweoneul malhaebwa";
constexpr in_port_t kPort = 0x409c; // htons(40000)
@@ -83,17 +86,14 @@ void SendUDPMessage(int sock) {
// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up.
TEST(BasicCookedPacketTest, WrongType) {
- // (b/129292371): Remove once we support packet sockets.
- SKIP_IF(IsRunningOnGvisor());
-
if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP),
SyscallFailsWithErrno(EPERM));
GTEST_SKIP();
}
- FileDescriptor sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP));
+ FileDescriptor sock = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_PACKET, SOCK_DGRAM, htons(ETH_P_PUP)));
// Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
@@ -124,9 +124,6 @@ class CookedPacketTest : public ::testing::TestWithParam<int> {
};
void CookedPacketTest::SetUp() {
- // (b/129292371): Remove once we support packet sockets.
- SKIP_IF(IsRunningOnGvisor());
-
if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
SyscallFailsWithErrno(EPERM));
@@ -138,9 +135,6 @@ void CookedPacketTest::SetUp() {
}
void CookedPacketTest::TearDown() {
- // (b/129292371): Remove once we support packet sockets.
- SKIP_IF(IsRunningOnGvisor());
-
// TearDown will be run even if we skip the test.
if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
EXPECT_THAT(close(socket_), SyscallSucceeds());
@@ -177,13 +171,16 @@ TEST_P(CookedPacketTest, Receive) {
ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
reinterpret_cast<struct sockaddr*>(&src), &src_len),
SyscallSucceedsWithValue(packet_size));
- ASSERT_EQ(src_len, sizeof(src));
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
+ // sizeof(sockaddr_ll).
+ ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
+ // TODO(b/129292371): Verify protocol once we return it.
// Verify the source address.
EXPECT_EQ(src.sll_family, AF_PACKET);
- EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP));
EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
- EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK);
EXPECT_EQ(src.sll_halen, ETH_ALEN);
// This came from the loopback device, so the address is all 0s.
for (int i = 0; i < src.sll_halen; i++) {
@@ -213,6 +210,9 @@ TEST_P(CookedPacketTest, Receive) {
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
+ // TODO(b/129292371): Remove once we support packet socket writing.
+ SKIP_IF(IsRunningOnGvisor());
+
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index 6491453b6..d258d353c 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -26,6 +26,7 @@
#include <sys/types.h>
#include <unistd.h>
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/base/internal/endian.h"
#include "test/syscalls/linux/socket_test_util.h"
@@ -61,6 +62,9 @@ namespace testing {
namespace {
+using ::testing::AnyOf;
+using ::testing::Eq;
+
constexpr char kMessage[] = "soweoneul malhaebwa";
constexpr in_port_t kPort = 0x409c; // htons(40000)
@@ -97,9 +101,6 @@ class RawPacketTest : public ::testing::TestWithParam<int> {
};
void RawPacketTest::SetUp() {
- // (b/129292371): Remove once we support packet sockets.
- SKIP_IF(IsRunningOnGvisor());
-
if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())),
SyscallFailsWithErrno(EPERM));
@@ -125,9 +126,6 @@ void RawPacketTest::SetUp() {
}
void RawPacketTest::TearDown() {
- // (b/129292371): Remove once we support packet sockets.
- SKIP_IF(IsRunningOnGvisor());
-
// TearDown will be run even if we skip the test.
if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
EXPECT_THAT(close(socket_), SyscallSucceeds());
@@ -164,16 +162,16 @@ TEST_P(RawPacketTest, Receive) {
ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
reinterpret_cast<struct sockaddr*>(&src), &src_len),
SyscallSucceedsWithValue(packet_size));
- // sizeof(src) is the size of a struct sockaddr_ll. sockaddr_ll ends with an 8
- // byte physical address field, but ethernet (MAC) addresses only use 6 bytes.
- // Thus src_len should get modified to be 2 less than the size of sockaddr_ll.
- ASSERT_EQ(src_len, sizeof(src) - 2);
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
+ // sizeof(sockaddr_ll).
+ ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
+ // TODO(b/129292371): Verify protocol once we return it.
// Verify the source address.
EXPECT_EQ(src.sll_family, AF_PACKET);
- EXPECT_EQ(src.sll_protocol, htons(ETH_P_IP));
EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
- EXPECT_EQ(src.sll_hatype, ARPHRD_LOOPBACK);
EXPECT_EQ(src.sll_halen, ETH_ALEN);
// This came from the loopback device, so the address is all 0s.
for (int i = 0; i < src.sll_halen; i++) {
@@ -214,6 +212,9 @@ TEST_P(RawPacketTest, Receive) {
// Send via a packet socket.
TEST_P(RawPacketTest, Send) {
+ // TODO(b/129292371): Remove once we support packet socket writing.
+ SKIP_IF(IsRunningOnGvisor());
+
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -309,7 +310,7 @@ TEST_P(RawPacketTest, Send) {
}
INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest,
- ::testing::Values(ETH_P_IP /*, ETH_P_ALL*/));
+ ::testing::Values(ETH_P_IP, ETH_P_ALL));
} // namespace
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index d0ef8d380..dcfd5f86c 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -16,16 +16,16 @@
#include <errno.h>
#include <netinet/in.h>
#include <poll.h>
-#include <sys/types.h>
#include <sys/socket.h>
#include <sys/syscall.h>
+#include <sys/types.h>
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
#include "absl/strings/str_split.h"
#include "absl/time/clock.h"
-#include "absl/time/time.h"
-#include "gtest/gtest.h"
-#include "test/util/capability_util.h"
#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
#include "test/util/test_util.h"
@@ -105,25 +105,29 @@ PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp,
EINVAL, absl::StrCat("failed to find ", type, "/", item, " in:", snmp));
}
-TEST(ProcNetSnmp, TcpReset) {
+TEST(ProcNetSnmp, TcpReset_NoRandomSave) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
- const DisableSave ds;
+ DisableSave ds;
uint64_t oldAttemptFails;
uint64_t oldActiveOpens;
uint64_t oldOutRsts;
auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
- oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
- oldOutRsts = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts"));
- oldAttemptFails = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails"));
+ oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
+ oldOutRsts =
+ ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts"));
+ oldAttemptFails = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails"));
FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0));
struct sockaddr_in sin = {
- .sin_family = AF_INET,
- .sin_port = htons(1234),
+ .sin_family = AF_INET,
+ .sin_port = htons(1234),
};
- sin.sin_addr.s_addr = inet_addr("127.0.0.1");
+
+ ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
ASSERT_THAT(connect(s.get(), (struct sockaddr *)&sin, sizeof(sin)),
SyscallFailsWithErrno(ECONNREFUSED));
@@ -131,41 +135,54 @@ TEST(ProcNetSnmp, TcpReset) {
uint64_t newActiveOpens;
uint64_t newOutRsts;
snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
- newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
- newOutRsts = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts"));
- newAttemptFails = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails"));
+ newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
+ newOutRsts =
+ ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "OutRsts"));
+ newAttemptFails = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "AttemptFails"));
EXPECT_EQ(oldActiveOpens, newActiveOpens - 1);
EXPECT_EQ(oldOutRsts, newOutRsts - 1);
EXPECT_EQ(oldAttemptFails, newAttemptFails - 1);
}
-TEST(ProcNetSnmp, TcpEstab) {
+TEST(ProcNetSnmp, TcpEstab_NoRandomSave) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
- const DisableSave ds;
+ DisableSave ds;
uint64_t oldEstabResets;
uint64_t oldActiveOpens;
uint64_t oldPassiveOpens;
uint64_t oldCurrEstab;
auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
- oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
- oldPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens"));
- oldCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab"));
- oldEstabResets = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets"));
+ oldActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
+ oldPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens"));
+ oldCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab"));
+ oldEstabResets = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets"));
FileDescriptor s_listen =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0));
-
struct sockaddr_in sin = {
- .sin_family = AF_INET,
- .sin_port = htons(1234),
+ .sin_family = AF_INET,
+ .sin_port = 0,
};
- sin.sin_addr.s_addr = inet_addr("127.0.0.1");
+
+ ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
ASSERT_THAT(bind(s_listen.get(), (struct sockaddr *)&sin, sizeof(sin)),
SyscallSucceeds());
ASSERT_THAT(listen(s_listen.get(), 1), SyscallSucceeds());
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = sizeof(sin);
+ ASSERT_THAT(
+ getsockname(s_listen.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen),
+ SyscallSucceeds());
+
FileDescriptor s_connect =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0));
ASSERT_THAT(connect(s_connect.get(), (struct sockaddr *)&sin, sizeof(sin)),
@@ -179,9 +196,12 @@ TEST(ProcNetSnmp, TcpEstab) {
uint64_t newPassiveOpens;
uint64_t newCurrEstab;
snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
- newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
- newPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens"));
- newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab"));
+ newActiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "ActiveOpens"));
+ newPassiveOpens = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "PassiveOpens"));
+ newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab"));
EXPECT_EQ(oldActiveOpens, newActiveOpens - 1);
EXPECT_EQ(oldPassiveOpens, newPassiveOpens - 1);
@@ -210,42 +230,47 @@ TEST(ProcNetSnmp, TcpEstab) {
s_connect.reset(-1);
// Wait until the process of the netstack.
- absl::SleepFor(absl::Seconds(1.0));
+ absl::SleepFor(absl::Seconds(1));
snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
- newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab"));
- newEstabResets = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets"));
+ newCurrEstab = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "CurrEstab"));
+ newEstabResets = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Tcp", "EstabResets"));
EXPECT_EQ(oldCurrEstab, newCurrEstab);
EXPECT_EQ(oldEstabResets, newEstabResets - 2);
}
-TEST(ProcNetSnmp, UdpNoPorts) {
+TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
- const DisableSave ds;
+ DisableSave ds;
uint64_t oldOutDatagrams;
uint64_t oldNoPorts;
auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
- oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
- oldNoPorts = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts"));
+ oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
+ oldNoPorts =
+ ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts"));
- FileDescriptor s =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
struct sockaddr_in sin = {
- .sin_family = AF_INET,
- .sin_port = htons(1234),
+ .sin_family = AF_INET,
+ .sin_port = htons(4444),
};
- sin.sin_addr.s_addr = inet_addr("127.0.0.1");
+ ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)),
SyscallSucceedsWithValue(1));
uint64_t newOutDatagrams;
uint64_t newNoPorts;
snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
- newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
- newNoPorts = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts"));
+ newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
+ newNoPorts =
+ ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "NoPorts"));
EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1);
EXPECT_EQ(oldNoPorts, newNoPorts - 1);
@@ -258,24 +283,32 @@ TEST(ProcNetSnmp, UdpIn) {
uint64_t oldOutDatagrams;
uint64_t oldInDatagrams;
auto snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
- oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
- oldInDatagrams = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams"));
+ oldOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
+ oldInDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams"));
+ std::cerr << "snmp: " << std::endl << snmp << std::endl;
FileDescriptor server =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
-
struct sockaddr_in sin = {
- .sin_family = AF_INET,
- .sin_port = htons(1234),
+ .sin_family = AF_INET,
+ .sin_port = htons(0),
};
- sin.sin_addr.s_addr = inet_addr("127.0.0.1");
+ ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
ASSERT_THAT(bind(server.get(), (struct sockaddr *)&sin, sizeof(sin)),
+ SyscallSucceeds());
+ // Get the port bound by the server socket.
+ socklen_t addrlen = sizeof(sin);
+ ASSERT_THAT(
+ getsockname(server.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen),
SyscallSucceeds());
FileDescriptor client =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
- ASSERT_THAT(sendto(client.get(), "a", 1, 0, (struct sockaddr *)&sin,
- sizeof(sin)), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(
+ sendto(client.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)),
+ SyscallSucceedsWithValue(1));
char buf[128];
ASSERT_THAT(recvfrom(server.get(), buf, sizeof(buf), 0, NULL, NULL),
@@ -284,8 +317,11 @@ TEST(ProcNetSnmp, UdpIn) {
uint64_t newOutDatagrams;
uint64_t newInDatagrams;
snmp = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
- newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
- newInDatagrams = ASSERT_NO_ERRNO_AND_VALUE(GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams"));
+ std::cerr << "new snmp: " << std::endl << snmp << std::endl;
+ newOutDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "OutDatagrams"));
+ newInDatagrams = ASSERT_NO_ERRNO_AND_VALUE(
+ GetSNMPMetricFromProc(snmp, "Udp", "InDatagrams"));
EXPECT_EQ(oldOutDatagrams, newOutDatagrams - 1);
EXPECT_EQ(oldInDatagrams, newInDatagrams - 1);
diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc
index 1c56540bc..3331288b7 100644
--- a/test/syscalls/linux/sendfile_socket.cc
+++ b/test/syscalls/linux/sendfile_socket.cc
@@ -185,7 +185,7 @@ TEST_P(SendFileTest, Shutdown) {
// Create a socket.
std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets());
const FileDescriptor client(std::get<0>(fds));
- FileDescriptor server(std::get<1>(fds)); // non-const, released below.
+ FileDescriptor server(std::get<1>(fds)); // non-const, reset below.
// If this is a TCP socket, then turn off linger.
if (GetParam() == AF_INET) {
@@ -210,14 +210,14 @@ TEST_P(SendFileTest, Shutdown) {
// checking the contents (other tests do that), so we just re-use the same
// buffer as above.
ScopedThread t([&]() {
- int done = 0;
+ size_t done = 0;
while (done < data.size()) {
- int n = read(server.get(), data.data(), data.size());
+ int n = RetryEINTR(read)(server.get(), data.data(), data.size());
ASSERT_THAT(n, SyscallSucceeds());
done += n;
}
// Close the server side socket.
- ASSERT_THAT(close(server.release()), SyscallSucceeds());
+ server.reset();
});
// Continuously stream from the file to the socket. Note we do not assert
diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc
index dafe82494..b5c82cd67 100644
--- a/test/syscalls/linux/socket_unix_non_stream.cc
+++ b/test/syscalls/linux/socket_unix_non_stream.cc
@@ -231,11 +231,21 @@ TEST_P(UnixNonStreamSocketPairTest, SendTimeout) {
setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
SyscallSucceeds());
- char buf[100] = {};
+ const int buf_size = 5 * kPageSize;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF, &buf_size,
+ sizeof(buf_size)),
+ SyscallSucceeds());
+ EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVBUF, &buf_size,
+ sizeof(buf_size)),
+ SyscallSucceeds());
+
+ // The buffer size should be big enough to avoid many iterations in the next
+ // loop. Otherwise, this will slow down cooperative_save tests.
+ std::vector<char> buf(kPageSize);
for (;;) {
int ret;
ASSERT_THAT(
- ret = RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0),
+ ret = RetryEINTR(send)(sockets->first_fd(), buf.data(), buf.size(), 0),
::testing::AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(EAGAIN)));
if (ret == -1) {
break;
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
index 659c93945..8f38ed92f 100644
--- a/test/syscalls/linux/socket_unix_stream.cc
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <poll.h>
#include <stdio.h>
#include <sys/un.h>
+
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
@@ -44,6 +46,50 @@ TEST_P(StreamUnixSocketPairTest, ReadOneSideClosed) {
SyscallSucceedsWithValue(0));
}
+TEST_P(StreamUnixSocketPairTest, RecvmsgOneSideClosed) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Set timeout so that it will not wait for ever.
+ struct timeval tv {
+ .tv_sec = 0, .tv_usec = 10
+ };
+ EXPECT_THAT(setsockopt(sockets->second_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv,
+ sizeof(tv)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+
+ char received_data[10] = {};
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(recvmsg(sockets->second_fd(), &msg, MSG_WAITALL),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(StreamUnixSocketPairTest, ReadOneSideClosedWithUnreadData) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char buf[10] = {};
+ ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(ECONNRESET));
+}
+
INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, StreamUnixSocketPairTest,
::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go
index c1e9ce22c..856398994 100644
--- a/test/syscalls/syscall_test_runner.go
+++ b/test/syscalls/syscall_test_runner.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/runsc/specutils"
"gvisor.dev/gvisor/runsc/testutil"
"gvisor.dev/gvisor/test/syscalls/gtest"
+ "gvisor.dev/gvisor/test/uds"
)
// Location of syscall tests, relative to the repo root.
@@ -50,6 +51,8 @@ var (
overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay")
parallel = flag.Bool("parallel", false, "run tests in parallel")
runscPath = flag.String("runsc", "", "path to runsc binary")
+
+ addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests")
)
// runTestCaseNative runs the test case directly on the host machine.
@@ -86,6 +89,19 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
// intepret them.
env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"})
+ if *addUDSTree {
+ socketDir, cleanup, err := uds.CreateSocketTree("/tmp")
+ if err != nil {
+ t.Fatalf("failed to create socket tree: %v", err)
+ }
+ defer cleanup()
+
+ env = append(env, "TEST_UDS_TREE="+socketDir)
+ // On Linux, the concept of "attach" location doesn't exist.
+ // Just pass the same path to make these test identical.
+ env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir)
+ }
+
cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName())
cmd.Env = env
cmd.Stdout = os.Stdout
@@ -96,101 +112,39 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
}
}
-// runsTestCaseRunsc runs the test case in runsc.
-func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
+// runRunsc runs spec in runsc in a standard test configuration.
+//
+// runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR.
+//
+// Returns an error if the sandboxed application exits non-zero.
+func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
+ bundleDir, err := testutil.SetupBundleDir(spec)
if err != nil {
- t.Fatalf("SetupRootDir failed: %v", err)
+ return fmt.Errorf("SetupBundleDir failed: %v", err)
}
- defer os.RemoveAll(rootDir)
-
- // Run a new container with the test executable and filter for the
- // given test suite and name.
- spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName())
-
- // Mark the root as writeable, as some tests attempt to
- // write to the rootfs, and expect EACCES, not EROFS.
- spec.Root.Readonly = false
-
- // Test spec comes with pre-defined mounts that we don't want. Reset it.
- spec.Mounts = nil
- if *useTmpfs {
- // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on
- // features only available in gVisor's internal tmpfs may fail.
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Destination: "/tmp",
- Type: "tmpfs",
- })
- } else {
- // Use a gofer-backed directory as '/tmp'.
- //
- // Tests might be running in parallel, so make sure each has a
- // unique test temp dir.
- //
- // Some tests (e.g., sticky) access this mount from other
- // users, so make sure it is world-accessible.
- tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "")
- if err != nil {
- t.Fatalf("could not create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- if err := os.Chmod(tmpDir, 0777); err != nil {
- t.Fatalf("could not chmod temp dir: %v", err)
- }
-
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Type: "bind",
- Destination: "/tmp",
- Source: tmpDir,
- })
- }
-
- // Set environment variable that indicates we are
- // running in gVisor and with the given platform.
- platformVar := "TEST_ON_GVISOR"
- env := append(os.Environ(), platformVar+"="+*platform)
-
- // Remove env variables that cause the gunit binary to write output
- // files, since they will stomp on eachother, and on the output files
- // from this go test.
- env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"})
-
- // Remove shard env variables so that the gunit binary does not try to
- // intepret them.
- env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"})
-
- // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to
- // be backed by tmpfs.
- for i, kv := range env {
- if strings.HasPrefix(kv, "TEST_TMPDIR=") {
- env[i] = "TEST_TMPDIR=/tmp"
- break
- }
- }
-
- spec.Process.Env = env
+ defer os.RemoveAll(bundleDir)
- bundleDir, err := testutil.SetupBundleDir(spec)
+ rootDir, err := testutil.SetupRootDir()
if err != nil {
- t.Fatalf("SetupBundleDir failed: %v", err)
+ return fmt.Errorf("SetupRootDir failed: %v", err)
}
- defer os.RemoveAll(bundleDir)
+ defer os.RemoveAll(rootDir)
+ name := tc.FullName()
id := testutil.UniqueContainerID()
- log.Infof("Running test %q in container %q", tc.FullName(), id)
+ log.Infof("Running test %q in container %q", name, id)
specutils.LogSpec(spec)
args := []string{
- "-platform", *platform,
"-root", rootDir,
- "-file-access", *fileAccess,
"-network=none",
"-log-format=text",
"-TESTONLY-unsafe-nonroot=true",
"-net-raw=true",
fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM),
"-watchdog-action=panic",
+ "-platform", *platform,
+ "-file-access", *fileAccess,
}
if *overlay {
args = append(args, "-overlay")
@@ -201,14 +155,18 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
if *strace {
args = append(args, "-strace")
}
+ if *addUDSTree {
+ args = append(args, "-fsgofer-host-uds")
+ }
+
if outDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
- tdir := filepath.Join(outDir, strings.Replace(tc.FullName(), "/", "_", -1))
+ tdir := filepath.Join(outDir, strings.Replace(name, "/", "_", -1))
if err := os.MkdirAll(tdir, 0755); err != nil {
- t.Fatalf("could not create test dir: %v", err)
+ return fmt.Errorf("could not create test dir: %v", err)
}
debugLogDir, err := ioutil.TempDir(tdir, "runsc")
if err != nil {
- t.Fatalf("could not create temp dir: %v", err)
+ return fmt.Errorf("could not create temp dir: %v", err)
}
debugLogDir += "/"
log.Infof("runsc logs: %s", debugLogDir)
@@ -248,7 +206,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
if !ok {
return
}
- t.Errorf("%s: Got signal: %v", tc.FullName(), s)
+ log.Warningf("%s: Got signal: %v", name, s)
done := make(chan bool)
go func() {
dArgs := append(args, "-alsologtostderr=true", "debug", "--stacks", id)
@@ -259,14 +217,14 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
done <- true
}()
- timeout := time.Tick(3 * time.Second)
+ timeout := time.After(3 * time.Second)
select {
case <-timeout:
- t.Logf("runsc debug --stacks is timeouted")
+ log.Infof("runsc debug --stacks is timeouted")
case <-done:
}
- t.Logf("Send SIGTERM to the sandbox process")
+ log.Warningf("Send SIGTERM to the sandbox process")
dArgs := append(args, "debug",
fmt.Sprintf("--signal=%d", syscall.SIGTERM),
id)
@@ -275,11 +233,143 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
cmd.Stderr = os.Stderr
cmd.Run()
}()
- if err = cmd.Run(); err != nil {
- t.Errorf("test %q exited with status %v, want 0", tc.FullName(), err)
- }
+
+ err = cmd.Run()
+
signal.Stop(sig)
close(sig)
+
+ return err
+}
+
+// setupUDSTree updates the spec to expose a UDS tree for gofer socket testing.
+func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) {
+ socketDir, cleanup, err := uds.CreateSocketTree("/tmp")
+ if err != nil {
+ return nil, fmt.Errorf("failed to create socket tree: %v", err)
+ }
+
+ // Standard access to entire tree.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets",
+ Source: socketDir,
+ Type: "bind",
+ })
+
+ // Individial attach points for each socket to test mounts that attach
+ // directly to the sockets.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/stream/echo",
+ Source: filepath.Join(socketDir, "stream/echo"),
+ Type: "bind",
+ })
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/stream/nonlistening",
+ Source: filepath.Join(socketDir, "stream/nonlistening"),
+ Type: "bind",
+ })
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/seqpacket/echo",
+ Source: filepath.Join(socketDir, "seqpacket/echo"),
+ Type: "bind",
+ })
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/seqpacket/nonlistening",
+ Source: filepath.Join(socketDir, "seqpacket/nonlistening"),
+ Type: "bind",
+ })
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp/sockets-attach/dgram/null",
+ Source: filepath.Join(socketDir, "dgram/null"),
+ Type: "bind",
+ })
+
+ spec.Process.Env = append(spec.Process.Env, "TEST_UDS_TREE=/tmp/sockets")
+ spec.Process.Env = append(spec.Process.Env, "TEST_UDS_ATTACH_TREE=/tmp/sockets-attach")
+
+ return cleanup, nil
+}
+
+// runsTestCaseRunsc runs the test case in runsc.
+func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
+ // Run a new container with the test executable and filter for the
+ // given test suite and name.
+ spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName())
+
+ // Mark the root as writeable, as some tests attempt to
+ // write to the rootfs, and expect EACCES, not EROFS.
+ spec.Root.Readonly = false
+
+ // Test spec comes with pre-defined mounts that we don't want. Reset it.
+ spec.Mounts = nil
+ if *useTmpfs {
+ // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on
+ // features only available in gVisor's internal tmpfs may fail.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp",
+ Type: "tmpfs",
+ })
+ } else {
+ // Use a gofer-backed directory as '/tmp'.
+ //
+ // Tests might be running in parallel, so make sure each has a
+ // unique test temp dir.
+ //
+ // Some tests (e.g., sticky) access this mount from other
+ // users, so make sure it is world-accessible.
+ tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "")
+ if err != nil {
+ t.Fatalf("could not create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ if err := os.Chmod(tmpDir, 0777); err != nil {
+ t.Fatalf("could not chmod temp dir: %v", err)
+ }
+
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Type: "bind",
+ Destination: "/tmp",
+ Source: tmpDir,
+ })
+ }
+
+ // Set environment variable that indicates we are
+ // running in gVisor and with the given platform.
+ platformVar := "TEST_ON_GVISOR"
+ env := append(os.Environ(), platformVar+"="+*platform)
+
+ // Remove env variables that cause the gunit binary to write output
+ // files, since they will stomp on eachother, and on the output files
+ // from this go test.
+ env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"})
+
+ // Remove shard env variables so that the gunit binary does not try to
+ // intepret them.
+ env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"})
+
+ // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to
+ // be backed by tmpfs.
+ for i, kv := range env {
+ if strings.HasPrefix(kv, "TEST_TMPDIR=") {
+ env[i] = "TEST_TMPDIR=/tmp"
+ break
+ }
+ }
+
+ spec.Process.Env = env
+
+ if *addUDSTree {
+ cleanup, err := setupUDSTree(spec)
+ if err != nil {
+ t.Fatalf("error creating UDS tree: %v", err)
+ }
+ defer cleanup()
+ }
+
+ if err := runRunsc(tc, spec); err != nil {
+ t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err)
+ }
}
// filterEnv returns an environment with the blacklisted variables removed.
diff --git a/test/uds/BUILD b/test/uds/BUILD
new file mode 100644
index 000000000..a3843e699
--- /dev/null
+++ b/test/uds/BUILD
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "uds",
+ testonly = 1,
+ srcs = ["uds.go"],
+ importpath = "gvisor.dev/gvisor/test/uds",
+ deps = [
+ "//pkg/log",
+ "//pkg/unet",
+ ],
+)
diff --git a/test/uds/uds.go b/test/uds/uds.go
new file mode 100644
index 000000000..b714c61b0
--- /dev/null
+++ b/test/uds/uds.go
@@ -0,0 +1,228 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package uds contains helpers for testing external UDS functionality.
+package uds
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/unet"
+)
+
+// createEchoSocket creates a socket that echoes back anything received.
+//
+// Only works for stream, seqpacket sockets.
+func createEchoSocket(path string, protocol int) (cleanup func(), err error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating echo(%d) socket: %v", protocol, err)
+ }
+
+ if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil {
+ return nil, fmt.Errorf("error binding echo(%d) socket: %v", protocol, err)
+ }
+
+ if err := syscall.Listen(fd, 0); err != nil {
+ return nil, fmt.Errorf("error listening echo(%d) socket: %v", protocol, err)
+ }
+
+ server, err := unet.NewServerSocket(fd)
+ if err != nil {
+ return nil, fmt.Errorf("error creating echo(%d) unet socket: %v", protocol, err)
+ }
+
+ acceptAndEchoOne := func() error {
+ s, err := server.Accept()
+ if err != nil {
+ return fmt.Errorf("failed to accept: %v", err)
+ }
+ defer s.Close()
+
+ for {
+ buf := make([]byte, 512)
+ for {
+ n, err := s.Read(buf)
+ if err == io.EOF {
+ return nil
+ }
+ if err != nil {
+ return fmt.Errorf("failed to read: %d, %v", n, err)
+ }
+
+ n, err = s.Write(buf[:n])
+ if err != nil {
+ return fmt.Errorf("failed to write: %d, %v", n, err)
+ }
+ }
+ }
+ }
+
+ go func() {
+ for {
+ if err := acceptAndEchoOne(); err != nil {
+ log.Warningf("Failed to handle echo(%d) socket: %v", protocol, err)
+ return
+ }
+ }
+ }()
+
+ cleanup = func() {
+ if err := server.Close(); err != nil {
+ log.Warningf("Failed to close echo(%d) socket: %v", protocol, err)
+ }
+ }
+
+ return cleanup, nil
+}
+
+// createNonListeningSocket creates a socket that is bound but not listening.
+//
+// Only relevant for stream, seqpacket sockets.
+func createNonListeningSocket(path string, protocol int) (cleanup func(), err error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating nonlistening(%d) socket: %v", protocol, err)
+ }
+
+ if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil {
+ return nil, fmt.Errorf("error binding nonlistening(%d) socket: %v", protocol, err)
+ }
+
+ cleanup = func() {
+ if err := syscall.Close(fd); err != nil {
+ log.Warningf("Failed to close nonlistening(%d) socket: %v", protocol, err)
+ }
+ }
+
+ return cleanup, nil
+}
+
+// createNullSocket creates a socket that reads anything received.
+//
+// Only works for dgram sockets.
+func createNullSocket(path string, protocol int) (cleanup func(), err error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating null(%d) socket: %v", protocol, err)
+ }
+
+ if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil {
+ return nil, fmt.Errorf("error binding null(%d) socket: %v", protocol, err)
+ }
+
+ s, err := unet.NewSocket(fd)
+ if err != nil {
+ return nil, fmt.Errorf("error creating null(%d) unet socket: %v", protocol, err)
+ }
+
+ go func() {
+ buf := make([]byte, 512)
+ for {
+ n, err := s.Read(buf)
+ if err != nil {
+ log.Warningf("failed to read: %d, %v", n, err)
+ return
+ }
+ }
+ }()
+
+ cleanup = func() {
+ if err := s.Close(); err != nil {
+ log.Warningf("Failed to close null(%d) socket: %v", protocol, err)
+ }
+ }
+
+ return cleanup, nil
+}
+
+type socketCreator func(path string, proto int) (cleanup func(), err error)
+
+// CreateSocketTree creates a local tree of unix domain sockets for use in
+// testing:
+// * /stream/echo
+// * /stream/nonlistening
+// * /seqpacket/echo
+// * /seqpacket/nonlistening
+// * /dgram/null
+func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) {
+ dir, err = ioutil.TempDir(baseDir, "sockets")
+ if err != nil {
+ return "", nil, fmt.Errorf("error creating temp dir: %v", err)
+ }
+
+ var protocols = []struct {
+ protocol int
+ name string
+ sockets map[string]socketCreator
+ }{
+ {
+ protocol: syscall.SOCK_STREAM,
+ name: "stream",
+ sockets: map[string]socketCreator{
+ "echo": createEchoSocket,
+ "nonlistening": createNonListeningSocket,
+ },
+ },
+ {
+ protocol: syscall.SOCK_SEQPACKET,
+ name: "seqpacket",
+ sockets: map[string]socketCreator{
+ "echo": createEchoSocket,
+ "nonlistening": createNonListeningSocket,
+ },
+ },
+ {
+ protocol: syscall.SOCK_DGRAM,
+ name: "dgram",
+ sockets: map[string]socketCreator{
+ "null": createNullSocket,
+ },
+ },
+ }
+
+ var cleanups []func()
+ for _, proto := range protocols {
+ protoDir := filepath.Join(dir, proto.name)
+ if err := os.Mkdir(protoDir, 0755); err != nil {
+ return "", nil, fmt.Errorf("error creating %s dir: %v", proto.name, err)
+ }
+
+ for name, fn := range proto.sockets {
+ path := filepath.Join(protoDir, name)
+ cleanup, err := fn(path, proto.protocol)
+ if err != nil {
+ return "", nil, fmt.Errorf("error creating %s %s socket: %v", proto.name, name, err)
+ }
+
+ cleanups = append(cleanups, cleanup)
+ }
+ }
+
+ cleanup = func() {
+ for _, c := range cleanups {
+ c()
+ }
+
+ os.RemoveAll(dir)
+ }
+
+ return dir, cleanup, nil
+}
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
index f7d231b14..88b1e7911 100644
--- a/test/util/fs_util.cc
+++ b/test/util/fs_util.cc
@@ -163,6 +163,26 @@ PosixError Chmod(absl::string_view path, int mode) {
return NoError();
}
+PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode,
+ dev_t dev) {
+ int res = mknodat(dfd.get(), std::string(path).c_str(), mode, dev);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("mknod ", path));
+ }
+
+ return NoError();
+}
+
+PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path,
+ int flags) {
+ int res = unlinkat(dfd.get(), std::string(path).c_str(), flags);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("unlink ", path));
+ }
+
+ return NoError();
+}
+
PosixError Mkdir(absl::string_view path, int mode) {
int res = mkdir(std::string(path).c_str(), mode);
if (res < 0) {
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index e5b555891..ee1b341d7 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -21,6 +21,7 @@
#include <unistd.h>
#include "absl/strings/string_view.h"
+#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
namespace gvisor {
@@ -44,6 +45,14 @@ PosixError Delete(absl::string_view path);
// Changes the mode of a file or returns an error.
PosixError Chmod(absl::string_view path, int mode);
+// Create a special or ordinary file.
+PosixError MknodAt(const FileDescriptor& dfd, absl::string_view path, int mode,
+ dev_t dev);
+
+// Unlink the file.
+PosixError UnlinkAt(const FileDescriptor& dfd, absl::string_view path,
+ int flags);
+
// Truncates a file to the given length or returns an error.
PosixError Truncate(absl::string_view path, int length);
diff --git a/test/util/multiprocess_util.cc b/test/util/multiprocess_util.cc
index 95f5f3b4f..8b676751b 100644
--- a/test/util/multiprocess_util.cc
+++ b/test/util/multiprocess_util.cc
@@ -14,6 +14,7 @@
#include "test/util/multiprocess_util.h"
+#include <asm/unistd.h>
#include <errno.h>
#include <fcntl.h>
#include <signal.h>
@@ -30,11 +31,12 @@
namespace gvisor {
namespace testing {
-PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename,
- const ExecveArray& argv,
- const ExecveArray& envv,
- const std::function<void()>& fn, pid_t* child,
- int* execve_errno) {
+namespace {
+
+// exec_fn wraps a variant of the exec family, e.g. execve or execveat.
+PosixErrorOr<Cleanup> ForkAndExecHelper(const std::function<void()>& exec_fn,
+ const std::function<void()>& fn,
+ pid_t* child, int* execve_errno) {
int pfds[2];
int ret = pipe2(pfds, O_CLOEXEC);
if (ret < 0) {
@@ -76,7 +78,9 @@ PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename,
fn();
}
- execve(filename.c_str(), argv.get(), envv.get());
+ // Call variant of exec function.
+ exec_fn();
+
int error = errno;
if (WriteFd(pfds[1], &error, sizeof(error)) != sizeof(error)) {
// We can't do much if the write fails, but we can at least exit with a
@@ -116,6 +120,36 @@ PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename,
return std::move(cleanup);
}
+} // namespace
+
+PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename,
+ const ExecveArray& argv,
+ const ExecveArray& envv,
+ const std::function<void()>& fn, pid_t* child,
+ int* execve_errno) {
+ char* const* argv_data = argv.get();
+ char* const* envv_data = envv.get();
+ const std::function<void()> exec_fn = [=] {
+ execve(filename.c_str(), argv_data, envv_data);
+ };
+ return ForkAndExecHelper(exec_fn, fn, child, execve_errno);
+}
+
+PosixErrorOr<Cleanup> ForkAndExecveat(const int32_t dirfd,
+ const std::string& pathname,
+ const ExecveArray& argv,
+ const ExecveArray& envv, const int flags,
+ const std::function<void()>& fn,
+ pid_t* child, int* execve_errno) {
+ char* const* argv_data = argv.get();
+ char* const* envv_data = envv.get();
+ const std::function<void()> exec_fn = [=] {
+ syscall(__NR_execveat, dirfd, pathname.c_str(), argv_data, envv_data,
+ flags);
+ };
+ return ForkAndExecHelper(exec_fn, fn, child, execve_errno);
+}
+
PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn) {
pid_t pid = fork();
if (pid == 0) {
diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h
index 0aecd3439..c413d63ea 100644
--- a/test/util/multiprocess_util.h
+++ b/test/util/multiprocess_util.h
@@ -102,6 +102,13 @@ inline PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename,
return ForkAndExec(filename, argv, envv, [] {}, child, execve_errno);
}
+// Equivalent to ForkAndExec, except using dirfd and flags with execveat.
+PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, const std::string& pathname,
+ const ExecveArray& argv,
+ const ExecveArray& envv, int flags,
+ const std::function<void()>& fn,
+ pid_t* child, int* execve_errno);
+
// Calls fn in a forked subprocess and returns the exit status of the
// subprocess.
//
diff --git a/third_party/gvsync/BUILD b/third_party/gvsync/BUILD
index 8dab51daa..7d6d59c48 100644
--- a/third_party/gvsync/BUILD
+++ b/third_party/gvsync/BUILD
@@ -1,4 +1,5 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template")
package(
default_visibility = ["//:sandbox"],
@@ -7,8 +8,6 @@ package(
exports_files(["LICENSE"])
-load("//tools/go_generics:defs.bzl", "go_template")
-
go_template(
name = "generic_atomicptr",
srcs = ["atomicptr_unsafe.go"],
diff --git a/third_party/gvsync/atomicptrtest/BUILD b/third_party/gvsync/atomicptrtest/BUILD
index 6cf69ea91..447ecf96a 100644
--- a/third_party/gvsync/atomicptrtest/BUILD
+++ b/third_party/gvsync/atomicptrtest/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "atomicptr_int",
out = "atomicptr_int_unsafe.go",
diff --git a/third_party/gvsync/seqatomictest/BUILD b/third_party/gvsync/seqatomictest/BUILD
index 9e87e0bc5..c858c20c4 100644
--- a/third_party/gvsync/seqatomictest/BUILD
+++ b/third_party/gvsync/seqatomictest/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "seqatomic_int",
out = "seqatomic_int_unsafe.go",
diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD
index a6f8cdd3c..9d26a88b7 100644
--- a/tools/go_generics/rules_tests/BUILD
+++ b/tools/go_generics/rules_tests/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-
go_template_instance(
name = "instance",
out = "instance_test.go",