summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/BUILD4
-rw-r--r--pkg/abi/linux/BUILD7
-rw-r--r--pkg/abi/linux/socket.go12
-rw-r--r--pkg/bits/BUILD3
-rw-r--r--pkg/bpf/BUILD3
-rw-r--r--pkg/cpuid/BUILD3
-rw-r--r--pkg/fd/BUILD3
-rw-r--r--pkg/fd/fd.go8
-rw-r--r--pkg/refs/BUILD5
-rw-r--r--pkg/segment/BUILD4
-rw-r--r--pkg/segment/test/BUILD3
-rw-r--r--pkg/sentry/arch/BUILD3
-rw-r--r--pkg/sentry/context/contexttest/BUILD4
-rw-r--r--pkg/sentry/device/BUILD3
-rw-r--r--pkg/sentry/fs/BUILD5
-rw-r--r--pkg/sentry/fs/dev/BUILD4
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD3
-rw-r--r--pkg/sentry/fs/filetest/BUILD4
-rw-r--r--pkg/sentry/fs/fsutil/BUILD5
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go26
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go19
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go18
-rw-r--r--pkg/sentry/fs/gofer/BUILD3
-rw-r--r--pkg/sentry/fs/gofer/file.go66
-rw-r--r--pkg/sentry/fs/gofer/file_state.go2
-rw-r--r--pkg/sentry/fs/gofer/fs.go11
-rw-r--r--pkg/sentry/fs/gofer/handles.go12
-rw-r--r--pkg/sentry/fs/gofer/inode.go108
-rw-r--r--pkg/sentry/fs/gofer/session.go5
-rw-r--r--pkg/sentry/fs/host/BUILD3
-rw-r--r--pkg/sentry/fs/host/socket.go3
-rw-r--r--pkg/sentry/fs/lock/BUILD5
-rw-r--r--pkg/sentry/fs/proc/BUILD3
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD3
-rw-r--r--pkg/sentry/fs/ramfs/BUILD3
-rw-r--r--pkg/sentry/fs/sys/BUILD4
-rw-r--r--pkg/sentry/fs/timerfd/BUILD4
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD3
-rw-r--r--pkg/sentry/fs/tty/BUILD3
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD5
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD3
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go3
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/ext/symlink.go4
-rw-r--r--pkg/sentry/fsimpl/memfs/BUILD20
-rw-r--r--pkg/sentry/fsimpl/memfs/filesystem.go55
-rw-r--r--pkg/sentry/fsimpl/memfs/memfs.go2
-rw-r--r--pkg/sentry/fsimpl/memfs/named_pipe.go59
-rw-r--r--pkg/sentry/fsimpl/memfs/pipe_test.go233
-rw-r--r--pkg/sentry/inet/BUILD4
-rw-r--r--pkg/sentry/kernel/BUILD5
-rw-r--r--pkg/sentry/kernel/auth/BUILD4
-rw-r--r--pkg/sentry/kernel/contexttest/BUILD4
-rw-r--r--pkg/sentry/kernel/epoll/BUILD5
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD3
-rw-r--r--pkg/sentry/kernel/fasync/BUILD4
-rw-r--r--pkg/sentry/kernel/fd_table.go22
-rw-r--r--pkg/sentry/kernel/fd_table_test.go36
-rw-r--r--pkg/sentry/kernel/futex/BUILD5
-rw-r--r--pkg/sentry/kernel/pipe/BUILD8
-rw-r--r--pkg/sentry/kernel/pipe/node.go72
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go24
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go213
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go111
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go220
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD5
-rw-r--r--pkg/sentry/kernel/shm/BUILD4
-rw-r--r--pkg/sentry/kernel/time/BUILD4
-rw-r--r--pkg/sentry/limits/BUILD3
-rw-r--r--pkg/sentry/loader/BUILD3
-rw-r--r--pkg/sentry/memmap/BUILD5
-rw-r--r--pkg/sentry/mm/BUILD5
-rw-r--r--pkg/sentry/pgalloc/BUILD5
-rw-r--r--pkg/sentry/platform/BUILD4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go16
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go3
-rw-r--r--pkg/sentry/platform/ring0/BUILD3
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD3
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD3
-rw-r--r--pkg/sentry/socket/BUILD4
-rw-r--r--pkg/sentry/socket/control/BUILD4
-rw-r--r--pkg/sentry/socket/hostinet/BUILD4
-rw-r--r--pkg/sentry/socket/netfilter/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD3
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD4
-rw-r--r--pkg/sentry/socket/netstack/BUILD4
-rw-r--r--pkg/sentry/socket/netstack/netstack.go40
-rw-r--r--pkg/sentry/socket/netstack/provider.go53
-rw-r--r--pkg/sentry/socket/unix/BUILD4
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD4
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go5
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go16
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go9
-rw-r--r--pkg/sentry/socket/unix/unix.go3
-rw-r--r--pkg/sentry/syscalls/linux/BUILD4
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go11
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go8
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go1
-rw-r--r--pkg/sentry/time/BUILD3
-rw-r--r--pkg/sentry/usage/BUILD4
-rw-r--r--pkg/sentry/usermem/BUILD5
-rw-r--r--pkg/sentry/vfs/file_description.go4
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go4
-rw-r--r--pkg/sentry/vfs/syscalls.go22
-rw-r--r--pkg/state/BUILD3
-rw-r--r--pkg/tcpip/buffer/prependable.go5
-rw-r--r--pkg/tcpip/checker/checker.go61
-rw-r--r--pkg/tcpip/header/BUILD19
-rw-r--r--pkg/tcpip/header/checksum.go50
-rw-r--r--pkg/tcpip/header/checksum_test.go109
-rw-r--r--pkg/tcpip/header/eth.go62
-rw-r--r--pkg/tcpip/header/eth_test.go68
-rw-r--r--pkg/tcpip/header/icmpv6.go32
-rw-r--r--pkg/tcpip/header/ipv6.go12
-rw-r--r--pkg/tcpip/header/ndp_neighbor_advert.go110
-rw-r--r--pkg/tcpip/header/ndp_neighbor_solicit.go52
-rw-r--r--pkg/tcpip/header/ndp_options.go172
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go112
-rw-r--r--pkg/tcpip/header/ndp_test.go199
-rw-r--r--pkg/tcpip/link/channel/channel.go48
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go143
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go21
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go5
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go22
-rw-r--r--pkg/tcpip/link/loopback/BUILD1
-rw-r--r--pkg/tcpip/link/loopback/loopback.go27
-rw-r--r--pkg/tcpip/link/muxed/injectable.go34
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go2
-rw-r--r--pkg/tcpip/link/rawfile/BUILD5
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go11
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go22
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go16
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go85
-rw-r--r--pkg/tcpip/link/waitable/waitable.go28
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go19
-rw-r--r--pkg/tcpip/network/arp/arp.go9
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD1
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go16
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go10
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go10
-rw-r--r--pkg/tcpip/network/ip_test.go9
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go35
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go14
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go132
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go12
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go30
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go4
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/ndp.go319
-rw-r--r--pkg/tcpip/stack/ndp_test.go443
-rw-r--r--pkg/tcpip/stack/nic.go309
-rw-r--r--pkg/tcpip/stack/registration.go83
-rw-r--r--pkg/tcpip/stack/route.go48
-rw-r--r--pkg/tcpip/stack/stack.go231
-rw-r--r--pkg/tcpip/stack/stack_test.go302
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go2
-rw-r--r--pkg/tcpip/tcpip.go13
-rw-r--r--pkg/tcpip/transport/packet/BUILD46
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go363
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go72
-rw-r--r--pkg/tcpip/transport/raw/BUILD15
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go30
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go14
-rw-r--r--pkg/tcpip/transport/raw/protocol.go12
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/connect.go77
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go11
173 files changed, 5412 insertions, 720 deletions
diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD
index 32c601a03..f5c08ea06 100644
--- a/pkg/abi/BUILD
+++ b/pkg/abi/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "abi",
srcs = [
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index f45934466..7c17109a6 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -1,13 +1,12 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
# Package linux contains the constants and types needed to interface with a
# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix
# when the host OS may not be Linux.
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "linux",
srcs = [
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index d5b731390..2e2cc6be7 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -256,6 +256,17 @@ type SockAddrInet6 struct {
Scope_id uint32
}
+// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h.
+type SockAddrLink struct {
+ Family uint16
+ Protocol uint16
+ InterfaceIndex int32
+ ARPHardwareType uint16
+ PacketType byte
+ HardwareAddrLen byte
+ HardwareAddr [8]byte
+}
+
// UnixPathMax is the maximum length of the path in an AF_UNIX socket.
//
// From uapi/linux/un.h.
@@ -278,6 +289,7 @@ type SockAddr interface {
func (s *SockAddrInet) implementsSockAddr() {}
func (s *SockAddrInet6) implementsSockAddr() {}
+func (s *SockAddrLink) implementsSockAddr() {}
func (s *SockAddrUnix) implementsSockAddr() {}
func (s *SockAddrNetlink) implementsSockAddr() {}
diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD
index 1b5dac99a..93b88a29a 100644
--- a/pkg/bits/BUILD
+++ b/pkg/bits/BUILD
@@ -1,10 +1,9 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-
go_library(
name = "bits",
srcs = [
diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD
index 8d31e068c..fba5643e8 100644
--- a/pkg/bpf/BUILD
+++ b/pkg/bpf/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "bpf",
srcs = [
diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD
index 32422f9e2..ed111fd2a 100644
--- a/pkg/cpuid/BUILD
+++ b/pkg/cpuid/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "cpuid",
srcs = [
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index c7f549428..afa8f7659 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -8,9 +8,6 @@ go_library(
srcs = ["fd.go"],
importpath = "gvisor.dev/gvisor/pkg/fd",
visibility = ["//visibility:public"],
- deps = [
- "//pkg/unet",
- ],
)
go_test(
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
index 7691b477b..83bcfe220 100644
--- a/pkg/fd/fd.go
+++ b/pkg/fd/fd.go
@@ -22,8 +22,6 @@ import (
"runtime"
"sync/atomic"
"syscall"
-
- "gvisor.dev/gvisor/pkg/unet"
)
// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
@@ -187,12 +185,6 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
return New(f), nil
}
-// DialUnix connects to a Unix Domain Socket and return the file descriptor.
-func DialUnix(path string) (*FD, error) {
- socket, err := unet.Connect(path, false)
- return New(socket.FD()), err
-}
-
// Close closes the file descriptor contained in the FD.
//
// Close is safe to call multiple times, but will return an error after the
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 827385139..7ad59dfd7 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "weak_ref_list",
out = "weak_ref_list.go",
diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD
index 700385907..1b487b887 100644
--- a/pkg/segment/BUILD
+++ b/pkg/segment/BUILD
@@ -1,10 +1,10 @@
+load("//tools/go_generics:defs.bzl", "go_template")
+
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
)
-load("//tools/go_generics:defs.bzl", "go_template")
-
go_template(
name = "generic_range",
srcs = ["range.go"],
diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD
index 12d7c77d2..a27c35e21 100644
--- a/pkg/segment/test/BUILD
+++ b/pkg/segment/test/BUILD
@@ -1,13 +1,12 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"],
)
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "int_range",
out = "int_range.go",
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index c71cff9f3..18c73cc24 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@rules_cc//cc:defs.bzl", "cc_proto_library")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "arch",
srcs = [
diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD
index 3b6841b7e..581e7aa96 100644
--- a/pkg/sentry/context/contexttest/BUILD
+++ b/pkg/sentry/context/contexttest/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "contexttest",
testonly = 1,
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
index 0c86197f7..1098ed777 100644
--- a/pkg/sentry/device/BUILD
+++ b/pkg/sentry/device/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "device",
srcs = ["device.go"],
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index 3119a61b6..378602cc9 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "fs",
srcs = [
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 80e106e6f..a0d9e8496 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "dev",
srcs = [
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index b9bd9ed17..277ee4c31 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "fdpipe",
srcs = [
diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD
index a9d6d9301..358dc2be3 100644
--- a/pkg/sentry/fs/filetest/BUILD
+++ b/pkg/sentry/fs/filetest/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "filetest",
testonly = 1,
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index b4ac83dc4..b2e8d9c77 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "dirty_set_impl",
out = "dirty_set_impl.go",
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index e239f12a5..b06a71cc2 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -209,3 +209,29 @@ func (f *HostFileMapper) unmapAndRemoveLocked(chunkStart uint64, m mapping) {
}
delete(f.mappings, chunkStart)
}
+
+// RegenerateMappings must be called when the file description mapped by f
+// changes, to replace existing mappings of the previous file description.
+func (f *HostFileMapper) RegenerateMappings(fd int) error {
+ f.mapsMu.Lock()
+ defer f.mapsMu.Unlock()
+
+ for chunkStart, m := range f.mappings {
+ prot := syscall.PROT_READ
+ if m.writable {
+ prot |= syscall.PROT_WRITE
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_MMAP,
+ m.addr,
+ chunkSize,
+ uintptr(prot),
+ syscall.MAP_SHARED|syscall.MAP_FIXED,
+ uintptr(fd),
+ uintptr(chunkStart))
+ if errno != 0 {
+ return errno
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index 693625ddc..30475f340 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -100,13 +100,30 @@ func (h *HostMappable) Translate(ctx context.Context, required, optional memmap.
}
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
-func (h *HostMappable) InvalidateUnsavable(ctx context.Context) error {
+func (h *HostMappable) InvalidateUnsavable(_ context.Context) error {
h.mu.Lock()
h.mappings.InvalidateAll(memmap.InvalidateOpts{})
h.mu.Unlock()
return nil
}
+// NotifyChangeFD must be called after the file description represented by
+// CachedFileObject.FD() changes.
+func (h *HostMappable) NotifyChangeFD() error {
+ // Update existing sentry mappings to refer to the new file description.
+ if err := h.hostFileMapper.RegenerateMappings(h.backingFile.FD()); err != nil {
+ return err
+ }
+
+ // Shoot down existing application mappings of the old file description;
+ // they will be remapped with the new file description on demand.
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ h.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ return nil
+}
+
// MapInternal implements platform.File.MapInternal.
func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write)
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index dd80757dc..798920d18 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -959,6 +959,23 @@ func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error
return nil
}
+// NotifyChangeFD must be called after the file description represented by
+// CachedFileObject.FD() changes.
+func (c *CachingInodeOperations) NotifyChangeFD() error {
+ // Update existing sentry mappings to refer to the new file description.
+ if err := c.hostFileMapper.RegenerateMappings(c.backingFile.FD()); err != nil {
+ return err
+ }
+
+ // Shoot down existing application mappings of the old file description;
+ // they will be remapped with the new file description on demand.
+ c.mapsMu.Lock()
+ defer c.mapsMu.Unlock()
+
+ c.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ return nil
+}
+
// Evict implements pgalloc.EvictableMemoryUser.Evict.
func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.EvictableRange) {
c.mapsMu.Lock()
@@ -1027,7 +1044,6 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
}
c.refs.MergeAdjacent(fr)
c.dataMu.Unlock()
-
}
// MapInternal implements platform.File.MapInternal. This is used when we
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 2b71ca0e1..4a005c605 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "gofer",
srcs = [
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index 9e2e412cd..7960b9c7b 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -214,28 +214,64 @@ func (f *fileOperations) readdirAll(ctx context.Context) (map[string]fs.DentAttr
return entries, nil
}
+// maybeSync will call FSync on the file if either the cache policy or file
+// flags require it.
+func (f *fileOperations) maybeSync(ctx context.Context, file *fs.File, offset, n int64) error {
+ if n == 0 {
+ // Nothing to sync.
+ return nil
+ }
+
+ if f.inodeOperations.session().cachePolicy.writeThrough(file.Dirent.Inode) {
+ // Call WriteOut directly, as some "writethrough" filesystems
+ // do not support sync.
+ return f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode)
+ }
+
+ flags := file.Flags()
+ var syncType fs.SyncType
+ switch {
+ case flags.Direct || flags.Sync:
+ syncType = fs.SyncAll
+ case flags.DSync:
+ syncType = fs.SyncData
+ default:
+ // No need to sync.
+ return nil
+ }
+
+ return f.Fsync(ctx, file, offset, offset+n, syncType)
+}
+
// Write implements fs.FileOperations.Write.
func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
if fs.IsDir(file.Dirent.Inode.StableAttr) {
// Not all remote file systems enforce this so this client does.
return 0, syserror.EISDIR
}
- cp := f.inodeOperations.session().cachePolicy
- if cp.useCachingInodeOps(file.Dirent.Inode) {
- n, err := f.inodeOperations.cachingInodeOps.Write(ctx, src, offset)
- if err != nil {
- return n, err
- }
- if cp.writeThrough(file.Dirent.Inode) {
- // Write out the file.
- err = f.inodeOperations.cachingInodeOps.WriteOut(ctx, file.Dirent.Inode)
- }
- return n, err
+
+ var (
+ n int64
+ err error
+ )
+ // The write is handled in different ways depending on the cache policy
+ // and availability of a host-mappable FD.
+ if f.inodeOperations.session().cachePolicy.useCachingInodeOps(file.Dirent.Inode) {
+ n, err = f.inodeOperations.cachingInodeOps.Write(ctx, src, offset)
+ } else if f.inodeOperations.fileState.hostMappable != nil {
+ n, err = f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset)
+ } else {
+ n, err = src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
}
- if f.inodeOperations.fileState.hostMappable != nil {
- return f.inodeOperations.fileState.hostMappable.Write(ctx, src, offset)
+
+ // We may need to sync the written bytes.
+ if syncErr := f.maybeSync(ctx, file, offset, n); syncErr != nil {
+ // Sync failed. Report 0 bytes written, since none of them are
+ // guaranteed to have been synced.
+ return 0, syncErr
}
- return src.CopyInTo(ctx, f.handles.readWriterAt(ctx, offset))
+
+ return n, err
}
// incrementReadCounters increments the read counters for the read starting at the given time. We
@@ -273,7 +309,7 @@ func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IO
}
// Fsync implements fs.FileOperations.Fsync.
-func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start int64, end int64, syncType fs.SyncType) error {
+func (f *fileOperations) Fsync(ctx context.Context, file *fs.File, start, end int64, syncType fs.SyncType) error {
switch syncType {
case fs.SyncAll, fs.SyncData:
if err := file.Dirent.Inode.WriteOut(ctx); err != nil {
diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go
index 9aa68a70e..c2fbb4be9 100644
--- a/pkg/sentry/fs/gofer/file_state.go
+++ b/pkg/sentry/fs/gofer/file_state.go
@@ -29,7 +29,7 @@ func (f *fileOperations) afterLoad() {
// Manually load the open handles.
var err error
// TODO(b/38173783): Context is not plumbed to save/restore.
- f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags)
+ f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags, f.inodeOperations.cachingInodeOps)
if err != nil {
return fmt.Errorf("failed to re-open handle: %v", err)
}
diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go
index 8f8ab5d29..cf96dd9fa 100644
--- a/pkg/sentry/fs/gofer/fs.go
+++ b/pkg/sentry/fs/gofer/fs.go
@@ -58,6 +58,11 @@ const (
// If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to
// true.
limitHostFDTranslationKey = "limit_host_fd_translation"
+
+ // overlayfsStaleRead if present closes cached readonly file after the first
+ // write. This is done to workaround a limitation of overlayfs in kernels
+ // before 4.19 where open FDs are not updated after the file is copied up.
+ overlayfsStaleRead = "overlayfs_stale_read"
)
// defaultAname is the default attach name.
@@ -145,6 +150,7 @@ type opts struct {
version string
privateunixsocket bool
limitHostFDTranslation bool
+ overlayfsStaleRead bool
}
// options parses mount(2) data into structured options.
@@ -247,6 +253,11 @@ func options(data string) (opts, error) {
delete(options, limitHostFDTranslationKey)
}
+ if _, ok := options[overlayfsStaleRead]; ok {
+ o.overlayfsStaleRead = true
+ delete(options, overlayfsStaleRead)
+ }
+
// Fail to attach if the caller wanted us to do something that we
// don't support.
if len(options) > 0 {
diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go
index 27eeae3d9..39c8ec33d 100644
--- a/pkg/sentry/fs/gofer/handles.go
+++ b/pkg/sentry/fs/gofer/handles.go
@@ -39,14 +39,22 @@ type handles struct {
// Host is an *fd.FD handle. May be nil.
Host *fd.FD
+
+ // isHostBorrowed tells whether 'Host' is owned or borrowed. If owned, it's
+ // closed on destruction, otherwise it's released.
+ isHostBorrowed bool
}
// DecRef drops a reference on handles.
func (h *handles) DecRef() {
h.DecRefWithDestructor(func() {
if h.Host != nil {
- if err := h.Host.Close(); err != nil {
- log.Warningf("error closing host file: %v", err)
+ if h.isHostBorrowed {
+ h.Host.Release()
+ } else {
+ if err := h.Host.Close(); err != nil {
+ log.Warningf("error closing host file: %v", err)
+ }
}
}
// FIXME(b/38173783): Context is not plumbed here.
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index d918d6620..99910388f 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -100,7 +100,7 @@ type inodeFileState struct {
// true.
//
// Once readHandles becomes non-nil, it can't be changed until
- // inodeFileState.Release(), because of a defect in the
+ // inodeFileState.Release()*, because of a defect in the
// fsutil.CachedFileObject interface: there's no way for the caller of
// fsutil.CachedFileObject.FD() to keep the returned FD open, so if we
// racily replace readHandles after inodeFileState.FD() has returned
@@ -108,6 +108,9 @@ type inodeFileState struct {
// FD. writeHandles can be changed if writeHandlesRW is false, since
// inodeFileState.FD() can't return a write-only FD, but can't be changed
// if writeHandlesRW is true for the same reason.
+ //
+ // * There is one notable exception in recreateReadHandles(), where it dup's
+ // the FD and invalidates the page cache.
readHandles *handles `state:"nosave"`
writeHandles *handles `state:"nosave"`
writeHandlesRW bool `state:"nosave"`
@@ -175,43 +178,124 @@ func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles)
// getHandles returns a set of handles for a new file using i opened with the
// given flags.
-func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags) (*handles, error) {
+func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cache *fsutil.CachingInodeOperations) (*handles, error) {
if !i.canShareHandles() {
return newHandles(ctx, i.file, flags)
}
+
i.handlesMu.Lock()
- defer i.handlesMu.Unlock()
+ h, invalidate, err := i.getHandlesLocked(ctx, flags)
+ i.handlesMu.Unlock()
+
+ if invalidate {
+ cache.NotifyChangeFD()
+ if i.hostMappable != nil {
+ i.hostMappable.NotifyChangeFD()
+ }
+ }
+
+ return h, err
+}
+
+// getHandlesLocked returns a pointer to cached handles and a boolean indicating
+// whether previously open read handle was recreated. Host mappings must be
+// invalidated if so.
+func (i *inodeFileState) getHandlesLocked(ctx context.Context, flags fs.FileFlags) (*handles, bool, error) {
// Do we already have usable shared handles?
if flags.Write {
if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) {
i.writeHandles.IncRef()
- return i.writeHandles, nil
+ return i.writeHandles, false, nil
}
} else if i.readHandles != nil {
i.readHandles.IncRef()
- return i.readHandles, nil
+ return i.readHandles, false, nil
}
+
// No; get new handles and cache them for future sharing.
h, err := newHandles(ctx, i.file, flags)
if err != nil {
- return nil, err
+ return nil, false, err
+ }
+
+ // Read handles invalidation is needed if:
+ // - Mount option 'overlayfs_stale_read' is set
+ // - Read handle is open: nothing to invalidate otherwise
+ // - Write handle is not open: file was not open for write and is being open
+ // for write now (will trigger copy up in overlayfs).
+ invalidate := false
+ if i.s.overlayfsStaleRead && i.readHandles != nil && i.writeHandles == nil && flags.Write {
+ if err := i.recreateReadHandles(ctx, h, flags); err != nil {
+ return nil, false, err
+ }
+ invalidate = true
}
i.setSharedHandlesLocked(flags, h)
- return h, nil
+ return h, invalidate, nil
+}
+
+func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handles, flags fs.FileFlags) error {
+ h := writer
+ if !flags.Read {
+ // Writer can't be used for read, must create a new handle.
+ var err error
+ h, err = newHandles(ctx, i.file, fs.FileFlags{Read: true})
+ if err != nil {
+ return err
+ }
+ defer h.DecRef()
+ }
+
+ if i.readHandles.Host == nil {
+ // If current readHandles doesn't have a host FD, it can simply be replaced.
+ i.readHandles.DecRef()
+
+ h.IncRef()
+ i.readHandles = h
+ return nil
+ }
+
+ if h.Host == nil {
+ // Current read handle has a host FD and can't be replaced with one that
+ // doesn't, because it breaks fsutil.CachedFileObject.FD() contract.
+ log.Warningf("Read handle can't be invalidated, reads may return stale data")
+ return nil
+ }
+
+ // Due to a defect in the fsutil.CachedFileObject interface,
+ // readHandles.Host.FD() may be used outside locks, making it impossible to
+ // reliably close it. To workaround it, we dup the new FD into the old one, so
+ // operations on the old will see the new data. Then, make the new handle take
+ // ownereship of the old FD and mark the old readHandle to not close the FD
+ // when done.
+ if err := syscall.Dup2(h.Host.FD(), i.readHandles.Host.FD()); err != nil {
+ return err
+ }
+
+ h.Host.Close()
+ h.Host = fd.New(i.readHandles.Host.FD())
+ i.readHandles.isHostBorrowed = true
+ i.readHandles.DecRef()
+
+ h.IncRef()
+ i.readHandles = h
+ return nil
}
// ReadToBlocksAt implements fsutil.CachedFileObject.ReadToBlocksAt.
func (i *inodeFileState) ReadToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
i.handlesMu.RLock()
- defer i.handlesMu.RUnlock()
- return i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts)
+ n, err := i.readHandles.readWriterAt(ctx, int64(offset)).ReadToBlocks(dsts)
+ i.handlesMu.RUnlock()
+ return n, err
}
// WriteFromBlocksAt implements fsutil.CachedFileObject.WriteFromBlocksAt.
func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
i.handlesMu.RLock()
- defer i.handlesMu.RUnlock()
- return i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs)
+ n, err := i.writeHandles.readWriterAt(ctx, int64(offset)).WriteFromBlocks(srcs)
+ i.handlesMu.RUnlock()
+ return n, err
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
@@ -449,7 +533,7 @@ func (i *inodeOperations) NonBlockingOpen(ctx context.Context, p fs.PermMask) (*
}
func (i *inodeOperations) getFileDefault(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
- h, err := i.fileState.getHandles(ctx, flags)
+ h, err := i.fileState.getHandles(ctx, flags, i.cachingInodeOps)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 50da865c1..0da608548 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -122,6 +122,10 @@ type session struct {
// CachingInodeOperations created by the session.
limitHostFDTranslation bool
+ // overlayfsStaleRead when set causes the readonly handle to be invalidated
+ // after file is open for write.
+ overlayfsStaleRead bool
+
// connID is a unique identifier for the session connection.
connID string `state:"wait"`
@@ -257,6 +261,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
aname: o.aname,
superBlockFlags: superBlockFlags,
limitHostFDTranslation: o.limitHostFDTranslation,
+ overlayfsStaleRead: o.overlayfsStaleRead,
mounter: mounter,
}
s.EnableLeakCheck("gofer.session")
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 3e532332e..1cbed07ae 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "host",
srcs = [
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 2392787cb..107336a3e 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -385,3 +385,6 @@ func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
func (c *ConnectedEndpoint) Release() {
c.ref.DecRefWithDestructor(c.close)
}
+
+// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
+func (c *ConnectedEndpoint) CloseUnread() {}
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index 5a7a5b8cd..8d62642e7 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "lock_range",
out = "lock_range.go",
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index c307603a6..75cbb0622 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "proc",
srcs = [
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index 76433c7d0..fe7067be1 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "seqfile",
srcs = ["seqfile.go"],
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index d0f351e5a..012cb3e44 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "ramfs",
srcs = [
diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD
index 70fa3af89..25f0f124e 100644
--- a/pkg/sentry/fs/sys/BUILD
+++ b/pkg/sentry/fs/sys/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "sys",
srcs = [
diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD
index 1d80daeaf..a215c1b95 100644
--- a/pkg/sentry/fs/timerfd/BUILD
+++ b/pkg/sentry/fs/timerfd/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "timerfd",
srcs = ["timerfd.go"],
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 11b680929..59ce400c2 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "tmpfs",
srcs = [
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 25811f668..95ad98cb0 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "tty",
srcs = [
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index b0c286b7a..7ccff8b0d 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
+package(licenses = ["notice"])
+
go_template_instance(
name = "dirent_list",
out = "dirent_list.go",
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 0b471d121..91802dc1e 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -301,8 +301,8 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
return offset, nil
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
-func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
// mmap(2) specifies that EACCESS should be returned for non-regular file fds.
return syserror.EACCES
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 2d50e30aa..fcfaf5c3e 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "disklayout",
srcs = [
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
index a0065343b..4d18b28cb 100644
--- a/pkg/sentry/fsimpl/ext/file_description.go
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -43,9 +43,6 @@ func (fd *fileDescription) inode() *inode {
return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode
}
-// OnClose implements vfs.FileDescriptionImpl.OnClose.
-func (fd *fileDescription) OnClose() error { return nil }
-
// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
return fd.flags, nil
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index ffc76ba5b..aec33e00a 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -152,8 +152,8 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
return offset, nil
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
-func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
// TODO(b/134676337): Implement mmap(2).
return syserror.ENODEV
}
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
index e06548a98..bdf8705c1 100644
--- a/pkg/sentry/fsimpl/ext/symlink.go
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -105,7 +105,7 @@ func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int6
return 0, syserror.EBADF
}
-// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
-func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
return syserror.EBADF
}
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD
index 7e364c5fd..04d667273 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/memfs/BUILD
@@ -24,14 +24,18 @@ go_library(
"directory.go",
"filesystem.go",
"memfs.go",
+ "named_pipe.go",
"regular_file.go",
"symlink.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs",
deps = [
"//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/sentry/arch",
"//pkg/sentry/context",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
@@ -54,3 +58,19 @@ go_test(
"//pkg/syserror",
],
)
+
+go_test(
+ name = "memfs_test",
+ size = "small",
+ srcs = ["pipe_test.go"],
+ embed = [":memfs"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/context",
+ "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go
index f79e2d9c8..f006c40cd 100644
--- a/pkg/sentry/fsimpl/memfs/filesystem.go
+++ b/pkg/sentry/fsimpl/memfs/filesystem.go
@@ -233,7 +233,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if err != nil {
return err
}
- _, err = checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -241,8 +241,40 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return err
}
defer rp.Mount().EndWrite()
- // TODO: actually implement mknod
- return syserror.EPERM
+
+ switch opts.Mode.FileType() {
+ case 0:
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ fallthrough
+ case linux.ModeRegular:
+ // TODO(b/138862511): Implement.
+ return syserror.EINVAL
+
+ case linux.ModeNamedPipe:
+ child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode))
+ parentVFSD.InsertChild(&child.vfsd, pc)
+ parentInode.impl.(*directory).childList.PushBack(child)
+ return nil
+
+ case linux.ModeSocket:
+ // TODO(b/138862511): Implement.
+ return syserror.EINVAL
+
+ case linux.ModeCharacterDevice:
+ fallthrough
+ case linux.ModeBlockDevice:
+ // TODO(b/72101894): We don't support creating block or character
+ // devices at the moment.
+ //
+ // When we start supporting block and character devices, we'll
+ // need to check for CAP_MKNOD here.
+ return syserror.EPERM
+
+ default:
+ // "EINVAL - mode requested creation of something other than a
+ // regular file, device special file, FIFO or socket." - mknod(2)
+ return syserror.EINVAL
+ }
}
// OpenAt implements vfs.FilesystemImpl.OpenAt.
@@ -250,8 +282,9 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// Filter out flags that are not supported by memfs. O_DIRECTORY and
// O_NOFOLLOW have no effect here (they're handled by VFS by setting
// appropriate bits in rp), but are returned by
- // FileDescriptionImpl.StatusFlags().
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW
+ // FileDescriptionImpl.StatusFlags(). O_NONBLOCK is supported only by
+ // pipes.
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK
if opts.Flags&linux.O_CREAT == 0 {
fs.mu.RLock()
@@ -260,7 +293,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return nil, err
}
- return inode.open(rp, vfsd, opts.Flags, false)
+ return inode.open(ctx, rp, vfsd, opts.Flags, false)
}
mustCreate := opts.Flags&linux.O_EXCL != 0
@@ -275,7 +308,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
- return inode.open(rp, vfsd, opts.Flags, false)
+ return inode.open(ctx, rp, vfsd, opts.Flags, false)
}
afterTrailingSymlink:
// Walk to the parent directory of the last path component.
@@ -320,7 +353,7 @@ afterTrailingSymlink:
child := fs.newDentry(childInode)
vfsd.InsertChild(&child.vfsd, pc)
inode.impl.(*directory).childList.PushBack(child)
- return childInode.open(rp, &child.vfsd, opts.Flags, true)
+ return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true)
}
// Open existing file or follow symlink.
if mustCreate {
@@ -336,10 +369,10 @@ afterTrailingSymlink:
// symlink target.
goto afterTrailingSymlink
}
- return childInode.open(rp, childVFSD, opts.Flags, false)
+ return childInode.open(ctx, rp, childVFSD, opts.Flags, false)
}
-func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
+func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(flags)
if !afterCreate {
if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil {
@@ -378,6 +411,8 @@ func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte
case *symlink:
// Can't open symlinks without O_PATH (which is unimplemented).
return nil, syserror.ELOOP
+ case *namedPipe:
+ return newNamedPipeFD(ctx, impl, rp, vfsd, flags)
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go
index b78471c0f..64c851c1a 100644
--- a/pkg/sentry/fsimpl/memfs/memfs.go
+++ b/pkg/sentry/fsimpl/memfs/memfs.go
@@ -227,6 +227,8 @@ func (i *inode) statTo(stat *linux.Statx) {
stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
stat.Size = uint64(len(impl.target))
stat.Blocks = allocatedBlocksForSize(stat.Size)
+ case *namedPipe:
+ stat.Mode |= linux.S_IFIFO
default:
panic(fmt.Sprintf("unknown inode type: %T", i.impl))
}
diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/memfs/named_pipe.go
new file mode 100644
index 000000000..732ed7c58
--- /dev/null
+++ b/pkg/sentry/fsimpl/memfs/named_pipe.go
@@ -0,0 +1,59 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+type namedPipe struct {
+ inode inode
+
+ pipe *pipe.VFSPipe
+}
+
+// Preconditions:
+// * fs.mu must be locked.
+// * rp.Mount().CheckBeginWrite() has been called successfully.
+func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode {
+ file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)}
+ file.inode.init(file, fs, creds, mode)
+ file.inode.nlink = 1 // Only the parent has a link.
+ return &file.inode
+}
+
+// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented
+// entirely via struct embedding.
+type namedPipeFD struct {
+ fileDescription
+
+ *pipe.VFSPipeFD
+}
+
+func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ var err error
+ var fd namedPipeFD
+ fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, rp, vfsd, &fd.vfsfd, flags)
+ if err != nil {
+ return nil, err
+ }
+ fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ return &fd.vfsfd, nil
+}
diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/memfs/pipe_test.go
new file mode 100644
index 000000000..0674b81a3
--- /dev/null
+++ b/pkg/sentry/fsimpl/memfs/pipe_test.go
@@ -0,0 +1,233 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memfs
+
+import (
+ "bytes"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const fileName = "mypipe"
+
+func TestSeparateFDs(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the read side. This is done in a concurrently because opening
+ // One end the pipe blocks until the other end is opened.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ rfdchan := make(chan *vfs.FileDescription)
+ go func() {
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY}
+ rfd, _ := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ rfdchan <- rfd
+ }()
+
+ // Open the write side.
+ openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY}
+ wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer wfd.DecRef()
+
+ rfd, ok := <-rfdchan
+ if !ok {
+ t.Fatalf("failed to open pipe for reading %q", fileName)
+ }
+ defer rfd.DecRef()
+
+ const msg = "vamos azul"
+ checkEmpty(ctx, t, rfd)
+ checkWrite(ctx, t, wfd, msg)
+ checkRead(ctx, t, rfd, msg)
+}
+
+func TestNonblockingRead(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the read side as nonblocking.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK}
+ rfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for reading %q: %v", fileName, err)
+ }
+ defer rfd.DecRef()
+
+ // Open the write side.
+ openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY}
+ wfd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer wfd.DecRef()
+
+ const msg = "geh blau"
+ checkEmpty(ctx, t, rfd)
+ checkWrite(ctx, t, wfd, msg)
+ checkRead(ctx, t, rfd, msg)
+}
+
+func TestNonblockingWriteError(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the write side as nonblocking, which should return ENXIO.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK}
+ _, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != syserror.ENXIO {
+ t.Fatalf("expected ENXIO, but got error: %v", err)
+ }
+}
+
+func TestSingleFD(t *testing.T) {
+ ctx, creds, vfsObj, root := setup(t)
+ defer root.DecRef()
+
+ // Open the pipe as readable and writable.
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ openOpts := vfs.OpenOptions{Flags: linux.O_RDWR}
+ fd, err := vfsObj.OpenAt(ctx, creds, &pop, &openOpts)
+ if err != nil {
+ t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
+ }
+ defer fd.DecRef()
+
+ const msg = "forza blu"
+ checkEmpty(ctx, t, fd)
+ checkWrite(ctx, t, fd, msg)
+ checkRead(ctx, t, fd, msg)
+}
+
+// setup creates a VFS with a pipe in the root directory at path fileName. The
+// returned VirtualDentry must be DecRef()'d be the caller. It calls t.Fatal
+// upon failure.
+func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry) {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+
+ // Create the pipe.
+ root := mntns.Root()
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }
+ mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644}
+ if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil {
+ t.Fatalf("failed to create file %q: %v", fileName, err)
+ }
+
+ // Sanity check: the file pipe exists and has the correct mode.
+ stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Pathname: fileName,
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat(%q) failed: %v", fileName, err)
+ }
+ if stat.Mode&^linux.S_IFMT != 0644 {
+ t.Errorf("got wrong permissions (%0o)", stat.Mode)
+ }
+ if stat.Mode&linux.S_IFMT != linux.ModeNamedPipe {
+ t.Errorf("got wrong file type (%0o)", stat.Mode)
+ }
+
+ return ctx, creds, vfsObj, root
+}
+
+// checkEmpty calls t.Fatal if the pipe in fd is not empty.
+func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
+ readData := make([]byte, 1)
+ dst := usermem.BytesIOSequence(readData)
+ bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{})
+ if err != syserror.ErrWouldBlock {
+ t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err)
+ }
+ if bytesRead != 0 {
+ t.Fatalf("expected to read 0 bytes, but got %d", bytesRead)
+ }
+}
+
+// checkWrite calls t.Fatal if it fails to write all of msg to fd.
+func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
+ writeData := []byte(msg)
+ src := usermem.BytesIOSequence(writeData)
+ bytesWritten, err := fd.Impl().Write(ctx, src, vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("error writing to pipe %q: %v", fileName, err)
+ }
+ if bytesWritten != int64(len(writeData)) {
+ t.Fatalf("expected to write %d bytes, but wrote %d", len(writeData), bytesWritten)
+ }
+}
+
+// checkRead calls t.Fatal if it fails to read msg from fd.
+func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
+ readData := make([]byte, len(msg))
+ dst := usermem.BytesIOSequence(readData)
+ bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{})
+ if err != nil {
+ t.Fatalf("error reading from pipe %q: %v", fileName, err)
+ }
+ if bytesRead != int64(len(msg)) {
+ t.Fatalf("expected to read %d bytes, but got %d", len(msg), bytesRead)
+ }
+ if !bytes.Equal(readData, []byte(msg)) {
+ t.Fatalf("expected to read %q from pipe, but got %q", msg, string(readData))
+ }
+}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 184b566d9..d5284f0d9 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -1,10 +1,10 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
)
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "inet",
srcs = [
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index aba2414d4..e041c51b3 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -1,12 +1,11 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("@rules_cc//cc:defs.bzl", "cc_proto_library")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "pending_signals_list",
out = "pending_signals_list.go",
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 1d00a6310..51de4568a 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "atomicptr_credentials",
out = "atomicptr_credentials_unsafe.go",
diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD
index bec13a3d9..3a88a585c 100644
--- a/pkg/sentry/kernel/contexttest/BUILD
+++ b/pkg/sentry/kernel/contexttest/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "contexttest",
testonly = 1,
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index 65427b112..3361e8b7d 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "epoll_list",
out = "epoll_list.go",
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index 983ca67ed..e65b961e8 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "eventfd",
srcs = ["eventfd.go"],
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index 5eddca115..49d81b712 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "fasync",
srcs = ["fasync.go"],
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index cc3f43a45..11f613a11 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -81,6 +81,9 @@ type FDTable struct {
// mu protects below.
mu sync.Mutex `state:"nosave"`
+ // next is start position to find fd.
+ next int32
+
// used contains the number of non-nil entries. It must be accessed
// atomically. It may be read atomically without holding mu (but not
// written).
@@ -226,6 +229,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
f.mu.Lock()
defer f.mu.Unlock()
+ // From f.next to find available fd.
+ if fd < f.next {
+ fd = f.next
+ }
+
// Install all entries.
for i := fd; i < end && len(fds) < len(files); i++ {
if d, _, _ := f.get(i); d == nil {
@@ -242,6 +250,11 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return nil, syscall.EMFILE
}
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fds[len(fds)-1] + 1
+ }
+
return fds, nil
}
@@ -361,6 +374,11 @@ func (f *FDTable) Remove(fd int32) *fs.File {
f.mu.Lock()
defer f.mu.Unlock()
+ // Update current available position.
+ if fd < f.next {
+ f.next = fd
+ }
+
orig, _, _ := f.get(fd)
if orig != nil {
orig.IncRef() // Reference for caller.
@@ -377,6 +395,10 @@ func (f *FDTable) RemoveIf(cond func(*fs.File, FDFlags) bool) {
f.forEach(func(fd int32, file *fs.File, flags FDFlags) {
if cond(file, flags) {
f.set(fd, nil, FDFlags{}) // Clear from table.
+ // Update current available position.
+ if fd < f.next {
+ f.next = fd
+ }
}
})
}
diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go
index 2413788e7..2bcb6216a 100644
--- a/pkg/sentry/kernel/fd_table_test.go
+++ b/pkg/sentry/kernel/fd_table_test.go
@@ -70,6 +70,42 @@ func TestFDTableMany(t *testing.T) {
if err := fdTable.NewFDAt(ctx, 1, file, FDFlags{}); err != nil {
t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err)
}
+
+ i := int32(2)
+ fdTable.Remove(i)
+ if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i {
+ t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err)
+ }
+ })
+}
+
+func TestFDTableOverLimit(t *testing.T) {
+ runTest(t, func(ctx context.Context, fdTable *FDTable, file *fs.File, _ *limits.LimitSet) {
+ if _, err := fdTable.NewFDs(ctx, maxFD, []*fs.File{file}, FDFlags{}); err == nil {
+ t.Fatalf("fdTable.NewFDs(maxFD, f): got nil, wanted error")
+ }
+
+ if _, err := fdTable.NewFDs(ctx, maxFD-2, []*fs.File{file, file, file}, FDFlags{}); err == nil {
+ t.Fatalf("fdTable.NewFDs(maxFD-2, {f,f,f}): got nil, wanted error")
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, maxFD-3, []*fs.File{file, file, file}, FDFlags{}); err != nil {
+ t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err)
+ } else {
+ for _, fd := range fds {
+ fdTable.Remove(fd)
+ }
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, maxFD-1, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != maxFD-1 {
+ t.Fatalf("fdTable.NewFDAt(1, r, FDFlags{}): got %v, wanted nil", err)
+ }
+
+ if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil {
+ t.Fatalf("Adding an FD to a resized map: got %v, want nil", err)
+ } else if len(fds) != 1 || fds[0] != 0 {
+ t.Fatalf("Added an FD to a resized map: got %v, want {1}", fds)
+ }
})
}
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 41f44999c..34286c7a8 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "atomicptr_bucket",
out = "atomicptr_bucket_unsafe.go",
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 2ce8952e2..9d34f6d4d 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "buffer_list",
out = "buffer_list.go",
@@ -25,8 +24,10 @@ go_library(
"device.go",
"node.go",
"pipe.go",
+ "pipe_util.go",
"reader.go",
"reader_writer.go",
+ "vfs.go",
"writer.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe",
@@ -41,6 +42,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/safemem",
"//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index a2dc72204..4a19ab7ce 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -18,7 +18,6 @@ import (
"sync"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -91,10 +90,10 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
switch {
case flags.Read && !flags.Write: // O_RDONLY.
r := i.p.Open(ctx, d, flags)
- i.newHandleLocked(&i.rWakeup)
+ newHandleLocked(&i.rWakeup)
if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() {
- if !i.waitFor(&i.wWakeup, ctx) {
+ if !waitFor(&i.mu, &i.wWakeup, ctx) {
r.DecRef()
return nil, syserror.ErrInterrupted
}
@@ -107,7 +106,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
case flags.Write && !flags.Read: // O_WRONLY.
w := i.p.Open(ctx, d, flags)
- i.newHandleLocked(&i.wWakeup)
+ newHandleLocked(&i.wWakeup)
if i.p.isNamed && !i.p.HasReaders() {
// On a nonblocking, write-only open, the open fails with ENXIO if the
@@ -117,7 +116,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
return nil, syserror.ENXIO
}
- if !i.waitFor(&i.rWakeup, ctx) {
+ if !waitFor(&i.mu, &i.rWakeup, ctx) {
w.DecRef()
return nil, syserror.ErrInterrupted
}
@@ -127,8 +126,8 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
case flags.Read && flags.Write: // O_RDWR.
// Pipes opened for read-write always succeeds without blocking.
rw := i.p.Open(ctx, d, flags)
- i.newHandleLocked(&i.rWakeup)
- i.newHandleLocked(&i.wWakeup)
+ newHandleLocked(&i.rWakeup)
+ newHandleLocked(&i.wWakeup)
return rw, nil
default:
@@ -136,65 +135,6 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
}
}
-// waitFor blocks until the underlying pipe has at least one reader/writer is
-// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this
-// function will block for either readers or writers, depending on where
-// 'wakeupChan' points.
-//
-// f.mu must be held by the caller. waitFor returns with f.mu held, but it will
-// drop f.mu before blocking for any reader/writers.
-func (i *inodeOperations) waitFor(wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool {
- // Ideally this function would simply use a condition variable. However, the
- // wait needs to be interruptible via 'sleeper', so we must sychronize via a
- // channel. The synchronization below relies on the fact that closing a
- // channel unblocks all receives on the channel.
-
- // Does an appropriate wakeup channel already exist? If not, create a new
- // one. This is all done under f.mu to avoid races.
- if *wakeupChan == nil {
- *wakeupChan = make(chan struct{})
- }
-
- // Grab a local reference to the wakeup channel since it may disappear as
- // soon as we drop f.mu.
- wakeup := *wakeupChan
-
- // Drop the lock and prepare to sleep.
- i.mu.Unlock()
- cancel := sleeper.SleepStart()
-
- // Wait for either a new reader/write to be signalled via 'wakeup', or
- // for the sleep to be cancelled.
- select {
- case <-wakeup:
- sleeper.SleepFinish(true)
- case <-cancel:
- sleeper.SleepFinish(false)
- }
-
- // Take the lock and check if we were woken. If we were woken and
- // interrupted, the former takes priority.
- i.mu.Lock()
- select {
- case <-wakeup:
- return true
- default:
- return false
- }
-}
-
-// newHandleLocked signals a new pipe reader or writer depending on where
-// 'wakeupChan' points. This unblocks any corresponding reader or writer
-// waiting for the other end of the channel to be opened, see Fifo.waitFor.
-//
-// i.mu must be held.
-func (*inodeOperations) newHandleLocked(wakeupChan *chan struct{}) {
- if *wakeupChan != nil {
- close(*wakeupChan)
- *wakeupChan = nil
- }
-}
-
func (*inodeOperations) Allocate(_ context.Context, _ *fs.Inode, _, _ int64) error {
return syserror.EPIPE
}
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 8e4e8e82e..1a1b38f83 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -111,11 +111,27 @@ func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
if atomicIOBytes > sizeBytes {
atomicIOBytes = sizeBytes
}
- return &Pipe{
- isNamed: isNamed,
- max: sizeBytes,
- atomicIOBytes: atomicIOBytes,
+ var p Pipe
+ initPipe(&p, isNamed, sizeBytes, atomicIOBytes)
+ return &p
+}
+
+func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) {
+ if sizeBytes < MinimumPipeSize {
+ sizeBytes = MinimumPipeSize
+ }
+ if sizeBytes > MaximumPipeSize {
+ sizeBytes = MaximumPipeSize
+ }
+ if atomicIOBytes <= 0 {
+ atomicIOBytes = 1
+ }
+ if atomicIOBytes > sizeBytes {
+ atomicIOBytes = sizeBytes
}
+ pipe.isNamed = isNamed
+ pipe.max = sizeBytes
+ pipe.atomicIOBytes = atomicIOBytes
}
// NewConnectedPipe initializes a pipe and returns a pair of objects
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
new file mode 100644
index 000000000..ef9641e6a
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -0,0 +1,213 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "io"
+ "math"
+ "sync"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// This file contains Pipe file functionality that is tied to neither VFS nor
+// the old fs architecture.
+
+// Release cleans up the pipe's state.
+func (p *Pipe) Release() {
+ p.rClose()
+ p.wClose()
+
+ // Wake up readers and writers.
+ p.Notify(waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads from the Pipe into dst.
+func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ n, err := p.read(ctx, readOps{
+ left: func() int64 {
+ return dst.NumBytes()
+ },
+ limit: func(l int64) {
+ dst = dst.TakeFirst64(l)
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, buf)
+ dst = dst.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// WriteTo writes to w from the Pipe.
+func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) {
+ ops := readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := buf.ReadToWriter(w, count, dup)
+ count -= n
+ return n, err
+ },
+ }
+ if dup {
+ // There is no notification for dup operations.
+ return p.dup(ctx, ops)
+ }
+ n, err := p.read(ctx, ops)
+ if n > 0 {
+ p.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// Write writes to the Pipe from src.
+func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ n, err := p.write(ctx, writeOps{
+ left: func() int64 {
+ return src.NumBytes()
+ },
+ limit: func(l int64) {
+ src = src.TakeFirst64(l)
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := src.CopyInTo(ctx, buf)
+ src = src.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// ReadFrom reads from r to the Pipe.
+func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) {
+ n, err := p.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := buf.WriteFromReader(r, count)
+ count -= n
+ return n, err
+ },
+ })
+ if n > 0 {
+ p.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// Readiness returns the ready events in the underlying pipe.
+func (p *Pipe) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return p.rwReadiness() & mask
+}
+
+// Ioctl implements ioctls on the Pipe.
+func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Switch on ioctl request.
+ switch int(args[1].Int()) {
+ case linux.FIONREAD:
+ v := p.queued()
+ if v > math.MaxInt32 {
+ v = math.MaxInt32 // Silently truncate.
+ }
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ default:
+ return 0, syscall.ENOTTY
+ }
+}
+
+// waitFor blocks until the underlying pipe has at least one reader/writer is
+// announced via 'wakeupChan', or until 'sleeper' is cancelled. Any call to this
+// function will block for either readers or writers, depending on where
+// 'wakeupChan' points.
+//
+// mu must be held by the caller. waitFor returns with mu held, but it will
+// drop mu before blocking for any reader/writers.
+func waitFor(mu *sync.Mutex, wakeupChan *chan struct{}, sleeper amutex.Sleeper) bool {
+ // Ideally this function would simply use a condition variable. However, the
+ // wait needs to be interruptible via 'sleeper', so we must sychronize via a
+ // channel. The synchronization below relies on the fact that closing a
+ // channel unblocks all receives on the channel.
+
+ // Does an appropriate wakeup channel already exist? If not, create a new
+ // one. This is all done under f.mu to avoid races.
+ if *wakeupChan == nil {
+ *wakeupChan = make(chan struct{})
+ }
+
+ // Grab a local reference to the wakeup channel since it may disappear as
+ // soon as we drop f.mu.
+ wakeup := *wakeupChan
+
+ // Drop the lock and prepare to sleep.
+ mu.Unlock()
+ cancel := sleeper.SleepStart()
+
+ // Wait for either a new reader/write to be signalled via 'wakeup', or
+ // for the sleep to be cancelled.
+ select {
+ case <-wakeup:
+ sleeper.SleepFinish(true)
+ case <-cancel:
+ sleeper.SleepFinish(false)
+ }
+
+ // Take the lock and check if we were woken. If we were woken and
+ // interrupted, the former takes priority.
+ mu.Lock()
+ select {
+ case <-wakeup:
+ return true
+ default:
+ return false
+ }
+}
+
+// newHandleLocked signals a new pipe reader or writer depending on where
+// 'wakeupChan' points. This unblocks any corresponding reader or writer
+// waiting for the other end of the channel to be opened, see Fifo.waitFor.
+//
+// Precondition: the mutex protecting wakeupChan must be held.
+func newHandleLocked(wakeupChan *chan struct{}) {
+ if *wakeupChan != nil {
+ close(*wakeupChan)
+ *wakeupChan = nil
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
index 7c307f013..b4d29fc77 100644
--- a/pkg/sentry/kernel/pipe/reader_writer.go
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -16,16 +16,12 @@ package pipe
import (
"io"
- "math"
- "syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
)
// ReaderWriter satisfies the FileOperations interface and services both
@@ -45,124 +41,27 @@ type ReaderWriter struct {
*Pipe
}
-// Release implements fs.FileOperations.Release.
-func (rw *ReaderWriter) Release() {
- rw.Pipe.rClose()
- rw.Pipe.wClose()
-
- // Wake up readers and writers.
- rw.Pipe.Notify(waiter.EventIn | waiter.EventOut)
-}
-
// Read implements fs.FileOperations.Read.
func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.read(ctx, readOps{
- left: func() int64 {
- return dst.NumBytes()
- },
- limit: func(l int64) {
- dst = dst.TakeFirst64(l)
- },
- read: func(buf *buffer) (int64, error) {
- n, err := dst.CopyOutFrom(ctx, buf)
- dst = dst.DropFirst64(n)
- return n, err
- },
- })
- if n > 0 {
- rw.Pipe.Notify(waiter.EventOut)
- }
- return n, err
+ return rw.Pipe.Read(ctx, dst)
}
// WriteTo implements fs.FileOperations.WriteTo.
func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) {
- ops := readOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- read: func(buf *buffer) (int64, error) {
- n, err := buf.ReadToWriter(w, count, dup)
- count -= n
- return n, err
- },
- }
- if dup {
- // There is no notification for dup operations.
- return rw.Pipe.dup(ctx, ops)
- }
- n, err := rw.Pipe.read(ctx, ops)
- if n > 0 {
- rw.Pipe.Notify(waiter.EventOut)
- }
- return n, err
+ return rw.Pipe.WriteTo(ctx, w, count, dup)
}
// Write implements fs.FileOperations.Write.
func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, writeOps{
- left: func() int64 {
- return src.NumBytes()
- },
- limit: func(l int64) {
- src = src.TakeFirst64(l)
- },
- write: func(buf *buffer) (int64, error) {
- n, err := src.CopyInTo(ctx, buf)
- src = src.DropFirst64(n)
- return n, err
- },
- })
- if n > 0 {
- rw.Pipe.Notify(waiter.EventIn)
- }
- return n, err
+ return rw.Pipe.Write(ctx, src)
}
// ReadFrom implements fs.FileOperations.WriteTo.
func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, writeOps{
- left: func() int64 {
- return count
- },
- limit: func(l int64) {
- count = l
- },
- write: func(buf *buffer) (int64, error) {
- n, err := buf.WriteFromReader(r, count)
- count -= n
- return n, err
- },
- })
- if n > 0 {
- rw.Pipe.Notify(waiter.EventIn)
- }
- return n, err
-}
-
-// Readiness returns the ready events in the underlying pipe.
-func (rw *ReaderWriter) Readiness(mask waiter.EventMask) waiter.EventMask {
- return rw.Pipe.rwReadiness() & mask
+ return rw.Pipe.ReadFrom(ctx, r, count)
}
// Ioctl implements fs.FileOperations.Ioctl.
func (rw *ReaderWriter) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- // Switch on ioctl request.
- switch int(args[1].Int()) {
- case linux.FIONREAD:
- v := rw.queued()
- if v > math.MaxInt32 {
- v = math.MaxInt32 // Silently truncate.
- }
- // Copy result to user-space.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
- return 0, err
- default:
- return 0, syscall.ENOTTY
- }
+ return rw.Pipe.Ioctl(ctx, io, args)
}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
new file mode 100644
index 000000000..6416e0dd8
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -0,0 +1,220 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// This file contains types enabling the pipe package to be used with the vfs
+// package.
+
+// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should
+// not be copied.
+type VFSPipe struct {
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // pipe is the underlying pipe.
+ pipe Pipe
+
+ // Channels for synchronizing the creation of new readers and writers
+ // of this fifo. See waitFor and newHandleLocked.
+ //
+ // These are not saved/restored because all waiters are unblocked on
+ // save, and either automatically restart (via ERESTARTSYS) or return
+ // EINTR on resume. On restarts via ERESTARTSYS, the appropriate
+ // channel will be recreated.
+ rWakeup chan struct{} `state:"nosave"`
+ wWakeup chan struct{} `state:"nosave"`
+}
+
+// NewVFSPipe returns an initialized VFSPipe.
+func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe {
+ var vp VFSPipe
+ initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes)
+ return &vp
+}
+
+// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics
+// during open:
+//
+// "Normally, opening the FIFO blocks until the other end is opened also. A
+// process can open a FIFO in nonblocking mode. In this case, opening for
+// read-only will succeed even if no-one has opened on the write side yet,
+// opening for write-only will fail with ENXIO (no such device or address)
+// unless the other end has already been opened. Under Linux, opening a FIFO
+// for read and write will succeed both in blocking and nonblocking mode. POSIX
+// leaves this behavior undefined. This can be used to open a FIFO for writing
+// while there are no readers available." - fifo(7)
+func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
+ vp.mu.Lock()
+ defer vp.mu.Unlock()
+
+ readable := vfs.MayReadFileWithOpenFlags(flags)
+ writable := vfs.MayWriteFileWithOpenFlags(flags)
+ if !readable && !writable {
+ return nil, syserror.EINVAL
+ }
+
+ vfd, err := vp.open(rp, vfsd, vfsfd, flags)
+ if err != nil {
+ return nil, err
+ }
+
+ switch {
+ case readable && writable:
+ // Pipes opened for read-write always succeed without blocking.
+ newHandleLocked(&vp.rWakeup)
+ newHandleLocked(&vp.wWakeup)
+
+ case readable:
+ newHandleLocked(&vp.rWakeup)
+ // If this pipe is being opened as nonblocking and there's no
+ // writer, we have to wait for a writer to open the other end.
+ if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ return nil, syserror.EINTR
+ }
+
+ case writable:
+ newHandleLocked(&vp.wWakeup)
+
+ if !vp.pipe.HasReaders() {
+ // Nonblocking, write-only opens fail with ENXIO when
+ // the read side isn't open yet.
+ if flags&linux.O_NONBLOCK != 0 {
+ return nil, syserror.ENXIO
+ }
+ // Wait for a reader to open the other end.
+ if !waitFor(&vp.mu, &vp.rWakeup, ctx) {
+ return nil, syserror.EINTR
+ }
+ }
+
+ default:
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ return vfd, nil
+}
+
+// Preconditions: vp.mu must be held.
+func (vp *VFSPipe) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
+ var fd VFSPipeFD
+ fd.flags = flags
+ fd.readable = vfs.MayReadFileWithOpenFlags(flags)
+ fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
+ fd.vfsfd = vfsfd
+ fd.pipe = &vp.pipe
+ if fd.writable {
+ // The corresponding Mount.EndWrite() is in VFSPipe.Release().
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ }
+
+ switch {
+ case fd.readable && fd.writable:
+ vp.pipe.rOpen()
+ vp.pipe.wOpen()
+ case fd.readable:
+ vp.pipe.rOpen()
+ case fd.writable:
+ vp.pipe.wOpen()
+ default:
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ return &fd, nil
+}
+
+// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is
+// expected that filesystesm will use this in a struct implementing
+// vfs.FileDescriptionImpl.
+type VFSPipeFD struct {
+ pipe *Pipe
+ flags uint32
+ readable bool
+ writable bool
+ vfsfd *vfs.FileDescription
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *VFSPipeFD) Release() {
+ var event waiter.EventMask
+ if fd.readable {
+ fd.pipe.rClose()
+ event |= waiter.EventIn
+ }
+ if fd.writable {
+ fd.pipe.wClose()
+ event |= waiter.EventOut
+ }
+ if event == 0 {
+ panic("invalid pipe flags: must be readable, writable, or both")
+ }
+
+ if fd.writable {
+ fd.vfsfd.VirtualDentry().Mount().EndWrite()
+ }
+
+ fd.pipe.Notify(event)
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *VFSPipeFD) OnClose(_ context.Context) error {
+ return nil
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ if !fd.readable {
+ return 0, syserror.EINVAL
+ }
+
+ return fd.pipe.Read(ctx, dst)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ if !fd.writable {
+ return 0, syserror.EINVAL
+ }
+
+ return fd.pipe.Write(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return fd.pipe.Ioctl(ctx, uio, args)
+}
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index 80e5e5da3..f4c00cd86 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "waiter_list",
out = "waiter_list.go",
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index aa7471eb6..cd48945e6 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "shm",
srcs = [
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 9beae4b31..31847e1df 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "time",
srcs = [
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
index 59649c770..156e67bf8 100644
--- a/pkg/sentry/limits/BUILD
+++ b/pkg/sentry/limits/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "limits",
srcs = [
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index 3b322f5f3..2890393bd 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_embed_data")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_embed_data(
name = "vdso_bin",
src = "//vdso:vdso.so",
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index 9687e7e76..3ef84245b 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "mappable_range",
out = "mappable_range.go",
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index b35c8c673..a804b8b5c 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "file_refcount_set",
out = "file_refcount_set.go",
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 3fd904c67..f404107af 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "evictable_range",
out = "evictable_range.go",
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index 9aa6ec507..157bffa81 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "file_range",
out = "file_range.go",
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 9f0ecfbe4..b699b057d 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -327,6 +327,19 @@ func (t *thread) dumpAndPanic(message string) {
panic(message)
}
+func (t *thread) unexpectedStubExit() {
+ msg, err := t.getEventMessage()
+ status := syscall.WaitStatus(msg)
+ if status.Signaled() && status.Signal() == syscall.SIGKILL {
+ // SIGKILL can be only sent by an user or OOM-killer. In both
+ // these cases, we don't need to panic. There is no reasons to
+ // think that something wrong in gVisor.
+ log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid)
+ syscall.Kill(os.Getpid(), syscall.SIGKILL)
+ }
+ t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
+}
+
// wait waits for a stop event.
//
// Precondition: outcome is a valid waitOutcome.
@@ -355,8 +368,7 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
}
if stopSig == syscall.SIGTRAP {
if status.TrapCause() == syscall.PTRACE_EVENT_EXIT {
- msg, err := t.getEventMessage()
- t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
+ t.unexpectedStubExit()
}
// Re-encode the trap cause the way it's expected.
return stopSig | syscall.Signal(status.TrapCause()<<8)
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index c075b5f91..3782d4332 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -129,6 +129,9 @@ func createStub() (*thread, error) {
// transitively) will be killed as well. It's simply not possible to
// safely handle a single stub getting killed: the exact state of
// execution is unknown and not recoverable.
+ //
+ // In addition, we set the PTRACE_O_TRACEEXIT option to log more
+ // information about a stub process when it receives a fatal signal.
return attachedThread(uintptr(syscall.SIGKILL)|syscall.CLONE_FILES, defaultAction)
}
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 8ed6c7652..48b0ceaec 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -1,9 +1,8 @@
load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-
go_template(
name = "defs",
srcs = [
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index d7029d5a9..780bf9a66 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "defs_impl",
out = "defs_impl.go",
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index ea090b686..934a90378 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -1,10 +1,9 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-
go_template(
name = "generic_walker",
srcs = [
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 3300f9a6b..26176b10d 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "socket",
srcs = ["socket.go"],
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 81dbd7309..4a6e83a8b 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "control",
srcs = ["control.go"],
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index a951f1bb0..4d174dda4 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "hostinet",
srcs = [
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 354a0d6ee..5eb06bbf4 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "netfilter",
srcs = [
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 7da68384e..f95803f91 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "netlink",
srcs = [
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 445080aa4..463544c1a 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -1,9 +1,8 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library")
-
go_library(
name = "port",
srcs = ["port.go"],
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
index 5dc8533ec..1d4912753 100644
--- a/pkg/sentry/socket/netlink/route/BUILD
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "route",
srcs = ["protocol.go"],
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 60523f79a..e414d8055 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "netstack",
srcs = [
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 6fd43fcbd..69dbfd197 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -53,6 +53,7 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -298,6 +299,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue
var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
+var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{}))
// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
// netstack representation taking any addresses into account.
@@ -309,12 +311,12 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
}
// AddressAndFamily reads an sockaddr struct from the given address and
-// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and
-// AF_INET6 addresses.
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
//
// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
//
-// AddressAndFamily returns an address, its family.
+// AddressAndFamily returns an address and its family.
func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
@@ -373,6 +375,22 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress,
}
return out, family, nil
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(b/129292371): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
case linux.AF_UNSPEC:
return tcpip.FullAddress{}, family, nil
@@ -1953,12 +1971,14 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
return &out, uint32(2 + l)
}
return &out, uint32(3 + l)
+
case linux.AF_INET:
var out linux.SockAddrInet
copy(out.Addr[:], addr.Addr)
out.Family = linux.AF_INET
out.Port = htons(addr.Port)
- return &out, uint32(binary.Size(out))
+ return &out, uint32(sockAddrInetSize)
+
case linux.AF_INET6:
var out linux.SockAddrInet6
if len(addr.Addr) == 4 {
@@ -1974,7 +1994,17 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
if isLinkLocal(addr.Addr) {
out.Scope_id = uint32(addr.NIC)
}
- return &out, uint32(binary.Size(out))
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(b/129292371): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
default:
return nil, 0
}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 357a664cc..2d2c1ba2a 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -62,6 +62,10 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
}
case linux.SOCK_RAW:
+ // TODO(b/142504697): "In order to create a raw socket, a
+ // process must have the CAP_NET_RAW capability in the user
+ // namespace that governs its network namespace." - raw(7)
+
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
@@ -85,7 +89,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
return 0, true, syserr.ErrProtocolNotSupported
}
-// Socket creates a new socket object for the AF_INET or AF_INET6 family.
+// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET
+// family.
func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Fail right away if we don't have a stack.
stack := t.NetworkContext()
@@ -99,6 +104,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
return nil, nil
}
+ // Packet sockets are handled separately, since they are neither INET
+ // nor INET6 specific.
+ if p.family == linux.AF_PACKET {
+ return packetSocket(t, eps, stype, protocol)
+ }
+
// Figure out the transport protocol.
transProto, associated, err := getTransportProtocol(t, stype, protocol)
if err != nil {
@@ -121,12 +132,47 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
return New(t, p.family, stype, int(transProto), wq, ep)
}
+func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // TODO(b/142504697): "In order to create a packet socket, a process
+ // must have the CAP_NET_RAW capability in the user namespace that
+ // governs its network namespace." - packet(7)
+
+ // Packet sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(t)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return nil, syserr.ErrNotPermitted
+ }
+
+ // "cooked" packets don't contain link layer information.
+ var cooked bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ cooked = true
+ case linux.SOCK_RAW:
+ cooked = false
+ default:
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // protocol is passed in network byte order, but netstack wants it in
+ // host order.
+ netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+
+ wq := &waiter.Queue{}
+ ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return New(t, linux.AF_PACKET, stype, protocol, wq, ep)
+}
+
// Pair just returns nil sockets (not supported).
func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
return nil, nil, nil
}
-// init registers socket providers for AF_INET and AF_INET6.
+// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET.
func init() {
// Providers backed by netstack.
p := []provider{
@@ -138,6 +184,9 @@ func init() {
family: linux.AF_INET6,
netProto: ipv6.ProtocolNumber,
},
+ {
+ family: linux.AF_PACKET,
+ },
}
for i := range p {
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 830f4da10..5b6a154f6 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "unix",
srcs = [
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 0b0240336..788ad70d2 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
+package(licenses = ["notice"])
+
go_template_instance(
name = "transport_message_list",
out = "transport_message_list.go",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 4bd15808a..dea11e253 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -220,6 +220,11 @@ func (e *connectionedEndpoint) Close() {
case e.Connected():
e.connected.CloseSend()
e.receiver.CloseRecv()
+ // Still have unread data? If yes, we set this into the write
+ // end so that the peer can get ECONNRESET) when it does read.
+ if e.receiver.RecvQueuedSize() > 0 {
+ e.connected.CloseUnread()
+ }
c = e.connected
r = e.receiver
e.connected = nil
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index 0415fae9a..e27b1c714 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -33,6 +33,7 @@ type queue struct {
mu sync.Mutex `state:"nosave"`
closed bool
+ unread bool
used int64
limit int64
dataList messageList
@@ -161,6 +162,9 @@ func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) {
err := syserr.ErrWouldBlock
if q.closed {
err = syserr.ErrClosedForReceive
+ if q.unread {
+ err = syserr.ErrConnectionReset
+ }
}
q.mu.Unlock()
@@ -188,7 +192,9 @@ func (q *queue) Peek() (*message, *syserr.Error) {
if q.dataList.Front() == nil {
err := syserr.ErrWouldBlock
if q.closed {
- err = syserr.ErrClosedForReceive
+ if err = syserr.ErrClosedForReceive; q.unread {
+ err = syserr.ErrConnectionReset
+ }
}
return nil, err
}
@@ -208,3 +214,11 @@ func (q *queue) QueuedSize() int64 {
func (q *queue) MaxQueueSize() int64 {
return q.limit
}
+
+// CloseUnread sets flag to indicate that the peer is closed (not shutdown)
+// with unread data. So if read on this queue shall return ECONNRESET error.
+func (q *queue) CloseUnread() {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.unread = true
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 1867b3a5c..529a7a7a9 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -608,6 +608,10 @@ type ConnectedEndpoint interface {
// Release releases any resources owned by the ConnectedEndpoint. It should
// be called before droping all references to a ConnectedEndpoint.
Release()
+
+ // CloseUnread sets the fact that this end is closed with unread data to
+ // the peer socket.
+ CloseUnread()
}
// +stateify savable
@@ -711,6 +715,11 @@ func (e *connectedEndpoint) Release() {
e.writeQueue.DecRef()
}
+// CloseUnread implements ConnectedEndpoint.CloseUnread.
+func (e *connectedEndpoint) CloseUnread() {
+ e.writeQueue.CloseUnread()
+}
+
// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
// unix domain socket Endpoint implementations.
//
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 50c308134..1aaae8487 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -595,7 +595,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
total += n
}
- if err != nil || !waitAll || isPacket || n >= dst.NumBytes() {
+ streamPeerClosed := s.stype == linux.SOCK_STREAM && n == 0 && err == nil
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() || streamPeerClosed {
if total > 0 {
err = nil
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index cf2a56bed..fb2c1777f 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "linux",
srcs = [
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index b64c49ff5..68589a377 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -16,7 +16,12 @@
package linux
const (
- _LINUX_SYSNAME = "Linux"
- _LINUX_RELEASE = "4.4"
- _LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
+ // LinuxSysname is the OS name advertised by gVisor.
+ LinuxSysname = "Linux"
+
+ // LinuxRelease is the Linux release version number advertised by gVisor.
+ LinuxRelease = "4.4.0"
+
+ // LinuxVersion is the version info advertised by gVisor.
+ LinuxVersion = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
)
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
index e215ac049..aedb6d774 100644
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -34,9 +34,9 @@ var AMD64 = &kernel.SyscallTable{
// guides the interface provided by this syscall table. The build
// version is that for a clean build with default kernel config, at 5
// minutes after v4.4 was tagged.
- Sysname: _LINUX_SYSNAME,
- Release: _LINUX_RELEASE,
- Version: _LINUX_VERSION,
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
},
AuditNumber: linux.AUDIT_ARCH_X86_64,
Table: map[uintptr]kernel.Syscall{
@@ -362,7 +362,7 @@ var AMD64 = &kernel.SyscallTable{
319: syscalls.Supported("memfd_create", MemfdCreate),
320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
- 322: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836)
+ 322: syscalls.PartiallySupported("execveat", Execveat, "No support for AT_EMPTY_PATH, AT_SYMLINK_FOLLOW.", nil),
323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
index 1d3b63020..4cf7f836a 100644
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -30,9 +30,9 @@ var ARM64 = &kernel.SyscallTable{
OS: abi.Linux,
Arch: arch.ARM64,
Version: kernel.Version{
- Sysname: _LINUX_SYSNAME,
- Release: _LINUX_RELEASE,
- Version: _LINUX_VERSION,
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
},
AuditNumber: linux.AUDIT_ARCH_AARCH64,
Table: map[uintptr]kernel.Syscall{
@@ -107,10 +107,10 @@ var ARM64 = &kernel.SyscallTable{
71: syscalls.Supported("sendfile", Sendfile),
72: syscalls.Supported("pselect", Pselect),
73: syscalls.Supported("ppoll", Ppoll),
- 74: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426)
+ 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 77: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 77: syscalls.Supported("tee", Tee),
78: syscalls.Supported("readlinkat", Readlinkat),
80: syscalls.Supported("fstat", Fstat),
81: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
@@ -245,7 +245,7 @@ var ARM64 = &kernel.SyscallTable{
210: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
211: syscalls.Supported("sendmsg", SendMsg),
212: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
- 213: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341)
+ 213: syscalls.Supported("readahead", Readahead),
214: syscalls.Supported("brk", Brk),
215: syscalls.Supported("munmap", Munmap),
216: syscalls.Supported("mremap", Mremap),
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 8ab7ffa25..6e425f1ec 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -15,10 +15,12 @@
package linux
import (
+ "path"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/usermem"
@@ -67,8 +69,22 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
argvAddr := args[1].Pointer()
envvAddr := args[2].Pointer()
- // Extract our arguments.
- filename, err := t.CopyInString(filenameAddr, linux.PATH_MAX)
+ return execveat(t, linux.AT_FDCWD, filenameAddr, argvAddr, envvAddr, 0)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirFD := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+
+ return execveat(t, dirFD, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
if err != nil {
return 0, nil, err
}
@@ -89,14 +105,38 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
}
+ if flags != 0 {
+ // TODO(b/128449944): Handle AT_EMPTY_PATH and AT_SYMLINK_NOFOLLOW.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
root := t.FSContext().RootDirectory()
defer root.DecRef()
- wd := t.FSContext().WorkingDirectory()
+
+ var wd *fs.Dirent
+ if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) {
+ // If pathname is absolute, LoadTaskImage() will ignore the wd.
+ wd = t.FSContext().WorkingDirectory()
+ } else {
+ // Need to extract the given FD.
+ f := t.GetFile(dirFD)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ wd = f.Dirent
+ wd.IncRef()
+ if !fs.IsDir(wd.Inode.StableAttr) {
+ return 0, nil, syserror.ENOTDIR
+ }
+ }
defer wd.DecRef()
// Load the new TaskContext.
maxTraversals := uint(linux.MaxSymlinkTraversals)
- tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, nil, argv, envv, t.Arch().FeatureSet())
+ tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, pathname, nil, argv, envv, t.Arch().FeatureSet())
if se != nil {
return 0, nil, se.ToError()
}
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index 27cd2c336..ad4b67806 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -191,7 +191,6 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
// Pwritev2 implements linux syscall pwritev2(2).
-// TODO(b/120161091): Implement O_SYNC and D_SYNC functionality.
func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
// While the syscall is
// pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index beb43ba13..d3a4cd943 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -1,10 +1,9 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "seqatomic_parameters",
out = "seqatomic_parameters_unsafe.go",
diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD
index a34c39540..c32fe3241 100644
--- a/pkg/sentry/usage/BUILD
+++ b/pkg/sentry/usage/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "usage",
srcs = [
diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD
index cc5d25762..684f59a6b 100644
--- a/pkg/sentry/usermem/BUILD
+++ b/pkg/sentry/usermem/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "addr_range",
out = "addr_range.go",
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 7eb2b2821..3a9665800 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -102,7 +102,7 @@ type FileDescriptionImpl interface {
// OnClose is called when a file descriptor representing the
// FileDescription is closed. Note that returning a non-nil error does not
// prevent the file descriptor from being closed.
- OnClose() error
+ OnClose(ctx context.Context) error
// StatusFlags returns file description status flags, as for
// fcntl(F_GETFL).
@@ -180,7 +180,7 @@ type FileDescriptionImpl interface {
// ConfigureMMap mutates opts to implement mmap(2) for the file. Most
// implementations that support memory mapping can call
// GenericConfigureMMap with the appropriate memmap.Mappable.
- ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error
+ ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error
// Ioctl implements the ioctl(2) syscall.
Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index ba230da72..4fbad7840 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -45,7 +45,7 @@ type FileDescriptionDefaultImpl struct{}
// OnClose implements FileDescriptionImpl.OnClose analogously to
// file_operations::flush == NULL in Linux.
-func (FileDescriptionDefaultImpl) OnClose() error {
+func (FileDescriptionDefaultImpl) OnClose(ctx context.Context) error {
return nil
}
@@ -117,7 +117,7 @@ func (FileDescriptionDefaultImpl) Sync(ctx context.Context) error {
// ConfigureMMap implements FileDescriptionImpl.ConfigureMMap analogously to
// file_operations::mmap == NULL in Linux.
-func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+func (FileDescriptionDefaultImpl) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
return syserror.ENODEV
}
diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go
index 23f2b9e08..abde0feaa 100644
--- a/pkg/sentry/vfs/syscalls.go
+++ b/pkg/sentry/vfs/syscalls.go
@@ -96,6 +96,26 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia
}
}
+// MknodAt creates a file of the given mode at the given path. It returns an
+// error from the syserror package.
+func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
+ rp, err := vfs.getResolvingPath(creds, pop)
+ if err != nil {
+ return nil
+ }
+ for {
+ if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil {
+ vfs.putResolvingPath(rp)
+ return nil
+ }
+ // Handle mount traversals.
+ if !rp.handleError(err) {
+ vfs.putResolvingPath(rp)
+ return err
+ }
+ }
+}
+
// OpenAt returns a FileDescription providing access to the file at the given
// path. A reference is taken on the returned FileDescription.
func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
@@ -198,8 +218,6 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) err
//
// - VFS.LinkAt()
//
-// - VFS.MknodAt()
-//
// - VFS.ReadlinkAt()
//
// - VFS.RenameAt()
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index 329904457..be93750bf 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -1,11 +1,10 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
go_template_instance(
name = "addr_range",
out = "addr_range.go",
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
index 4287464f3..48a2a2713 100644
--- a/pkg/tcpip/buffer/prependable.go
+++ b/pkg/tcpip/buffer/prependable.go
@@ -41,6 +41,11 @@ func NewPrependableFromView(v View) Prependable {
return Prependable{buf: v, usedIdx: 0}
}
+// NewEmptyPrependableFromView creates a new prependable buffer from a View.
+func NewEmptyPrependableFromView(v View) Prependable {
+ return Prependable{buf: v, usedIdx: len(v)}
+}
+
// View returns a View of the backing buffer that contains all prepended
// data so far.
func (p Prependable) View() View {
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 096ad71ab..02137e1c9 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -686,3 +686,64 @@ func ICMPv6Code(want byte) TransportChecker {
}
}
}
+
+// NDP creates a checker that checks that the packet contains a valid NDP
+// message for type of ty, with potentially additional checks specified by
+// checkers.
+//
+// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDP message as far as the size of the message (minSize) is concerned. The
+// values within the message are up to checkers to validate.
+func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ // Check normal ICMPv6 first.
+ ICMPv6(
+ ICMPv6Type(msgType),
+ ICMPv6Code(0))(t, h)
+
+ last := h[len(h)-1]
+
+ icmp := header.ICMPv6(last.Payload())
+ if got := len(icmp.NDPPayload()); got < minSize {
+ t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize)
+ }
+
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// NDPNS creates a checker that checks that the packet contains a valid NDP
+// Neighbor Solicitation message (as per the raw wire format), with potentially
+// additional checks specified by checkers.
+//
+// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNS message as far as the size of the messages concerned. The values within
+// the message are up to checkers to validate.
+func NDPNS(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
+}
+
+// NDPNSTargetAddress creates a checker that checks the Target Address field of
+// a header.NDPNeighborSolicit.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNS message as far as the size is concerned.
+func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+
+ if got := ns.TargetAddress(); got != want {
+ t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index a255231a3..a3485b35c 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -16,6 +16,10 @@ go_library(
"ipv4.go",
"ipv6.go",
"ipv6_fragment.go",
+ "ndp_neighbor_advert.go",
+ "ndp_neighbor_solicit.go",
+ "ndp_options.go",
+ "ndp_router_advert.go",
"tcp.go",
"udp.go",
],
@@ -30,13 +34,26 @@ go_library(
)
go_test(
- name = "header_test",
+ name = "header_x_test",
size = "small",
srcs = [
+ "checksum_test.go",
"ipversion_test.go",
"tcp_test.go",
],
deps = [
":header",
+ "//pkg/tcpip/buffer",
+ ],
+)
+
+go_test(
+ name = "header_test",
+ size = "small",
+ srcs = [
+ "eth_test.go",
+ "ndp_test.go",
],
+ embed = [":header"],
+ deps = ["//pkg/tcpip"],
)
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 39a4d69be..9749c7f4d 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -23,11 +23,17 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
-func calculateChecksum(buf []byte, initial uint32) uint16 {
+func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
v := initial
+ if odd {
+ v += uint32(buf[0])
+ buf = buf[1:]
+ }
+
l := len(buf)
- if l&1 != 0 {
+ odd = l&1 != 0
+ if odd {
l--
v += uint32(buf[l]) << 8
}
@@ -36,7 +42,7 @@ func calculateChecksum(buf []byte, initial uint32) uint16 {
v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
}
- return ChecksumCombine(uint16(v), uint16(v>>16))
+ return ChecksumCombine(uint16(v), uint16(v>>16)), odd
}
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
@@ -44,7 +50,8 @@ func calculateChecksum(buf []byte, initial uint32) uint16 {
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
- return calculateChecksum(buf, uint32(initial))
+ s, _ := calculateChecksum(buf, false, uint32(initial))
+ return s
}
// ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in
@@ -52,19 +59,40 @@ func Checksum(buf []byte, initial uint16) uint16 {
//
// The initial checksum must have been computed on an even number of bytes.
func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
- var odd bool
+ return ChecksumVVWithOffset(vv, initial, 0, vv.Size())
+}
+
+// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the
+// bytes in the given VectorizedView.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 {
+ odd := false
sum := initial
for _, v := range vv.Views() {
if len(v) == 0 {
continue
}
- s := uint32(sum)
- if odd {
- s += uint32(v[0])
- v = v[1:]
+
+ if off >= len(v) {
+ off -= len(v)
+ continue
+ }
+ v = v[off:]
+
+ l := len(v)
+ if l > size {
+ l = size
+ }
+ v = v[:l]
+
+ sum, odd = calculateChecksum(v, odd, uint32(sum))
+
+ size -= len(v)
+ if size == 0 {
+ break
}
- odd = len(v)&1 != 0
- sum = calculateChecksum(v, s)
+ off = 0
}
return sum
}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
new file mode 100644
index 000000000..86b466c1c
--- /dev/null
+++ b/pkg/tcpip/header/checksum_test.go
@@ -0,0 +1,109 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package header provides the implementation of the encoding and decoding of
+// network protocol headers.
+package header_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestChecksumVVWithOffset(t *testing.T) {
+ testCases := []struct {
+ name string
+ vv buffer.VectorisedView
+ off, size int
+ initial uint16
+ want uint16
+ }{
+ {
+ name: "empty",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ }),
+ off: 0,
+ size: 0,
+ want: 0,
+ },
+ {
+ name: "OneView",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ }),
+ off: 0,
+ size: 5,
+ want: 1294,
+ },
+ {
+ name: "TwoViews",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 0,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "TwoViewsWithOffset",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 1,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "ThreeViewsWithOffset",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
+ }),
+ off: 7,
+ size: 11,
+ want: 33819,
+ },
+ {
+ name: "ThreeViewsWithInitial",
+ vv: buffer.NewVectorisedView(0, []buffer.View{
+ buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}),
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
+ buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}),
+ }),
+ initial: 77,
+ off: 7,
+ size: 11,
+ want: 33896,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want {
+ t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want)
+ }
+ v := tc.vv.ToView()
+ v.TrimFront(tc.off)
+ v.CapLength(tc.size)
+ if got, want := header.Checksum(v, tc.initial), tc.want; got != want {
+ t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index 4c3d3311f..f5d2c127f 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -48,8 +48,48 @@ const (
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
+
+ // unspecifiedEthernetAddress is the unspecified ethernet address
+ // (all bits set to 0).
+ unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+
+ // unicastMulticastFlagMask is the mask of the least significant bit in
+ // the first octet (in network byte order) of an ethernet address that
+ // determines whether the ethernet address is a unicast or multicast. If
+ // the masked bit is a 1, then the address is a multicast, unicast
+ // otherwise.
+ //
+ // See the IEEE Std 802-2001 document for more details. Specifically,
+ // section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf:
+ // "A 48-bit universal address consists of two parts. The first 24 bits
+ // correspond to the OUI as assigned by the IEEE, expect that the
+ // assignee may set the LSB of the first octet to 1 for group addresses
+ // or set it to 0 for individual addresses."
+ unicastMulticastFlagMask = 1
+
+ // unicastMulticastFlagByteIdx is the byte that holds the
+ // unicast/multicast flag. See unicastMulticastFlagMask.
+ unicastMulticastFlagByteIdx = 0
+)
+
+const (
+ // EthernetProtocolAll is a catch-all for all protocols carried inside
+ // an ethernet frame. It is mainly used to create packet sockets that
+ // capture all traffic.
+ EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003
+
+ // EthernetProtocolPUP is the PARC Universial Packet protocol ethertype.
+ EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200
)
+// Ethertypes holds the protocol numbers describing the payload of an ethernet
+// frame. These types aren't necessarily supported by netstack, but can be used
+// to catch all traffic of a type via packet endpoints.
+var Ethertypes = []tcpip.NetworkProtocolNumber{
+ EthernetProtocolAll,
+ EthernetProtocolPUP,
+}
+
// SourceAddress returns the "MAC source" field of the ethernet frame header.
func (b Ethernet) SourceAddress() tcpip.LinkAddress {
return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
@@ -72,3 +112,25 @@ func (b Ethernet) Encode(e *EthernetFields) {
copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}
+
+// IsValidUnicastEthernetAddress returns true if addr is a valid unicast
+// ethernet address.
+func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
+ // Must be of the right length.
+ if len(addr) != EthernetAddressSize {
+ return false
+ }
+
+ // Must not be unspecified.
+ if addr == unspecifiedEthernetAddress {
+ return false
+ }
+
+ // Must not be a multicast.
+ if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
+ return false
+ }
+
+ // addr is a valid unicast ethernet address.
+ return true
+}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
new file mode 100644
index 000000000..6634c90f5
--- /dev/null
+++ b/pkg/tcpip/header/eth_test.go
@@ -0,0 +1,68 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestIsValidUnicastEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.LinkAddress
+ expected bool
+ }{
+ {
+ "Nil",
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty",
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "InvalidLength",
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Unspecified",
+ unspecifiedEthernetAddress,
+ false,
+ },
+ {
+ "Multicast",
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ false,
+ },
+ {
+ "Valid",
+ tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
+ true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := IsValidUnicastEthernetAddress(test.addr); got != test.expected {
+ t.Fatalf("got IsValidUnicastEthernetAddress = %t, want = %t", got, test.expected)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 1125a7d14..c2bfd8c79 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -25,6 +25,12 @@ import (
type ICMPv6 []byte
const (
+ // ICMPv6HeaderSize is the size of the ICMPv6 header. That is, the
+ // sum of the size of the ICMPv6 Type, Code and Checksum fields, as
+ // per RFC 4443 section 2.1. After the ICMPv6 header, the ICMPv6
+ // message body begins.
+ ICMPv6HeaderSize = 4
+
// ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
ICMPv6MinimumSize = 8
@@ -37,10 +43,16 @@ const (
// ICMPv6NeighborSolicitMinimumSize is the minimum size of a
// neighbor solicitation packet.
- ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6HeaderSize + NDPNSMinimumSize
+
+ // ICMPv6NeighborAdvertMinimumSize is the minimum size of a
+ // neighbor advertisement packet.
+ ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize
- // ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
- ICMPv6NeighborAdvertSize = 32
+ // ICMPv6NeighborAdvertSize is size of a neighbor advertisement
+ // including the NDP Target Link Layer option for an Ethernet
+ // address.
+ ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpTargetEthernetLinkLayerAddressSize
// ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
ICMPv6EchoMinimumSize = 8
@@ -68,6 +80,13 @@ const (
// icmpv6SequenceOffset is the offset of the sequence field
// in a ICMPv6 Echo Request/Reply message.
icmpv6SequenceOffset = 6
+
+ // NDPHopLimit is the expected IP hop limit value of 255 for received
+ // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
+ // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
+ // drop the NDP packet. All outgoing NDP packets must use this value for
+ // its IP hop limit field.
+ NDPHopLimit = 255
)
// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
@@ -166,6 +185,13 @@ func (b ICMPv6) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
+// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6
+// packet's message body as defined by RFC 4443 section 2.1; the portion of the
+// ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
+func (b ICMPv6) NDPPayload() []byte {
+ return b[ICMPv6HeaderSize:]
+}
+
// Payload implements Transport.Payload.
func (b ICMPv6) Payload() []byte {
return b[ICMPv6PayloadOffset:]
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 9d3abc0e4..f1e60911b 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -87,7 +87,8 @@ const (
// section 5.
IPv6MinimumMTU = 1280
- // IPv6Any is the non-routable IPv6 "any" meta address.
+ // IPv6Any is the non-routable IPv6 "any" meta address. It is also
+ // known as the unspecified address.
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
)
@@ -100,6 +101,15 @@ var IPv6EmptySubnet = func() tcpip.Subnet {
return subnet
}()
+// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
+// by RFC 4291 section 2.5.6.
+//
+// The prefix is fe80::/64
+var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{
+ Address: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: 64,
+}
+
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
diff --git a/pkg/tcpip/header/ndp_neighbor_advert.go b/pkg/tcpip/header/ndp_neighbor_advert.go
new file mode 100644
index 000000000..505c92668
--- /dev/null
+++ b/pkg/tcpip/header/ndp_neighbor_advert.go
@@ -0,0 +1,110 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will
+// only contain the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.4 for more details.
+type NDPNeighborAdvert []byte
+
+const (
+ // NDPNAMinimumSize is the minimum size of a valid NDP Neighbor
+ // Advertisement message (body of an ICMPv6 packet).
+ NDPNAMinimumSize = 20
+
+ // ndpNATargetAddressOffset is the start of the Target Address
+ // field within an NDPNeighborAdvert.
+ ndpNATargetAddressOffset = 4
+
+ // ndpNAOptionsOffset is the start of the NDP options in an
+ // NDPNeighborAdvert.
+ ndpNAOptionsOffset = ndpNATargetAddressOffset + IPv6AddressSize
+
+ // ndpNAFlagsOffset is the offset of the flags within an
+ // NDPNeighborAdvert
+ ndpNAFlagsOffset = 0
+
+ // ndpNARouterFlagMask is the mask of the Router Flag field in
+ // the flags byte within in an NDPNeighborAdvert.
+ ndpNARouterFlagMask = (1 << 7)
+
+ // ndpNASolicitedFlagMask is the mask of the Solicited Flag field in
+ // the flags byte within in an NDPNeighborAdvert.
+ ndpNASolicitedFlagMask = (1 << 6)
+
+ // ndpNAOverrideFlagMask is the mask of the Override Flag field in
+ // the flags byte within in an NDPNeighborAdvert.
+ ndpNAOverrideFlagMask = (1 << 5)
+)
+
+// TargetAddress returns the value within the Target Address field.
+func (b NDPNeighborAdvert) TargetAddress() tcpip.Address {
+ return tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize])
+}
+
+// SetTargetAddress sets the value within the Target Address field.
+func (b NDPNeighborAdvert) SetTargetAddress(addr tcpip.Address) {
+ copy(b[ndpNATargetAddressOffset:][:IPv6AddressSize], addr)
+}
+
+// RouterFlag returns the value of the Router Flag field.
+func (b NDPNeighborAdvert) RouterFlag() bool {
+ return b[ndpNAFlagsOffset]&ndpNARouterFlagMask != 0
+}
+
+// SetRouterFlag sets the value in the Router Flag field.
+func (b NDPNeighborAdvert) SetRouterFlag(f bool) {
+ if f {
+ b[ndpNAFlagsOffset] |= ndpNARouterFlagMask
+ } else {
+ b[ndpNAFlagsOffset] &^= ndpNARouterFlagMask
+ }
+}
+
+// SolicitedFlag returns the value of the Solicited Flag field.
+func (b NDPNeighborAdvert) SolicitedFlag() bool {
+ return b[ndpNAFlagsOffset]&ndpNASolicitedFlagMask != 0
+}
+
+// SetSolicitedFlag sets the value in the Solicited Flag field.
+func (b NDPNeighborAdvert) SetSolicitedFlag(f bool) {
+ if f {
+ b[ndpNAFlagsOffset] |= ndpNASolicitedFlagMask
+ } else {
+ b[ndpNAFlagsOffset] &^= ndpNASolicitedFlagMask
+ }
+}
+
+// OverrideFlag returns the value of the Override Flag field.
+func (b NDPNeighborAdvert) OverrideFlag() bool {
+ return b[ndpNAFlagsOffset]&ndpNAOverrideFlagMask != 0
+}
+
+// SetOverrideFlag sets the value in the Override Flag field.
+func (b NDPNeighborAdvert) SetOverrideFlag(f bool) {
+ if f {
+ b[ndpNAFlagsOffset] |= ndpNAOverrideFlagMask
+ } else {
+ b[ndpNAFlagsOffset] &^= ndpNAOverrideFlagMask
+ }
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPNeighborAdvert) Options() NDPOptions {
+ return NDPOptions(b[ndpNAOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_neighbor_solicit.go b/pkg/tcpip/header/ndp_neighbor_solicit.go
new file mode 100644
index 000000000..3a1b8e139
--- /dev/null
+++ b/pkg/tcpip/header/ndp_neighbor_solicit.go
@@ -0,0 +1,52 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only
+// contain the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.3 for more details.
+type NDPNeighborSolicit []byte
+
+const (
+ // NDPNSMinimumSize is the minimum size of a valid NDP Neighbor
+ // Solicitation message (body of an ICMPv6 packet).
+ NDPNSMinimumSize = 20
+
+ // ndpNSTargetAddessOffset is the start of the Target Address
+ // field within an NDPNeighborSolicit.
+ ndpNSTargetAddessOffset = 4
+
+ // ndpNSOptionsOffset is the start of the NDP options in an
+ // NDPNeighborSolicit.
+ ndpNSOptionsOffset = ndpNSTargetAddessOffset + IPv6AddressSize
+)
+
+// TargetAddress returns the value within the Target Address field.
+func (b NDPNeighborSolicit) TargetAddress() tcpip.Address {
+ return tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize])
+}
+
+// SetTargetAddress sets the value within the Target Address field.
+func (b NDPNeighborSolicit) SetTargetAddress(addr tcpip.Address) {
+ copy(b[ndpNSTargetAddessOffset:][:IPv6AddressSize], addr)
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPNeighborSolicit) Options() NDPOptions {
+ return NDPOptions(b[ndpNSOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
new file mode 100644
index 000000000..b28bde15b
--- /dev/null
+++ b/pkg/tcpip/header/ndp_options.go
@@ -0,0 +1,172 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // NDPTargetLinkLayerAddressOptionType is the type of the Target
+ // Link-Layer Address option, as per RFC 4861 section 4.6.1.
+ NDPTargetLinkLayerAddressOptionType = 2
+
+ // ndpTargetEthernetLinkLayerAddressSize is the size of a Target
+ // Link Layer Option for an Ethernet address.
+ ndpTargetEthernetLinkLayerAddressSize = 8
+
+ // lengthByteUnits is the multiplier factor for the Length field of an
+ // NDP option. That is, the length field for NDP options is in units of
+ // 8 octets, as per RFC 4861 section 4.6.
+ lengthByteUnits = 8
+)
+
+// NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6.
+type NDPOptions []byte
+
+// Serialize serializes the provided list of NDP options into o.
+//
+// Note, b must be of sufficient size to hold all the options in s. See
+// NDPOptionsSerializer.Length for details on the getting the total size
+// of a serialized NDPOptionsSerializer.
+//
+// Serialize may panic if b is not of sufficient size to hold all the options
+// in s.
+func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
+ done := 0
+
+ for _, o := range s {
+ l := paddedLength(o)
+
+ if l == 0 {
+ continue
+ }
+
+ b[0] = o.Type()
+
+ // We know this safe because paddedLength would have returned
+ // 0 if o had an invalid length (> 255 * lengthByteUnits).
+ b[1] = uint8(l / lengthByteUnits)
+
+ // Serialize NDP option body.
+ used := o.serializeInto(b[2:])
+
+ // Zero out remaining (padding) bytes, if any exists.
+ for i := used + 2; i < l; i++ {
+ b[i] = 0
+ }
+
+ b = b[l:]
+ done += l
+ }
+
+ return done
+}
+
+// ndpOption is the set of functions to be implemented by all NDP option types.
+type ndpOption interface {
+ // Type returns the type of this ndpOption.
+ Type() uint8
+
+ // Length returns the length of the body of this ndpOption, in bytes.
+ Length() int
+
+ // serializeInto serializes this ndpOption into the provided byte
+ // buffer.
+ //
+ // Note, the caller MUST provide a byte buffer with size of at least
+ // Length. Implementers of this function may assume that the byte buffer
+ // is of sufficient size. serializeInto MAY panic if the provided byte
+ // buffer is not of sufficient size.
+ //
+ // serializeInto will return the number of bytes that was used to
+ // serialize this ndpOption. Implementers must only use the number of
+ // bytes required to serialize this ndpOption. Callers MAY provide a
+ // larger buffer than required to serialize into.
+ serializeInto([]byte) int
+}
+
+// paddedLength returns the length of o, in bytes, with any padding bytes, if
+// required.
+func paddedLength(o ndpOption) int {
+ l := o.Length()
+
+ if l == 0 {
+ return 0
+ }
+
+ // Length excludes the 2 Type and Length bytes.
+ l += 2
+
+ // Add extra bytes if needed to make sure the option is
+ // lengthByteUnits-byte aligned. We do this by adding lengthByteUnits-1
+ // to l and then stripping off the last few LSBits from l. This will
+ // make sure that l is rounded up to the nearest unit of
+ // lengthByteUnits. This works since lengthByteUnits is a power of 2
+ // (= 8).
+ mask := lengthByteUnits - 1
+ l += mask
+ l &^= mask
+
+ if l/lengthByteUnits > 255 {
+ // Should never happen because an option can only have a max
+ // value of 255 for its Length field, so just return 0 so this
+ // option does not get serialized.
+ //
+ // Returning 0 here will make sure that this option does not get
+ // serialized when NDPOptions.Serialize is called with the
+ // NDPOptionsSerializer that holds this option, effectively
+ // skipping this option during serialization. Also note that
+ // a value of zero for the Length field in an NDP option is
+ // invalid so this is another sign to the caller that this NDP
+ // option is malformed, as per RFC 4861 section 4.6.
+ return 0
+ }
+
+ return l
+}
+
+// NDPOptionsSerializer is a serializer for NDP options.
+type NDPOptionsSerializer []ndpOption
+
+// Length returns the total number of bytes required to serialize.
+func (b NDPOptionsSerializer) Length() int {
+ l := 0
+
+ for _, o := range b {
+ l += paddedLength(o)
+ }
+
+ return l
+}
+
+// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option
+// as defined by RFC 4861 section 4.6.1.
+type NDPTargetLinkLayerAddressOption tcpip.LinkAddress
+
+// Type implements ndpOption.Type.
+func (o NDPTargetLinkLayerAddressOption) Type() uint8 {
+ return NDPTargetLinkLayerAddressOptionType
+}
+
+// Length implements ndpOption.Length.
+func (o NDPTargetLinkLayerAddressOption) Length() int {
+ return len(o)
+}
+
+// serializeInto implements ndpOption.serializeInto.
+func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
+ return copy(b, o)
+}
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
new file mode 100644
index 000000000..bf7610863
--- /dev/null
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -0,0 +1,112 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "encoding/binary"
+ "time"
+)
+
+// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
+// the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.2 for more details.
+type NDPRouterAdvert []byte
+
+const (
+ // NDPRAMinimumSize is the minimum size of a valid NDP Router
+ // Advertisement message (body of an ICMPv6 packet).
+ NDPRAMinimumSize = 12
+
+ // ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field
+ // within an NDPRouterAdvert.
+ ndpRACurrHopLimitOffset = 0
+
+ // ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags
+ // within an NDPRouterAdvert.
+ ndpRAFlagsOffset = 1
+
+ // ndpRAManagedAddrConfFlagMask is the mask of the Managed Address
+ // Configuration flag within the bit-field/flags byte of an
+ // NDPRouterAdvert.
+ ndpRAManagedAddrConfFlagMask = (1 << 7)
+
+ // ndpRAOtherConfFlagMask is the mask of the Other Configuration flag
+ // within the bit-field/flags byte of an NDPRouterAdvert.
+ ndpRAOtherConfFlagMask = (1 << 6)
+
+ // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
+ // field within an NDPRouterAdvert.
+ ndpRARouterLifetimeOffset = 2
+
+ // ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time
+ // field within an NDPRouterAdvert.
+ ndpRAReachableTimeOffset = 4
+
+ // ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer
+ // field within an NDPRouterAdvert.
+ ndpRARetransTimerOffset = 8
+
+ // ndpRAOptionsOffset is the start of the NDP options in an
+ // NDPRouterAdvert.
+ ndpRAOptionsOffset = 12
+)
+
+// CurrHopLimit returns the value of the Curr Hop Limit field.
+func (b NDPRouterAdvert) CurrHopLimit() uint8 {
+ return b[ndpRACurrHopLimitOffset]
+}
+
+// ManagedAddrConfFlag returns the value of the Managed Address Configuration
+// flag.
+func (b NDPRouterAdvert) ManagedAddrConfFlag() bool {
+ return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0
+}
+
+// OtherConfFlag returns the value of the Other Configuration flag.
+func (b NDPRouterAdvert) OtherConfFlag() bool {
+ return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
+}
+
+// RouterLifetime returns the lifetime associated with the default router. A
+// value of 0 means the source of the Router Advertisement is not a default
+// router and SHOULD NOT appear on the default router list. Note, a value of 0
+// only means that the router should not be used as a default router, it does
+// not apply to other information contained in the Router Advertisement.
+func (b NDPRouterAdvert) RouterLifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 4861 section 4.2.
+ return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:]))
+}
+
+// ReachableTime returns the time that a node assumes a neighbor is reachable
+// after having received a reachability confirmation. A value of 0 means
+// that it is unspecified by the source of the Router Advertisement message.
+func (b NDPRouterAdvert) ReachableTime() time.Duration {
+ // The field is the time in milliseconds, as per RFC 4861 section 4.2.
+ return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:]))
+}
+
+// RetransTimer returns the time between retransmitted Neighbor Solicitation
+// messages. A value of 0 means that it is unspecified by the source of the
+// Router Advertisement message.
+func (b NDPRouterAdvert) RetransTimer() time.Duration {
+ // The field is the time in milliseconds, as per RFC 4861 section 4.2.
+ return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:]))
+}
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPRouterAdvert) Options() NDPOptions {
+ return NDPOptions(b[ndpRAOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
new file mode 100644
index 000000000..0aac14f43
--- /dev/null
+++ b/pkg/tcpip/header/ndp_test.go
@@ -0,0 +1,199 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bytes"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit.
+func TestNDPNeighborSolicit(t *testing.T) {
+ b := []byte{
+ 0, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ }
+
+ // Test getting the Target Address.
+ ns := NDPNeighborSolicit(b)
+ addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ if got := ns.TargetAddress(); got != addr {
+ t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr)
+ }
+
+ // Test updating the Target Address.
+ addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ ns.SetTargetAddress(addr2)
+ if got := ns.TargetAddress(); got != addr2 {
+ t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr2)
+ }
+ // Make sure the address got updated in the backing buffer.
+ if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 {
+ t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2)
+ }
+}
+
+// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert.
+func TestNDPNeighborAdvert(t *testing.T) {
+ b := []byte{
+ 160, 0, 0, 0,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ }
+
+ // Test getting the Target Address.
+ na := NDPNeighborAdvert(b)
+ addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
+ if got := na.TargetAddress(); got != addr {
+ t.Fatalf("got TargetAddress = %s, want %s", got, addr)
+ }
+
+ // Test getting the Router Flag.
+ if got := na.RouterFlag(); !got {
+ t.Fatalf("got RouterFlag = false, want = true")
+ }
+
+ // Test getting the Solicited Flag.
+ if got := na.SolicitedFlag(); got {
+ t.Fatalf("got SolicitedFlag = true, want = false")
+ }
+
+ // Test getting the Override Flag.
+ if got := na.OverrideFlag(); !got {
+ t.Fatalf("got OverrideFlag = false, want = true")
+ }
+
+ // Test updating the Target Address.
+ addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
+ na.SetTargetAddress(addr2)
+ if got := na.TargetAddress(); got != addr2 {
+ t.Fatalf("got TargetAddress = %s, want %s", got, addr2)
+ }
+ // Make sure the address got updated in the backing buffer.
+ if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 {
+ t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2)
+ }
+
+ // Test updating the Router Flag.
+ na.SetRouterFlag(false)
+ if got := na.RouterFlag(); got {
+ t.Fatalf("got RouterFlag = true, want = false")
+ }
+
+ // Test updating the Solicited Flag.
+ na.SetSolicitedFlag(true)
+ if got := na.SolicitedFlag(); !got {
+ t.Fatalf("got SolicitedFlag = false, want = true")
+ }
+
+ // Test updating the Override Flag.
+ na.SetOverrideFlag(false)
+ if got := na.OverrideFlag(); got {
+ t.Fatalf("got OverrideFlag = true, want = false")
+ }
+
+ // Make sure flags got updated in the backing buffer.
+ if got := b[ndpNAFlagsOffset]; got != 64 {
+ t.Fatalf("got flags byte = %d, want = 64")
+ }
+}
+
+func TestNDPRouterAdvert(t *testing.T) {
+ b := []byte{
+ 64, 128, 1, 2,
+ 3, 4, 5, 6,
+ 7, 8, 9, 10,
+ }
+
+ ra := NDPRouterAdvert(b)
+
+ if got := ra.CurrHopLimit(); got != 64 {
+ t.Fatalf("got ra.CurrHopLimit = %d, want = 64", got)
+ }
+
+ if got := ra.ManagedAddrConfFlag(); !got {
+ t.Fatalf("got ManagedAddrConfFlag = false, want = true")
+ }
+
+ if got := ra.OtherConfFlag(); got {
+ t.Fatalf("got OtherConfFlag = true, want = false")
+ }
+
+ if got, want := ra.RouterLifetime(), time.Second*258; got != want {
+ t.Fatalf("got ra.RouterLifetime = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
+ t.Fatalf("got ra.ReachableTime = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
+ t.Fatalf("got ra.RetransTimer = %d, want = %d", got, want)
+ }
+}
+
+// TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a
+// NDPTargetLinkLayerAddressOption.
+func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expectedBuf []byte
+ addr tcpip.LinkAddress
+ }{
+ {
+ "Ethernet",
+ make([]byte, 8),
+ []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ "\x01\x02\x03\x04\x05\x06",
+ },
+ {
+ "Padding",
+ []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ []byte{2, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ "\x01\x02\x03\x04\x05\x06\x07\x08",
+ },
+ {
+ "Empty",
+ []byte{},
+ []byte{},
+ "",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ serializer := NDPOptionsSerializer{
+ NDPTargetLinkLayerAddressOption(test.addr),
+ }
+ if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
+ t.Fatalf("got Length = %d, want = %d", got, want)
+ }
+ opts.Serialize(serializer)
+ if !bytes.Equal(test.buf, test.expectedBuf) {
+ t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 18adb2085..14f197a77 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -72,7 +72,7 @@ func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.Vector
// InjectLinkAddr injects an inbound packet with a remote link address.
func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) {
- e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil))
+ e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil), nil /* linkHeader */)
}
// Attach saves the stack network-layer dispatcher for use later when packets
@@ -96,7 +96,7 @@ func (e *Endpoint) MTU() uint32 {
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
caps := stack.LinkEndpointCapabilities(0)
if e.GSO {
- caps |= stack.CapabilityGSO
+ caps |= stack.CapabilityHardwareGSO
}
return caps
}
@@ -134,5 +134,49 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+// WritePackets stores outbound packets into the channel.
+func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ payloadView := payload.ToView()
+ n := 0
+packetLoop:
+ for i := range hdrs {
+ hdr := &hdrs[i].Hdr
+ off := hdrs[i].Off
+ size := hdrs[i].Size
+ p := PacketInfo{
+ Header: hdr.View(),
+ Proto: protocol,
+ Payload: buffer.NewViewFromBytes(payloadView[off : off+size]),
+ GSO: gso,
+ }
+
+ select {
+ case e.C <- p:
+ n++
+ default:
+ break packetLoop
+ }
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ p := PacketInfo{
+ Header: packet.ToView(),
+ Proto: 0,
+ Payload: buffer.View{},
+ GSO: nil,
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index f80ac3435..ae4858529 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -165,6 +165,9 @@ type Options struct {
// disabled.
GSOMaxSize uint32
+ // SoftwareGSOEnabled indicates whether software GSO is enabled or not.
+ SoftwareGSOEnabled bool
+
// PacketDispatchMode specifies the type of inbound dispatcher to be
// used for this endpoint.
PacketDispatchMode PacketDispatchMode
@@ -242,7 +245,11 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
}
if isSocket {
if opts.GSOMaxSize != 0 {
- e.caps |= stack.CapabilityGSO
+ if opts.SoftwareGSOEnabled {
+ e.caps |= stack.CapabilitySoftwareGSO
+ } else {
+ e.caps |= stack.CapabilityHardwareGSO
+ }
e.gsoMaxSize = opts.GSOMaxSize
}
}
@@ -397,7 +404,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
eth.Encode(ethHdr)
}
- if e.Capabilities()&stack.CapabilityGSO != 0 {
+ if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
vnetHdr := virtioNetHdr{}
vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
if gso != nil {
@@ -430,8 +437,130 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil)
}
-// WriteRawPacket writes a raw packet directly to the file descriptor.
-func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
+// WritePackets writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ var ethHdrBuf []byte
+ // hdr + data
+ iovLen := 2
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ ethHdrBuf = make([]byte, header.EthernetMinimumSize)
+ eth := header.Ethernet(ethHdrBuf)
+ ethHdr := &header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ Type: protocol,
+ }
+
+ // Preserve the src address if it's set in the route.
+ if r.LocalLinkAddress != "" {
+ ethHdr.SrcAddr = r.LocalLinkAddress
+ } else {
+ ethHdr.SrcAddr = e.addr
+ }
+ eth.Encode(ethHdr)
+ iovLen++
+ }
+
+ n := len(hdrs)
+
+ views := payload.Views()
+ /*
+ * Each bondary in views can add one more iovec.
+ *
+ * payload | | | |
+ * -----------------------------
+ * packets | | | | | | |
+ * -----------------------------
+ * iovecs | | | | | | | | |
+ */
+ iovec := make([]syscall.Iovec, n*iovLen+len(views)-1)
+ mmsgHdrs := make([]rawfile.MMsgHdr, n)
+
+ iovecIdx := 0
+ viewIdx := 0
+ viewOff := 0
+ off := 0
+ nextOff := 0
+ for i := range hdrs {
+ prevIovecIdx := iovecIdx
+ mmsgHdr := &mmsgHdrs[i]
+ mmsgHdr.Msg.Iov = &iovec[iovecIdx]
+ packetSize := hdrs[i].Size
+ hdr := &hdrs[i].Hdr
+
+ off = hdrs[i].Off
+ if off != nextOff {
+ // We stop in a different point last time.
+ size := packetSize
+ viewIdx = 0
+ viewOff = 0
+ for size > 0 {
+ if size >= len(views[viewIdx]) {
+ viewIdx++
+ viewOff = 0
+ size -= len(views[viewIdx])
+ } else {
+ viewOff = size
+ size = 0
+ }
+ }
+ }
+ nextOff = off + packetSize
+
+ if ethHdrBuf != nil {
+ v := &iovec[iovecIdx]
+ v.Base = &ethHdrBuf[0]
+ v.Len = uint64(len(ethHdrBuf))
+ iovecIdx++
+ }
+
+ v := &iovec[iovecIdx]
+ hdrView := hdr.View()
+ v.Base = &hdrView[0]
+ v.Len = uint64(len(hdrView))
+ iovecIdx++
+
+ for packetSize > 0 {
+ vec := &iovec[iovecIdx]
+ iovecIdx++
+
+ v := views[viewIdx]
+ vec.Base = &v[viewOff]
+ s := len(v) - viewOff
+ if s <= packetSize {
+ viewIdx++
+ viewOff = 0
+ } else {
+ s = packetSize
+ viewOff += s
+ }
+ vec.Len = uint64(s)
+ packetSize -= s
+ }
+
+ mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx)
+ }
+
+ packets := 0
+ for packets < n {
+ sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs)
+ if err != nil {
+ return packets, err
+ }
+ packets += sent
+ mmsgHdrs = mmsgHdrs[sent:]
+ }
+ return packets, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ return rawfile.NonBlockingWrite(e.fds[0], packet.ToView())
+}
+
+// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
+func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
return rawfile.NonBlockingWrite(e.fds[0], packet)
}
@@ -468,9 +597,9 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
-// Inject injects an inbound packet.
-func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+// InjectInbound injects an inbound packet.
+func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */)
}
// NewInjectable creates a new fd-based InjectableEndpoint.
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 04406bc9a..59378b96c 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -43,9 +43,10 @@ const (
)
type packetInfo struct {
- raddr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- contents buffer.View
+ raddr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ contents buffer.View
+ linkHeader buffer.View
}
type context struct {
@@ -92,8 +93,8 @@ func (c *context) cleanup() {
syscall.Close(c.fds[1])
}
-func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- c.ch <- packetInfo{remote, protocol, vv.ToView()}
+func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
+ c.ch <- packetInfo{remote, protocol, vv.ToView(), linkHeader}
}
func TestNoEthernetProperties(t *testing.T) {
@@ -293,11 +294,12 @@ func TestDeliverPacket(t *testing.T) {
b[i] = uint8(rand.Intn(256))
}
+ var hdr header.Ethernet
if !eth {
// So that it looks like an IPv4 packet.
b[0] = 0x40
} else {
- hdr := make(header.Ethernet, header.EthernetMinimumSize)
+ hdr = make(header.Ethernet, header.EthernetMinimumSize)
hdr.Encode(&header.EthernetFields{
SrcAddr: raddr,
DstAddr: laddr,
@@ -315,9 +317,10 @@ func TestDeliverPacket(t *testing.T) {
select {
case pi := <-c.ch:
want := packetInfo{
- raddr: raddr,
- proto: proto,
- contents: b,
+ raddr: raddr,
+ proto: proto,
+ contents: b,
+ linkHeader: buffer.View(hdr),
}
if !eth {
want.proto = header.IPv4ProtocolNumber
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 8bfeb97e4..554d45715 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -169,9 +169,10 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
+ eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth := header.Ethernet(pkt)
+ eth = header.Ethernet(pkt)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -189,6 +190,6 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}), buffer.View(eth))
return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 7ca217e5b..12168a1dc 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -53,7 +53,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
d := &readVDispatcher{fd: fd, e: e}
d.views = make([]buffer.View, len(BufConfig))
iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
iovLen++
}
d.iovecs = make([]syscall.Iovec, iovLen)
@@ -63,7 +63,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
func (d *readVDispatcher) allocateViews(bufConfig []int) {
var vnetHdr [virtioNetHdrSize]byte
vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// The kernel adds virtioNetHdr before each packet, but
// we don't use it, so so we allocate a buffer for it,
// add it in iovecs but don't add it in a view.
@@ -106,7 +106,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
if err != nil {
return false, err
}
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// Skip virtioNetHdr which is added before each packet, it
// isn't used and it isn't in a view.
n -= virtioNetHdrSize
@@ -118,9 +118,10 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
+ eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth := header.Ethernet(d.views[0])
+ eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize])
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -141,7 +142,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
vv := buffer.NewVectorisedView(n, d.views[:used])
vv.TrimFront(d.e.hdrSize)
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth))
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
@@ -194,7 +195,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
}
d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv)
iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// virtioNetHdr is prepended before each packet.
iovLen++
}
@@ -225,7 +226,7 @@ func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) {
for k := 0; k < len(d.views); k++ {
var vnetHdr [virtioNetHdrSize]byte
vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
// The kernel adds virtioNetHdr before each packet, but
// we don't use it, so so we allocate a buffer for it,
// add it in iovecs but don't add it in a view.
@@ -261,7 +262,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
// Process each of received packets.
for k := 0; k < nMsgs; k++ {
n := int(d.msgHdrs[k].Len)
- if d.e.Capabilities()&stack.CapabilityGSO != 0 {
+ if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
n -= virtioNetHdrSize
}
if n <= d.e.hdrSize {
@@ -271,9 +272,10 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
+ eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth := header.Ethernet(d.views[k][0])
+ eth = header.Ethernet(d.views[k][0])
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -293,7 +295,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
used := d.capViews(k, int(n), BufConfig)
vv := buffer.NewVectorisedView(int(n), d.views[k][:used])
vv.TrimFront(d.e.hdrSize)
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv)
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth))
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
index 47a54845c..23e4d1418 100644
--- a/pkg/tcpip/link/loopback/BUILD
+++ b/pkg/tcpip/link/loopback/BUILD
@@ -10,6 +10,7 @@ go_library(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index b36629d2c..a3b48fa73 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -23,6 +23,7 @@ package loopback
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,6 +71,9 @@ func (*endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
@@ -81,10 +85,27 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa
// Because we're immediately turning around and writing the packet back to the
// rx path, we intentionally don't preserve the remote and local link
// addresses from the stack.Route we're passed.
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv)
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */)
return nil
}
-// Wait implements stack.LinkEndpoint.Wait.
-func (*endpoint) Wait() {}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ // Reject the packet if it's shorter than an ethernet header.
+ if packet.Size() < header.EthernetMinimumSize {
+ return tcpip.ErrBadAddress
+ }
+
+ // There should be an ethernet header at the beginning of packet.
+ linkHeader := header.Ethernet(packet.First()[:header.EthernetMinimumSize])
+ packet.TrimFront(len(linkHeader))
+ e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), packet, buffer.View(linkHeader))
+
+ return nil
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 7c946101d..682b60291 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -79,29 +79,47 @@ func (m *InjectableEndpoint) IsAttached() bool {
return m.dispatcher != nil
}
-// Inject implements stack.InjectableLinkEndpoint.
-func (m *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv)
+// InjectInbound implements stack.InjectableLinkEndpoint.
+func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+ m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */)
+}
+
+// WritePackets writes outbound packets to the appropriate
+// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
+// r.RemoteAddress has a route registered in this endpoint.
+func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ endpoint, ok := m.routes[r.RemoteAddress]
+ if !ok {
+ return 0, tcpip.ErrNoRoute
+ }
+ return endpoint.WritePackets(r, gso, hdrs, payload, protocol)
}
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
- return endpoint.WritePacket(r, nil /* gso */, hdr, payload, protocol)
+ return endpoint.WritePacket(r, gso, hdr, payload, protocol)
}
return tcpip.ErrNoRoute
}
-// WriteRawPacket writes outbound packets to the appropriate
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (m *InjectableEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ // WriteRawPacket doesn't get a route or network address, so there's
+ // nowhere to write this.
+ return tcpip.ErrNoRoute
+}
+
+// InjectOutbound writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the dest address.
-func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
+func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
endpoint, ok := m.routes[dest]
if !ok {
return tcpip.ErrNoRoute
}
- return endpoint.WriteRawPacket(dest, packet)
+ return endpoint.InjectOutbound(dest, packet)
}
// Wait implements stack.LinkEndpoint.Wait.
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 3086fec00..9cd300af8 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -31,7 +31,7 @@ import (
func TestInjectableEndpointRawDispatch(t *testing.T) {
endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
- endpoint.WriteRawPacket(dstIP, []byte{0xFA})
+ endpoint.InjectOutbound(dstIP, []byte{0xFA})
buf := make([]byte, ipv4.MaxTotalSize)
bytesRead, err := sock.Read(buf)
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 2e8bc772a..05c7b8024 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -16,5 +16,8 @@ go_library(
visibility = [
"//visibility:public",
],
- deps = ["//pkg/tcpip"],
+ deps = [
+ "//pkg/tcpip",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
)
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 7e286a3a6..44e25d475 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -22,6 +22,7 @@ import (
"syscall"
"unsafe"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -101,6 +102,16 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
return nil
}
+// NonBlockingSendMMsg sends multiple messages on a socket.
+func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+ n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
+ if e != 0 {
+ return 0, TranslateErrno(e)
+ }
+
+ return int(n), nil
+}
+
// PollEvent represents the pollfd structure passed to a poll() system call.
type PollEvent struct {
FD int32
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 9e71d4edf..279e2b457 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -212,6 +212,26 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa
return nil
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ v := packet.ToView()
+ // Transmit the packet.
+ e.mu.Lock()
+ ok := e.tx.transmit(v, buffer.View{})
+ e.mu.Unlock()
+
+ if !ok {
+ return tcpip.ErrWouldBlock
+ }
+
+ return nil
+}
+
// dispatchLoop reads packets from the rx queue in a loop and dispatches them
// to the network stack.
func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
@@ -254,7 +274,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
// Send packet up the stack.
eth := header.Ethernet(b)
- d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView())
+ d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), buffer.View(eth))
}
// Clean state.
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 0e9ba0846..f3e9705c9 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -78,9 +78,10 @@ func (q *queueBuffers) cleanup() {
}
type packetInfo struct {
- addr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- vv buffer.VectorisedView
+ addr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ vv buffer.VectorisedView
+ linkHeader buffer.View
}
type testContext struct {
@@ -130,12 +131,13 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
return c
}
-func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
c.mu.Lock()
c.packets = append(c.packets, packetInfo{
- addr: remoteLinkAddr,
- proto: proto,
- vv: vv.Clone(nil),
+ addr: remoteLinkAddr,
+ proto: proto,
+ vv: vv.Clone(nil),
+ linkHeader: linkHeader,
})
c.mu.Unlock()
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index e401dce44..39757ea2a 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -116,7 +116,7 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
-func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
logPacket("recv", protocol, vv.First(), nil)
}
@@ -147,7 +147,7 @@ func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local
panic(err)
}
}
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv)
+ e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader)
}
// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
@@ -193,10 +193,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-// WritePacket implements the stack.LinkEndpoint interface. It is called by
-// higher-level protocols to write packets; it just logs the packet and forwards
-// the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
logPacket("send", protocol, hdr.View(), gso)
}
@@ -218,28 +215,74 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
panic(err)
}
length -= len(hdrBuf)
- if length > 0 {
- for _, v := range payload.Views() {
- if len(v) > length {
- v = v[:length]
- }
- n, err := buf.Write(v)
- if err != nil {
- panic(err)
- }
- length -= n
- if length == 0 {
- break
- }
- }
- }
+ logVectorisedView(payload, length, buf)
if _, err := e.file.Write(buf.Bytes()); err != nil {
panic(err)
}
}
+}
+
+// WritePacket implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and
+// forwards the request to the lower endpoint.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ e.dumpPacket(gso, hdr, payload, protocol)
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
+// WritePackets implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and
+// forwards the request to the lower endpoint.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ view := payload.ToView()
+ for _, d := range hdrs {
+ e.dumpPacket(gso, d.Hdr, buffer.NewVectorisedView(d.Size, []buffer.View{view[d.Off:][:d.Size]}), protocol)
+ }
+ return e.lower.WritePackets(r, gso, hdrs, payload, protocol)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
+ logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */)
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ length := packet.Size()
+ if length > int(e.maxPCAPLen) {
+ length = int(e.maxPCAPLen)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
+ if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(packet.Size()))); err != nil {
+ panic(err)
+ }
+ logVectorisedView(packet, length, buf)
+ if _, err := e.file.Write(buf.Bytes()); err != nil {
+ panic(err)
+ }
+ }
+ return e.lower.WriteRawPacket(packet)
+}
+
+func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) {
+ if length <= 0 {
+ return
+ }
+ for _, v := range vv.Views() {
+ if len(v) > length {
+ v = v[:length]
+ }
+ n, err := buf.Write(v)
+ if err != nil {
+ panic(err)
+ }
+ length -= n
+ if length == 0 {
+ return
+ }
+ }
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 5a1791cb5..a04fc1062 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint {
// It is called by the link-layer endpoint being wrapped when a packet arrives,
// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
// been called.
-func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
if !e.dispatchGate.Enter() {
return
}
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv)
+ e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader)
e.dispatchGate.Leave()
}
@@ -109,6 +109,30 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return err
}
+// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
+// higher-level protocols to write packets. It only forwards packets to the
+// lower endpoint if Wait or WaitWrite haven't been called.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if !e.writeGate.Enter() {
+ return len(hdrs), nil
+ }
+
+ n, err := e.lower.WritePackets(r, gso, hdrs, payload, protocol)
+ e.writeGate.Leave()
+ return n, err
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ if !e.writeGate.Enter() {
+ return nil
+ }
+
+ err := e.lower.WriteRawPacket(packet)
+ e.writeGate.Leave()
+ return err
+}
+
// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
// and waits for inflight ones to finish before returning.
func (e *Endpoint) WaitWrite() {
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index ae23c96b7..5f0f8fa2d 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -35,7 +35,7 @@ type countedEndpoint struct {
dispatcher stack.NetworkDispatcher
}
-func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
e.dispatchCount++
}
@@ -70,6 +70,17 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P
return nil
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ e.writeCount += len(hdrs)
+ return len(hdrs), nil
+}
+
+func (e *countedEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+ e.writeCount++
+ return nil
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
@@ -109,21 +120,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{})
+ ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 922181ac0..46178459e 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -83,6 +83,11 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buf
return tcpip.ErrNotSupported
}
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketDescriptor, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -109,7 +114,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ fallthrough // also fill the cache from requests
case header.ARPReply:
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 825ff3392..2cad0a0b6 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -28,6 +28,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/tcpip",
"//pkg/tcpip/buffer",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 1628a82be..6da5238ec 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -17,6 +17,7 @@
package fragmentation
import (
+ "fmt"
"log"
"sync"
"time"
@@ -82,7 +83,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
// Process processes an incoming fragment belonging to an ID
// and returns a complete packet when all the packets belonging to that ID have been received.
-func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
f.mu.Lock()
r, ok := f.reassemblers[id]
if ok && r.tooOld(f.timeout) {
@@ -97,8 +98,15 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
}
f.mu.Unlock()
- res, done, consumed := r.process(first, last, more, vv)
-
+ res, done, consumed, err := r.process(first, last, more, vv)
+ if err != nil {
+ // We probably got an invalid sequence of fragments. Just
+ // discard the reassembler and move on.
+ f.mu.Lock()
+ f.release(r)
+ f.mu.Unlock()
+ return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err)
+ }
f.mu.Lock()
f.size += consumed
if done {
@@ -114,7 +122,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
}
}
f.mu.Unlock()
- return res, done
+ return res, done, nil
}
func (f *Fragmentation) release(r *reassembler) {
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 799798544..72c0f53be 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -83,7 +83,10 @@ func TestFragmentationProcess(t *testing.T) {
t.Run(c.comment, func(t *testing.T) {
f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
for i, in := range c.in {
- vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ if err != nil {
+ t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err)
+ }
if !reflect.DeepEqual(vv, c.out[i].vv) {
t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
}
@@ -114,7 +117,10 @@ func TestReassemblingTimeout(t *testing.T) {
time.Sleep(2 * timeout)
// Send another fragment that completes a packet.
// However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done := f.Process(0, 1, 1, false, vv(1, "1"))
+ _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
+ if err != nil {
+ t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
+ }
if done {
t.Errorf("Fragmentation does not respect the reassembling timeout.")
}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 8037f734b..9e002e396 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
return used
}
-func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
+func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
@@ -86,7 +86,7 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
- return buffer.VectorisedView{}, false, consumed
+ return buffer.VectorisedView{}, false, consumed, nil
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
@@ -96,13 +96,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise
}
// Check if all the holes have been deleted and we are ready to reassamble.
if r.deleted < len(r.holes) {
- return buffer.VectorisedView{}, false, consumed
+ return buffer.VectorisedView{}, false, consumed, nil
}
res, err := r.heap.reassemble()
if err != nil {
- panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
+ return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err)
}
- return res, true, consumed
+ return res, true, consumed, nil
}
func (r *reassembler) tooOld(timeout time.Duration) bool {
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index f644a8b08..8d74497ba 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -171,6 +171,15 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen
return nil
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, hdr []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index df1a08113..1339f8474 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -198,10 +198,9 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff
return nil
}
-// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- length := uint16(hdr.UsedLength() + payload.Size())
+ length := uint16(hdr.UsedLength() + payloadSize)
id := uint32(0)
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
@@ -219,6 +218,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
+ e.addIPHeader(r, &hdr, payload.Size(), params)
if loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(payload.Views()))
@@ -242,6 +246,23 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ if loop&stack.PacketLoop != 0 {
+ panic("multiple packets in local loop")
+ }
+ if loop&stack.PacketOut == 0 {
+ return len(hdrs), nil
+ }
+
+ for i := range hdrs {
+ e.addIPHeader(r, &hdrs[i].Hdr, hdrs[i].Size, params)
+ }
+ n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+}
+
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
@@ -326,7 +347,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
return
}
var ready bool
- vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ var err error
+ vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
if !ready {
return
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 560638ce8..99f84acd7 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -47,10 +47,6 @@ func TestExcludeBroadcast(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
- }
-
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
NIC: 1,
@@ -441,6 +437,16 @@ func TestInvalidFragments(t *testing.T) {
1,
1,
},
+ {
+ "multiple_fragments_with_more_fragments_set_to_false",
+ [][]byte{
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ },
+ 1,
+ 1,
+ },
}
for _, tc := range testCases {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index b5df85455..b289e902f 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -21,15 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-const (
- // ndpHopLimit is the expected IP hop limit value of 255 for received
- // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
- // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
- // drop the NDP packet. All outgoing NDP packets must use this value for
- // its IP hop limit field.
- ndpHopLimit = 255
-)
-
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
@@ -79,6 +70,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
return
}
h := header.ICMPv6(v)
+ iph := header.IPv6(netHeader)
// As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
// 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
@@ -89,7 +81,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
header.ICMPv6RouterSolicit,
header.ICMPv6RouterAdvert,
header.ICMPv6RedirectMsg:
- if header.IPv6(netHeader).HopLimit() != ndpHopLimit {
+ if iph.HopLimit() != header.NDPHopLimit {
received.Invalid.Increment()
return
}
@@ -121,25 +113,72 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
-
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
+
+ ns := header.NDPNeighborSolicit(h.NDPPayload())
+ targetAddr := ns.TargetAddress()
+ s := r.Stack()
+ rxNICID := r.NICID()
+
+ isTentative, err := s.IsAddrTentative(rxNICID, targetAddr)
+ if err != nil {
+ // We will only get an error if rxNICID is unrecognized,
+ // which should not happen. For now short-circuit this
+ // packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ }
+
+ if isTentative {
+ // If the target address is tentative and the source
+ // of the packet is a unicast (specified) address, then
+ // the source of the packet is attempting to perform
+ // address resolution on the target. In this case, the
+ // solicitation is silently ignored, as per RFC 4862
+ // section 5.4.3.
+ //
+ // If the target address is tentative and the source of
+ // the packet is the unspecified address (::), then we
+ // know another node is also performing DAD for the
+ // same address (since targetAddr is tentative for us,
+ // we know we are also performing DAD on it). In this
+ // case we let the stack know so it can handle such a
+ // scenario and do nothing further with the NDP NS.
+ if iph.SourceAddress() == header.IPv6Any {
+ s.DupTentativeAddrDetected(rxNICID, targetAddr)
+ }
+
+ // Do not handle neighbor solicitations targeted
+ // to an address that is tentative on the received
+ // NIC any further.
+ return
+ }
+
+ // At this point we know that targetAddr is not tentative on
+ // rxNICID so the packet is processed as defined in RFC 4861,
+ // as per RFC 4862 section 5.4.3.
+
if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize)
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]),
+ }
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
- pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag
- copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr)
- pkt[icmpV6OptOffset] = ndpOptDstLinkAddr
- pkt[icmpV6LengthOffset] = 1
- copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:])
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(targetAddr)
+ opts := na.Options()
+ opts.Serialize(optsSerializer)
// ICMPv6 Neighbor Solicit messages are always sent to
// specially crafted IPv6 multicast addresses. As a result, the
@@ -154,7 +193,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
r.LocalAddress = targetAddr
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
+ // TODO(tamird/ghanan): there exists an explicit NDP option that is
+ // used to update the neighbor table with link addresses for a
+ // neighbor from an NS (see the Source Link Layer option RFC
+ // 4861 section 4.6.1 and section 7.2.3).
+ //
+ // Furthermore, the entirety of NDP handling here seems to be
+ // contradicted by RFC 4861.
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+
+ // RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
+ //
+ // 7.1.2. Validation of Neighbor Advertisements
+ //
+ // The IP Hop Limit field has a value of 255, i.e., the packet
+ // could not possibly have been forwarded by a router.
+ if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}); err != nil {
sent.Dropped.Increment()
return
}
@@ -166,7 +220,42 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
+
+ na := header.NDPNeighborAdvert(h.NDPPayload())
+ targetAddr := na.TargetAddress()
+ stack := r.Stack()
+ rxNICID := r.NICID()
+
+ isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr)
+ if err != nil {
+ // We will only get an error if rxNICID is unrecognized,
+ // which should not happen. For now short-circuit this
+ // packet.
+ //
+ // TODO(b/141002840): Handle this better?
+ return
+ }
+
+ if isTentative {
+ // We just got an NA from a node that owns an address we
+ // are performing DAD on, implying the address is not
+ // unique. In this case we let the stack know so it can
+ // handle such a scenario and do nothing furthur with
+ // the NDP NA.
+ stack.DupTentativeAddrDetected(rxNICID, targetAddr)
+ return
+ }
+
+ // At this point we know that the targetAddress is not tentative
+ // on rxNICID. However, targetAddr may still be assigned to
+ // rxNICID but not tentative (it could be permanent). Such a
+ // scenario is beyond the scope of RFC 4862. As such, we simply
+ // ignore such a scenario for now and proceed as normal.
+ //
+ // TODO(b/143147598): Handle the scenario described above. Also
+ // inform the netstack integration that a duplicate address was
+ // detected outside of DAD.
+
e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
@@ -178,7 +267,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
-
vv.TrimFront(header.ICMPv6EchoMinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
@@ -262,7 +350,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: ndpHopLimit,
+ HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 501be208e..7c11dde55 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,7 +15,6 @@
package ipv6
import (
- "fmt"
"reflect"
"strings"
"testing"
@@ -144,7 +143,7 @@ func TestICMPCounts(t *testing.T) {
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: ndpHopLimit,
+ HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
@@ -179,13 +178,10 @@ func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) {
t := v.Type()
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
- switch v.Kind() {
- case reflect.Ptr:
- f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter))
- case reflect.Struct:
+ if s, ok := v.Interface().(*tcpip.StatCounter); ok {
+ f(t.Field(i).Name, s)
+ } else {
visitStats(v, f)
- default:
- panic(fmt.Sprintf("unexpected type %s", v.Type()))
}
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index cd1e34085..5898f8f9e 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -97,9 +97,8 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
- length := uint16(hdr.UsedLength() + payload.Size())
+func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) {
+ length := uint16(hdr.UsedLength() + payloadSize)
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
@@ -109,6 +108,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
+ e.addIPHeader(r, &hdr, payload.Size(), params)
if loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(payload.Views()))
@@ -127,6 +131,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ if loop&stack.PacketLoop != 0 {
+ panic("not implemented")
+ }
+ if loop&stack.PacketOut == 0 {
+ return len(hdrs), nil
+ }
+
+ for i := range hdrs {
+ hdr := &hdrs[i].Hdr
+ size := hdrs[i].Size
+ e.addIPHeader(r, hdr, size, params)
+ }
+
+ n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+}
+
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index e30791fe3..c32716f2e 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -150,7 +150,7 @@ func TestHopLimitValidation(t *testing.T) {
// Receive the NDP packet with an invalid hop limit
// value.
- handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r)
+ handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r)
// Invalid count should have increased.
if got := invalid.Value(); got != 1 {
@@ -164,7 +164,7 @@ func TestHopLimitValidation(t *testing.T) {
}
// Receive the NDP packet with a valid hop limit value.
- handleIPv6Payload(hdr, ndpHopLimit, ep, &r)
+ handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r)
// Rx count of NDP packet of type typ.typ should have
// increased.
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 6a78432c9..460db3cf8 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -22,6 +22,7 @@ go_library(
"icmp_rate_limit.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
+ "ndp.go",
"nic.go",
"registration.go",
"route.go",
@@ -53,6 +54,7 @@ go_test(
name = "stack_x_test",
size = "small",
srcs = [
+ "ndp_test.go",
"stack_test.go",
"transport_demuxer_test.go",
"transport_test.go",
@@ -61,14 +63,17 @@ go_test(
":stack",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
new file mode 100644
index 000000000..ea2dbed2e
--- /dev/null
+++ b/pkg/tcpip/stack/ndp.go
@@ -0,0 +1,319 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "log"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ // defaultDupAddrDetectTransmits is the default number of NDP Neighbor
+ // Solicitation messages to send when doing Duplicate Address Detection
+ // for a tentative address.
+ //
+ // Default = 1 (from RFC 4862 section 5.1)
+ defaultDupAddrDetectTransmits = 1
+
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending NDP Neighbor solicitation messages.
+ //
+ // Default = 1s (from RFC 4861 section 10).
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending NDP Neighbor solicitation messages. Note, RFC 4861 does
+ // not impose a minimum Retransmit Timer, but we do here to make sure
+ // the messages are not sent all at once. We also come to this value
+ // because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1.
+ // Note, the unit of the RetransmitTimer field in the Router
+ // Advertisement is milliseconds.
+ //
+ // Min = 1ms.
+ minimumRetransmitTimer = time.Millisecond
+)
+
+// NDPDispatcher is the interface integrators of netstack must implement to
+// receive and handle NDP related events.
+type NDPDispatcher interface {
+ // OnDuplicateAddressDetectionStatus will be called when the DAD process
+ // for an address (addr) on a NIC (with ID nicid) completes. resolved
+ // will be set to true if DAD completed successfully (no duplicate addr
+ // detected); false otherwise (addr was detected to be a duplicate on
+ // the link the NIC is a part of, or it was stopped for some other
+ // reason, such as the address being removed). If an error occured
+ // during DAD, err will be set and resolved must be ignored.
+ //
+ // This function is permitted to block indefinitely without interfering
+ // with the stack's operation.
+ OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+}
+
+// NDPConfigurations is the NDP configurations for the netstack.
+type NDPConfigurations struct {
+ // The number of Neighbor Solicitation messages to send when doing
+ // Duplicate Address Detection for a tentative address.
+ //
+ // Note, a value of zero effectively disables DAD.
+ DupAddrDetectTransmits uint8
+
+ // The amount of time to wait between sending Neighbor solicitation
+ // messages.
+ //
+ // Must be greater than 0.5s.
+ RetransmitTimer time.Duration
+}
+
+// DefaultNDPConfigurations returns an NDPConfigurations populated with
+// default values.
+func DefaultNDPConfigurations() NDPConfigurations {
+ return NDPConfigurations{
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ }
+}
+
+// validate modifies an NDPConfigurations with valid values. If invalid values
+// are present in c, the corresponding default values will be used instead.
+//
+// If RetransmitTimer is less than minimumRetransmitTimer, then a value of
+// defaultRetransmitTimer will be used.
+func (c *NDPConfigurations) validate() {
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+}
+
+// ndpState is the per-interface NDP state.
+type ndpState struct {
+ // The NIC this ndpState is for.
+ nic *NIC
+
+ // The DAD state to send the next NS message, or resolve the address.
+ dad map[tcpip.Address]dadState
+}
+
+// dadState holds the Duplicate Address Detection timer and channel to signal
+// to the DAD goroutine that DAD should stop.
+type dadState struct {
+ // The DAD timer to send the next NS message, or resolve the address.
+ timer *time.Timer
+
+ // Used to let the DAD timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this dadState is associated with.
+ done *bool
+}
+
+// startDuplicateAddressDetection performs Duplicate Address Detection.
+//
+// This function must only be called by IPv6 addresses that are currently
+// tentative.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+ // addr must be a valid unicast IPv6 address.
+ if !header.IsV6UnicastAddress(addr) {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
+ // Should not attempt to perform DAD on an address that is currently in
+ // the DAD process.
+ if _, ok := ndp.dad[addr]; ok {
+ // Should never happen because we should only ever call this
+ // function for newly created addresses. If we attemped to
+ // "add" an address that already existed, we would returned an
+ // error since we attempted to add a duplicate address, or its
+ // reference count would have been increased without doing the
+ // work that would have been done for an address that was brand
+ // new. See NIC.addPermanentAddressLocked.
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ remaining := ndp.nic.stack.ndpConfigs.DupAddrDetectTransmits
+
+ {
+ done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref)
+ if err != nil {
+ return err
+ }
+ if done {
+ return nil
+ }
+ }
+
+ remaining--
+
+ var done bool
+ var timer *time.Timer
+ timer = time.AfterFunc(ndp.nic.stack.ndpConfigs.RetransmitTimer, func() {
+ var d bool
+ var err *tcpip.Error
+
+ // doDadIteration does a single iteration of the DAD loop.
+ //
+ // Returns true if the integrator needs to be informed of DAD
+ // completing.
+ doDadIteration := func() bool {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ if done {
+ // If we reach this point, it means that the DAD
+ // timer fired after another goroutine already
+ // obtained the NIC lock and stopped DAD before
+ // this function obtained the NIC lock. Simply
+ // return here and do nothing further.
+ return false
+ }
+
+ ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ // This should never happen.
+ // We should have an endpoint for addr since we
+ // are still performing DAD on it. If the
+ // endpoint does not exist, but we are doing DAD
+ // on it, then we started DAD at some point, but
+ // forgot to stop it when the endpoint was
+ // deleted.
+ panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref)
+ if err != nil || d {
+ delete(ndp.dad, addr)
+
+ if err != nil {
+ log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
+ }
+
+ // Let the integrator know DAD has completed.
+ return true
+ }
+
+ remaining--
+ timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
+ return false
+ }
+
+ if doDadIteration() && ndp.nic.stack.ndpDisp != nil {
+ ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err)
+ }
+
+ })
+
+ ndp.dad[addr] = dadState{
+ timer: timer,
+ done: &done,
+ }
+
+ return nil
+}
+
+// doDuplicateAddressDetection is called on every iteration of the timer, and
+// when DAD starts.
+//
+// It handles resolving the address (if there are no more NS to send), or
+// sending the next NS if there are more NS to send.
+//
+// This function must only be called by IPv6 addresses that are currently
+// tentative.
+//
+// The NIC that ndp belongs to (n) MUST be locked.
+//
+// Returns true if DAD has resolved; false if DAD is still ongoing.
+func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
+ if ref.getKind() != permanentTentative {
+ // The endpoint should still be marked as tentative
+ // since we are still performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ if remaining == 0 {
+ // DAD has resolved.
+ ref.setKind(permanent)
+ return true, nil
+ }
+
+ // Send a new NS.
+ snmc := header.SolicitedNodeAddr(addr)
+ snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}]
+ if !ok {
+ // This should never happen as if we have the
+ // address, we should have the solicited-node
+ // address.
+ panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
+ }
+
+ // Use the unspecified address as the source address when performing
+ // DAD.
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(addr)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}); err != nil {
+ sent.Dropped.Increment()
+ return false, err
+ }
+ sent.NeighborSolicit.Increment()
+
+ return false, nil
+}
+
+// stopDuplicateAddressDetection ends a running Duplicate Address Detection
+// process. Note, this may leave the DAD process for a tentative address in
+// such a state forever, unless some other external event resolves the DAD
+// process (receiving an NA from the true owner of addr, or an NS for addr
+// (implying another node is attempting to use addr)). It is up to the caller
+// of this function to handle such a scenario. Normally, addr will be removed
+// from n right after this function returns or the address successfully
+// resolved.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
+ dad, ok := ndp.dad[addr]
+ if !ok {
+ // Not currently performing DAD on addr, just return.
+ return
+ }
+
+ if dad.timer != nil {
+ dad.timer.Stop()
+ dad.timer = nil
+
+ *dad.done = true
+ dad.done = nil
+ }
+
+ delete(ndp.dad, addr)
+
+ // Let the integrator know DAD did not resolve.
+ if ndp.nic.stack.ndpDisp != nil {
+ go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ }
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
new file mode 100644
index 000000000..b089ce2ae
--- /dev/null
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -0,0 +1,443 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = "\x02\x02\x03\x04\x05\x06"
+)
+
+// TestDADDisabled tests that an address successfully resolves immediately
+// when DAD is not enabled (the default for an empty stack.Options).
+func TestDADDisabled(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Should get the address immediately since we should not have performed
+ // DAD on it.
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+ }
+
+ // We should not have sent any NDP NS messages.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ t.Fatalf("got NeighborSolicit = %d, want = 0", got)
+ }
+}
+
+// ndpDADEvent is a set of parameters that was passed to
+// ndpDispatcher.OnDuplicateAddressDetectionStatus.
+type ndpDADEvent struct {
+ nicid tcpip.NICID
+ addr tcpip.Address
+ resolved bool
+ err *tcpip.Error
+}
+
+var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
+
+// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
+// related events happen for test purposes.
+type ndpDispatcher struct {
+ dadC chan ndpDADEvent
+}
+
+// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
+//
+// If the DAD event matches what we are expecting, send signal on n.dadC.
+func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
+ n.dadC <- ndpDADEvent{
+ nicid,
+ addr,
+ resolved,
+ err,
+ }
+}
+
+// TestDADResolve tests that an address successfully resolves after performing
+// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
+// Included in the subtests is a test to make sure that an invalid
+// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+func TestDADResolve(t *testing.T) {
+ tests := []struct {
+ name string
+ dupAddrDetectTransmits uint8
+ retransTimer time.Duration
+ expectedRetransmitTimer time.Duration
+ }{
+ {"1:1s:1s", 1, time.Second, time.Second},
+ {"2:1s:1s", 2, time.Second, time.Second},
+ {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
+ // 0s is an invalid RetransmitTimer timer and will be fixed to
+ // the default RetransmitTimer value of 1s.
+ {"1:0s:1s", 1, 0, time.Second},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ }
+ opts.NDPConfigs.RetransmitTimer = test.retransTimer
+ opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit
+
+ // Should have sent an NDP NS immediately.
+ if got := stat.Value(); got != 1 {
+ t.Fatalf("got NeighborSolicit = %d, want = 1", got)
+
+ }
+
+ // Address should not be considered bound to the NIC yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Wait for the remaining time - some delta (500ms), to
+ // make sure the address is still not resolved.
+ const delta = 500 * time.Millisecond
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(2 * delta):
+ // We should get a resolution event after 500ms
+ // (delta) since we wait for 500ms less than the
+ // expected resolution time above to make sure
+ // that the address did not yet resolve. Waiting
+ // for 1s (2x delta) without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+ }
+
+ // Should not have sent any more NS messages.
+ if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
+ }
+
+ // Validate the sent Neighbor Solicitation messages.
+ for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
+ p := <-e.C
+
+ // Make sure its an IPv6 packet.
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+
+ // Check NDP packet.
+ checker.IPv6(t, p.Header.ToVectorisedView().First(),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr1)))
+ }
+ })
+ }
+
+}
+
+// TestDADFail tests to make sure that the DAD process fails if another node is
+// detected to be performing DAD on the same address (receive an NS message from
+// a node doing DAD for the same address), or if another node is detected to own
+// the address already (receive an NA message for the tentative address).
+func TestDADFail(t *testing.T) {
+ tests := []struct {
+ name string
+ makeBuf func(tgt tcpip.Address) buffer.Prependable
+ getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ "RxSolicit",
+ func(tgt tcpip.Address) buffer.Prependable {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(tgt)
+ snmc := header.SolicitedNodeAddr(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
+ })
+
+ return hdr
+
+ },
+ func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return s.NeighborSolicit
+ },
+ },
+ {
+ "RxAdvert",
+ func(tgt tcpip.Address) buffer.Prependable {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ return hdr
+
+ },
+ func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return s.NeighborAdvert
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ }
+ opts.NDPConfigs.RetransmitTimer = time.Second * 2
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Address should not be considered bound to the NIC yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Receive a packet to simulate multiple nodes owning or
+ // attempting to own the same address.
+ hdr := test.makeBuf(addr1)
+ e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ if got := stat.Value(); got != 1 {
+ t.Fatalf("got stat = %d, want = 1", got)
+ }
+
+ // Wait for DAD to fail and make sure the address did
+ // not get resolved.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the
+ // expected resolution time + extra 1s buffer,
+ // something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ })
+ }
+}
+
+// TestDADStop tests to make sure that the DAD process stops when an address is
+// removed.
+func TestDADStop(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Remove the address. This should stop DAD.
+ if err := s.RemoveAddress(1, addr1); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err)
+ }
+
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 5993fe582..2d29fa88e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -19,7 +19,6 @@ import (
"sync"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/ilist"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -37,12 +36,17 @@ type NIC struct {
mu sync.RWMutex
spoofing bool
promiscuous bool
- primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
addressRanges []tcpip.Subnet
mcastJoins map[NetworkEndpointID]int32
+ // packetEPs is protected by mu, but the contained PacketEndpoint
+ // values are not.
+ packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
stats NICStats
+
+ ndp ndpState
}
// NICStats includes transmitted and received stats.
@@ -77,15 +81,19 @@ const (
)
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
- return &NIC{
+ // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
+ // example, make sure that the link address it provides is a valid
+ // unicast ethernet address.
+ nic := &NIC{
stack: stack,
id: id,
name: name,
linkEP: ep,
loopback: loopback,
- primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
+ packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint),
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -96,7 +104,21 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
Bytes: &tcpip.StatCounter{},
},
},
+ ndp: ndpState{
+ dad: make(map[tcpip.Address]dadState),
+ },
+ }
+ nic.ndp.nic = nic
+
+ // Register supported packet endpoint protocols.
+ for _, netProto := range header.Ethertypes {
+ nic.packetEPs[netProto] = []PacketEndpoint{}
}
+ for _, netProto := range stack.networkProtocols {
+ nic.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ }
+
+ return nic
}
// enable enables the NIC. enable will attach the link to its LinkEndpoint and
@@ -121,11 +143,50 @@ func (n *NIC) enable() *tcpip.Error {
// when we perform Duplicate Address Detection, or Router Advertisement
// when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
// section 4.2 for more information.
- if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
- return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress)
+ //
+ // Also auto-generate an IPv6 link-local address based on the NIC's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+ _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]
+ if !ok {
+ return nil
}
- return nil
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ if !n.stack.autoGenIPv6LinkLocal {
+ return nil
+ }
+
+ l2addr := n.linkEP.LinkAddress()
+
+ // Only attempt to generate the link-local address if we have a
+ // valid MAC address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address
+ // (provided by LinkEndpoint.LinkAddress) before reaching this
+ // point.
+ if !header.IsValidUnicastEthernetAddress(l2addr) {
+ return nil
+ }
+
+ addr := header.LinkLocalAddr(l2addr)
+
+ _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ },
+ }, CanBePrimaryEndpoint)
+
+ return err
}
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
@@ -161,18 +222,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
n.mu.RLock()
defer n.mu.RUnlock()
- list := n.primary[protocol]
- if list == nil {
- return nil
- }
-
- for e := list.Front(); e != nil; e = e.Next() {
- r := e.(*referencedNetworkEndpoint)
- // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set.
- switch r.ep.ID().LocalAddress {
- case header.IPv4Broadcast, header.IPv4Any:
- continue
- }
+ for _, r := range n.primary[protocol] {
if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
@@ -282,13 +332,35 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
if ref, ok := n.endpoints[id]; ok {
switch ref.getKind() {
- case permanent:
+ case permanentTentative, permanent:
// The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
case permanentExpired, temporary:
- // Promote the endpoint to become permanent.
+ // Promote the endpoint to become permanent and respect
+ // the new peb.
if ref.tryIncRef() {
ref.setKind(permanent)
+
+ refs := n.primary[ref.protocol]
+ for i, r := range refs {
+ if r == ref {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return ref, nil
+ case FirstPrimaryEndpoint:
+ if i == 0 {
+ return ref, nil
+ }
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ case NeverPrimaryEndpoint:
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ return ref, nil
+ }
+ }
+ }
+
+ n.insertPrimaryEndpointLocked(ref, peb)
+
return ref, nil
}
// tryIncRef failing means the endpoint is scheduled to be removed once
@@ -298,6 +370,7 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
n.removeEndpointLocked(ref)
}
}
+
return n.addAddressLocked(protocolAddress, peb, permanent)
}
@@ -321,6 +394,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
if err != nil {
return nil, err
}
+
+ isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
+
+ // If the address is an IPv6 address and it is a permanent address,
+ // mark it as tentative so it goes through the DAD process.
+ if isIPv6Unicast && kind == permanent {
+ kind = permanentTentative
+ }
+
ref := &referencedNetworkEndpoint{
refs: 1,
ep: ep,
@@ -338,7 +420,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
// If we are adding an IPv6 unicast address, join the solicited-node
// multicast address.
- if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) {
+ if isIPv6Unicast {
snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
return nil, err
@@ -347,17 +429,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
n.endpoints[id] = ref
- l, ok := n.primary[protocolAddress.Protocol]
- if !ok {
- l = &ilist.List{}
- n.primary[protocolAddress.Protocol] = l
- }
+ n.insertPrimaryEndpointLocked(ref, peb)
- switch peb {
- case CanBePrimaryEndpoint:
- l.PushBack(ref)
- case FirstPrimaryEndpoint:
- l.PushFront(ref)
+ // If we are adding a tentative IPv6 address, start DAD.
+ if isIPv6Unicast && kind == permanentTentative {
+ if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
+ return nil, err
+ }
}
return ref, nil
@@ -382,10 +460,12 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
- // Don't include expired or temporary endpoints to avoid confusion and
- // prevent the caller from using those.
+ // Don't include tentative, expired or temporary endpoints to
+ // avoid confusion and prevent the caller from using those.
switch ref.getKind() {
- case permanentExpired, temporary:
+ case permanentTentative, permanentExpired, temporary:
+ // TODO(b/140898488): Should tentative addresses be
+ // returned?
continue
}
addrs = append(addrs, tcpip.ProtocolAddress{
@@ -406,12 +486,12 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for proto, list := range n.primary {
- for e := list.Front(); e != nil; e = e.Next() {
- ref := e.(*referencedNetworkEndpoint)
- // Don't include expired or tempory endpoints to avoid confusion and
- // prevent the caller from using those.
+ for _, ref := range list {
+ // Don't include tentative, expired or tempory endpoints
+ // to avoid confusion and prevent the caller from using
+ // those.
switch ref.getKind() {
- case permanentExpired, temporary:
+ case permanentTentative, permanentExpired, temporary:
continue
}
@@ -471,6 +551,19 @@ func (n *NIC) AddressRanges() []tcpip.Subnet {
return append(sns, n.addressRanges...)
}
+// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
+// by peb.
+//
+// n MUST be locked.
+func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ n.primary[r.protocol] = append(n.primary[r.protocol], r)
+ case FirstPrimaryEndpoint:
+ n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...)
+ }
+}
+
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
@@ -488,9 +581,12 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
}
delete(n.endpoints, id)
- wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front()
- if wasInList {
- n.primary[r.protocol].Remove(r)
+ refs := n.primary[r.protocol]
+ for i, ref := range refs {
+ if ref == r {
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ break
+ }
}
r.ep.Close()
@@ -504,10 +600,22 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
r, ok := n.endpoints[NetworkEndpointID{addr}]
- if !ok || r.getKind() != permanent {
+ if !ok {
return tcpip.ErrBadLocalAddress
}
+ kind := r.getKind()
+ if kind != permanent && kind != permanentTentative {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
+
+ // If we are removing a tentative IPv6 unicast address, stop DAD.
+ if isIPv6Unicast && kind == permanentTentative {
+ n.ndp.stopDuplicateAddressDetection(addr)
+ }
+
r.setKind(permanentExpired)
if !r.decRefLocked() {
// The endpoint still has references to it.
@@ -518,7 +626,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing an IPv6 unicast address, leave the solicited-node
// multicast address.
- if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) {
+ if isIPv6Unicast {
snmc := header.SolicitedNodeAddr(addr)
if err := n.leaveGroupLocked(snmc); err != nil {
return err
@@ -548,6 +656,11 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
// exists yet. Otherwise it just increments its count. n MUST be locked before
// joinGroupLocked is called.
func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ // TODO(b/143102137): When implementing MLD, make sure MLD packets are
+ // not sent unless a valid link-local address is available for use on n
+ // as an MLD packet's source address must be a link-local address as
+ // outlined in RFC 3810 section 5.
+
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
if joins == 0 {
@@ -611,7 +724,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
n.stats.Rx.Packets.Increment()
n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size()))
@@ -621,6 +734,26 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
return
}
+ // If no local link layer address is provided, assume it was sent
+ // directly to this NIC.
+ if local == "" {
+ local = n.linkEP.LinkAddress()
+ }
+
+ // Are any packet sockets listening for this network protocol?
+ n.mu.RLock()
+ packetEPs := n.packetEPs[protocol]
+ // Check whether there are packet sockets listening for every protocol.
+ // If we received a packet with protocol EthernetProtocolAll, then the
+ // previous for loop will have handled it.
+ if protocol != header.EthernetProtocolAll {
+ packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...)
+ }
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ ep.HandlePacket(n.id, local, protocol, vv, linkHeader)
+ }
+
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
@@ -632,8 +765,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
src, dst := netProto.ParseAddresses(vv.First())
- n.stack.AddLinkAddress(n.id, src, remote)
-
if ref := n.getRef(protocol, dst); ref != nil {
handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
return
@@ -682,7 +813,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
return
}
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ // If a packet socket handled the packet, don't treat it as invalid.
+ if len(packetEPs) == 0 {
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ }
}
// DeliverTransportPacket delivers the packets to the appropriate transport
@@ -769,14 +903,58 @@ func (n *NIC) Stack() *Stack {
return n.stack
}
+// isAddrTentative returns true if addr is tentative on n.
+//
+// Note that if addr is not associated with n, then this function will return
+// false. It will only return true if the address is associated with the NIC
+// AND it is tentative.
+func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
+ ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ return false
+ }
+
+ return ref.getKind() == permanentTentative
+}
+
+// dupTentativeAddrDetected attempts to inform n that a tentative addr
+// is a duplicate on a link.
+//
+// dupTentativeAddrDetected will delete the tentative address if it exists.
+func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ return tcpip.ErrBadAddress
+ }
+
+ if ref.getKind() != permanentTentative {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ return n.removePermanentAddressLocked(addr)
+}
+
type networkEndpointKind int32
const (
+ // A permanentTentative endpoint is a permanent address that is not yet
+ // considered to be fully bound to an interface in the traditional
+ // sense. That is, the address is associated with a NIC, but packets
+ // destined to the address MUST NOT be accepted and MUST be silently
+ // dropped, and the address MUST NOT be used as a source address for
+ // outgoing packets. For IPv6, addresses will be of this kind until
+ // NDP's Duplicate Address Detection has resolved, or be deleted if
+ // the process results in detecting a duplicate address.
+ permanentTentative networkEndpointKind = iota
+
// A permanent endpoint is created by adding a permanent address (vs. a
// temporary one) to the NIC. Its reference count is biased by 1 to avoid
// removal when no route holds a reference to it. It is removed by explicitly
// removing the permanent address from the NIC.
- permanent networkEndpointKind = iota
+ permanent
// An expired permanent endoint is a permanent endoint that had its address
// removed from the NIC, and it is waiting to be removed once no more routes
@@ -794,8 +972,37 @@ const (
temporary
)
+func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.packetEPs[netProto]
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+ n.packetEPs[netProto] = append(eps, ep)
+
+ return nil
+}
+
+func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.packetEPs[netProto]
+ if !ok {
+ return
+ }
+
+ for i, epOther := range eps {
+ if epOther == ep {
+ n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
+ return
+ }
+ }
+}
+
type referencedNetworkEndpoint struct {
- ilist.Entry
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 9d6157f22..0869fb084 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -71,8 +71,8 @@ type TransportEndpoint interface {
// RawTransportEndpoint is the interface that needs to be implemented by raw
// transport protocol endpoints. RawTransportEndpoints receive the entire
-// packet - including the link, network, and transport headers - as delivered
-// to netstack.
+// packet - including the network and transport headers - as delivered to
+// netstack.
type RawTransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint. The packet contains all data from the link
@@ -80,6 +80,22 @@ type RawTransportEndpoint interface {
HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
}
+// PacketEndpoint is the interface that needs to be implemented by packet
+// transport protocol endpoints. These endpoints receive link layer headers in
+// addition to whatever they contain (usually network and transport layer
+// headers and a payload).
+type PacketEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive that
+ // match the endpoint.
+ //
+ // Implementers should treat packet as immutable and should copy it
+ // before before modification.
+ //
+ // linkHeader may have a length of 0, in which case the PacketEndpoint
+ // should construct its own ethernet header for applications.
+ HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View)
+}
+
// TransportProtocol is the interface that needs to be implemented by transport
// protocols (e.g., tcp, udp) that want to be part of the networking stack.
type TransportProtocol interface {
@@ -185,6 +201,10 @@ type NetworkEndpoint interface {
// protocol.
WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error
+ // WritePackets writes packets to the given destination address and
+ // protocol.
+ WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error)
+
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error
@@ -242,9 +262,10 @@ type NetworkProtocol interface {
// packets to the appropriate network endpoint after it has been handled by
// the data link layer.
type NetworkDispatcher interface {
- // DeliverNetworkPacket finds the appropriate network protocol
- // endpoint and hands the packet over for further processing.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+ // DeliverNetworkPacket finds the appropriate network protocol endpoint
+ // and hands the packet over for further processing. linkHeader may have
+ // length 0 when the caller does not have ethernet data.
+ DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -266,7 +287,11 @@ const (
CapabilitySaveRestore
CapabilityDisconnectOk
CapabilityLoopback
- CapabilityGSO
+ CapabilityHardwareGSO
+
+ // CapabilitySoftwareGSO indicates the link endpoint supports of sending
+ // multiple packets using a single call (LinkEndpoint.WritePackets).
+ CapabilitySoftwareGSO
)
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
@@ -301,6 +326,18 @@ type LinkEndpoint interface {
// r.LocalLinkAddress if it is provided.
WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+ // WritePackets writes packets with the given protocol through the
+ // given route.
+ //
+ // Right now, WritePackets is used only when the software segmentation
+ // offload is enabled. If it will be used for something else, it may
+ // require to change syscall filters.
+ WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link. The packet
+ // should already have an ethernet header.
+ WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error
+
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
Attach(dispatcher NetworkDispatcher)
@@ -324,13 +361,14 @@ type LinkEndpoint interface {
type InjectableLinkEndpoint interface {
LinkEndpoint
- // Inject injects an inbound packet.
- Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+ // InjectInbound injects an inbound packet.
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
- // WriteRawPacket writes a fully formed outbound packet directly to the link.
+ // InjectOutbound writes a fully formed outbound packet directly to the
+ // link.
//
// dest is used by endpoints with multiple raw destinations.
- WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error
+ InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error
}
// A LinkAddressResolver is an extension to a NetworkProtocol that
@@ -379,11 +417,16 @@ type LinkAddressCache interface {
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// UnassociatedEndpointFactory produces endpoints for writing packets not
-// associated with a particular transport protocol. Such endpoints can be used
-// to write arbitrary packets that include the IP header.
-type UnassociatedEndpointFactory interface {
- NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+// RawFactory produces endpoints for writing various types of raw packets.
+type RawFactory interface {
+ // NewUnassociatedEndpoint produces endpoints for writing packets not
+ // associated with a particular transport protocol. Such endpoints can
+ // be used to write arbitrary packets that include the network header.
+ NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // NewPacketEndpoint produces endpoints for reading and writing packets
+ // that include network and (when cooked is false) link layer headers.
+ NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
// GSOType is the type of GSO segments.
@@ -394,8 +437,14 @@ type GSOType int
// Types of gso segments.
const (
GSONone GSOType = iota
+
+ // Hardware GSO types:
GSOTCPv4
GSOTCPv6
+
+ // GSOSW is used for software GSO segments which have to be sent by
+ // endpoint.WritePackets.
+ GSOSW
)
// GSO contains generic segmentation offload properties.
@@ -423,3 +472,7 @@ type GSOEndpoint interface {
// GSOMaxSize returns the maximum GSO packet size.
GSOMaxSize() uint32
}
+
+// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
+// This isn't a hard limit, because it is never set into packet headers.
+const SoftwareGSOMaxSize = (1 << 16)
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index e72373964..1a0a51b57 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -47,8 +47,8 @@ type Route struct {
// starts.
ref *referencedNetworkEndpoint
- // loop controls where WritePacket should send packets.
- loop PacketLooping
+ // Loop controls where WritePacket should send packets.
+ Loop PacketLooping
}
// makeRoute initializes a new route. It takes ownership of the provided
@@ -69,7 +69,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
LocalLinkAddress: localLinkAddr,
RemoteAddress: remoteAddr,
ref: ref,
- loop: loop,
+ Loop: loop,
}
}
@@ -159,7 +159,7 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
return tcpip.ErrInvalidEndpointState
}
- err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.loop)
+ err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.Loop)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
@@ -169,6 +169,44 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
return err
}
+// PacketDescriptor is a packet descriptor which contains a packet header and
+// offset and size of packet data in a payload view.
+type PacketDescriptor struct {
+ Hdr buffer.Prependable
+ Off int
+ Size int
+}
+
+// NewPacketDescriptors allocates a set of packet descriptors.
+func NewPacketDescriptors(n int, hdrSize int) []PacketDescriptor {
+ buf := make([]byte, n*hdrSize)
+ hdrs := make([]PacketDescriptor, n)
+ for i := range hdrs {
+ hdrs[i].Hdr = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize])
+ }
+ return hdrs
+}
+
+// WritePackets writes the set of packets through the given route.
+func (r *Route) WritePackets(gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams) (int, *tcpip.Error) {
+ if !r.ref.isValidForOutgoing() {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ n, err := r.ref.ep.WritePackets(r, gso, hdrs, payload, params, r.Loop)
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(hdrs) - n))
+ }
+ r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
+ payloadSize := 0
+ for i := 0; i < n; i++ {
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdrs[i].Hdr.UsedLength()))
+ payloadSize += hdrs[i].Size
+ }
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize))
+ return n, err
+}
+
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
@@ -176,7 +214,7 @@ func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.
return tcpip.ErrInvalidEndpointState
}
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.Loop); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index f67975525..5ea432a24 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -351,10 +351,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
- // unassociatedFactory creates unassociated endpoints. If nil, raw
- // endpoints are disabled. It is set during Stack creation and is
- // immutable.
- unassociatedFactory UnassociatedEndpointFactory
+ // rawFactory creates raw endpoints. If nil, raw endpoints are
+ // disabled. It is set during Stack creation and is immutable.
+ rawFactory RawFactory
demux *transportDemuxer
@@ -399,6 +398,18 @@ type Stack struct {
//
// TODO(gvisor.dev/issue/940): S/R this field.
portSeed uint32
+
+ // ndpConfigs is the NDP configurations used by interfaces.
+ ndpConfigs NDPConfigurations
+
+ // autoGenIPv6LinkLocal determines whether or not the stack will attempt
+ // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
}
// Options contains optional Stack configuration.
@@ -422,9 +433,32 @@ type Options struct {
// stack (false).
HandleLocal bool
- // UnassociatedFactory produces unassociated endpoints raw endpoints.
- // Raw endpoints are enabled only if this is non-nil.
- UnassociatedFactory UnassociatedEndpointFactory
+ // NDPConfigs is the NDP configurations used by interfaces.
+ //
+ // By default, NDPConfigs will have a zero value for its
+ // DupAddrDetectTransmits field, implying that DAD will not be performed
+ // before assigning an address to a NIC.
+ NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determins whether or not the stack will attempt
+ // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // Note, setting this to true does not mean that a link-local address
+ // will be assigned right away, or at all. If Duplicate Address
+ // Detection is enabled, an address will only be assigned if it
+ // successfully resolves. If it fails, no further attempt will be made
+ // to auto-generate an IPv6 link-local address.
+ //
+ // The generated link-local address will follow RFC 4291 Appendix A
+ // guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // RawFactory produces raw endpoints. Raw endpoints are enabled only if
+ // this is non-nil.
+ RawFactory RawFactory
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -458,6 +492,9 @@ func (*TransportEndpointInfo) IsEndpointInfo() {}
// New allocates a new networking stack with only the requested networking and
// transport protocols configured with default options.
//
+// Note, NDPConfigurations will be fixed before being used by the Stack. That
+// is, if an invalid value was provided, it will be reset to the default value.
+//
// Protocol options can be changed by calling the
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
@@ -468,18 +505,24 @@ func New(opts Options) *Stack {
clock = &tcpip.StdClock{}
}
+ // Make sure opts.NDPConfigs contains valid values only.
+ opts.NDPConfigs.validate()
+
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
- nics: make(map[tcpip.NICID]*NIC),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- icmpRateLimiter: NewICMPRateLimiter(),
- portSeed: generateRandUint32(),
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ portSeed: generateRandUint32(),
+ ndpConfigs: opts.NDPConfigs,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ ndpDisp: opts.NDPDisp,
}
// Add specified network protocols.
@@ -497,8 +540,8 @@ func New(opts Options) *Stack {
}
}
- // Add the factory for unassociated endpoints, if present.
- s.unassociatedFactory = opts.UnassociatedFactory
+ // Add the factory for raw endpoints, if present.
+ s.rawFactory = opts.RawFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -633,12 +676,12 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if s.unassociatedFactory == nil {
+ if s.rawFactory == nil {
return nil, tcpip.ErrNotPermitted
}
if !associated {
- return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue)
+ return s.rawFactory.NewUnassociatedEndpoint(s, network, transport, waiterQueue)
}
t, ok := s.transportProtocols[transport]
@@ -649,6 +692,16 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
return t.proto.NewRawEndpoint(s, network, waiterQueue)
}
+// NewPacketEndpoint creates a new packet endpoint listening for the given
+// netProto.
+func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if s.rawFactory == nil {
+ return nil, tcpip.ErrNotPermitted
+ }
+
+ return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
+}
+
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
@@ -1118,6 +1171,109 @@ func (s *Stack) Resume() {
}
}
+// RegisterPacketEndpoint registers ep with the stack, causing it to receive
+// all traffic of the specified netProto on the given NIC. If nicID is 0, it
+// receives traffic from every NIC.
+func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If no NIC is specified, capture on all devices.
+ if nicID == 0 {
+ // Register with each NIC.
+ for _, nic := range s.nics {
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ s.unregisterPacketEndpointLocked(0, netProto, ep)
+ return err
+ }
+ }
+ return nil
+ }
+
+ // Capture on a specific device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnregisterPacketEndpoint unregisters ep for packets of the specified
+// netProto from the specified NIC. If nicID is 0, ep is unregistered from all
+// NICs.
+func (s *Stack) UnregisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.unregisterPacketEndpointLocked(nicID, netProto, ep)
+}
+
+func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ // If no NIC is specified, unregister on all devices.
+ if nicID == 0 {
+ // Unregister with each NIC.
+ for _, nic := range s.nics {
+ nic.unregisterPacketEndpoint(netProto, ep)
+ }
+ return
+ }
+
+ // Unregister in a single device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return
+ }
+ nic.unregisterPacketEndpoint(netProto, ep)
+}
+
+// WritePacket writes data directly to the specified NIC. It adds an ethernet
+// header based on the arguments.
+func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicid]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ // Add our own fake ethernet header.
+ ethFields := header.EthernetFields{
+ SrcAddr: nic.linkEP.LinkAddress(),
+ DstAddr: dst,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ ethHeader := buffer.View(fakeHeader).ToVectorisedView()
+ ethHeader.Append(payload)
+
+ if err := nic.linkEP.WriteRawPacket(ethHeader); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WriteRawPacket writes data directly to the specified NIC without adding any
+// headers.
+func (s *Stack) WriteRawPacket(nicid tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicid]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ return err
+ }
+
+ return nil
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1238,6 +1394,37 @@ func (s *Stack) AllowICMPMessage() bool {
return s.icmpRateLimiter.Allow()
}
+// IsAddrTentative returns true if addr is tentative on the NIC with ID id.
+//
+// Note that if addr is not associated with a NIC with id ID, then this
+// function will return false. It will only return true if the address is
+// associated with the NIC AND it is tentative.
+func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return false, tcpip.ErrUnknownNICID
+ }
+
+ return nic.isAddrTentative(addr), nil
+}
+
+// DupTentativeAddrDetected attempts to inform the NIC with ID id that a
+// tentative addr on it is a duplicate on a link.
+func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.dupTentativeAddrDetected(addr)
+}
+
// PortSeed returns a 32 bit value that can be used as a seed value for port
// picking.
//
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 10fd1065f..9dae853d0 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -24,11 +24,14 @@ import (
"sort"
"strings"
"testing"
+ "time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -144,6 +147,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber)
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -1864,3 +1872,297 @@ func TestNICForwarding(t *testing.T) {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
+
+// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses
+// (or lack there-of if disabled (default)). Note, DAD will be disabled in
+// these tests.
+func TestNICAutoGenAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ shouldGen bool
+ }{
+ {
+ "Disabled",
+ false,
+ linkAddr1,
+ false,
+ },
+ {
+ "Enabled",
+ true,
+ linkAddr1,
+ true,
+ },
+ {
+ "Nil MAC",
+ true,
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty MAC",
+ true,
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "Invalid MAC",
+ true,
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Multicast MAC",
+ true,
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ false,
+ },
+ {
+ "Unspecified MAC",
+ true,
+ tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ }
+
+ if test.autoGen {
+ // Only set opts.AutoGenIPv6LinkLocal when
+ // test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by
+ // default.
+ opts.AutoGenIPv6LinkLocal = true
+ }
+
+ e := channel.New(10, 1280, test.linkAddr)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+
+ if test.shouldGen {
+ // Should have auto-generated an address and
+ // resolved immediately (DAD is disabled).
+ if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+ } else {
+ // Should not have auto-generated an address.
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ }
+ })
+ }
+}
+
+// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
+// link-local addresses will only be assigned after the DAD process resolves.
+func TestNICAutoGenAddrDoesDAD(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Address should not be considered bound to the
+ // NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ linkLocalAddr := header.LinkLocalAddr(linkAddr1)
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // We should get a resolution event after 1s (default time to
+ // resolve as per default NDP configurations). Waiting for that
+ // resolution time + an extra 1s without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != linkLocalAddr {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+}
+
+// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected
+// when an address's kind gets "promoted" to permanent from permanentExpired.
+func TestNewPEBOnPromotionToPermanent(t *testing.T) {
+ pebs := []stack.PrimaryEndpointBehavior{
+ stack.NeverPrimaryEndpoint,
+ stack.CanBePrimaryEndpoint,
+ stack.FirstPrimaryEndpoint,
+ }
+
+ for _, pi := range pebs {
+ for _, ps := range pebs {
+ t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ // Add a permanent address with initial
+ // PrimaryEndpointBehavior (peb), pi. If pi is
+ // NeverPrimaryEndpoint, the address should not
+ // be returned by a call to GetMainNICAddress;
+ // else, it should.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ addr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("s.GetMainNICAddress failed:", err)
+ }
+ if pi == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatalf("NewSubnet failed:", err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // Take a route through the address so its ref
+ // count gets incremented and does not actually
+ // get deleted when RemoveAddress is called
+ // below. This is because we want to test that a
+ // new peb is respected when an address gets
+ // "promoted" to permanent from a
+ // permanentExpired kind.
+ r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(1, "\x01"); err != nil {
+ t.Fatalf("RemoveAddress failed:", err)
+ }
+
+ //
+ // At this point, the address should still be
+ // known by the NIC, but have its
+ // kind = permanentExpired.
+ //
+
+ // Add some other address with peb set to
+ // FirstPrimaryEndpoint.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+
+ }
+
+ // Add back the address we removed earlier and
+ // make sure the new peb was respected.
+ // (The address should just be promoted now).
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ var primaryAddrs []tcpip.Address
+ for _, pa := range s.NICInfo()[1].ProtocolAddresses {
+ primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address)
+ }
+ var expectedList []tcpip.Address
+ switch ps {
+ case stack.FirstPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x01",
+ "\x03",
+ }
+ case stack.CanBePrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ "\x01",
+ }
+ case stack.NeverPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ }
+ }
+ if !cmp.Equal(primaryAddrs, expectedList) {
+ t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList)
+ }
+
+ // Once we remove the other address, if the new
+ // peb, ps, was NeverPrimaryEndpoint, no address
+ // should be returned by a call to
+ // GetMainNICAddress; else, our original address
+ // should be returned.
+ if err := s.RemoveAddress(1, "\x03"); err != nil {
+ t.Fatalf("RemoveAddress failed:", err)
+ }
+ addr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("s.GetMainNICAddress failed:", err)
+ }
+ if ps == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else {
+ if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 92267ce4d..97a1aec4b 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -465,7 +465,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
- return nil
+ return tcpip.ErrNotSupported
}
eps.mu.Lock()
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 26f338d8d..353ecd49b 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -255,7 +255,7 @@ type FullAddress struct {
// This may not be used by all endpoint types.
NIC NICID
- // Addr is the network address.
+ // Addr is the network or link layer address.
Addr Address
// Port is the transport port.
@@ -1100,15 +1100,12 @@ func (*TransportEndpointStats) IsEndpointStats() {}
func fillIn(v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
- switch v.Kind() {
- case reflect.Ptr:
- if s := v.Addr().Interface().(**StatCounter); *s == nil {
- *s = &StatCounter{}
+ if s, ok := v.Addr().Interface().(**StatCounter); ok {
+ if *s == nil {
+ *s = new(StatCounter)
}
- case reflect.Struct:
+ } else {
fillIn(v)
- default:
- panic(fmt.Sprintf("unexpected type %s", v.Type()))
}
}
}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
new file mode 100644
index 000000000..8ea2e6ee5
--- /dev/null
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -0,0 +1,46 @@
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "packet_list",
+ out = "packet_list.go",
+ package = "packet",
+ prefix = "packet",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*packet",
+ "Linker": "*packet",
+ },
+)
+
+go_library(
+ name = "packet",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "packet_list.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet",
+ imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/sleep",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "packet_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
new file mode 100644
index 000000000..73cdaa265
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -0,0 +1,363 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packet provides the implementation of packet sockets (see
+// packet(7)). Packet sockets allow applications to:
+//
+// * manually write and inspect link, network, and transport headers
+// * receive all traffic of a given network protocol, or all protocols
+//
+// Packet sockets are similar to raw sockets, but provide even more power to
+// users, letting them effectively talk directly to the network device.
+//
+// Packet sockets skip the input and output iptables chains.
+package packet
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type packet struct {
+ packetEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is pre-allocated space to back data. As long as the packet is
+ // made up of fewer than 8 buffer.Views, no extra allocation is
+ // necessary to store packet data.
+ views [8]buffer.View `state:"nosave"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
+// to have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ stack.TransportEndpointInfo
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+ cooked bool
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+}
+
+// NewEndpoint returns a new packet endpoint.
+func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := &endpoint{
+ stack: s,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ NetProto: netProto,
+ },
+ cooked: cooked,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+
+ if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
+ return nil, err
+ }
+ return ep, nil
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (ep *endpoint) Close() {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return
+ }
+
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+
+ // Clear the receive list.
+ ep.rcvClosed = true
+ ep.rcvBufSize = 0
+ for !ep.rcvList.Empty() {
+ ep.rcvList.Remove(ep.rcvList.Front())
+ }
+
+ ep.closed = true
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (ep *endpoint) ModerateRecvBuf(copied int) {}
+
+// IPTables implements tcpip.Endpoint.IPTables.
+func (ep *endpoint) IPTables() (iptables.IPTables, error) {
+ return ep.stack.IPTables(), nil
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ ep.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if ep.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if ep.rcvClosed {
+ ep.stats.ReadErrors.ReadClosed.Increment()
+ err = tcpip.ErrClosedForReceive
+ }
+ ep.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ packet := ep.rcvList.Front()
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = packet.senderAddr
+ }
+
+ return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+}
+
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // TODO(b/129292371): Implement.
+ return 0, nil, tcpip.ErrInvalidOptionValue
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
+// disconnected, and this function always returns tpcip.ErrNotSupported.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
+// connected, and this function always returnes tcpip.ErrNotSupported.
+func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
+// with Shutdown, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
+// Listen, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
+// Accept, and this function always returns tcpip.ErrNotSupported.
+func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ // TODO(gvisor.dev/issue/173): Add Bind support.
+
+ // "By default, all packets of the specified protocol type are passed
+ // to a packet socket. To get packets only from a specific interface
+ // use bind(2) specifying an address in a struct sockaddr_ll to bind
+ // the packet socket to an interface. Fields used for binding are
+ // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
+ // - packet(7).
+
+ return tcpip.ErrNotSupported
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() || ep.rcvClosed {
+ result |= waiter.EventIn
+ }
+ ep.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
+// used with SetSockOpt, and this function always returns
+// tcpip.ErrNotSupported.
+func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// HandlePacket implements stack.PacketEndpoint.HandlePacket.
+func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, ethHeader buffer.View) {
+ ep.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if ep.rcvClosed {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return
+ }
+
+ if ep.rcvBufSize >= ep.rcvBufSizeMax {
+ ep.rcvMu.Unlock()
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return
+ }
+
+ wasEmpty := ep.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ var packet packet
+ // TODO(b/129292371): Return network protocol.
+ if len(ethHeader) > 0 {
+ // Get info directly from the ethernet header.
+ hdr := header.Ethernet(ethHeader)
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicid,
+ Addr: tcpip.Address(hdr.SourceAddress()),
+ }
+ } else {
+ // Guess the would-be ethernet header.
+ packet.senderAddr = tcpip.FullAddress{
+ NIC: nicid,
+ Addr: tcpip.Address(localAddr),
+ }
+ }
+
+ if ep.cooked {
+ // Cooked packets can simply be queued.
+ packet.data = vv.Clone(packet.views[:])
+ } else {
+ // Raw packets need their ethernet headers prepended before
+ // queueing.
+ if len(ethHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ ethHeader = buffer.View(fakeHeader)
+ }
+ combinedVV := buffer.View(ethHeader).ToVectorisedView()
+ combinedVV.Append(vv)
+ packet.data = combinedVV.Clone(packet.views[:])
+ }
+ packet.timestampNS = ep.stack.NowNanoseconds()
+
+ ep.rcvList.PushBack(&packet)
+ ep.rcvBufSize += packet.data.Size()
+
+ ep.rcvMu.Unlock()
+ ep.stats.PacketsReceived.Increment()
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// State implements socket.Socket.State.
+func (ep *endpoint) State() uint32 {
+ return 0
+}
+
+// Info returns a copy of the endpoint info.
+func (ep *endpoint) Info() tcpip.EndpointInfo {
+ ep.mu.RLock()
+ // Make a copy of the endpoint info.
+ ret := ep.TransportEndpointInfo
+ ep.mu.RUnlock()
+ return &ret
+}
+
+// Stats returns a pointer to the endpoint stats.
+func (ep *endpoint) Stats() tcpip.EndpointStats {
+ return &ep.stats
+}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
new file mode 100644
index 000000000..9b88f17e4
--- /dev/null
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -0,0 +1,72 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package packet
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves packet.data field.
+func (p *packet) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads packet.data field.
+func (p *packet) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ // StackFromEnv is a stack used specifically for save/restore.
+ ep.stack = stack.StackFromEnv
+
+ // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
+ if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index fba598d51..4af49218c 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -4,14 +4,14 @@ load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
go_template_instance(
- name = "packet_list",
- out = "packet_list.go",
+ name = "raw_packet_list",
+ out = "raw_packet_list.go",
package = "raw",
- prefix = "packet",
+ prefix = "rawPacket",
template = "//pkg/ilist:generic_list",
types = {
- "Element": "*packet",
- "Linker": "*packet",
+ "Element": "*rawPacket",
+ "Linker": "*rawPacket",
},
)
@@ -20,8 +20,8 @@ go_library(
srcs = [
"endpoint.go",
"endpoint_state.go",
- "packet_list.go",
"protocol.go",
+ "raw_packet_list.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
@@ -34,6 +34,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
)
@@ -41,7 +42,7 @@ go_library(
filegroup(
name = "autogen",
srcs = [
- "packet_list.go",
+ "raw_packet_list.go",
],
visibility = ["//:sandbox"],
)
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b4c660859..308f10d24 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -17,8 +17,7 @@
//
// * manually write and inspect transport layer headers and payloads
// * receive all traffic of a given transport protocol (e.g. ICMP or UDP)
-// * optionally write and inspect network layer and link layer headers for
-// packets
+// * optionally write and inspect network layer headers of packets
//
// Raw sockets don't have any notion of ports, and incoming packets are
// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
@@ -38,8 +37,8 @@ import (
)
// +stateify savable
-type packet struct {
- packetEntry
+type rawPacket struct {
+ rawPacketEntry
// data holds the actual packet data, including any headers and
// payload.
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
@@ -72,7 +71,7 @@ type endpoint struct {
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
- rcvList packetList
+ rcvList rawPacketList
rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
rcvClosed bool
@@ -90,7 +89,6 @@ type endpoint struct {
}
// NewEndpoint returns a raw endpoint for the given protocols.
-// TODO(b/129292371): IP_HDRINCL and AF_PACKET.
func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
}
@@ -187,17 +185,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
return buffer.View{}, tcpip.ControlMessages{}, err
}
- packet := e.rcvList.Front()
- e.rcvList.Remove(packet)
- e.rcvBufSize -= packet.data.Size()
+ pkt := e.rcvList.Front()
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
e.rcvMu.Unlock()
if addr != nil {
- *addr = packet.senderAddr
+ *addr = pkt.senderAddr
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+ return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -602,7 +600,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- packet := &packet{
+ pkt := &rawPacket{
senderAddr: tcpip.FullAddress{
NIC: route.NICID(),
Addr: route.RemoteAddress,
@@ -611,11 +609,11 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu
combinedVV := netHeader.ToVectorisedView()
combinedVV.Append(vv)
- packet.data = combinedVV.Clone(packet.views[:])
- packet.timestampNS = e.stack.NowNanoseconds()
+ pkt.data = combinedVV.Clone(pkt.views[:])
+ pkt.timestampNS = e.stack.NowNanoseconds()
- e.rcvList.PushBack(packet)
- e.rcvBufSize += packet.data.Size()
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += pkt.data.Size()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index a6c7cc43a..33bfb56cd 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -20,15 +20,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// saveData saves packet.data field.
-func (p *packet) saveData() buffer.VectorisedView {
+// saveData saves rawPacket.data field.
+func (p *rawPacket) saveData() buffer.VectorisedView {
// We cannot save p.data directly as p.data.views may alias to p.views,
// which is not allowed by state framework (in-struct pointer).
return p.data.Clone(nil)
}
-// loadData loads packet.data field.
-func (p *packet) loadData(data buffer.VectorisedView) {
+// loadData loads rawPacket.data field.
+func (p *rawPacket) loadData(data buffer.VectorisedView) {
// NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
// here because data.views is not guaranteed to be loaded by now. Plus,
// data.views will be allocated anyway so there really is little point
@@ -86,7 +86,9 @@ func (ep *endpoint) Resume(s *stack.Stack) {
}
}
- if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
- panic(err)
+ if ep.associated {
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
+ panic(err)
+ }
}
}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index a2512d666..f30aa2a4a 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -17,13 +17,19 @@ package raw
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/packet"
"gvisor.dev/gvisor/pkg/waiter"
)
-// EndpointFactory implements stack.UnassociatedEndpointFactory.
+// EndpointFactory implements stack.RawFactory.
type EndpointFactory struct{}
-// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint.
+func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
+
+// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint.
+func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return packet.NewEndpoint(stack, cooked, netProto, waiterQueue)
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index aed70e06f..f1dbc6f91 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -44,6 +44,7 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index b724d02bb..8db1cc028 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -607,17 +607,11 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu
return nil
}
-// sendTCP sends a TCP segment with the provided options via the provided
-// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDescriptor, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) {
optLen := len(opts)
- // Allocate a buffer for the TCP header.
- hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
-
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
- }
-
+ hdr := &d.Hdr
+ packetSize := d.Size
+ off := d.Off
// Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
tcp.Encode(&header.TCPFields{
@@ -631,7 +625,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
})
copy(tcp[header.TCPMinimumSize:], opts)
- length := uint16(hdr.UsedLength() + data.Size())
+ length := uint16(hdr.UsedLength() + packetSize)
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
// Only calculate the checksum if offloading isn't supported.
if gso != nil && gso.NeedsCsum {
@@ -641,14 +635,71 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
} else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVV(data, xsum)
+ xsum = header.ChecksumVVWithOffset(data, xsum, off, packetSize)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
+}
+
+func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ mss := int(gso.MSS)
+ n := (data.Size() + mss - 1) / mss
+
+ hdrs := stack.NewPacketDescriptors(n, header.TCPMinimumSize+int(r.MaxHeaderLength())+optLen)
+
+ size := data.Size()
+ off := 0
+ for i := 0; i < n; i++ {
+ packetSize := mss
+ if packetSize > size {
+ packetSize = size
+ }
+ size -= packetSize
+ hdrs[i].Off = off
+ hdrs[i].Size = packetSize
+ buildTCPHdr(r, id, &hdrs[i], data, flags, seq, ack, rcvWnd, opts, gso)
+ off += packetSize
+ seq = seq.Add(seqnum.Size(packetSize))
+ }
+ if ttl == 0 {
+ ttl = r.DefaultTTL()
+ }
+ sent, err := r.WritePackets(gso, hdrs, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos})
+ if err != nil {
+ r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
+ }
+ r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent))
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
+ return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso)
+ }
+
+ d := &stack.PacketDescriptor{
+ Hdr: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
+ Off: 0,
+ Size: data.Size(),
+ }
+ buildTCPHdr(r, id, d, data, flags, seq, ack, rcvWnd, opts, gso)
+
if ttl == 0 {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ if err := r.WritePacket(gso, d.Hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 31a22c1eb..c6bc5528c 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2328,11 +2328,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
return s
}
-func (e *endpoint) initGSO() {
- if e.route.Capabilities()&stack.CapabilityGSO == 0 {
- return
- }
-
+func (e *endpoint) initHardwareGSO() {
gso := &stack.GSO{}
switch e.route.NetProto {
case header.IPv4ProtocolNumber:
@@ -2350,6 +2346,18 @@ func (e *endpoint) initGSO() {
e.gso = gso
}
+func (e *endpoint) initGSO() {
+ if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ e.initHardwareGSO()
+ } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ e.gso = &stack.GSO{
+ MaxSize: e.route.GSOMaxSize(),
+ Type: stack.GSOSW,
+ NeedsCsum: false,
+ }
+ }
+}
+
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 6e87245b7..91c8487f3 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -140,7 +140,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
- TransProto: header.TCPProtocolNumber,
+ TransProto: header.UDPProtocolNumber,
},
waiterQueue: waiterQueue,
// RFC 1075 section 5.4 recommends a TTL of 1 for membership
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index de026880f..5c3358a5e 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -121,8 +121,15 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
- payload.Append(vv)
+ // The buffers used by vv and netHeader may be used elsewhere
+ // in the system. For example, a raw or packet socket may use
+ // what UDP considers an unreachable destination. Thus we deep
+ // copy vv and netHeader to prevent multiple ownership and SR
+ // errors.
+ newNetHeader := make(buffer.View, len(netHeader))
+ copy(newNetHeader, netHeader)
+ payload := buffer.NewVectorisedView(len(newNetHeader), []buffer.View{newNetHeader})
+ payload.Append(vv.ToView().ToVectorisedView())
payload.CapLength(payloadLen)
hdr := buffer.NewPrependable(headerLen)