summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD4
-rw-r--r--pkg/abi/linux/epoll.go12
-rw-r--r--pkg/abi/linux/epoll_amd64.go (renamed from pkg/usermem/usermem_unsafe.go)22
-rw-r--r--pkg/abi/linux/epoll_arm64.go26
-rw-r--r--pkg/abi/linux/file.go2
-rw-r--r--pkg/abi/linux/file_amd64.go4
-rw-r--r--pkg/abi/linux/file_arm64.go4
-rw-r--r--pkg/abi/linux/fs.go2
-rw-r--r--pkg/abi/linux/ioctl.go26
-rw-r--r--pkg/abi/linux/ioctl_tun.go29
-rw-r--r--pkg/abi/linux/netfilter.go62
-rw-r--r--pkg/abi/linux/signal.go2
-rw-r--r--pkg/abi/linux/socket.go13
-rw-r--r--pkg/abi/linux/time.go8
-rw-r--r--pkg/abi/linux/xattr.go1
-rw-r--r--pkg/atomicbitops/BUILD10
-rw-r--r--pkg/atomicbitops/atomicbitops.go (renamed from pkg/atomicbitops/atomic_bitops.go)31
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s (renamed from pkg/atomicbitops/atomic_bitops_amd64.s)38
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s (renamed from pkg/atomicbitops/atomic_bitops_arm64.s)34
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go (renamed from pkg/atomicbitops/atomic_bitops_common.go)42
-rw-r--r--pkg/atomicbitops/atomicbitops_test.go (renamed from pkg/atomicbitops/atomic_bitops_test.go)64
-rw-r--r--pkg/binary/binary.go10
-rw-r--r--pkg/cpuid/cpuid_x86.go19
-rw-r--r--pkg/fspath/BUILD4
-rw-r--r--pkg/fspath/builder.go8
-rw-r--r--pkg/fspath/fspath.go3
-rw-r--r--pkg/gohacks/BUILD11
-rw-r--r--pkg/gohacks/gohacks_unsafe.go57
-rw-r--r--pkg/log/glog.go6
-rw-r--r--pkg/log/json.go2
-rw-r--r--pkg/log/json_k8s.go2
-rw-r--r--pkg/log/log.go60
-rw-r--r--pkg/p9/buffer.go10
-rw-r--r--pkg/p9/messages.go672
-rw-r--r--pkg/p9/messages_test.go4
-rw-r--r--pkg/p9/p9.go68
-rw-r--r--pkg/p9/transport.go4
-rw-r--r--pkg/p9/transport_flipcall.go4
-rw-r--r--pkg/p9/transport_test.go8
-rw-r--r--pkg/sentry/control/BUILD5
-rw-r--r--pkg/sentry/control/proc.go127
-rw-r--r--pkg/sentry/fs/dev/BUILD5
-rw-r--r--pkg/sentry/fs/dev/dev.go10
-rw-r--r--pkg/sentry/fs/dev/net_tun.go170
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go17
-rw-r--r--pkg/sentry/fs/mount_test.go11
-rw-r--r--pkg/sentry/fs/mounts.go10
-rw-r--r--pkg/sentry/fs/proc/BUILD1
-rw-r--r--pkg/sentry/fs/proc/mounts.go16
-rw-r--r--pkg/sentry/fs/proc/net.go5
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fs/proc/task.go17
-rw-r--r--pkg/sentry/fs/tty/slave.go2
-rw-r--r--pkg/sentry/fsbridge/BUILD24
-rw-r--r--pkg/sentry/fsbridge/bridge.go54
-rw-r--r--pkg/sentry/fsbridge/fs.go181
-rw-r--r--pkg/sentry/fsbridge/vfs.go138
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go4
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go5
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go5
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go12
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go29
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go3
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go5
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go18
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go20
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go13
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD1
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go18
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go8
-rw-r--r--pkg/sentry/fsimpl/proc/task.go4
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go20
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_net.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go4
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go17
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD1
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go7
-rw-r--r--pkg/sentry/fsimpl/sys/sys_test.go7
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD2
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go24
-rw-r--r--pkg/sentry/fsimpl/testutil/testutil.go16
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go10
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go18
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go22
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go5
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go233
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go6
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go19
-rw-r--r--pkg/sentry/inet/BUILD1
-rw-r--r--pkg/sentry/inet/namespace.go99
-rw-r--r--pkg/sentry/kernel/BUILD3
-rw-r--r--pkg/sentry/kernel/fd_table.go49
-rw-r--r--pkg/sentry/kernel/fs_context.go120
-rw-r--r--pkg/sentry/kernel/kernel.go184
-rw-r--r--pkg/sentry/kernel/rseq.go16
-rw-r--r--pkg/sentry/kernel/task.go54
-rw-r--r--pkg/sentry/kernel/task_clone.go27
-rw-r--r--pkg/sentry/kernel/task_context.go2
-rw-r--r--pkg/sentry/kernel/task_exit.go7
-rw-r--r--pkg/sentry/kernel/task_log.go21
-rw-r--r--pkg/sentry/kernel/task_net.go19
-rw-r--r--pkg/sentry/kernel/task_start.go55
-rw-r--r--pkg/sentry/kernel/thread_group.go6
-rw-r--r--pkg/sentry/loader/BUILD2
-rw-r--r--pkg/sentry/loader/elf.go28
-rw-r--r--pkg/sentry/loader/interpreter.go6
-rw-r--r--pkg/sentry/loader/loader.go179
-rw-r--r--pkg/sentry/loader/vdso.go7
-rw-r--r--pkg/sentry/mm/BUILD2
-rw-r--r--pkg/sentry/mm/address_space.go23
-rw-r--r--pkg/sentry/mm/lifecycle.go11
-rw-r--r--pkg/sentry/mm/metadata.go10
-rw-r--r--pkg/sentry/mm/mm.go4
-rw-r--r--pkg/sentry/platform/kvm/machine.go18
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD2
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go71
-rw-r--r--pkg/sentry/socket/hostinet/socket.go11
-rw-r--r--pkg/sentry/socket/netfilter/BUILD5
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go95
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go340
-rw-r--r--pkg/sentry/socket/netfilter/targets.go35
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go143
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go (renamed from pkg/tcpip/iptables/udp_matcher.go)91
-rw-r--r--pkg/sentry/socket/netlink/message.go15
-rw-r--r--pkg/sentry/socket/netstack/netstack.go64
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/strace/BUILD2
-rw-r--r--pkg/sentry/strace/epoll.go89
-rw-r--r--pkg/sentry/strace/linux64_amd64.go6
-rw-r--r--pkg/sentry/strace/linux64_arm64.go4
-rw-r--r--pkg/sentry/strace/socket.go7
-rw-r--r--pkg/sentry/strace/strace.go36
-rw-r--r--pkg/sentry/strace/syscalls.go10
-rw-r--r--pkg/sentry/syscalls/linux/BUILD3
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go40
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go28
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_amd64.go75
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_arm64.go77
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go17
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD28
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go225
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go44
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go147
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go326
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fscontext.go131
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go149
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go35
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go140
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go92
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/path.go94
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go584
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go511
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go380
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go346
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go87
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sys_read.go95
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go353
-rw-r--r--pkg/sentry/vfs/BUILD2
-rw-r--r--pkg/sentry/vfs/context.go7
-rw-r--r--pkg/sentry/vfs/dentry.go4
-rw-r--r--pkg/sentry/vfs/device.go3
-rw-r--r--pkg/sentry/vfs/epoll.go22
-rw-r--r--pkg/sentry/vfs/file_description.go10
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go10
-rw-r--r--pkg/sentry/vfs/filesystem.go2
-rw-r--r--pkg/sentry/vfs/filesystem_type.go1
-rw-r--r--pkg/sentry/vfs/mount.go14
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go20
-rw-r--r--pkg/sentry/vfs/options.go7
-rw-r--r--pkg/sentry/vfs/permissions.go19
-rw-r--r--pkg/sentry/vfs/resolving_path.go2
-rw-r--r--pkg/sentry/vfs/vfs.go53
-rw-r--r--pkg/sleep/commit_noasm.go13
-rw-r--r--pkg/sleep/sleep_unsafe.go23
-rw-r--r--pkg/syncevent/BUILD39
-rw-r--r--pkg/syncevent/broadcaster.go218
-rw-r--r--pkg/syncevent/broadcaster_test.go376
-rw-r--r--pkg/syncevent/receiver.go103
-rw-r--r--pkg/syncevent/source.go59
-rw-r--r--pkg/syncevent/syncevent.go32
-rw-r--r--pkg/syncevent/syncevent_example_test.go108
-rw-r--r--pkg/syncevent/waiter_amd64.s32
-rw-r--r--pkg/syncevent/waiter_arm64.s34
-rw-r--r--pkg/syncevent/waiter_asm_unsafe.go (renamed from pkg/fspath/builder_unsafe.go)15
-rw-r--r--pkg/syncevent/waiter_noasm_unsafe.go39
-rw-r--r--pkg/syncevent/waiter_test.go414
-rw-r--r--pkg/syncevent/waiter_unsafe.go206
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go63
-rw-r--r--pkg/tcpip/buffer/view.go6
-rw-r--r--pkg/tcpip/checker/checker.go14
-rw-r--r--pkg/tcpip/iptables/BUILD1
-rw-r--r--pkg/tcpip/iptables/iptables.go125
-rw-r--r--pkg/tcpip/iptables/targets.go43
-rw-r--r--pkg/tcpip/iptables/types.go50
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go180
-rw-r--r--pkg/tcpip/link/tun/BUILD18
-rw-r--r--pkg/tcpip/link/tun/device.go352
-rw-r--r--pkg/tcpip/link/tun/protocol.go56
-rw-r--r--pkg/tcpip/network/arp/arp.go20
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go6
-rw-r--r--pkg/tcpip/stack/ndp.go25
-rw-r--r--pkg/tcpip/stack/ndp_test.go900
-rw-r--r--pkg/tcpip/stack/nic.go154
-rw-r--r--pkg/tcpip/stack/registration.go25
-rw-r--r--pkg/tcpip/stack/stack.go112
-rw-r--r--pkg/tcpip/stack/stack_test.go524
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go20
-rw-r--r--pkg/tcpip/stack/transport_test.go15
-rw-r--r--pkg/tcpip/tcpip.go48
-rw-r--r--pkg/tcpip/time_unsafe.go2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go16
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go5
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go9
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go31
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go33
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go67
-rw-r--r--pkg/tcpip/transport/udp/protocol.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go120
-rw-r--r--pkg/usermem/BUILD2
-rw-r--r--pkg/usermem/usermem.go9
247 files changed, 11522 insertions, 2416 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 1f3c0c687..322d1ccc4 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -17,6 +17,8 @@ go_library(
"dev.go",
"elf.go",
"epoll.go",
+ "epoll_amd64.go",
+ "epoll_arm64.go",
"errors.go",
"eventfd.go",
"exec.go",
@@ -28,6 +30,7 @@ go_library(
"futex.go",
"inotify.go",
"ioctl.go",
+ "ioctl_tun.go",
"ip.go",
"ipc.go",
"limits.go",
@@ -59,6 +62,7 @@ go_library(
"wait.go",
"xattr.go",
],
+ marshal = True,
visibility = ["//visibility:public"],
deps = [
"//pkg/abi",
diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go
index 0e881aa3c..1121a1a92 100644
--- a/pkg/abi/linux/epoll.go
+++ b/pkg/abi/linux/epoll.go
@@ -14,12 +14,9 @@
package linux
-// EpollEvent is equivalent to struct epoll_event from epoll(2).
-type EpollEvent struct {
- Events uint32
- Fd int32
- Data int32
-}
+import (
+ "gvisor.dev/gvisor/pkg/binary"
+)
// Event masks.
const (
@@ -60,3 +57,6 @@ const (
EPOLL_CTL_DEL = 0x2
EPOLL_CTL_MOD = 0x3
)
+
+// SizeOfEpollEvent is the size of EpollEvent struct.
+var SizeOfEpollEvent = int(binary.Size(EpollEvent{}))
diff --git a/pkg/usermem/usermem_unsafe.go b/pkg/abi/linux/epoll_amd64.go
index 876783e78..34ff18009 100644
--- a/pkg/usermem/usermem_unsafe.go
+++ b/pkg/abi/linux/epoll_amd64.go
@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+package linux
-import (
- "unsafe"
-)
-
-// stringFromImmutableBytes is equivalent to string(bs), except that it never
-// copies even if escape analysis can't prove that bs does not escape. This is
-// only valid if bs is never mutated after stringFromImmutableBytes returns.
-func stringFromImmutableBytes(bs []byte) string {
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
+// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
+type EpollEvent struct {
+ Events uint32
+ // Linux makes struct epoll_event::data a __u64. We represent it as
+ // [2]int32 because, on amd64, Linux also makes struct epoll_event
+ // __attribute__((packed)), such that there is no padding between Events
+ // and Data.
+ Data [2]int32
}
diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go
new file mode 100644
index 000000000..f86c35329
--- /dev/null
+++ b/pkg/abi/linux/epoll_arm64.go
@@ -0,0 +1,26 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
+type EpollEvent struct {
+ Events uint32
+ // Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding
+ // here.
+ _ int32
+ Data [2]int32
+}
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index c3ab15a4f..e229ac21c 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -241,6 +241,8 @@ const (
)
// Statx represents struct statx.
+//
+// +marshal
type Statx struct {
Mask uint32
Blksize uint32
diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go
index 9d307e840..6b72364ea 100644
--- a/pkg/abi/linux/file_amd64.go
+++ b/pkg/abi/linux/file_amd64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package linux
// Constants for open(2).
@@ -23,6 +25,8 @@ const (
)
// Stat represents struct stat.
+//
+// +marshal
type Stat struct {
Dev uint64
Ino uint64
diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go
index 26a54f416..6492c9038 100644
--- a/pkg/abi/linux/file_arm64.go
+++ b/pkg/abi/linux/file_arm64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package linux
// Constants for open(2).
@@ -23,6 +25,8 @@ const (
)
// Stat represents struct stat.
+//
+// +marshal
type Stat struct {
Dev uint64
Ino uint64
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index 2c652baa2..158d2db5b 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -38,6 +38,8 @@ const (
)
// Statfs is struct statfs, from uapi/asm-generic/statfs.h.
+//
+// +marshal
type Statfs struct {
// Type is one of the filesystem magic values, defined above.
Type uint64
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 0e18db9ef..2062e6a4b 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -72,3 +72,29 @@ const (
SIOCGMIIPHY = 0x8947
SIOCGMIIREG = 0x8948
)
+
+// ioctl(2) directions. Used to calculate requests number.
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NONE = 0
+ _IOC_WRITE = 1
+ _IOC_READ = 2
+)
+
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NRBITS = 8
+ _IOC_TYPEBITS = 8
+ _IOC_SIZEBITS = 14
+ _IOC_DIRBITS = 2
+
+ _IOC_NRSHIFT = 0
+ _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS
+ _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS
+ _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS
+)
+
+// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
+func IOC(dir, typ, nr, size uint32) uint32 {
+ return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
+}
diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go
new file mode 100644
index 000000000..c59c9c136
--- /dev/null
+++ b/pkg/abi/linux/ioctl_tun.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// ioctl(2) request numbers from linux/if_tun.h
+var (
+ TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4)
+ TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4)
+)
+
+// Flags from net/if_tun.h
+const (
+ IFF_TUN = 0x0001
+ IFF_TAP = 0x0002
+ IFF_NO_PI = 0x1000
+ IFF_NOFILTER = 0x1000
+)
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 8e40bcc62..bd2e13ba1 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -225,11 +225,14 @@ type XTEntryTarget struct {
// SizeOfXTEntryTarget is the size of an XTEntryTarget.
const SizeOfXTEntryTarget = 32
-// XTStandardTarget is a builtin target, one of ACCEPT, DROP, JUMP, QUEUE, or
-// RETURN. It corresponds to struct xt_standard_target in
+// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
+// RETURN, or jump. It corresponds to struct xt_standard_target in
// include/uapi/linux/netfilter/x_tables.h.
type XTStandardTarget struct {
- Target XTEntryTarget
+ Target XTEntryTarget
+ // A positive verdict indicates a jump, and is the offset from the
+ // start of the table to jump to. A negative value means one of the
+ // other built-in targets.
Verdict int32
_ [4]byte
}
@@ -348,6 +351,59 @@ func goString(cstring []byte) string {
return string(cstring)
}
+// XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp
+// in include/uapi/linux/netfilter/xt_tcpudp.h.
+type XTTCP struct {
+ // SourcePortStart specifies the inclusive start of the range of source
+ // ports to which the matcher applies.
+ SourcePortStart uint16
+
+ // SourcePortEnd specifies the inclusive end of the range of source ports
+ // to which the matcher applies.
+ SourcePortEnd uint16
+
+ // DestinationPortStart specifies the start of the destination port
+ // range to which the matcher applies.
+ DestinationPortStart uint16
+
+ // DestinationPortEnd specifies the end of the destination port
+ // range to which the matcher applies.
+ DestinationPortEnd uint16
+
+ // Option specifies that a particular TCP option must be set.
+ Option uint8
+
+ // FlagMask masks TCP flags when comparing to the FlagCompare byte. It allows
+ // for specification of which flags are important to the matcher.
+ FlagMask uint8
+
+ // FlagCompare, in combination with FlagMask, is used to match only packets
+ // that have certain flags set.
+ FlagCompare uint8
+
+ // InverseFlags flips the meaning of certain fields. See the
+ // TX_TCP_INV_* flags.
+ InverseFlags uint8
+}
+
+// SizeOfXTTCP is the size of an XTTCP.
+const SizeOfXTTCP = 12
+
+// Flags in XTTCP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_tcpudp.h.
+const (
+ // Invert the meaning of SourcePortStart/End.
+ XT_TCP_INV_SRCPT = 0x01
+ // Invert the meaning of DestinationPortStart/End.
+ XT_TCP_INV_DSTPT = 0x02
+ // Invert the meaning of FlagCompare.
+ XT_TCP_INV_FLAGS = 0x04
+ // Invert the meaning of Option.
+ XT_TCP_INV_OPTION = 0x08
+ // Enable all flags.
+ XT_TCP_INV_MASK = 0x0F
+)
+
// XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp
// in include/uapi/linux/netfilter/xt_tcpudp.h.
type XTUDP struct {
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index c69b04ea9..1c330e763 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -115,6 +115,8 @@ const (
)
// SignalSet is a signal mask with a bit corresponding to each signal.
+//
+// +marshal
type SignalSet uint64
// SignalSetSize is the size in bytes of a SignalSet.
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 766ee4014..4a14ef691 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -411,6 +411,15 @@ type ControlMessageCredentials struct {
GID uint32
}
+// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message.
+//
+// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
+type ControlMessageIPPacketInfo struct {
+ NIC int32
+ LocalAddr InetAddr
+ DestinationAddr InetAddr
+}
+
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{}))
@@ -431,6 +440,10 @@ const SizeOfControlMessageTOS = 1
// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message.
const SizeOfControlMessageTClass = 4
+// SizeOfControlMessageIPPacketInfo is the size of an IP_PKTINFO
+// control message.
+const SizeOfControlMessageIPPacketInfo = 12
+
// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
// From net/scm.h.
const SCM_MAX_FD = 253
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index 5c5a58cd4..e6860ed49 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -101,6 +101,8 @@ func NsecToTimeT(nsec int64) TimeT {
}
// Timespec represents struct timespec in <time.h>.
+//
+// +marshal
type Timespec struct {
Sec int64
Nsec int64
@@ -155,6 +157,8 @@ func DurationToTimespec(dur time.Duration) Timespec {
const SizeOfTimeval = 16
// Timeval represents struct timeval in <time.h>.
+//
+// +marshal
type Timeval struct {
Sec int64
Usec int64
@@ -228,6 +232,8 @@ type Tms struct {
type TimerID int32
// StatxTimestamp represents struct statx_timestamp.
+//
+// +marshal
type StatxTimestamp struct {
Sec int64
Nsec uint32
@@ -256,6 +262,8 @@ func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
}
// Utime represents struct utimbuf used by utimes(2).
+//
+// +marshal
type Utime struct {
Actime int64
Modtime int64
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
index a3b6406fa..99180b208 100644
--- a/pkg/abi/linux/xattr.go
+++ b/pkg/abi/linux/xattr.go
@@ -18,6 +18,7 @@ package linux
const (
XATTR_NAME_MAX = 255
XATTR_SIZE_MAX = 65536
+ XATTR_LIST_MAX = 65536
XATTR_CREATE = 1
XATTR_REPLACE = 2
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 3948074ba..1a30f6967 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -5,10 +5,10 @@ package(licenses = ["notice"])
go_library(
name = "atomicbitops",
srcs = [
- "atomic_bitops.go",
- "atomic_bitops_amd64.s",
- "atomic_bitops_arm64.s",
- "atomic_bitops_common.go",
+ "atomicbitops.go",
+ "atomicbitops_amd64.s",
+ "atomicbitops_arm64.s",
+ "atomicbitops_noasm.go",
],
visibility = ["//:sandbox"],
)
@@ -16,7 +16,7 @@ go_library(
go_test(
name = "atomicbitops_test",
size = "small",
- srcs = ["atomic_bitops_test.go"],
+ srcs = ["atomicbitops_test.go"],
library = ":atomicbitops",
deps = ["//pkg/sync"],
)
diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomicbitops.go
index fcc41a9ea..1be081719 100644
--- a/pkg/atomicbitops/atomic_bitops.go
+++ b/pkg/atomicbitops/atomicbitops.go
@@ -14,47 +14,34 @@
// +build amd64 arm64
-// Package atomicbitops provides basic bitwise operations in an atomic way.
-// The implementation on amd64 leverages the LOCK prefix directly instead of
-// relying on the generic cas primitives, and the arm64 leverages the LDAXR
-// and STLXR pair primitives.
+// Package atomicbitops provides extensions to the sync/atomic package.
//
-// WARNING: the bitwise ops provided in this package doesn't imply any memory
-// ordering. Using them to construct locks must employ proper memory barriers.
+// All read-modify-write operations implemented by this package have
+// acquire-release memory ordering (like sync/atomic).
package atomicbitops
-// AndUint32 atomically applies bitwise and operation to *addr with val.
+// AndUint32 atomically applies bitwise AND operation to *addr with val.
func AndUint32(addr *uint32, val uint32)
-// OrUint32 atomically applies bitwise or operation to *addr with val.
+// OrUint32 atomically applies bitwise OR operation to *addr with val.
func OrUint32(addr *uint32, val uint32)
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
+// XorUint32 atomically applies bitwise XOR operation to *addr with val.
func XorUint32(addr *uint32, val uint32)
// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) uint32
-// AndUint64 atomically applies bitwise and operation to *addr with val.
+// AndUint64 atomically applies bitwise AND operation to *addr with val.
func AndUint64(addr *uint64, val uint64)
-// OrUint64 atomically applies bitwise or operation to *addr with val.
+// OrUint64 atomically applies bitwise OR operation to *addr with val.
func OrUint64(addr *uint64, val uint64)
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
+// XorUint64 atomically applies bitwise XOR operation to *addr with val.
func XorUint64(addr *uint64, val uint64)
// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) uint64
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool
diff --git a/pkg/atomicbitops/atomic_bitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
index db0972001..54c887ee5 100644
--- a/pkg/atomicbitops/atomic_bitops_amd64.s
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -75,41 +75,3 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
CMPXCHGQ DX, 0(DI)
MOVQ AX, ret+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- TESTL AX, AX
- JZ fail
- LEAL 1(AX), DX
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB AX, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- LEAL -1(AX), DX
- TESTL DX, DX
- JZ fail
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB DX, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
index 97f8808c1..5c780851b 100644
--- a/pkg/atomicbitops/atomic_bitops_arm64.s
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -50,7 +50,6 @@ TEXT ·CompareAndSwapUint32(SB),$0-20
MOVD addr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
-
again:
LDAXRW (R0), R3
CMPW R1, R3
@@ -95,7 +94,6 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
MOVD addr+0(FP), R0
MOVD old+8(FP), R1
MOVD new+16(FP), R2
-
again:
LDAXR (R0), R3
CMP R1, R3
@@ -105,35 +103,3 @@ again:
done:
MOVD R3, prev+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- CBZ R1, fail
- ADDW $1, R1
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- SUBSW $1, R1, R1
- BEQ fail
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomicbitops_noasm.go
index 85163ad62..3b2898256 100644
--- a/pkg/atomicbitops/atomic_bitops_common.go
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -20,7 +20,6 @@ import (
"sync/atomic"
)
-// AndUint32 atomically applies bitwise and operation to *addr with val.
func AndUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -31,7 +30,6 @@ func AndUint32(addr *uint32, val uint32) {
}
}
-// OrUint32 atomically applies bitwise or operation to *addr with val.
func OrUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -42,7 +40,6 @@ func OrUint32(addr *uint32, val uint32) {
}
}
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
func XorUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -53,8 +50,6 @@ func XorUint32(addr *uint32, val uint32) {
}
}
-// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
-// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
for {
prev = atomic.LoadUint32(addr)
@@ -67,7 +62,6 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
}
}
-// AndUint64 atomically applies bitwise and operation to *addr with val.
func AndUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -78,7 +72,6 @@ func AndUint64(addr *uint64, val uint64) {
}
}
-// OrUint64 atomically applies bitwise or operation to *addr with val.
func OrUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -89,7 +82,6 @@ func OrUint64(addr *uint64, val uint64) {
}
}
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
func XorUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -100,8 +92,6 @@ func XorUint64(addr *uint64, val uint64) {
}
}
-// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
-// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
for {
prev = atomic.LoadUint64(addr)
@@ -113,35 +103,3 @@ func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
}
}
}
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 0 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v+1) {
- return true
- }
- }
-}
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 1 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v-1) {
- return true
- }
- }
-}
diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomicbitops_test.go
index 9466d3e23..73af71bb4 100644
--- a/pkg/atomicbitops/atomic_bitops_test.go
+++ b/pkg/atomicbitops/atomicbitops_test.go
@@ -196,67 +196,3 @@ func TestCompareAndSwapUint64(t *testing.T) {
}
}
}
-
-func TestIncUnlessZeroInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: 0,
- ret: false,
- },
- {
- initial: 1,
- final: 2,
- ret: true,
- },
- {
- initial: 2,
- final: 3,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := IncUnlessZeroInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
-
-func TestDecUnlessOneInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: -1,
- ret: true,
- },
- {
- initial: 1,
- final: 1,
- ret: false,
- },
- {
- initial: 2,
- final: 1,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := DecUnlessOneInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
diff --git a/pkg/binary/binary.go b/pkg/binary/binary.go
index 631785f7b..25065aef9 100644
--- a/pkg/binary/binary.go
+++ b/pkg/binary/binary.go
@@ -254,3 +254,13 @@ func WriteUint64(w io.Writer, order binary.ByteOrder, num uint64) error {
_, err := w.Write(buf)
return err
}
+
+// AlignUp rounds a length up to an alignment. align must be a power of 2.
+func AlignUp(length int, align uint) int {
+ return (length + int(align) - 1) & ^(int(align) - 1)
+}
+
+// AlignDown rounds a length down to an alignment. align must be a power of 2.
+func AlignDown(length int, align uint) int {
+ return length & ^(int(align) - 1)
+}
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index 333ca0a04..a0bc55ea1 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -725,6 +725,18 @@ func vendorIDFromRegs(bx, cx, dx uint32) string {
return string(bytes)
}
+var maxXsaveSize = func() uint32 {
+ // Leaf 0 of xsaveinfo function returns the size for currently
+ // enabled xsave features in ebx, the maximum size if all valid
+ // features are saved with xsave in ecx, and valid XCR0 bits in
+ // edx:eax.
+ //
+ // If xSaveInfo isn't supported, cpuid will not fault but will
+ // return bogus values.
+ _, _, maxXsaveSize, _ := HostID(uint32(xSaveInfo), 0)
+ return maxXsaveSize
+}()
+
// ExtendedStateSize returns the number of bytes needed to save the "extended
// state" for this processor and the boundary it must be aligned to. Extended
// state includes floating point registers, and other cpu state that's not
@@ -736,12 +748,7 @@ func vendorIDFromRegs(bx, cx, dx uint32) string {
// about 2.5K worst case, with avx512).
func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
if fs.UseXsave() {
- // Leaf 0 of xsaveinfo function returns the size for currently
- // enabled xsave features in ebx, the maximum size if all valid
- // features are saved with xsave in ecx, and valid XCR0 bits in
- // edx:eax.
- _, _, maxSize, _ := HostID(uint32(xSaveInfo), 0)
- return uint(maxSize), 64
+ return uint(maxXsaveSize), 64
}
// If we don't support xsave, we fall back to fxsave, which requires
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index ee84471b2..67dd1e225 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -8,9 +8,11 @@ go_library(
name = "fspath",
srcs = [
"builder.go",
- "builder_unsafe.go",
"fspath.go",
],
+ deps = [
+ "//pkg/gohacks",
+ ],
)
go_test(
diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go
index 7ddb36826..6318d3874 100644
--- a/pkg/fspath/builder.go
+++ b/pkg/fspath/builder.go
@@ -16,6 +16,8 @@ package fspath
import (
"fmt"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
)
// Builder is similar to strings.Builder, but is used to produce pathnames
@@ -102,3 +104,9 @@ func (b *Builder) AppendString(str string) {
copy(b.buf[b.start:], b.buf[oldStart:])
copy(b.buf[len(b.buf)-len(str):], str)
}
+
+// String returns the accumulated string. No other methods should be called
+// after String.
+func (b *Builder) String() string {
+ return gohacks.StringFromImmutableBytes(b.buf[b.start:])
+}
diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go
index 9fb3fee24..4c983d5fd 100644
--- a/pkg/fspath/fspath.go
+++ b/pkg/fspath/fspath.go
@@ -67,7 +67,8 @@ func Parse(pathname string) Path {
// Path contains the information contained in a pathname string.
//
-// Path is copyable by value.
+// Path is copyable by value. The zero value for Path is equivalent to
+// fspath.Parse(""), i.e. the empty path.
type Path struct {
// Begin is an iterator to the first path component in the relative part of
// the path.
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
new file mode 100644
index 000000000..798a65eca
--- /dev/null
+++ b/pkg/gohacks/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gohacks",
+ srcs = [
+ "gohacks_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
new file mode 100644
index 000000000..aad675172
--- /dev/null
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -0,0 +1,57 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gohacks contains utilities for subverting the Go compiler.
+package gohacks
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Noescape hides a pointer from escape analysis. Noescape is the identity
+// function but escape analysis doesn't think the output depends on the input.
+// Noescape is inlined and currently compiles down to zero instructions.
+// USE CAREFULLY!
+//
+// (Noescape is copy/pasted from Go's runtime/stubs.go:noescape().)
+//
+//go:nosplit
+func Noescape(p unsafe.Pointer) unsafe.Pointer {
+ x := uintptr(p)
+ return unsafe.Pointer(x ^ 0)
+}
+
+// ImmutableBytesFromString is equivalent to []byte(s), except that it uses the
+// same memory backing s instead of making a heap-allocated copy. This is only
+// valid if the returned slice is never mutated.
+func ImmutableBytesFromString(s string) []byte {
+ shdr := (*reflect.StringHeader)(unsafe.Pointer(&s))
+ var bs []byte
+ bshdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ bshdr.Data = shdr.Data
+ bshdr.Len = shdr.Len
+ bshdr.Cap = shdr.Len
+ return bs
+}
+
+// StringFromImmutableBytes is equivalent to string(bs), except that it uses
+// the same memory backing bs instead of making a heap-allocated copy. This is
+// only valid if bs is never mutated after StringFromImmutableBytes returns.
+func StringFromImmutableBytes(bs []byte) string {
+ // This is cheaper than messing with reflect.StringHeader and
+ // reflect.SliceHeader, which as of this writing produces many dead stores
+ // of zeroes. Compare strings.Builder.String().
+ return *(*string)(unsafe.Pointer(&bs))
+}
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
index cab5fae55..b4f7bb5a4 100644
--- a/pkg/log/glog.go
+++ b/pkg/log/glog.go
@@ -46,7 +46,7 @@ var pid = os.Getpid()
// line The line number
// msg The user-supplied message
//
-func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
// Log level.
prefix := byte('?')
switch level {
@@ -64,9 +64,7 @@ func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, ar
microsecond := int(timestamp.Nanosecond() / 1000)
// 0 = this frame.
- // 1 = Debugf, etc.
- // 2 = Caller.
- _, file, line, ok := runtime.Caller(2)
+ _, file, line, ok := runtime.Caller(depth + 1)
if ok {
// Trim any directory path from the file.
slash := strings.LastIndexByte(file, byte('/'))
diff --git a/pkg/log/json.go b/pkg/log/json.go
index a278c8fc8..0943db1cc 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -62,7 +62,7 @@ type JSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e JSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e JSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := jsonLog{
Msg: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
index cee6eb514..6c6fc8b6f 100644
--- a/pkg/log/json_k8s.go
+++ b/pkg/log/json_k8s.go
@@ -33,7 +33,7 @@ type K8sJSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e *K8sJSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := k8sJSONLog{
Log: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 5056f17e6..a794da1aa 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -79,7 +79,7 @@ func (l Level) String() string {
type Emitter interface {
// Emit emits the given log statement. This allows for control over the
// timestamp used for logging.
- Emit(level Level, timestamp time.Time, format string, v ...interface{})
+ Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{})
}
// Writer writes the output to the given writer.
@@ -142,7 +142,7 @@ func (l *Writer) Write(data []byte) (int, error) {
}
// Emit emits the message.
-func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (l *Writer) Emit(_ int, _ Level, _ time.Time, format string, args ...interface{}) {
fmt.Fprintf(l, format, args...)
}
@@ -150,9 +150,9 @@ func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...i
type MultiEmitter []Emitter
// Emit emits to all emitters.
-func (m *MultiEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (m *MultiEmitter) Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) {
for _, e := range *m {
- e.Emit(level, timestamp, format, v...)
+ e.Emit(1+depth, level, timestamp, format, v...)
}
}
@@ -167,7 +167,7 @@ type TestEmitter struct {
}
// Emit emits to the TestLogger.
-func (t *TestEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (t *TestEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
t.Logf(format, v...)
}
@@ -198,22 +198,37 @@ type BasicLogger struct {
// Debugf implements logger.Debugf.
func (l *BasicLogger) Debugf(format string, v ...interface{}) {
- if l.IsLogging(Debug) {
- l.Emit(Debug, time.Now(), format, v...)
- }
+ l.DebugfAtDepth(1, format, v...)
}
// Infof implements logger.Infof.
func (l *BasicLogger) Infof(format string, v ...interface{}) {
- if l.IsLogging(Info) {
- l.Emit(Info, time.Now(), format, v...)
- }
+ l.InfofAtDepth(1, format, v...)
}
// Warningf implements logger.Warningf.
func (l *BasicLogger) Warningf(format string, v ...interface{}) {
+ l.WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs at a specific depth.
+func (l *BasicLogger) DebugfAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Debug) {
+ l.Emit(1+depth, Debug, time.Now(), format, v...)
+ }
+}
+
+// InfofAtDepth logs at a specific depth.
+func (l *BasicLogger) InfofAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Info) {
+ l.Emit(1+depth, Info, time.Now(), format, v...)
+ }
+}
+
+// WarningfAtDepth logs at a specific depth.
+func (l *BasicLogger) WarningfAtDepth(depth int, format string, v ...interface{}) {
if l.IsLogging(Warning) {
- l.Emit(Warning, time.Now(), format, v...)
+ l.Emit(1+depth, Warning, time.Now(), format, v...)
}
}
@@ -257,17 +272,32 @@ func SetLevel(newLevel Level) {
// Debugf logs to the global logger.
func Debugf(format string, v ...interface{}) {
- Log().Debugf(format, v...)
+ Log().DebugfAtDepth(1, format, v...)
}
// Infof logs to the global logger.
func Infof(format string, v ...interface{}) {
- Log().Infof(format, v...)
+ Log().InfofAtDepth(1, format, v...)
}
// Warningf logs to the global logger.
func Warningf(format string, v ...interface{}) {
- Log().Warningf(format, v...)
+ Log().WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs to the global logger.
+func DebugfAtDepth(depth int, format string, v ...interface{}) {
+ Log().DebugfAtDepth(1+depth, format, v...)
+}
+
+// InfofAtDepth logs to the global logger.
+func InfofAtDepth(depth int, format string, v ...interface{}) {
+ Log().InfofAtDepth(1+depth, format, v...)
+}
+
+// WarningfAtDepth logs to the global logger.
+func WarningfAtDepth(depth int, format string, v ...interface{}) {
+ Log().WarningfAtDepth(1+depth, format, v...)
}
// defaultStackSize is the default buffer size to allocate for stack traces.
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go
index 249536d8a..6a4951821 100644
--- a/pkg/p9/buffer.go
+++ b/pkg/p9/buffer.go
@@ -20,16 +20,16 @@ import (
// encoder is used for messages and 9P primitives.
type encoder interface {
- // Decode decodes from the given buffer. Decode may be called more than once
+ // decode decodes from the given buffer. decode may be called more than once
// to reuse the instance. It must clear any previous state.
//
// This may not fail, exhaustion will be recorded in the buffer.
- Decode(b *buffer)
+ decode(b *buffer)
- // Encode encodes to the given buffer.
+ // encode encodes to the given buffer.
//
// This may not fail.
- Encode(b *buffer)
+ encode(b *buffer)
}
// order is the byte order used for encoding.
@@ -39,7 +39,7 @@ var order = binary.LittleEndian
//
// This is passed to the encoder methods.
type buffer struct {
- // data is the underlying data. This may grow during Encode.
+ // data is the underlying data. This may grow during encode.
data []byte
// overflow indicates whether an overflow has occurred.
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index b1cede5f5..3863ad1f5 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -51,7 +51,7 @@ type payloader interface {
// SetPayload returns the decoded message.
//
// This is going to be total message size - FixedSize. But this should
- // be validated during Decode, which will be called after SetPayload.
+ // be validated during decode, which will be called after SetPayload.
SetPayload([]byte)
}
@@ -90,14 +90,14 @@ type Tversion struct {
Version string
}
-// Decode implements encoder.Decode.
-func (t *Tversion) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tversion) decode(b *buffer) {
t.MSize = b.Read32()
t.Version = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Tversion) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tversion) encode(b *buffer) {
b.Write32(t.MSize)
b.WriteString(t.Version)
}
@@ -121,14 +121,14 @@ type Rversion struct {
Version string
}
-// Decode implements encoder.Decode.
-func (r *Rversion) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rversion) decode(b *buffer) {
r.MSize = b.Read32()
r.Version = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (r *Rversion) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rversion) encode(b *buffer) {
b.Write32(r.MSize)
b.WriteString(r.Version)
}
@@ -149,13 +149,13 @@ type Tflush struct {
OldTag Tag
}
-// Decode implements encoder.Decode.
-func (t *Tflush) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tflush) decode(b *buffer) {
t.OldTag = b.ReadTag()
}
-// Encode implements encoder.Encode.
-func (t *Tflush) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tflush) encode(b *buffer) {
b.WriteTag(t.OldTag)
}
@@ -173,12 +173,12 @@ func (t *Tflush) String() string {
type Rflush struct {
}
-// Decode implements encoder.Decode.
-func (*Rflush) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rflush) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rflush) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rflush) encode(*buffer) {
}
// Type implements message.Type.
@@ -203,8 +203,8 @@ type Twalk struct {
Names []string
}
-// Decode implements encoder.Decode.
-func (t *Twalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Twalk) decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
@@ -214,8 +214,8 @@ func (t *Twalk) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (t *Twalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Twalk) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.NewFID)
b.Write16(uint16(len(t.Names)))
@@ -240,22 +240,22 @@ type Rwalk struct {
QIDs []QID
}
-// Decode implements encoder.Decode.
-func (r *Rwalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rwalk) decode(b *buffer) {
n := b.Read16()
r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
- q.Decode(b)
+ q.decode(b)
r.QIDs = append(r.QIDs, q)
}
}
-// Encode implements encoder.Encode.
-func (r *Rwalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rwalk) encode(b *buffer) {
b.Write16(uint16(len(r.QIDs)))
for _, q := range r.QIDs {
- q.Encode(b)
+ q.encode(b)
}
}
@@ -275,13 +275,13 @@ type Tclunk struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tclunk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tclunk) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tclunk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tclunk) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -299,12 +299,12 @@ func (t *Tclunk) String() string {
type Rclunk struct {
}
-// Decode implements encoder.Decode.
-func (*Rclunk) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rclunk) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rclunk) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rclunk) encode(*buffer) {
}
// Type implements message.Type.
@@ -325,13 +325,13 @@ type Tremove struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tremove) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tremove) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tremove) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tremove) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -349,12 +349,12 @@ func (t *Tremove) String() string {
type Rremove struct {
}
-// Decode implements encoder.Decode.
-func (*Rremove) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rremove) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rremove) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rremove) encode(*buffer) {
}
// Type implements message.Type.
@@ -374,13 +374,13 @@ type Rlerror struct {
Error uint32
}
-// Decode implements encoder.Decode.
-func (r *Rlerror) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rlerror) decode(b *buffer) {
r.Error = b.Read32()
}
-// Encode implements encoder.Encode.
-func (r *Rlerror) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rlerror) encode(b *buffer) {
b.Write32(r.Error)
}
@@ -409,16 +409,16 @@ type Tauth struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tauth) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tauth) decode(b *buffer) {
t.AuthenticationFID = b.ReadFID()
t.UserName = b.ReadString()
t.AttachName = b.ReadString()
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tauth) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tauth) encode(b *buffer) {
b.WriteFID(t.AuthenticationFID)
b.WriteString(t.UserName)
b.WriteString(t.AttachName)
@@ -437,7 +437,7 @@ func (t *Tauth) String() string {
// Rauth is an authentication response.
//
-// Encode, Decode and Length are inherited directly from QID.
+// encode and decode are inherited directly from QID.
type Rauth struct {
QID
}
@@ -463,16 +463,16 @@ type Tattach struct {
Auth Tauth
}
-// Decode implements encoder.Decode.
-func (t *Tattach) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tattach) decode(b *buffer) {
t.FID = b.ReadFID()
- t.Auth.Decode(b)
+ t.Auth.decode(b)
}
-// Encode implements encoder.Encode.
-func (t *Tattach) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tattach) encode(b *buffer) {
b.WriteFID(t.FID)
- t.Auth.Encode(b)
+ t.Auth.encode(b)
}
// Type implements message.Type.
@@ -509,14 +509,14 @@ type Tlopen struct {
Flags OpenFlags
}
-// Decode implements encoder.Decode.
-func (t *Tlopen) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlopen) decode(b *buffer) {
t.FID = b.ReadFID()
t.Flags = b.ReadOpenFlags()
}
-// Encode implements encoder.Encode.
-func (t *Tlopen) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlopen) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteOpenFlags(t.Flags)
}
@@ -542,15 +542,15 @@ type Rlopen struct {
filePayload
}
-// Decode implements encoder.Decode.
-func (r *Rlopen) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rlopen) decode(b *buffer) {
+ r.QID.decode(b)
r.IoUnit = b.Read32()
}
-// Encode implements encoder.Encode.
-func (r *Rlopen) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rlopen) encode(b *buffer) {
+ r.QID.encode(b)
b.Write32(r.IoUnit)
}
@@ -587,8 +587,8 @@ type Tlcreate struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tlcreate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlcreate) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.OpenFlags = b.ReadOpenFlags()
@@ -596,8 +596,8 @@ func (t *Tlcreate) Decode(b *buffer) {
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tlcreate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlcreate) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.WriteOpenFlags(t.OpenFlags)
@@ -617,7 +617,7 @@ func (t *Tlcreate) String() string {
// Rlcreate is a create response.
//
-// The Encode, Decode, etc. methods are inherited from Rlopen.
+// The encode, decode, etc. methods are inherited from Rlopen.
type Rlcreate struct {
Rlopen
}
@@ -647,16 +647,16 @@ type Tsymlink struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tsymlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tsymlink) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Target = b.ReadString()
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tsymlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tsymlink) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.WriteString(t.Target)
@@ -679,14 +679,14 @@ type Rsymlink struct {
QID QID
}
-// Decode implements encoder.Decode.
-func (r *Rsymlink) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rsymlink) decode(b *buffer) {
+ r.QID.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rsymlink) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rsymlink) encode(b *buffer) {
+ r.QID.encode(b)
}
// Type implements message.Type.
@@ -711,15 +711,15 @@ type Tlink struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Tlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlink) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Target = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Tlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlink) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteFID(t.Target)
b.WriteString(t.Name)
@@ -744,12 +744,12 @@ func (*Rlink) Type() MsgType {
return MsgRlink
}
-// Decode implements encoder.Decode.
-func (*Rlink) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rlink) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rlink) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rlink) encode(*buffer) {
}
// String implements fmt.Stringer.
@@ -772,16 +772,16 @@ type Trenameat struct {
NewName string
}
-// Decode implements encoder.Decode.
-func (t *Trenameat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Trenameat) decode(b *buffer) {
t.OldDirectory = b.ReadFID()
t.OldName = b.ReadString()
t.NewDirectory = b.ReadFID()
t.NewName = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Trenameat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Trenameat) encode(b *buffer) {
b.WriteFID(t.OldDirectory)
b.WriteString(t.OldName)
b.WriteFID(t.NewDirectory)
@@ -802,12 +802,12 @@ func (t *Trenameat) String() string {
type Rrenameat struct {
}
-// Decode implements encoder.Decode.
-func (*Rrenameat) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rrenameat) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rrenameat) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rrenameat) encode(*buffer) {
}
// Type implements message.Type.
@@ -832,15 +832,15 @@ type Tunlinkat struct {
Flags uint32
}
-// Decode implements encoder.Decode.
-func (t *Tunlinkat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tunlinkat) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Flags = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tunlinkat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tunlinkat) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.Write32(t.Flags)
@@ -860,12 +860,12 @@ func (t *Tunlinkat) String() string {
type Runlinkat struct {
}
-// Decode implements encoder.Decode.
-func (*Runlinkat) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Runlinkat) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Runlinkat) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Runlinkat) encode(*buffer) {
}
// Type implements message.Type.
@@ -893,15 +893,15 @@ type Trename struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Trename) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Trename) decode(b *buffer) {
t.FID = b.ReadFID()
t.Directory = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Trename) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Trename) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.Directory)
b.WriteString(t.Name)
@@ -921,12 +921,12 @@ func (t *Trename) String() string {
type Rrename struct {
}
-// Decode implements encoder.Decode.
-func (*Rrename) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rrename) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rrename) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rrename) encode(*buffer) {
}
// Type implements message.Type.
@@ -945,13 +945,13 @@ type Treadlink struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Treadlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Treadlink) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Treadlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Treadlink) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -971,13 +971,13 @@ type Rreadlink struct {
Target string
}
-// Decode implements encoder.Decode.
-func (r *Rreadlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rreadlink) decode(b *buffer) {
r.Target = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (r *Rreadlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rreadlink) encode(b *buffer) {
b.WriteString(r.Target)
}
@@ -1003,15 +1003,15 @@ type Tread struct {
Count uint32
}
-// Decode implements encoder.Decode.
-func (t *Tread) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tread) decode(b *buffer) {
t.FID = b.ReadFID()
t.Offset = b.Read64()
t.Count = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tread) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tread) encode(b *buffer) {
b.WriteFID(t.FID)
b.Write64(t.Offset)
b.Write32(t.Count)
@@ -1033,20 +1033,20 @@ type Rread struct {
Data []byte
}
-// Decode implements encoder.Decode.
+// decode implements encoder.decode.
//
// Data is automatically decoded via Payload.
-func (r *Rread) Decode(b *buffer) {
+func (r *Rread) decode(b *buffer) {
count := b.Read32()
if count != uint32(len(r.Data)) {
b.markOverrun()
}
}
-// Encode implements encoder.Encode.
+// encode implements encoder.encode.
//
// Data is automatically encoded via Payload.
-func (r *Rread) Encode(b *buffer) {
+func (r *Rread) encode(b *buffer) {
b.Write32(uint32(len(r.Data)))
}
@@ -1087,8 +1087,8 @@ type Twrite struct {
Data []byte
}
-// Decode implements encoder.Decode.
-func (t *Twrite) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Twrite) decode(b *buffer) {
t.FID = b.ReadFID()
t.Offset = b.Read64()
count := b.Read32()
@@ -1097,10 +1097,10 @@ func (t *Twrite) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
+// encode implements encoder.encode.
//
// This uses the buffer payload to avoid a copy.
-func (t *Twrite) Encode(b *buffer) {
+func (t *Twrite) encode(b *buffer) {
b.WriteFID(t.FID)
b.Write64(t.Offset)
b.Write32(uint32(len(t.Data)))
@@ -1137,13 +1137,13 @@ type Rwrite struct {
Count uint32
}
-// Decode implements encoder.Decode.
-func (r *Rwrite) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rwrite) decode(b *buffer) {
r.Count = b.Read32()
}
-// Encode implements encoder.Encode.
-func (r *Rwrite) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rwrite) encode(b *buffer) {
b.Write32(r.Count)
}
@@ -1178,8 +1178,8 @@ type Tmknod struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tmknod) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tmknod) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Mode = b.ReadFileMode()
@@ -1188,8 +1188,8 @@ func (t *Tmknod) Decode(b *buffer) {
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tmknod) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tmknod) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.WriteFileMode(t.Mode)
@@ -1214,14 +1214,14 @@ type Rmknod struct {
QID QID
}
-// Decode implements encoder.Decode.
-func (r *Rmknod) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rmknod) decode(b *buffer) {
+ r.QID.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rmknod) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rmknod) encode(b *buffer) {
+ r.QID.encode(b)
}
// Type implements message.Type.
@@ -1249,16 +1249,16 @@ type Tmkdir struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tmkdir) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tmkdir) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Permissions = b.ReadPermissions()
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tmkdir) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tmkdir) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.WritePermissions(t.Permissions)
@@ -1281,14 +1281,14 @@ type Rmkdir struct {
QID QID
}
-// Decode implements encoder.Decode.
-func (r *Rmkdir) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rmkdir) decode(b *buffer) {
+ r.QID.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rmkdir) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rmkdir) encode(b *buffer) {
+ r.QID.encode(b)
}
// Type implements message.Type.
@@ -1310,16 +1310,16 @@ type Tgetattr struct {
AttrMask AttrMask
}
-// Decode implements encoder.Decode.
-func (t *Tgetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tgetattr) decode(b *buffer) {
t.FID = b.ReadFID()
- t.AttrMask.Decode(b)
+ t.AttrMask.decode(b)
}
-// Encode implements encoder.Encode.
-func (t *Tgetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tgetattr) encode(b *buffer) {
b.WriteFID(t.FID)
- t.AttrMask.Encode(b)
+ t.AttrMask.encode(b)
}
// Type implements message.Type.
@@ -1344,18 +1344,18 @@ type Rgetattr struct {
Attr Attr
}
-// Decode implements encoder.Decode.
-func (r *Rgetattr) Decode(b *buffer) {
- r.Valid.Decode(b)
- r.QID.Decode(b)
- r.Attr.Decode(b)
+// decode implements encoder.decode.
+func (r *Rgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.QID.decode(b)
+ r.Attr.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rgetattr) Encode(b *buffer) {
- r.Valid.Encode(b)
- r.QID.Encode(b)
- r.Attr.Encode(b)
+// encode implements encoder.encode.
+func (r *Rgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.QID.encode(b)
+ r.Attr.encode(b)
}
// Type implements message.Type.
@@ -1380,18 +1380,18 @@ type Tsetattr struct {
SetAttr SetAttr
}
-// Decode implements encoder.Decode.
-func (t *Tsetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tsetattr) decode(b *buffer) {
t.FID = b.ReadFID()
- t.Valid.Decode(b)
- t.SetAttr.Decode(b)
+ t.Valid.decode(b)
+ t.SetAttr.decode(b)
}
-// Encode implements encoder.Encode.
-func (t *Tsetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tsetattr) encode(b *buffer) {
b.WriteFID(t.FID)
- t.Valid.Encode(b)
- t.SetAttr.Encode(b)
+ t.Valid.encode(b)
+ t.SetAttr.encode(b)
}
// Type implements message.Type.
@@ -1408,12 +1408,12 @@ func (t *Tsetattr) String() string {
type Rsetattr struct {
}
-// Decode implements encoder.Decode.
-func (*Rsetattr) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rsetattr) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rsetattr) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rsetattr) encode(*buffer) {
}
// Type implements message.Type.
@@ -1435,18 +1435,18 @@ type Tallocate struct {
Length uint64
}
-// Decode implements encoder.Decode.
-func (t *Tallocate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tallocate) decode(b *buffer) {
t.FID = b.ReadFID()
- t.Mode.Decode(b)
+ t.Mode.decode(b)
t.Offset = b.Read64()
t.Length = b.Read64()
}
-// Encode implements encoder.Encode.
-func (t *Tallocate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tallocate) encode(b *buffer) {
b.WriteFID(t.FID)
- t.Mode.Encode(b)
+ t.Mode.encode(b)
b.Write64(t.Offset)
b.Write64(t.Length)
}
@@ -1465,12 +1465,12 @@ func (t *Tallocate) String() string {
type Rallocate struct {
}
-// Decode implements encoder.Decode.
-func (*Rallocate) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rallocate) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rallocate) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rallocate) encode(*buffer) {
}
// Type implements message.Type.
@@ -1492,14 +1492,14 @@ type Tlistxattr struct {
Size uint64
}
-// Decode implements encoder.Decode.
-func (t *Tlistxattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlistxattr) decode(b *buffer) {
t.FID = b.ReadFID()
t.Size = b.Read64()
}
-// Encode implements encoder.Encode.
-func (t *Tlistxattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlistxattr) encode(b *buffer) {
b.WriteFID(t.FID)
b.Write64(t.Size)
}
@@ -1520,8 +1520,8 @@ type Rlistxattr struct {
Xattrs []string
}
-// Decode implements encoder.Decode.
-func (r *Rlistxattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rlistxattr) decode(b *buffer) {
n := b.Read16()
r.Xattrs = r.Xattrs[:0]
for i := 0; i < int(n); i++ {
@@ -1529,8 +1529,8 @@ func (r *Rlistxattr) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (r *Rlistxattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rlistxattr) encode(b *buffer) {
b.Write16(uint16(len(r.Xattrs)))
for _, x := range r.Xattrs {
b.WriteString(x)
@@ -1559,15 +1559,15 @@ type Txattrwalk struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Txattrwalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Txattrwalk) decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Txattrwalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Txattrwalk) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.NewFID)
b.WriteString(t.Name)
@@ -1589,13 +1589,13 @@ type Rxattrwalk struct {
Size uint64
}
-// Decode implements encoder.Decode.
-func (r *Rxattrwalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rxattrwalk) decode(b *buffer) {
r.Size = b.Read64()
}
-// Encode implements encoder.Encode.
-func (r *Rxattrwalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rxattrwalk) encode(b *buffer) {
b.Write64(r.Size)
}
@@ -1627,16 +1627,16 @@ type Txattrcreate struct {
Flags uint32
}
-// Decode implements encoder.Decode.
-func (t *Txattrcreate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Txattrcreate) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.AttrSize = b.Read64()
t.Flags = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Txattrcreate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Txattrcreate) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.Write64(t.AttrSize)
@@ -1657,12 +1657,12 @@ func (t *Txattrcreate) String() string {
type Rxattrcreate struct {
}
-// Decode implements encoder.Decode.
-func (r *Rxattrcreate) Decode(*buffer) {
+// decode implements encoder.decode.
+func (r *Rxattrcreate) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (r *Rxattrcreate) Encode(*buffer) {
+// encode implements encoder.encode.
+func (r *Rxattrcreate) encode(*buffer) {
}
// Type implements message.Type.
@@ -1687,15 +1687,15 @@ type Tgetxattr struct {
Size uint64
}
-// Decode implements encoder.Decode.
-func (t *Tgetxattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tgetxattr) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.Size = b.Read64()
}
-// Encode implements encoder.Encode.
-func (t *Tgetxattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tgetxattr) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.Write64(t.Size)
@@ -1717,13 +1717,13 @@ type Rgetxattr struct {
Value string
}
-// Decode implements encoder.Decode.
-func (r *Rgetxattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rgetxattr) decode(b *buffer) {
r.Value = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (r *Rgetxattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rgetxattr) encode(b *buffer) {
b.WriteString(r.Value)
}
@@ -1752,16 +1752,16 @@ type Tsetxattr struct {
Flags uint32
}
-// Decode implements encoder.Decode.
-func (t *Tsetxattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tsetxattr) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.Value = b.ReadString()
t.Flags = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tsetxattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tsetxattr) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.WriteString(t.Value)
@@ -1782,12 +1782,12 @@ func (t *Tsetxattr) String() string {
type Rsetxattr struct {
}
-// Decode implements encoder.Decode.
-func (r *Rsetxattr) Decode(*buffer) {
+// decode implements encoder.decode.
+func (r *Rsetxattr) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (r *Rsetxattr) Encode(*buffer) {
+// encode implements encoder.encode.
+func (r *Rsetxattr) encode(*buffer) {
}
// Type implements message.Type.
@@ -1809,14 +1809,14 @@ type Tremovexattr struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Tremovexattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tremovexattr) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Tremovexattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tremovexattr) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
}
@@ -1835,12 +1835,12 @@ func (t *Tremovexattr) String() string {
type Rremovexattr struct {
}
-// Decode implements encoder.Decode.
-func (r *Rremovexattr) Decode(*buffer) {
+// decode implements encoder.decode.
+func (r *Rremovexattr) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (r *Rremovexattr) Encode(*buffer) {
+// encode implements encoder.encode.
+func (r *Rremovexattr) encode(*buffer) {
}
// Type implements message.Type.
@@ -1865,15 +1865,15 @@ type Treaddir struct {
Count uint32
}
-// Decode implements encoder.Decode.
-func (t *Treaddir) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Treaddir) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Offset = b.Read64()
t.Count = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Treaddir) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Treaddir) encode(b *buffer) {
b.WriteFID(t.Directory)
b.Write64(t.Offset)
b.Write32(t.Count)
@@ -1907,14 +1907,14 @@ type Rreaddir struct {
payload []byte
}
-// Decode implements encoder.Decode.
-func (r *Rreaddir) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rreaddir) decode(b *buffer) {
r.Count = b.Read32()
entriesBuf := buffer{data: r.payload}
r.Entries = r.Entries[:0]
for {
var d Dirent
- d.Decode(&entriesBuf)
+ d.decode(&entriesBuf)
if entriesBuf.isOverrun() {
// Couldn't decode a complete entry.
break
@@ -1923,11 +1923,11 @@ func (r *Rreaddir) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (r *Rreaddir) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rreaddir) encode(b *buffer) {
entriesBuf := buffer{}
for _, d := range r.Entries {
- d.Encode(&entriesBuf)
+ d.encode(&entriesBuf)
if len(entriesBuf.data) >= int(r.Count) {
break
}
@@ -1972,13 +1972,13 @@ type Tfsync struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tfsync) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tfsync) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tfsync) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tfsync) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -1996,12 +1996,12 @@ func (t *Tfsync) String() string {
type Rfsync struct {
}
-// Decode implements encoder.Decode.
-func (*Rfsync) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rfsync) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rfsync) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rfsync) encode(*buffer) {
}
// Type implements message.Type.
@@ -2020,13 +2020,13 @@ type Tstatfs struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tstatfs) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tstatfs) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tstatfs) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tstatfs) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -2046,14 +2046,14 @@ type Rstatfs struct {
FSStat FSStat
}
-// Decode implements encoder.Decode.
-func (r *Rstatfs) Decode(b *buffer) {
- r.FSStat.Decode(b)
+// decode implements encoder.decode.
+func (r *Rstatfs) decode(b *buffer) {
+ r.FSStat.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rstatfs) Encode(b *buffer) {
- r.FSStat.Encode(b)
+// encode implements encoder.encode.
+func (r *Rstatfs) encode(b *buffer) {
+ r.FSStat.encode(b)
}
// Type implements message.Type.
@@ -2072,13 +2072,13 @@ type Tflushf struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tflushf) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tflushf) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tflushf) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tflushf) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -2096,12 +2096,12 @@ func (t *Tflushf) String() string {
type Rflushf struct {
}
-// Decode implements encoder.Decode.
-func (*Rflushf) Decode(*buffer) {
+// decode implements encoder.decode.
+func (*Rflushf) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rflushf) Encode(*buffer) {
+// encode implements encoder.encode.
+func (*Rflushf) encode(*buffer) {
}
// Type implements message.Type.
@@ -2126,8 +2126,8 @@ type Twalkgetattr struct {
Names []string
}
-// Decode implements encoder.Decode.
-func (t *Twalkgetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Twalkgetattr) decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
@@ -2137,8 +2137,8 @@ func (t *Twalkgetattr) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (t *Twalkgetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Twalkgetattr) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.NewFID)
b.Write16(uint16(len(t.Names)))
@@ -2169,26 +2169,26 @@ type Rwalkgetattr struct {
QIDs []QID
}
-// Decode implements encoder.Decode.
-func (r *Rwalkgetattr) Decode(b *buffer) {
- r.Valid.Decode(b)
- r.Attr.Decode(b)
+// decode implements encoder.decode.
+func (r *Rwalkgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.Attr.decode(b)
n := b.Read16()
r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
- q.Decode(b)
+ q.decode(b)
r.QIDs = append(r.QIDs, q)
}
}
-// Encode implements encoder.Encode.
-func (r *Rwalkgetattr) Encode(b *buffer) {
- r.Valid.Encode(b)
- r.Attr.Encode(b)
+// encode implements encoder.encode.
+func (r *Rwalkgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.Attr.encode(b)
b.Write16(uint16(len(r.QIDs)))
for _, q := range r.QIDs {
- q.Encode(b)
+ q.encode(b)
}
}
@@ -2210,15 +2210,15 @@ type Tucreate struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tucreate) Decode(b *buffer) {
- t.Tlcreate.Decode(b)
+// decode implements encoder.decode.
+func (t *Tucreate) decode(b *buffer) {
+ t.Tlcreate.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tucreate) Encode(b *buffer) {
- t.Tlcreate.Encode(b)
+// encode implements encoder.encode.
+func (t *Tucreate) encode(b *buffer) {
+ t.Tlcreate.encode(b)
b.WriteUID(t.UID)
}
@@ -2255,15 +2255,15 @@ type Tumkdir struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tumkdir) Decode(b *buffer) {
- t.Tmkdir.Decode(b)
+// decode implements encoder.decode.
+func (t *Tumkdir) decode(b *buffer) {
+ t.Tmkdir.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tumkdir) Encode(b *buffer) {
- t.Tmkdir.Encode(b)
+// encode implements encoder.encode.
+func (t *Tumkdir) encode(b *buffer) {
+ t.Tmkdir.encode(b)
b.WriteUID(t.UID)
}
@@ -2300,15 +2300,15 @@ type Tumknod struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tumknod) Decode(b *buffer) {
- t.Tmknod.Decode(b)
+// decode implements encoder.decode.
+func (t *Tumknod) decode(b *buffer) {
+ t.Tmknod.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tumknod) Encode(b *buffer) {
- t.Tmknod.Encode(b)
+// encode implements encoder.encode.
+func (t *Tumknod) encode(b *buffer) {
+ t.Tmknod.encode(b)
b.WriteUID(t.UID)
}
@@ -2345,15 +2345,15 @@ type Tusymlink struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tusymlink) Decode(b *buffer) {
- t.Tsymlink.Decode(b)
+// decode implements encoder.decode.
+func (t *Tusymlink) decode(b *buffer) {
+ t.Tsymlink.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tusymlink) Encode(b *buffer) {
- t.Tsymlink.Encode(b)
+// encode implements encoder.encode.
+func (t *Tusymlink) encode(b *buffer) {
+ t.Tsymlink.encode(b)
b.WriteUID(t.UID)
}
@@ -2391,14 +2391,14 @@ type Tlconnect struct {
Flags ConnectFlags
}
-// Decode implements encoder.Decode.
-func (t *Tlconnect) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlconnect) decode(b *buffer) {
t.FID = b.ReadFID()
t.Flags = b.ReadConnectFlags()
}
-// Encode implements encoder.Encode.
-func (t *Tlconnect) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlconnect) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteConnectFlags(t.Flags)
}
@@ -2418,11 +2418,11 @@ type Rlconnect struct {
filePayload
}
-// Decode implements encoder.Decode.
-func (r *Rlconnect) Decode(*buffer) {}
+// decode implements encoder.decode.
+func (r *Rlconnect) decode(*buffer) {}
-// Encode implements encoder.Encode.
-func (r *Rlconnect) Encode(*buffer) {}
+// encode implements encoder.encode.
+func (r *Rlconnect) encode(*buffer) {}
// Type implements message.Type.
func (*Rlconnect) Type() MsgType {
@@ -2445,14 +2445,14 @@ type Tchannel struct {
Control uint32
}
-// Decode implements encoder.Decode.
-func (t *Tchannel) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tchannel) decode(b *buffer) {
t.ID = b.Read32()
t.Control = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tchannel) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tchannel) encode(b *buffer) {
b.Write32(t.ID)
b.Write32(t.Control)
}
@@ -2474,14 +2474,14 @@ type Rchannel struct {
filePayload
}
-// Decode implements encoder.Decode.
-func (r *Rchannel) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rchannel) decode(b *buffer) {
r.Offset = b.Read64()
r.Length = b.Read64()
}
-// Encode implements encoder.Encode.
-func (r *Rchannel) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rchannel) encode(b *buffer) {
b.Write64(r.Offset)
b.Write64(r.Length)
}
@@ -2577,7 +2577,7 @@ func calculateSize(m message) uint32 {
return p.FixedSize()
}
var dataBuf buffer
- m.Encode(&dataBuf)
+ m.encode(&dataBuf)
return uint32(len(dataBuf.data))
}
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index 825c939da..c20324404 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -382,7 +382,7 @@ func TestEncodeDecode(t *testing.T) {
// Encode the original.
data := make([]byte, initialBufferLength)
buf := buffer{data: data[:0]}
- enc.Encode(&buf)
+ enc.encode(&buf)
// Create a new object, same as the first.
enc2 := reflect.New(reflect.ValueOf(enc).Elem().Type()).Interface().(encoder)
@@ -399,7 +399,7 @@ func TestEncodeDecode(t *testing.T) {
}
// Mark sure it was okay.
- enc2.Decode(&buf2)
+ enc2.decode(&buf2)
if buf2.isOverrun() {
t.Errorf("object %#v->%#v got overrun on decode", enc, enc2)
continue
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 20ab31f7a..28d851ff5 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -450,15 +450,15 @@ func (q QID) String() string {
return fmt.Sprintf("QID{Type: %d, Version: %d, Path: %d}", q.Type, q.Version, q.Path)
}
-// Decode implements encoder.Decode.
-func (q *QID) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (q *QID) decode(b *buffer) {
q.Type = b.ReadQIDType()
q.Version = b.Read32()
q.Path = b.Read64()
}
-// Encode implements encoder.Encode.
-func (q *QID) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (q *QID) encode(b *buffer) {
b.WriteQIDType(q.Type)
b.Write32(q.Version)
b.Write64(q.Path)
@@ -515,8 +515,8 @@ type FSStat struct {
NameLength uint32
}
-// Decode implements encoder.Decode.
-func (f *FSStat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (f *FSStat) decode(b *buffer) {
f.Type = b.Read32()
f.BlockSize = b.Read32()
f.Blocks = b.Read64()
@@ -528,8 +528,8 @@ func (f *FSStat) Decode(b *buffer) {
f.NameLength = b.Read32()
}
-// Encode implements encoder.Encode.
-func (f *FSStat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (f *FSStat) encode(b *buffer) {
b.Write32(f.Type)
b.Write32(f.BlockSize)
b.Write64(f.Blocks)
@@ -679,8 +679,8 @@ func (a AttrMask) String() string {
return fmt.Sprintf("AttrMask{with: %s}", strings.Join(masks, " "))
}
-// Decode implements encoder.Decode.
-func (a *AttrMask) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (a *AttrMask) decode(b *buffer) {
mask := b.Read64()
a.Mode = mask&0x00000001 != 0
a.NLink = mask&0x00000002 != 0
@@ -698,8 +698,8 @@ func (a *AttrMask) Decode(b *buffer) {
a.DataVersion = mask&0x00002000 != 0
}
-// Encode implements encoder.Encode.
-func (a *AttrMask) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (a *AttrMask) encode(b *buffer) {
var mask uint64
if a.Mode {
mask |= 0x00000001
@@ -774,8 +774,8 @@ func (a Attr) String() string {
a.Mode, a.UID, a.GID, a.NLink, a.RDev, a.Size, a.BlockSize, a.Blocks, a.ATimeSeconds, a.ATimeNanoSeconds, a.MTimeSeconds, a.MTimeNanoSeconds, a.CTimeSeconds, a.CTimeNanoSeconds, a.BTimeSeconds, a.BTimeNanoSeconds, a.Gen, a.DataVersion)
}
-// Encode implements encoder.Encode.
-func (a *Attr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (a *Attr) encode(b *buffer) {
b.WriteFileMode(a.Mode)
b.WriteUID(a.UID)
b.WriteGID(a.GID)
@@ -796,8 +796,8 @@ func (a *Attr) Encode(b *buffer) {
b.Write64(a.DataVersion)
}
-// Decode implements encoder.Decode.
-func (a *Attr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (a *Attr) decode(b *buffer) {
a.Mode = b.ReadFileMode()
a.UID = b.ReadUID()
a.GID = b.ReadGID()
@@ -926,8 +926,8 @@ func (s SetAttrMask) Empty() bool {
return !s.Permissions && !s.UID && !s.GID && !s.Size && !s.ATime && !s.MTime && !s.CTime && !s.ATimeNotSystemTime && !s.MTimeNotSystemTime
}
-// Decode implements encoder.Decode.
-func (s *SetAttrMask) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (s *SetAttrMask) decode(b *buffer) {
mask := b.Read32()
s.Permissions = mask&0x00000001 != 0
s.UID = mask&0x00000002 != 0
@@ -972,8 +972,8 @@ func (s SetAttrMask) bitmask() uint32 {
return mask
}
-// Encode implements encoder.Encode.
-func (s *SetAttrMask) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (s *SetAttrMask) encode(b *buffer) {
b.Write32(s.bitmask())
}
@@ -994,8 +994,8 @@ func (s SetAttr) String() string {
return fmt.Sprintf("SetAttr{Permissions: 0o%o, UID: %d, GID: %d, Size: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}}", s.Permissions, s.UID, s.GID, s.Size, s.ATimeSeconds, s.ATimeNanoSeconds, s.MTimeSeconds, s.MTimeNanoSeconds)
}
-// Decode implements encoder.Decode.
-func (s *SetAttr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (s *SetAttr) decode(b *buffer) {
s.Permissions = b.ReadPermissions()
s.UID = b.ReadUID()
s.GID = b.ReadGID()
@@ -1006,8 +1006,8 @@ func (s *SetAttr) Decode(b *buffer) {
s.MTimeNanoSeconds = b.Read64()
}
-// Encode implements encoder.Encode.
-func (s *SetAttr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (s *SetAttr) encode(b *buffer) {
b.WritePermissions(s.Permissions)
b.WriteUID(s.UID)
b.WriteGID(s.GID)
@@ -1064,17 +1064,17 @@ func (d Dirent) String() string {
return fmt.Sprintf("Dirent{QID: %d, Offset: %d, Type: 0x%X, Name: %s}", d.QID, d.Offset, d.Type, d.Name)
}
-// Decode implements encoder.Decode.
-func (d *Dirent) Decode(b *buffer) {
- d.QID.Decode(b)
+// decode implements encoder.decode.
+func (d *Dirent) decode(b *buffer) {
+ d.QID.decode(b)
d.Offset = b.Read64()
d.Type = b.ReadQIDType()
d.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (d *Dirent) Encode(b *buffer) {
- d.QID.Encode(b)
+// encode implements encoder.encode.
+func (d *Dirent) encode(b *buffer) {
+ d.QID.encode(b)
b.Write64(d.Offset)
b.WriteQIDType(d.Type)
b.WriteString(d.Name)
@@ -1118,8 +1118,8 @@ func (a *AllocateMode) ToLinux() uint32 {
return rv
}
-// Decode implements encoder.Decode.
-func (a *AllocateMode) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (a *AllocateMode) decode(b *buffer) {
mask := b.Read32()
a.KeepSize = mask&0x01 != 0
a.PunchHole = mask&0x02 != 0
@@ -1130,8 +1130,8 @@ func (a *AllocateMode) Decode(b *buffer) {
a.Unshare = mask&0x40 != 0
}
-// Encode implements encoder.Encode.
-func (a *AllocateMode) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (a *AllocateMode) encode(b *buffer) {
mask := uint32(0)
if a.KeepSize {
mask |= 0x01
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 9c11e28ce..7cec0e86d 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -80,7 +80,7 @@ func send(s *unet.Socket, tag Tag, m message) error {
}
// Encode the message. The buffer will grow automatically.
- m.Encode(&dataBuf)
+ m.encode(&dataBuf)
// Get our vectors to send.
var hdr [headerLength]byte
@@ -316,7 +316,7 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message,
}
// Decode the message data.
- m.Decode(&dataBuf)
+ m.decode(&dataBuf)
if dataBuf.isOverrun() {
// No need to drain the socket.
return NoTag, nil, ErrNoValidMessage
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
index 233f825e3..a0d274f3b 100644
--- a/pkg/p9/transport_flipcall.go
+++ b/pkg/p9/transport_flipcall.go
@@ -151,7 +151,7 @@ func (ch *channel) send(m message) (uint32, error) {
} else {
ch.buf.Write8(0) // No incoming FD.
}
- m.Encode(&ch.buf)
+ m.encode(&ch.buf)
ssz := uint32(len(ch.buf.data)) // Updated below.
// Is there a payload?
@@ -205,7 +205,7 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) {
ch.buf.data = ch.buf.data[:fs]
}
- r.Decode(&ch.buf)
+ r.decode(&ch.buf)
if ch.buf.isOverrun() {
// Nothing valid was available.
log.Debugf("recv [got %d bytes, needed more]", rsz)
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index 2f50ff3ea..3668fcad7 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -56,8 +56,8 @@ func TestSendRecv(t *testing.T) {
// badDecode overruns on decode.
type badDecode struct{}
-func (*badDecode) Decode(b *buffer) { b.markOverrun() }
-func (*badDecode) Encode(b *buffer) {}
+func (*badDecode) decode(b *buffer) { b.markOverrun() }
+func (*badDecode) encode(b *buffer) {}
func (*badDecode) Type() MsgType { return MsgTypeBadDecode }
func (*badDecode) String() string { return "badDecode{}" }
@@ -81,8 +81,8 @@ func TestRecvOverrun(t *testing.T) {
// unregistered is not registered on decode.
type unregistered struct{}
-func (*unregistered) Decode(b *buffer) {}
-func (*unregistered) Encode(b *buffer) {}
+func (*unregistered) decode(b *buffer) {}
+func (*unregistered) encode(b *buffer) {}
func (*unregistered) Type() MsgType { return MsgTypeUnregistered }
func (*unregistered) String() string { return "unregistered{}" }
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index e69496477..d16d78aa5 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -16,10 +16,13 @@ go_library(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/fd",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/sentry/fs",
"//pkg/sentry/fs/host",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
@@ -27,8 +30,10 @@ go_library(
"//pkg/sentry/state",
"//pkg/sentry/strace",
"//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/sync",
+ "//pkg/syserror",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
],
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index ced51c66c..5457ba5e7 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -18,19 +18,26 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "path"
"sort"
"strings"
"text/tabwriter"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/urpc"
)
@@ -60,6 +67,12 @@ type ExecArgs struct {
// process's MountNamespace.
MountNamespace *fs.MountNamespace
+ // MountNamespaceVFS2 is the mount namespace to execute the new process in.
+ // A reference on MountNamespace must be held for the lifetime of the
+ // ExecArgs. If MountNamespace is nil, it will default to the init
+ // process's MountNamespace.
+ MountNamespaceVFS2 *vfs.MountNamespace
+
// WorkingDirectory defines the working directory for the new process.
WorkingDirectory string `json:"wd"`
@@ -150,6 +163,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
Envv: args.Envv,
WorkingDirectory: args.WorkingDirectory,
MountNamespace: args.MountNamespace,
+ MountNamespaceVFS2: args.MountNamespaceVFS2,
Credentials: creds,
FDTable: fdTable,
Umask: 0022,
@@ -166,24 +180,53 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
// be donated to the new process in CreateProcess.
initArgs.MountNamespace.IncRef()
}
+ if initArgs.MountNamespaceVFS2 != nil {
+ // initArgs must hold a reference on MountNamespaceVFS2, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespaceVFS2.IncRef()
+ }
ctx := initArgs.NewContext(proc.Kernel)
if initArgs.Filename == "" {
- // Get the full path to the filename from the PATH env variable.
- paths := fs.GetPath(initArgs.Envv)
- mns := initArgs.MountNamespace
- if mns == nil {
- mns = proc.Kernel.GlobalInit().Leader().MountNamespace()
- }
- f, err := mns.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
- if err != nil {
- return nil, 0, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
+ if kernel.VFS2Enabled {
+ // Get the full path to the filename from the PATH env variable.
+ if initArgs.MountNamespaceVFS2 == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ //
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
+ }
+
+ paths := fs.GetPath(initArgs.Envv)
+ vfsObj := proc.Kernel.VFS()
+ file, err := ResolveExecutablePath(ctx, vfsObj, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
+ if err != nil {
+ return nil, 0, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
+ }
+ initArgs.File = fsbridge.NewVFSFile(file)
+ } else {
+ // Get the full path to the filename from the PATH env variable.
+ paths := fs.GetPath(initArgs.Envv)
+ if initArgs.MountNamespace == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace()
+
+ // initArgs must hold a reference on MountNamespace, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespaceVFS2.IncRef()
+ }
+ f, err := initArgs.MountNamespace.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
+ if err != nil {
+ return nil, 0, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
+ }
+ initArgs.Filename = f
}
- initArgs.Filename = f
}
mounter := fs.FileOwnerFromContext(ctx)
+ // TODO(gvisor.dev/issue/1623): Use host FD when supported in VFS2.
var ttyFile *fs.File
for appFD, hostFile := range args.FilePayload.Files {
var appFile *fs.File
@@ -411,3 +454,67 @@ func ttyName(tty *kernel.TTY) string {
}
return fmt.Sprintf("pts/%d", tty.Index)
}
+
+// ResolveExecutablePath resolves the given executable name given a set of
+// paths that might contain it.
+func ResolveExecutablePath(ctx context.Context, vfsObj *vfs.VirtualFilesystem, wd, name string, paths []string) (*vfs.FileDescription, error) {
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef()
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Absolute paths can be used directly.
+ if path.IsAbs(name) {
+ return openExecutable(ctx, vfsObj, creds, root, name)
+ }
+
+ // Paths with '/' in them should be joined to the working directory, or
+ // to the root if working directory is not set.
+ if strings.IndexByte(name, '/') > 0 {
+ if len(wd) == 0 {
+ wd = "/"
+ }
+ if !path.IsAbs(wd) {
+ return nil, fmt.Errorf("working directory %q must be absolute", wd)
+ }
+ return openExecutable(ctx, vfsObj, creds, root, path.Join(wd, name))
+ }
+
+ // Otherwise, we must lookup the name in the paths, starting from the
+ // calling context's root directory.
+ for _, p := range paths {
+ if !path.IsAbs(p) {
+ // Relative paths aren't safe, no one should be using them.
+ log.Warningf("Skipping relative path %q in $PATH", p)
+ continue
+ }
+
+ binPath := path.Join(p, name)
+ f, err := openExecutable(ctx, vfsObj, creds, root, binPath)
+ if err != nil {
+ return nil, err
+ }
+ if f == nil {
+ continue // Not found/no access.
+ }
+ return f, nil
+ }
+ return nil, syserror.ENOENT
+}
+
+func openExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry, path string) (*vfs.FileDescription, error) {
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root, // binPath is absolute, Start can be anything.
+ Path: fspath.Parse(path),
+ FollowFinalSymlink: true,
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ }
+ f, err := vfsObj.OpenAt(ctx, creds, &pop, opts)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ return nil, nil
+ }
+ return f, err
+}
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 4c4b7d5cc..9b6bb26d0 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -9,6 +9,7 @@ go_library(
"device.go",
"fs.go",
"full.go",
+ "net_tun.go",
"null.go",
"random.go",
"tty.go",
@@ -19,15 +20,19 @@ go_library(
"//pkg/context",
"//pkg/rand",
"//pkg/safemem",
+ "//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
+ "//pkg/sentry/socket/netstack",
"//pkg/syserror",
+ "//pkg/tcpip/link/tun",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index 35bd23991..7e66c29b0 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
})
}
-func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555))
+func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
@@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
// A devpts is typically mounted at /dev/pts to provide
// pseudoterminal support. Place an empty directory there for
// the devpts to be mounted over.
- "pts": newDirectory(ctx, msrc),
+ "pts": newDirectory(ctx, nil, msrc),
// Similarly, applications expect a ptmx device at /dev/ptmx
// connected to the terminals provided by /dev/pts/. Rather
// than creating a device directly (which requires a hairy
@@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"ptmx": newSymlink(ctx, "pts/ptmx", msrc),
"tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
+
+ "net": newDirectory(ctx, map[string]*fs.Inode{
+ "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
+ }, msrc),
}
iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
new file mode 100644
index 000000000..755644488
--- /dev/null
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// +stateify savable
+type netTunInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*netTunInodeOperations)(nil)
+
+func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations {
+ return &netTunInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
+}
+
+// +stateify savable
+type netTunFileOperations struct {
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ device tun.Device
+}
+
+var _ fs.FileOperations = (*netTunFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (fops *netTunFileOperations) Release() {
+ fops.device.Release()
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fops.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fops.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Write implements fs.FileOperations.Write.
+func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fops.device.Write(data)
+}
+
+// Read implements fs.FileOperations.Read.
+func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := fops.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fops.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fops.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ fops.device.EventUnregister(e)
+}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index 67278aa86..e82afd112 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -65,13 +65,18 @@ type mapping struct {
writable bool
}
-// NewHostFileMapper returns a HostFileMapper with no references or cached
-// mappings.
+// Init must be called on zero-value HostFileMappers before first use.
+func (f *HostFileMapper) Init() {
+ f.refs = make(map[uint64]int32)
+ f.mappings = make(map[uint64]mapping)
+}
+
+// NewHostFileMapper returns an initialized HostFileMapper allocated on the
+// heap with no references or cached mappings.
func NewHostFileMapper() *HostFileMapper {
- return &HostFileMapper{
- refs: make(map[uint64]int32),
- mappings: make(map[uint64]mapping),
- }
+ f := &HostFileMapper{}
+ f.Init()
+ return f
}
// IncRefOn increments the reference count on all offsets in mr.
diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go
index e672a438c..a3d10770b 100644
--- a/pkg/sentry/fs/mount_test.go
+++ b/pkg/sentry/fs/mount_test.go
@@ -36,11 +36,12 @@ func mountPathsAre(root *Dirent, got []*Mount, want ...string) error {
gotPaths := make(map[string]struct{}, len(got))
gotStr := make([]string, len(got))
for i, g := range got {
- groot := g.Root()
- name, _ := groot.FullName(root)
- groot.DecRef()
- gotStr[i] = name
- gotPaths[name] = struct{}{}
+ if groot := g.Root(); groot != nil {
+ name, _ := groot.FullName(root)
+ groot.DecRef()
+ gotStr[i] = name
+ gotPaths[name] = struct{}{}
+ }
}
if len(got) != len(want) {
return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want)
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 574a2cc91..c7981f66e 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -100,10 +100,14 @@ func newUndoMount(d *Dirent) *Mount {
}
}
-// Root returns the root dirent of this mount. Callers must call DecRef on the
-// returned dirent.
+// Root returns the root dirent of this mount.
+//
+// This may return nil if the mount has already been free. Callers must handle this
+// case appropriately. If non-nil, callers must call DecRef on the returned *Dirent.
func (m *Mount) Root() *Dirent {
- m.root.IncRef()
+ if !m.root.TryIncRef() {
+ return nil
+ }
return m.root
}
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 280093c5e..77c2c5c0e 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -36,6 +36,7 @@ go_library(
"//pkg/sentry/fs/proc/device",
"//pkg/sentry/fs/proc/seqfile",
"//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index c10888100..94deb553b 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -60,13 +60,15 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
})
for _, m := range ms {
mroot := m.Root()
+ if mroot == nil {
+ continue // No longer valid.
+ }
mountPath, desc := mroot.FullName(rootDir)
mroot.DecRef()
if !desc {
// MountSources that are not descendants of the chroot jail are ignored.
continue
}
-
fn(mountPath, m)
}
}
@@ -91,6 +93,12 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
var buf bytes.Buffer
forEachMount(mif.t, func(mountPath string, m *fs.Mount) {
+ mroot := m.Root()
+ if mroot == nil {
+ return // No longer valid.
+ }
+ defer mroot.DecRef()
+
// Format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
@@ -107,9 +115,6 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
// (3) Major:Minor device ID. We don't have a superblock, so we
// just use the root inode device number.
- mroot := m.Root()
- defer mroot.DecRef()
-
sa := mroot.Inode.StableAttr
fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor)
@@ -207,6 +212,9 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
//
// The "needs dump"and fsck flags are always 0, which is allowed.
root := m.Root()
+ if root == nil {
+ return // No longer valid.
+ }
defer root.DecRef()
flags := root.Inode.MountSource.Flags
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 6f2775344..95d5817ff 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -43,7 +43,10 @@ import (
// newNet creates a new proc net entry.
func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
"snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc),
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 0772d4ae4..d4c4b533d 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -357,7 +357,9 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine
func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"ipv4": p.newSysNetIPv4Dir(ctx, msrc, s),
"core": p.newSysNetCore(ctx, msrc, s),
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index ca020e11e..8ab8d8a02 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -249,7 +250,7 @@ func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
return newProcInode(t, exeSymlink, msrc, fs.Symlink, t)
}
-func (e *exe) executable() (d *fs.Dirent, err error) {
+func (e *exe) executable() (file fsbridge.File, err error) {
e.t.WithMuLocked(func(t *kernel.Task) {
mm := t.MemoryManager()
if mm == nil {
@@ -262,8 +263,8 @@ func (e *exe) executable() (d *fs.Dirent, err error) {
// The MemoryManager may be destroyed, in which case
// MemoryManager.destroy will simply set the executable to nil
// (with locks held).
- d = mm.Executable()
- if d == nil {
+ file = mm.Executable()
+ if file == nil {
err = syserror.ENOENT
}
})
@@ -283,15 +284,7 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
}
defer exec.DecRef()
- root := fs.RootFromContext(ctx)
- if root == nil {
- // This doesn't correspond to anything in Linux because the vfs is
- // global there.
- return "", syserror.EINVAL
- }
- defer root.DecRef()
- n, _ := exec.FullName(root)
- return n, nil
+ return exec.PathnameWithDeleted(ctx), nil
}
// namespaceSymlink represents a symlink in the namespacefs, such as the files
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index db55cdc48..6a2dbc576 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -73,7 +73,7 @@ func (si *slaveInodeOperations) Release(ctx context.Context) {
}
// Truncate implements fs.InodeOperations.Truncate.
-func (slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
return nil
}
diff --git a/pkg/sentry/fsbridge/BUILD b/pkg/sentry/fsbridge/BUILD
new file mode 100644
index 000000000..6c798f0bd
--- /dev/null
+++ b/pkg/sentry/fsbridge/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "fsbridge",
+ srcs = [
+ "bridge.go",
+ "fs.go",
+ "vfs.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsbridge/bridge.go b/pkg/sentry/fsbridge/bridge.go
new file mode 100644
index 000000000..8e7590721
--- /dev/null
+++ b/pkg/sentry/fsbridge/bridge.go
@@ -0,0 +1,54 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsbridge provides common interfaces to bridge between VFS1 and VFS2
+// files.
+package fsbridge
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// File provides a common interface to bridge between VFS1 and VFS2 files.
+type File interface {
+ // PathnameWithDeleted returns an absolute pathname to vd, consistent with
+ // Linux's d_path(). In particular, if vd.Dentry() has been disowned,
+ // PathnameWithDeleted appends " (deleted)" to the returned pathname.
+ PathnameWithDeleted(ctx context.Context) string
+
+ // ReadFull read all contents from the file.
+ ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error)
+
+ // ConfigureMMap mutates opts to implement mmap(2) for the file.
+ ConfigureMMap(context.Context, *memmap.MMapOpts) error
+
+ // Type returns the file type, e.g. linux.S_IFREG.
+ Type(context.Context) (linux.FileMode, error)
+
+ // IncRef increments reference.
+ IncRef()
+
+ // DecRef decrements reference.
+ DecRef()
+}
+
+// Lookup provides a common interface to open files.
+type Lookup interface {
+ // OpenPath opens a file.
+ OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, remainingTraversals *uint, resolveFinal bool) (File, error)
+}
diff --git a/pkg/sentry/fsbridge/fs.go b/pkg/sentry/fsbridge/fs.go
new file mode 100644
index 000000000..093ce1fb3
--- /dev/null
+++ b/pkg/sentry/fsbridge/fs.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsbridge
+
+import (
+ "io"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fsFile implements File interface over fs.File.
+//
+// +stateify savable
+type fsFile struct {
+ file *fs.File
+}
+
+var _ File = (*fsFile)(nil)
+
+// NewFSFile creates a new File over fs.File.
+func NewFSFile(file *fs.File) File {
+ return &fsFile{file: file}
+}
+
+// PathnameWithDeleted implements File.
+func (f *fsFile) PathnameWithDeleted(ctx context.Context) string {
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // This doesn't correspond to anything in Linux because the vfs is
+ // global there.
+ return ""
+ }
+ defer root.DecRef()
+
+ name, _ := f.file.Dirent.FullName(root)
+ return name
+}
+
+// ReadFull implements File.
+func (f *fsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var total int64
+ for dst.NumBytes() > 0 {
+ n, err := f.file.Preadv(ctx, dst, offset+total)
+ total += n
+ if err == io.EOF && total != 0 {
+ return total, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return total, err
+ }
+ dst = dst.DropFirst64(n)
+ }
+ return total, nil
+}
+
+// ConfigureMMap implements File.
+func (f *fsFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return f.file.ConfigureMMap(ctx, opts)
+}
+
+// Type implements File.
+func (f *fsFile) Type(context.Context) (linux.FileMode, error) {
+ return linux.FileMode(f.file.Dirent.Inode.StableAttr.Type.LinuxType()), nil
+}
+
+// IncRef implements File.
+func (f *fsFile) IncRef() {
+ f.file.IncRef()
+}
+
+// DecRef implements File.
+func (f *fsFile) DecRef() {
+ f.file.DecRef()
+}
+
+// fsLookup implements Lookup interface using fs.File.
+//
+// +stateify savable
+type fsLookup struct {
+ mntns *fs.MountNamespace
+
+ root *fs.Dirent
+ workingDir *fs.Dirent
+}
+
+var _ Lookup = (*fsLookup)(nil)
+
+// NewFSLookup creates a new Lookup using VFS1.
+func NewFSLookup(mntns *fs.MountNamespace, root, workingDir *fs.Dirent) Lookup {
+ return &fsLookup{
+ mntns: mntns,
+ root: root,
+ workingDir: workingDir,
+ }
+}
+
+// OpenPath implements Lookup.
+func (l *fsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, remainingTraversals *uint, resolveFinal bool) (File, error) {
+ var d *fs.Dirent
+ var err error
+ if resolveFinal {
+ d, err = l.mntns.FindInode(ctx, l.root, l.workingDir, path, remainingTraversals)
+ } else {
+ d, err = l.mntns.FindLink(ctx, l.root, l.workingDir, path, remainingTraversals)
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer d.DecRef()
+
+ if !resolveFinal && fs.IsSymlink(d.Inode.StableAttr) {
+ return nil, syserror.ELOOP
+ }
+
+ fsPerm := openOptionsToPermMask(&opts)
+ if err := d.Inode.CheckPermission(ctx, fsPerm); err != nil {
+ return nil, err
+ }
+
+ // If they claim it's a directory, then make sure.
+ if strings.HasSuffix(path, "/") {
+ if d.Inode.StableAttr.Type != fs.Directory {
+ return nil, syserror.ENOTDIR
+ }
+ }
+
+ if opts.FileExec && d.Inode.StableAttr.Type != fs.RegularFile {
+ ctx.Infof("%q is not a regular file: %v", path, d.Inode.StableAttr.Type)
+ return nil, syserror.EACCES
+ }
+
+ f, err := d.Inode.GetFile(ctx, d, flagsToFileFlags(opts.Flags))
+ if err != nil {
+ return nil, err
+ }
+
+ return &fsFile{file: f}, nil
+}
+
+func openOptionsToPermMask(opts *vfs.OpenOptions) fs.PermMask {
+ mode := opts.Flags & linux.O_ACCMODE
+ return fs.PermMask{
+ Read: mode == linux.O_RDONLY || mode == linux.O_RDWR,
+ Write: mode == linux.O_WRONLY || mode == linux.O_RDWR,
+ Execute: opts.FileExec,
+ }
+}
+
+func flagsToFileFlags(flags uint32) fs.FileFlags {
+ return fs.FileFlags{
+ Direct: flags&linux.O_DIRECT != 0,
+ DSync: flags&(linux.O_DSYNC|linux.O_SYNC) != 0,
+ Sync: flags&linux.O_SYNC != 0,
+ NonBlocking: flags&linux.O_NONBLOCK != 0,
+ Read: (flags & linux.O_ACCMODE) != linux.O_WRONLY,
+ Write: (flags & linux.O_ACCMODE) != linux.O_RDONLY,
+ Append: flags&linux.O_APPEND != 0,
+ Directory: flags&linux.O_DIRECTORY != 0,
+ Async: flags&linux.O_ASYNC != 0,
+ LargeFile: flags&linux.O_LARGEFILE != 0,
+ Truncate: flags&linux.O_TRUNC != 0,
+ }
+}
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
new file mode 100644
index 000000000..6aa17bfc1
--- /dev/null
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -0,0 +1,138 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsbridge
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fsFile implements File interface over vfs.FileDescription.
+//
+// +stateify savable
+type vfsFile struct {
+ file *vfs.FileDescription
+}
+
+var _ File = (*vfsFile)(nil)
+
+// NewVFSFile creates a new File over fs.File.
+func NewVFSFile(file *vfs.FileDescription) File {
+ return &vfsFile{file: file}
+}
+
+// PathnameWithDeleted implements File.
+func (f *vfsFile) PathnameWithDeleted(ctx context.Context) string {
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef()
+
+ vfsObj := f.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem()
+ name, _ := vfsObj.PathnameWithDeleted(ctx, root, f.file.VirtualDentry())
+ return name
+}
+
+// ReadFull implements File.
+func (f *vfsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var total int64
+ for dst.NumBytes() > 0 {
+ n, err := f.file.PRead(ctx, dst, offset+total, vfs.ReadOptions{})
+ total += n
+ if err == io.EOF && total != 0 {
+ return total, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return total, err
+ }
+ dst = dst.DropFirst64(n)
+ }
+ return total, nil
+}
+
+// ConfigureMMap implements File.
+func (f *vfsFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return f.file.ConfigureMMap(ctx, opts)
+}
+
+// Type implements File.
+func (f *vfsFile) Type(ctx context.Context) (linux.FileMode, error) {
+ stat, err := f.file.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return 0, err
+ }
+ return linux.FileMode(stat.Mode).FileType(), nil
+}
+
+// IncRef implements File.
+func (f *vfsFile) IncRef() {
+ f.file.IncRef()
+}
+
+// DecRef implements File.
+func (f *vfsFile) DecRef() {
+ f.file.DecRef()
+}
+
+// fsLookup implements Lookup interface using fs.File.
+//
+// +stateify savable
+type vfsLookup struct {
+ mntns *vfs.MountNamespace
+
+ root vfs.VirtualDentry
+ workingDir vfs.VirtualDentry
+}
+
+var _ Lookup = (*vfsLookup)(nil)
+
+// NewVFSLookup creates a new Lookup using VFS2.
+func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) Lookup {
+ return &vfsLookup{
+ mntns: mntns,
+ root: root,
+ workingDir: workingDir,
+ }
+}
+
+// OpenPath implements Lookup.
+//
+// remainingTraversals is not configurable in VFS2, all callers are using the
+// default anyways.
+//
+// TODO(gvisor.dev/issue/1623): Check mount has read and exec permission.
+func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
+ vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
+ creds := auth.CredentialsFromContext(ctx)
+ path := fspath.Parse(pathname)
+ pop := &vfs.PathOperation{
+ Root: l.root,
+ Start: l.workingDir,
+ Path: path,
+ FollowFinalSymlink: resolveFinal,
+ }
+ if path.Absolute {
+ pop.Start = l.root
+ }
+ fd, err := vfsObj.OpenAt(ctx, creds, pop, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return &vfsFile{file: fd}, nil
+}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index e03a0c665..abd4f24e7 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -28,6 +28,9 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
+// Name is the default filesystem name.
+const Name = "devtmpfs"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct {
initOnce sync.Once
@@ -107,6 +110,7 @@ func (a *Accessor) wrapContext(ctx context.Context) *accessorContext {
func (ac *accessorContext) Value(key interface{}) interface{} {
switch key {
case vfs.CtxMountNamespace:
+ ac.a.mntns.IncRef()
return ac.a.mntns
case vfs.CtxRoot:
ac.a.root.IncRef()
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
index 73308a2b5..b6d52c015 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
@@ -29,7 +29,10 @@ func TestDevtmpfs(t *testing.T) {
ctx := contexttest.Context(t)
creds := auth.CredentialsFromContext(ctx)
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
// Register tmpfs just so that we can have a root filesystem that isn't
// devtmpfs.
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index 2015a8871..89caee3df 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -52,7 +52,10 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ return nil, nil, nil, nil, err
+ }
vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index ebb72b75e..bd6ede995 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -188,14 +188,14 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
childType = fs.ToInodeType(childInode.diskInode.Mode().FileType())
}
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: child.diskDirent.FileName(),
Type: fs.ToDirentType(childType),
Ino: uint64(child.diskDirent.Inode()),
NextOff: fd.off + 1,
- }) {
+ }); err != nil {
dir.childList.InsertBefore(child, fd.iter)
- return nil
+ return err
}
fd.off++
}
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 05f992826..29bb73765 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -65,7 +65,10 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
@@ -496,9 +499,9 @@ func newIterDirentCb() *iterDirentsCb {
}
// Handle implements vfs.IterDirentsCallback.Handle.
-func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) bool {
+func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) error {
cb.dirents = append(cb.dirents, dirent)
- return true
+ return nil
}
// TestIterDirents tests the FileDescriptionImpl.IterDirents functionality.
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 07bf58953..e05429d41 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -296,7 +296,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if vfs.MayWriteFileWithOpenFlags(opts.Flags) || opts.Flags&(linux.O_CREAT|linux.O_EXCL|linux.O_TMPFILE) != 0 {
return nil, syserror.EROFS
}
- return inode.open(rp, vfsd, opts.Flags)
+ return inode.open(rp, vfsd, &opts)
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 191b39970..6962083f5 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -148,8 +148,8 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
}
// open creates and returns a file description for the dentry passed in.
-func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
- ats := vfs.AccessTypesForOpenFlags(flags)
+func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
if err := in.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
@@ -157,7 +157,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
switch in.impl.(type) {
case *regularFile:
var fd regularFileFD
- if err := fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
@@ -168,17 +168,17 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
return nil, syserror.EISDIR
}
var fd directoryFD
- if err := fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
case *symlink:
- if flags&linux.O_PATH == 0 {
+ if opts.Flags&linux.O_PATH == 0 {
// Can't open symlinks without O_PATH.
return nil, syserror.ELOOP
}
var fd symlinkFD
- fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
+ fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
default:
panic(fmt.Sprintf("unknown inode type: %T", in.impl))
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 6d4ebc2bf..5dbfc6250 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -65,8 +65,8 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
}
for fd.off < int64(len(fd.dirents)) {
- if !cb.Handle(fd.dirents[fd.off]) {
- return nil
+ if err := cb.Handle(fd.dirents[fd.off]); err != nil {
+ return err
}
fd.off++
}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 8eb61debf..5cfb0dc4c 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -400,6 +400,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
}
vfsObj := rp.VirtualFilesystem()
mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
childVFSD := parent.vfsd.Child(name)
@@ -593,7 +594,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
}
}
if rp.Done() {
- return start.openLocked(ctx, rp, opts.Flags)
+ return start.openLocked(ctx, rp, &opts)
}
afterTrailingSymlink:
@@ -633,12 +634,12 @@ afterTrailingSymlink:
start = parent
goto afterTrailingSymlink
}
- return child.openLocked(ctx, rp, opts.Flags)
+ return child.openLocked(ctx, rp, &opts)
}
// Preconditions: fs.renameMu must be locked.
-func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, flags uint32) (*vfs.FileDescription, error) {
- ats := vfs.AccessTypesForOpenFlags(flags)
+func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
if err := d.checkPermissions(rp.Credentials(), ats, d.isDir()); err != nil {
return nil, err
}
@@ -646,11 +647,11 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, flags ui
filetype := d.fileType()
switch {
case filetype == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD:
- if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, flags&linux.O_TRUNC != 0); err != nil {
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil {
return nil, err
}
fd := &regularFileFD{}
- if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
AllowDirectIO: true,
}); err != nil {
return nil, err
@@ -658,21 +659,21 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, flags ui
return &fd.vfsfd, nil
case filetype == linux.S_IFDIR:
// Can't open directories with O_CREAT.
- if flags&linux.O_CREAT != 0 {
+ if opts.Flags&linux.O_CREAT != 0 {
return nil, syserror.EISDIR
}
// Can't open directories writably.
if ats&vfs.MayWrite != 0 {
return nil, syserror.EISDIR
}
- if flags&linux.O_DIRECT != 0 {
+ if opts.Flags&linux.O_DIRECT != 0 {
return nil, syserror.EINVAL
}
if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, false /* write */, false /* trunc */); err != nil {
return nil, err
}
fd := &directoryFD{}
- if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
@@ -680,17 +681,17 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, flags ui
// Can't open symlinks without O_PATH (which is unimplemented).
return nil, syserror.ELOOP
default:
- if flags&linux.O_DIRECT != 0 {
+ if opts.Flags&linux.O_DIRECT != 0 {
return nil, syserror.EINVAL
}
- h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, flags&linux.O_TRUNC != 0)
+ h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0)
if err != nil {
return nil, err
}
fd := &specialFileFD{
handle: h,
}
- if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
h.close(ctx)
return nil, err
}
@@ -934,7 +935,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if oldParent == newParent && oldName == newName {
return nil
}
- if err := vfsObj.PrepareRenameDentry(vfs.MountNamespaceFromContext(ctx), &renamed.vfsd, replacedVFSD); err != nil {
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil {
return err
}
if err := renamed.file.rename(ctx, newParent.file, newName); err != nil {
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index d0552bd99..d00850e25 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -52,6 +52,9 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// Name is the default filesystem name.
+const Name = "9p"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct{}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 8e11e06b3..54c1031a7 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -571,6 +571,8 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
default:
panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop))
}
+ // After this point, d may be used as a memmap.Mappable.
+ d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init)
return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts)
}
@@ -799,6 +801,9 @@ type dentryPlatformFile struct {
// If this dentry represents a regular file, and handle.fd >= 0,
// hostFileMapper caches mappings of handle.fd.
hostFileMapper fsutil.HostFileMapper
+
+ // hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
+ hostFileMapperInitOnce sync.Once
}
// IncRef implements platform.File.IncRef.
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 733792c78..d092ccb2a 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -53,9 +53,9 @@ func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.Dy
}
// Open implements Inode.Open.
-func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &DynamicBytesFD{}
- if err := fd.Init(rp.Mount(), vfsd, f.data, flags); err != nil {
+ if err := fd.Init(rp.Mount(), vfsd, f.data, opts.Flags); err != nil {
return nil, err
}
return &fd.vfsfd, nil
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 6104751c8..5650512e0 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -43,12 +43,12 @@ type GenericDirectoryFD struct {
}
// Init initializes a GenericDirectoryFD.
-func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, flags uint32) error {
- if vfs.AccessTypesForOpenFlags(flags)&vfs.MayWrite != 0 {
+func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, opts *vfs.OpenOptions) error {
+ if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 {
// Can't open directories for writing.
return syserror.EISDIR
}
- if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
return err
}
fd.children = children
@@ -116,8 +116,8 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
Ino: stat.Ino,
NextOff: 1,
}
- if !cb.Handle(dirent) {
- return nil
+ if err := cb.Handle(dirent); err != nil {
+ return err
}
fd.off++
}
@@ -132,8 +132,8 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
Ino: stat.Ino,
NextOff: 2,
}
- if !cb.Handle(dirent) {
- return nil
+ if err := cb.Handle(dirent); err != nil {
+ return err
}
fd.off++
}
@@ -153,8 +153,8 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
Ino: stat.Ino,
NextOff: fd.off + 1,
}
- if !cb.Handle(dirent) {
- return nil
+ if err := cb.Handle(dirent); err != nil {
+ return err
}
fd.off++
}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index e49303c26..292f58afd 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -365,7 +365,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// appropriate bits in rp), but are returned by
// FileDescriptionImpl.StatusFlags().
opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW
- ats := vfs.AccessTypesForOpenFlags(opts.Flags)
+ ats := vfs.AccessTypesForOpenFlags(&opts)
// Do not create new file.
if opts.Flags&linux.O_CREAT == 0 {
@@ -379,7 +379,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- return inode.Open(rp, vfsd, opts.Flags)
+ return inode.Open(rp, vfsd, opts)
}
// May create new file.
@@ -398,7 +398,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- return inode.Open(rp, vfsd, opts.Flags)
+ return inode.Open(rp, vfsd, opts)
}
afterTrailingSymlink:
parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
@@ -438,7 +438,7 @@ afterTrailingSymlink:
return nil, err
}
parentVFSD.Impl().(*Dentry).InsertChild(pc, child)
- return child.Impl().(*Dentry).inode.Open(rp, child, opts.Flags)
+ return child.Impl().(*Dentry).inode.Open(rp, child, opts)
}
// Open existing file or follow symlink.
if mustCreate {
@@ -463,7 +463,7 @@ afterTrailingSymlink:
if err := childInode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- return childInode.Open(rp, childVFSD, opts.Flags)
+ return childInode.Open(rp, childVFSD, opts)
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
@@ -544,6 +544,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
virtfs := rp.VirtualFilesystem()
srcDirDentry := srcDirVFSD.Impl().(*Dentry)
@@ -595,7 +596,10 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
parentDentry := vfsd.Parent().Impl().(*Dentry)
parentDentry.dirMu.Lock()
defer parentDentry.dirMu.Unlock()
- if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
+
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil {
@@ -697,7 +701,9 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
parentDentry := vfsd.Parent().Impl().(*Dentry)
parentDentry.dirMu.Lock()
defer parentDentry.dirMu.Unlock()
- if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil {
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index adca2313f..099d70a16 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -507,7 +507,7 @@ type InodeSymlink struct {
}
// Open implements Inode.Open.
-func (InodeSymlink) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (InodeSymlink) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
return nil, syserror.ELOOP
}
@@ -549,8 +549,8 @@ func (s *StaticDirectory) Init(creds *auth.Credentials, ino uint64, perm linux.F
}
// Open implements kernfs.Inode.
-func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 79ebea8a5..c74fa999b 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -303,7 +303,7 @@ type Inode interface {
// inode for its lifetime.
//
// Precondition: !rp.Done(). vfsd.Impl() must be a kernfs Dentry.
- Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error)
+ Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error)
}
type inodeRefs interface {
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index ee65cf491..0459fb305 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -45,7 +45,10 @@ type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry
func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System {
ctx := contexttest.Context(t)
creds := auth.CredentialsFromContext(ctx)
- v := vfs.New()
+ v := &vfs.VirtualFilesystem{}
+ if err := v.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
@@ -113,9 +116,9 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod
return &dir.dentry
}
-func (d *readonlyDir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (d *readonlyDir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- if err := fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags); err != nil {
+ if err := fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts); err != nil {
return nil, err
}
return fd.VFSFileDescription(), nil
@@ -143,9 +146,9 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte
return &dir.dentry
}
-func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 12aac2e6a..a83245866 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -14,6 +14,7 @@ go_library(
"tasks_net.go",
"tasks_sys.go",
],
+ visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 11477b6a9..5c19d5522 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -26,15 +26,18 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
-// procFSType is the factory class for procfs.
+// Name is the default filesystem name.
+const Name = "proc"
+
+// FilesystemType is the factory class for procfs.
//
// +stateify savable
-type procFSType struct{}
+type FilesystemType struct{}
-var _ vfs.FilesystemType = (*procFSType)(nil)
+var _ vfs.FilesystemType = (*FilesystemType)(nil)
// GetFilesystem implements vfs.FilesystemType.
-func (ft *procFSType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (ft *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
k := kernel.KernelFromContext(ctx)
if k == nil {
return nil, nil, fmt.Errorf("procfs requires a kernel")
@@ -47,12 +50,13 @@ func (ft *procFSType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFile
procfs := &kernfs.Filesystem{}
procfs.VFSFilesystem().Init(vfsObj, procfs)
- var data *InternalData
+ var cgroups map[string]string
if opts.InternalData != nil {
- data = opts.InternalData.(*InternalData)
+ data := opts.InternalData.(*InternalData)
+ cgroups = data.Cgroups
}
- _, dentry := newTasksInode(procfs, k, pidns, data.Cgroups)
+ _, dentry := newTasksInode(procfs, k, pidns, cgroups)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index 353e37195..f3f4e49b4 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -105,8 +105,8 @@ func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallb
Ino: i.inoGen.NextIno(),
NextOff: offset + 1,
}
- if !cb.Handle(dirent) {
- return offset, nil
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
}
offset++
}
@@ -114,9 +114,9 @@ func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallb
}
// Open implements kernfs.Inode.
-func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index eb5bc62c0..2d814668a 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -98,9 +98,9 @@ func (i *taskInode) Valid(ctx context.Context) bool {
}
// Open implements kernfs.Inode.
-func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 14bd334e8..10c08fa90 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -73,9 +73,9 @@ func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNames
"meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
"mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
"net": newNetDir(root, inoGen, k),
- "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}),
+ "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}),
"uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
- "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}),
+ "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}),
}
inode := &tasksInode{
@@ -151,8 +151,8 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
Ino: i.inoGen.NextIno(),
NextOff: offset + 1,
}
- if !cb.Handle(dirent) {
- return offset, nil
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
}
offset++
}
@@ -163,8 +163,8 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
Ino: i.inoGen.NextIno(),
NextOff: offset + 1,
}
- if !cb.Handle(dirent) {
- return offset, nil
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
}
offset++
}
@@ -196,8 +196,8 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
Ino: i.inoGen.NextIno(),
NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1,
}
- if !cb.Handle(dirent) {
- return offset, nil
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
}
offset++
}
@@ -205,9 +205,9 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
}
// Open implements kernfs.Inode.
-func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/tasks_net.go
index 608fec017..d4e1812d8 100644
--- a/pkg/sentry/fsimpl/proc/tasks_net.go
+++ b/pkg/sentry/fsimpl/proc/tasks_net.go
@@ -39,7 +39,10 @@ import (
func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
const (
arp = "IP address HW type Flags HW address Mask Device\n"
netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n"
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index c7ce74883..3d5dc463c 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -50,7 +50,9 @@ func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]*kernfs.Dentry{
"ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
"tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}),
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index 6fc3524db..c5d531fe0 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -90,8 +90,7 @@ func setup(t *testing.T) *testutil.System {
ctx := k.SupervisorContext()
creds := auth.CredentialsFromContext(ctx)
- vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("procfs", &procFSType{}, &vfs.RegisterFilesystemTypeOptions{
+ k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
fsOpts := vfs.GetFilesystemOptions{
@@ -102,11 +101,11 @@ func setup(t *testing.T) *testutil.System {
},
},
}
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "procfs", &fsOpts)
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", Name, &fsOpts)
if err != nil {
t.Fatalf("NewMountNamespace(): %v", err)
}
- return testutil.NewSystem(ctx, t, vfsObj, mntns)
+ return testutil.NewSystem(ctx, t, k.VFS(), mntns)
}
func TestTasksEmpty(t *testing.T) {
@@ -131,7 +130,7 @@ func TestTasks(t *testing.T) {
var tasks []*kernel.Task
for i := 0; i < 5; i++ {
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc)
+ task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root)
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
@@ -213,7 +212,7 @@ func TestTasksOffset(t *testing.T) {
k := kernel.KernelFromContext(s.Ctx)
for i := 0; i < 3; i++ {
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- if _, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc); err != nil {
+ if _, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root); err != nil {
t.Fatalf("CreateTask(): %v", err)
}
}
@@ -337,7 +336,7 @@ func TestTask(t *testing.T) {
k := kernel.KernelFromContext(s.Ctx)
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- _, err := testutil.CreateTask(s.Ctx, "name", tc)
+ _, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root)
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
@@ -352,7 +351,7 @@ func TestProcSelf(t *testing.T) {
k := kernel.KernelFromContext(s.Ctx)
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- task, err := testutil.CreateTask(s.Ctx, "name", tc)
+ task, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root)
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
@@ -433,7 +432,7 @@ func TestTree(t *testing.T) {
var tasks []*kernel.Task
for i := 0; i < 5; i++ {
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc)
+ task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root)
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index 66c0d8bc8..a741e2bb6 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"sys.go",
],
+ visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index e35d52d17..c36c4fa11 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -28,6 +28,9 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// Name is the default filesystem name.
+const Name = "sysfs"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct{}
@@ -97,9 +100,9 @@ func (d *dir) SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error {
}
// Open implements kernfs.Inode.Open.
-func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
index 8b1cf0bd0..4b3602d47 100644
--- a/pkg/sentry/fsimpl/sys/sys_test.go
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -34,16 +34,15 @@ func newTestSystem(t *testing.T) *testutil.System {
}
ctx := k.SupervisorContext()
creds := auth.CredentialsFromContext(ctx)
- v := vfs.New()
- v.MustRegisterFilesystemType("sysfs", sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ k.VFS().MustRegisterFilesystemType(sys.Name, sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mns, err := v.NewMountNamespace(ctx, creds, "", "sysfs", &vfs.GetFilesystemOptions{})
+ mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.GetFilesystemOptions{})
if err != nil {
t.Fatalf("Failed to create new mount namespace: %v", err)
}
- return testutil.NewSystem(ctx, t, v, mns)
+ return testutil.NewSystem(ctx, t, k.VFS(), mns)
}
func TestReadCPUFile(t *testing.T) {
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index efd5974c4..e4f36f4ae 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -16,7 +16,7 @@ go_library(
"//pkg/cpuid",
"//pkg/fspath",
"//pkg/memutil",
- "//pkg/sentry/fs",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/sched",
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index 89f8c4915..488478e29 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -24,7 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/memutil"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
// Platforms are plugable.
_ "gvisor.dev/gvisor/pkg/sentry/platform/kvm"
@@ -99,36 +100,41 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("initializing kernel: %v", err)
}
- ctx := k.SupervisorContext()
+ kernel.VFS2Enabled = true
- // Create mount namespace without root as it's the minimum required to create
- // the global thread group.
- mntns, err := fs.NewMountNamespace(ctx, nil)
- if err != nil {
- return nil, err
+ if err := k.VFS().Init(); err != nil {
+ return nil, fmt.Errorf("VFS init: %v", err)
}
+ k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
ls, err := limits.NewLinuxLimitSet()
if err != nil {
return nil, err
}
- tg := k.NewThreadGroup(mntns, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls)
+ tg := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls)
k.TestOnly_SetGlobalInit(tg)
return k, nil
}
// CreateTask creates a new bare bones task for tests.
-func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup) (*kernel.Task, error) {
+func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) {
k := kernel.KernelFromContext(ctx)
config := &kernel.TaskConfig{
Kernel: k,
ThreadGroup: tc,
TaskContext: &kernel.TaskContext{Name: name},
Credentials: auth.CredentialsFromContext(ctx),
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
UTSNamespace: kernel.UTSNamespaceFromContext(ctx),
IPCNamespace: kernel.IPCNamespaceFromContext(ctx),
AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ MountNamespaceVFS2: mntns,
+ FSContext: kernel.NewFSContextVFS2(root, cwd, 0022),
}
return k.TaskSet().NewTask(config)
}
diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go
index 69fd84ddd..e16808c63 100644
--- a/pkg/sentry/fsimpl/testutil/testutil.go
+++ b/pkg/sentry/fsimpl/testutil/testutil.go
@@ -41,12 +41,12 @@ type System struct {
Creds *auth.Credentials
VFS *vfs.VirtualFilesystem
Root vfs.VirtualDentry
- mns *vfs.MountNamespace
+ MntNs *vfs.MountNamespace
}
// NewSystem constructs a System.
//
-// Precondition: Caller must hold a reference on mns, whose ownership
+// Precondition: Caller must hold a reference on MntNs, whose ownership
// is transferred to the new System.
func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns *vfs.MountNamespace) *System {
s := &System{
@@ -54,7 +54,7 @@ func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns
Ctx: ctx,
Creds: auth.CredentialsFromContext(ctx),
VFS: v,
- mns: mns,
+ MntNs: mns,
Root: mns.Root(),
}
return s
@@ -75,7 +75,7 @@ func (s *System) WithSubtest(t *testing.T) *System {
Ctx: s.Ctx,
Creds: s.Creds,
VFS: s.VFS,
- mns: s.mns,
+ MntNs: s.MntNs,
Root: s.Root,
}
}
@@ -90,7 +90,7 @@ func (s *System) WithTemporaryContext(ctx context.Context) *System {
Ctx: ctx,
Creds: s.Creds,
VFS: s.VFS,
- mns: s.mns,
+ MntNs: s.MntNs,
Root: s.Root,
}
}
@@ -98,7 +98,7 @@ func (s *System) WithTemporaryContext(ctx context.Context) *System {
// Destroy release resources associated with a test system.
func (s *System) Destroy() {
s.Root.DecRef()
- s.mns.DecRef() // Reference on mns passed to NewSystem.
+ s.MntNs.DecRef() // Reference on MntNs passed to NewSystem.
}
// ReadToEnd reads the contents of fd until EOF to a string.
@@ -226,7 +226,7 @@ func (d *DirentCollector) SkipDotsChecks(value bool) {
}
// Handle implements vfs.IterDirentsCallback.Handle.
-func (d *DirentCollector) Handle(dirent vfs.Dirent) bool {
+func (d *DirentCollector) Handle(dirent vfs.Dirent) error {
d.mu.Lock()
if d.dirents == nil {
d.dirents = make(map[string]*vfs.Dirent)
@@ -234,7 +234,7 @@ func (d *DirentCollector) Handle(dirent vfs.Dirent) bool {
d.order = append(d.order, &dirent)
d.dirents[dirent.Name] = &dirent
d.mu.Unlock()
- return true
+ return nil
}
// Count returns the number of dirents currently in the collector.
diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index 9fce5e4b4..383133e44 100644
--- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -175,7 +175,10 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ b.Fatalf("VFS init: %v", err)
+ }
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
@@ -366,7 +369,10 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ b.Fatalf("VFS init: %v", err)
+ }
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index dc0d27cf9..b4380af38 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -74,25 +74,25 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
defer fs.mu.Unlock()
if fd.off == 0 {
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: ".",
Type: linux.DT_DIR,
Ino: vfsd.Impl().(*dentry).inode.ino,
NextOff: 1,
- }) {
- return nil
+ }); err != nil {
+ return err
}
fd.off++
}
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: "..",
Type: parentInode.direntType(),
Ino: parentInode.ino,
NextOff: 2,
- }) {
- return nil
+ }); err != nil {
+ return err
}
fd.off++
}
@@ -111,14 +111,14 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
for child != nil {
// Skip other directoryFD iterators.
if child.inode != nil {
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: child.vfsd.Name(),
Type: child.inode.direntType(),
Ino: child.inode.ino,
NextOff: fd.off + 1,
- }) {
+ }); err != nil {
dir.childList.InsertBefore(child, fd.iter)
- return nil
+ return err
}
fd.off++
}
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 72bc15264..e1b551422 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -16,7 +16,6 @@ package tmpfs
import (
"fmt"
- "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -334,7 +333,7 @@ afterTrailingSymlink:
}
func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) {
- ats := vfs.AccessTypesForOpenFlags(opts.Flags)
+ ats := vfs.AccessTypesForOpenFlags(opts)
if !afterCreate {
if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil {
return nil, err
@@ -347,10 +346,9 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, err
}
if opts.Flags&linux.O_TRUNC != 0 {
- impl.mu.Lock()
- impl.data.Truncate(0, impl.memFile)
- atomic.StoreUint64(&impl.size, 0)
- impl.mu.Unlock()
+ if _, err := impl.truncate(0); err != nil {
+ return nil, err
+ }
}
return &fd.vfsfd, nil
case *directory:
@@ -486,7 +484,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
vfsObj := rp.VirtualFilesystem()
oldParentDir := oldParent.inode.impl.(*directory)
newParentDir := newParent.inode.impl.(*directory)
- if err := vfsObj.PrepareRenameDentry(vfs.MountNamespaceFromContext(ctx), renamedVFSD, replacedVFSD); err != nil {
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareRenameDentry(mntns, renamedVFSD, replacedVFSD); err != nil {
return err
}
if replaced != nil {
@@ -543,7 +543,9 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
defer mnt.EndWrite()
vfsObj := rp.VirtualFilesystem()
- if err := vfsObj.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), childVFSD); err != nil {
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil {
return err
}
parent.inode.impl.(*directory).childList.Remove(child)
@@ -631,7 +633,9 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
defer mnt.EndWrite()
vfsObj := rp.VirtualFilesystem()
- if err := vfsObj.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), childVFSD); err != nil {
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil {
return err
}
parent.inode.impl.(*directory).childList.Remove(child)
diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index 5ee7f2a72..1614f2c39 100644
--- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -151,7 +151,10 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index dab346a41..711442424 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -15,6 +15,7 @@
package tmpfs
import (
+ "fmt"
"io"
"math"
"sync/atomic"
@@ -22,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -34,25 +36,53 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// regularFile is a regular (=S_IFREG) tmpfs file.
type regularFile struct {
inode inode
// memFile is a platform.File used to allocate pages to this regularFile.
memFile *pgalloc.MemoryFile
- // mu protects the fields below.
- mu sync.RWMutex
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the file into memmap.MappingSpaces.
+ //
+ // Protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // writableMappingPages tracks how many pages of virtual memory are mapped
+ // as potentially writable from this file. If a page has multiple mappings,
+ // each mapping is counted separately.
+ //
+ // This counter is susceptible to overflow as we can potentially count
+ // mappings from many VMAs. We count pages rather than bytes to slightly
+ // mitigate this.
+ //
+ // Protected by mapsMu.
+ writableMappingPages uint64
+
+ // dataMu protects the fields below.
+ dataMu sync.RWMutex
// data maps offsets into the file to offsets into memFile that store
// the file's data.
+ //
+ // Protected by dataMu.
data fsutil.FileRangeSet
- // size is the size of data, but accessed using atomic memory
- // operations to avoid locking in inode.stat().
- size uint64
-
// seals represents file seals on this inode.
+ //
+ // Protected by dataMu.
seals uint32
+
+ // size is the size of data.
+ //
+ // Protected by both dataMu and inode.mu; reading it requires holding
+ // either mutex, while writing requires holding both AND using atomics.
+ // Readers that do not require consistency (like Stat) may read the
+ // value atomically without holding either lock.
+ size uint64
}
func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
@@ -66,39 +96,170 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod
// truncate grows or shrinks the file to the given size. It returns true if the
// file size was updated.
-func (rf *regularFile) truncate(size uint64) (bool, error) {
- rf.mu.Lock()
- defer rf.mu.Unlock()
+func (rf *regularFile) truncate(newSize uint64) (bool, error) {
+ rf.inode.mu.Lock()
+ defer rf.inode.mu.Unlock()
+ return rf.truncateLocked(newSize)
+}
- if size == rf.size {
+// Preconditions: rf.inode.mu must be held.
+func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) {
+ oldSize := rf.size
+ if newSize == oldSize {
// Nothing to do.
return false, nil
}
- if size > rf.size {
- // Growing the file.
+ // Need to hold inode.mu and dataMu while modifying size.
+ rf.dataMu.Lock()
+ if newSize > oldSize {
+ // Can we grow the file?
if rf.seals&linux.F_SEAL_GROW != 0 {
- // Seal does not allow growth.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- rf.size = size
+ // We only need to update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
return true, nil
}
- // Shrinking the file
+ // We are shrinking the file. First check if this is allowed.
if rf.seals&linux.F_SEAL_SHRINK != 0 {
- // Seal does not allow shrink.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- // TODO(gvisor.dev/issues/1197): Invalidate mappings once we have
- // mappings.
+ // Update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
+
+ // Invalidate past translations of truncated pages.
+ oldpgend := fs.OffsetPageEnd(int64(oldSize))
+ newpgend := fs.OffsetPageEnd(int64(newSize))
+ if newpgend < oldpgend {
+ rf.mapsMu.Lock()
+ rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/shmem.c:shmem_setattr() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ rf.mapsMu.Unlock()
+ }
- rf.data.Truncate(size, rf.memFile)
- rf.size = size
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them.
+ rf.dataMu.Lock()
+ rf.data.Truncate(newSize, rf.memFile)
+ rf.dataMu.Unlock()
return true, nil
}
+// AddMapping implements memmap.Mappable.AddMapping.
+func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+
+ // Reject writable mapping if F_SEAL_WRITE is set.
+ if rf.seals&linux.F_SEAL_WRITE != 0 && writable {
+ return syserror.EPERM
+ }
+
+ rf.mappings.AddMapping(ms, ar, offset, writable)
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages < pagesBefore {
+ panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+
+ rf.mappings.RemoveMapping(ms, ar, offset, writable)
+
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages > pagesBefore {
+ panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return rf.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ rf.dataMu.Lock()
+ defer rf.dataMu.Unlock()
+
+ // Constrain translations to f.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(int64(rf.size))
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ // Newly-allocated pages are zeroed, so we don't need to do anything.
+ return dsts.NumBytes(), nil
+ })
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: rf.memFile,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: usermem.AnyAccess,
+ })
+ translatedEnd = segMR.End
+ }
+
+ // Don't return the error returned by f.data.Fill if it occurred outside of
+ // required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (*regularFile) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
type regularFileFD struct {
fileDescription
@@ -152,8 +313,10 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Overflow.
return 0, syserror.EFBIG
}
+ f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
+ f.inode.mu.Unlock()
putRegularFileReadWriter(rw)
return n, err
}
@@ -215,6 +378,12 @@ func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng
return nil
}
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ file := fd.inode().impl.(*regularFile)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
+}
+
// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
type regularFileReadWriter struct {
file *regularFile
@@ -244,14 +413,15 @@ func putRegularFileReadWriter(rw *regularFileReadWriter) {
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- rw.file.mu.RLock()
+ rw.file.dataMu.RLock()
+ defer rw.file.dataMu.RUnlock()
+ size := rw.file.size
// Compute the range to read (limited by file size and overflow-checked).
- if rw.off >= rw.file.size {
- rw.file.mu.RUnlock()
+ if rw.off >= size {
return 0, io.EOF
}
- end := rw.file.size
+ end := size
if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
end = rend
}
@@ -265,7 +435,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
// Get internal mappings.
ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -275,7 +444,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -291,7 +459,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -299,13 +466,16 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
}
}
- rw.file.mu.RUnlock()
return done, nil
}
// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: inode.mu must be held.
func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- rw.file.mu.Lock()
+ // Hold dataMu so we can modify size.
+ rw.file.dataMu.Lock()
+ defer rw.file.dataMu.Unlock()
// Compute the range to write (overflow-checked).
end := rw.off + srcs.NumBytes()
@@ -316,7 +486,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
// Check if seals prevent either file growth or all writes.
switch {
case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
- rw.file.mu.Unlock()
return 0, syserror.EPERM
case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
// When growth is sealed, Linux effectively allows writes which would
@@ -338,7 +507,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
}
if end <= rw.off {
// Truncation would result in no data being written.
- rw.file.mu.Unlock()
return 0, syserror.EPERM
}
}
@@ -395,9 +563,8 @@ exitLoop:
// If the write ends beyond the file's previous size, it causes the
// file to grow.
if rw.off > rw.file.size {
- atomic.StoreUint64(&rw.file.size, rw.off)
+ rw.file.size = rw.off
}
- rw.file.mu.Unlock()
return done, retErr
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
index e9f71e334..0399725cf 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -40,7 +40,11 @@ var nextFileID int64
func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) {
creds := auth.CredentialsFromContext(ctx)
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err)
+ }
+
vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 2108d0f4d..521206305 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -18,9 +18,10 @@
// Lock order:
//
// filesystem.mu
-// regularFileFD.offMu
-// regularFile.mu
// inode.mu
+// regularFileFD.offMu
+// regularFile.mapsMu
+// regularFile.dataMu
package tmpfs
import (
@@ -40,6 +41,9 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// Name is the default filesystem name.
+const Name = "tmpfs"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct{}
@@ -223,12 +227,15 @@ func (i *inode) tryIncRef() bool {
func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
- // This is unnecessary; it's mostly to simulate what tmpfs would do.
if regFile, ok := i.impl.(*regularFile); ok {
- regFile.mu.Lock()
+ // Hold inode.mu and regFile.dataMu while mutating
+ // size.
+ i.mu.Lock()
+ regFile.dataMu.Lock()
regFile.data.DropAll(regFile.memFile)
atomic.StoreUint64(&regFile.size, 0)
- regFile.mu.Unlock()
+ regFile.dataMu.Unlock()
+ i.mu.Unlock()
}
} else if refs < 0 {
panic("tmpfs.inode.decRef() called without holding a reference")
@@ -317,7 +324,7 @@ func (i *inode) setStat(stat linux.Statx) error {
if mask&linux.STATX_SIZE != 0 {
switch impl := i.impl.(type) {
case *regularFile:
- updated, err := impl.truncate(stat.Size)
+ updated, err := impl.truncateLocked(stat.Size)
if err != nil {
return err
}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 334432abf..07bf39fed 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -10,6 +10,7 @@ go_library(
srcs = [
"context.go",
"inet.go",
+ "namespace.go",
"test_stack.go",
],
deps = [
diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go
new file mode 100644
index 000000000..c16667e7f
--- /dev/null
+++ b/pkg/sentry/inet/namespace.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+// Namespace represents a network namespace. See network_namespaces(7).
+//
+// +stateify savable
+type Namespace struct {
+ // stack is the network stack implementation of this network namespace.
+ stack Stack `state:"nosave"`
+
+ // creator allows kernel to create new network stack for network namespaces.
+ // If nil, no networking will function if network is namespaced.
+ creator NetworkStackCreator
+
+ // isRoot indicates whether this is the root network namespace.
+ isRoot bool
+}
+
+// NewRootNamespace creates the root network namespace, with creator
+// allowing new network namespaces to be created. If creator is nil, no
+// networking will function if the network is namespaced.
+func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace {
+ return &Namespace{
+ stack: stack,
+ creator: creator,
+ isRoot: true,
+ }
+}
+
+// NewNamespace creates a new network namespace from the root.
+func NewNamespace(root *Namespace) *Namespace {
+ n := &Namespace{
+ creator: root.creator,
+ }
+ n.init()
+ return n
+}
+
+// Stack returns the network stack of n. Stack may return nil if no network
+// stack is configured.
+func (n *Namespace) Stack() Stack {
+ return n.stack
+}
+
+// IsRoot returns whether n is the root network namespace.
+func (n *Namespace) IsRoot() bool {
+ return n.isRoot
+}
+
+// RestoreRootStack restores the root network namespace with stack. This should
+// only be called when restoring kernel.
+func (n *Namespace) RestoreRootStack(stack Stack) {
+ if !n.isRoot {
+ panic("RestoreRootStack can only be called on root network namespace")
+ }
+ if n.stack != nil {
+ panic("RestoreRootStack called after a stack has already been set")
+ }
+ n.stack = stack
+}
+
+func (n *Namespace) init() {
+ // Root network namespace will have stack assigned later.
+ if n.isRoot {
+ return
+ }
+ if n.creator != nil {
+ var err error
+ n.stack, err = n.creator.CreateStack()
+ if err != nil {
+ panic(err)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (n *Namespace) afterLoad() {
+ n.init()
+}
+
+// NetworkStackCreator allows new instances of a network stack to be created. It
+// is used by the kernel to create new network namespaces when requested.
+type NetworkStackCreator interface {
+ // CreateStack creates a new network stack for a network namespace.
+ CreateStack() (Stack, error)
+}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 2231d6973..beba29a09 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -157,6 +157,7 @@ go_library(
"//pkg/context",
"//pkg/cpuid",
"//pkg/eventchannel",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/metric",
"//pkg/refs",
@@ -167,6 +168,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/hostcpu",
"//pkg/sentry/inet",
"//pkg/sentry/kernel/auth",
@@ -199,6 +201,7 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 23b88f7a6..58001d56c 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -296,6 +296,50 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return fds, nil
}
+// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
+// the given file description. If it succeeds, it takes a reference on file.
+func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ if minfd < 0 {
+ // Don't accept negative FDs.
+ return -1, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if minfd >= end {
+ return -1, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ fd := minfd
+ if fd < f.next {
+ fd = f.next
+ }
+ for fd < end {
+ if d, _, _ := f.get(fd); d == nil {
+ f.setVFS2(fd, file, flags)
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fd + 1
+ }
+ return fd, nil
+ }
+ fd++
+ }
+ return -1, syscall.EMFILE
+}
+
// NewFDAt sets the file reference for the given FD. If there is an active
// reference for that FD, the ref count for that existing reference is
// decremented.
@@ -316,9 +360,6 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
return syscall.EBADF
}
- f.mu.Lock()
- defer f.mu.Unlock()
-
// Check the limit for the provided file.
if limitSet := limits.FromContext(ctx); limitSet != nil {
if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur {
@@ -327,6 +368,8 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
}
// Install the entry.
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.setAll(fd, file, fileVFS2, flags)
return nil
}
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index 2448c1d99..47f78df9a 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -37,10 +38,16 @@ type FSContext struct {
// destroyed.
root *fs.Dirent
+ // rootVFS2 is the filesystem root.
+ rootVFS2 vfs.VirtualDentry
+
// cwd is the current working directory. Will be nil iff the FSContext
// has been destroyed.
cwd *fs.Dirent
+ // cwdVFS2 is the current working directory.
+ cwdVFS2 vfs.VirtualDentry
+
// umask is the current file mode creation mask. When a thread using this
// context invokes a syscall that creates a file, bits set in umask are
// removed from the permissions that the file is created with.
@@ -60,6 +67,19 @@ func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext {
return &f
}
+// NewFSContextVFS2 returns a new filesystem context.
+func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext {
+ root.IncRef()
+ cwd.IncRef()
+ f := FSContext{
+ rootVFS2: root,
+ cwdVFS2: cwd,
+ umask: umask,
+ }
+ f.EnableLeakCheck("kernel.FSContext")
+ return &f
+}
+
// destroy is the destructor for an FSContext.
//
// This will call DecRef on both root and cwd Dirents. If either call to
@@ -75,11 +95,17 @@ func (f *FSContext) destroy() {
f.mu.Lock()
defer f.mu.Unlock()
- f.root.DecRef()
- f.root = nil
-
- f.cwd.DecRef()
- f.cwd = nil
+ if VFS2Enabled {
+ f.rootVFS2.DecRef()
+ f.rootVFS2 = vfs.VirtualDentry{}
+ f.cwdVFS2.DecRef()
+ f.cwdVFS2 = vfs.VirtualDentry{}
+ } else {
+ f.root.DecRef()
+ f.root = nil
+ f.cwd.DecRef()
+ f.cwd = nil
+ }
}
// DecRef implements RefCounter.DecRef with destructor f.destroy.
@@ -93,12 +119,21 @@ func (f *FSContext) DecRef() {
func (f *FSContext) Fork() *FSContext {
f.mu.Lock()
defer f.mu.Unlock()
- f.cwd.IncRef()
- f.root.IncRef()
+
+ if VFS2Enabled {
+ f.cwdVFS2.IncRef()
+ f.rootVFS2.IncRef()
+ } else {
+ f.cwd.IncRef()
+ f.root.IncRef()
+ }
+
return &FSContext{
- cwd: f.cwd,
- root: f.root,
- umask: f.umask,
+ cwd: f.cwd,
+ root: f.root,
+ cwdVFS2: f.cwdVFS2,
+ rootVFS2: f.rootVFS2,
+ umask: f.umask,
}
}
@@ -109,12 +144,23 @@ func (f *FSContext) Fork() *FSContext {
func (f *FSContext) WorkingDirectory() *fs.Dirent {
f.mu.Lock()
defer f.mu.Unlock()
- if f.cwd != nil {
- f.cwd.IncRef()
- }
+
+ f.cwd.IncRef()
return f.cwd
}
+// WorkingDirectoryVFS2 returns the current working directory.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.cwdVFS2.IncRef()
+ return f.cwdVFS2
+}
+
// SetWorkingDirectory sets the current working directory.
// This will take an extra reference on the Dirent.
//
@@ -137,6 +183,20 @@ func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) {
old.DecRef()
}
+// SetWorkingDirectoryVFS2 sets the current working directory.
+// This will take an extra reference on the VirtualDentry.
+//
+// This is not a valid call after destroy.
+func (f *FSContext) SetWorkingDirectoryVFS2(d vfs.VirtualDentry) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ old := f.cwdVFS2
+ f.cwdVFS2 = d
+ d.IncRef()
+ old.DecRef()
+}
+
// RootDirectory returns the current filesystem root.
//
// This will return nil if called after destroy(), otherwise it will return a
@@ -150,6 +210,18 @@ func (f *FSContext) RootDirectory() *fs.Dirent {
return f.root
}
+// RootDirectoryVFS2 returns the current filesystem root.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.rootVFS2.IncRef()
+ return f.rootVFS2
+}
+
// SetRootDirectory sets the root directory.
// This will take an extra reference on the Dirent.
//
@@ -172,6 +244,28 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
old.DecRef()
}
+// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd.
+//
+// This is not a valid call after free.
+func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) {
+ if !vd.Ok() {
+ panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry")
+ }
+
+ f.mu.Lock()
+
+ if !f.rootVFS2.Ok() {
+ f.mu.Unlock()
+ panic(fmt.Sprintf("FSContext.SetRootDirectoryVFS2(%v)) called after destroy", vd))
+ }
+
+ old := f.rootVFS2
+ vd.IncRef()
+ f.rootVFS2 = vd
+ f.mu.Unlock()
+ old.DecRef()
+}
+
// Umask returns the current umask.
func (f *FSContext) Umask() uint {
f.mu.Lock()
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 3ee760ba2..c62fd6eb1 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -43,11 +43,13 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/timerfd"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -71,6 +73,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
+// VFS2Enabled is set to true when VFS2 is enabled. Added as a global for allow
+// easy access everywhere. To be removed once VFS2 becomes the default.
+var VFS2Enabled = false
+
// Kernel represents an emulated Linux kernel. It must be initialized by calling
// Init() or LoadFrom().
//
@@ -105,7 +111,7 @@ type Kernel struct {
timekeeper *Timekeeper
tasks *TaskSet
rootUserNamespace *auth.UserNamespace
- networkStack inet.Stack `state:"nosave"`
+ rootNetworkNamespace *inet.Namespace
applicationCores uint
useHostCores bool
extraAuxv []arch.AuxEntry
@@ -238,6 +244,9 @@ type Kernel struct {
// SpecialOpts contains special kernel options.
SpecialOpts
+
+ // VFS keeps the filesystem state used across the kernel.
+ vfs vfs.VirtualFilesystem
}
// InitKernelArgs holds arguments to Init.
@@ -251,8 +260,9 @@ type InitKernelArgs struct {
// RootUserNamespace is the root user namespace.
RootUserNamespace *auth.UserNamespace
- // NetworkStack is the TCP/IP network stack. NetworkStack may be nil.
- NetworkStack inet.Stack
+ // RootNetworkNamespace is the root network namespace. If nil, no networking
+ // will be available.
+ RootNetworkNamespace *inet.Namespace
// ApplicationCores is the number of logical CPUs visible to sandboxed
// applications. The set of logical CPU IDs is [0, ApplicationCores); thus
@@ -311,7 +321,10 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.rootUTSNamespace = args.RootUTSNamespace
k.rootIPCNamespace = args.RootIPCNamespace
k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace
- k.networkStack = args.NetworkStack
+ k.rootNetworkNamespace = args.RootNetworkNamespace
+ if k.rootNetworkNamespace == nil {
+ k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil)
+ }
k.applicationCores = args.ApplicationCores
if args.UseHostCores {
k.useHostCores = true
@@ -534,8 +547,6 @@ func (ts *TaskSet) unregisterEpollWaiters() {
func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
- k.networkStack = net
-
initAppCores := k.applicationCores
// Load the pre-saved CPUID FeatureSet.
@@ -566,6 +577,10 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
log.Infof("Kernel load stats: %s", &stats)
log.Infof("Kernel load took [%s].", time.Since(kernelStart))
+ // rootNetworkNamespace should be populated after loading the state file.
+ // Restore the root network stack.
+ k.rootNetworkNamespace.RestoreRootStack(net)
+
// Load the memory file's state.
memoryStart := time.Now()
if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
@@ -624,7 +639,7 @@ type CreateProcessArgs struct {
// File is a passed host FD pointing to a file to load as the init binary.
//
// This is checked if and only if Filename is "".
- File *fs.File
+ File fsbridge.File
// Argvv is a list of arguments.
Argv []string
@@ -673,6 +688,13 @@ type CreateProcessArgs struct {
// increment it).
MountNamespace *fs.MountNamespace
+ // MountNamespaceVFS2 optionally contains the mount namespace for this
+ // process. If nil, the init process's mount namespace is used.
+ //
+ // Anyone setting MountNamespaceVFS2 must donate a reference (i.e.
+ // increment it).
+ MountNamespaceVFS2 *vfs.MountNamespace
+
// ContainerID is the container that the process belongs to.
ContainerID string
}
@@ -711,11 +733,22 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} {
return ctx.args.Credentials
case fs.CtxRoot:
if ctx.args.MountNamespace != nil {
- // MountNamespace.Root() will take a reference on the root
- // dirent for us.
+ // MountNamespace.Root() will take a reference on the root dirent for us.
return ctx.args.MountNamespace.Root()
}
return nil
+ case vfs.CtxRoot:
+ if ctx.args.MountNamespaceVFS2 == nil {
+ return nil
+ }
+ // MountNamespaceVFS2.Root() takes a reference on the root dirent for us.
+ return ctx.args.MountNamespaceVFS2.Root()
+ case vfs.CtxMountNamespace:
+ if ctx.k.globalInit == nil {
+ return nil
+ }
+ // MountNamespaceVFS2 takes a reference for us.
+ return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
case ktime.CtxRealtimeClock:
@@ -757,34 +790,77 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
defer k.extMu.Unlock()
log.Infof("EXEC: %v", args.Argv)
- // Grab the mount namespace.
- mounts := args.MountNamespace
- if mounts == nil {
- mounts = k.GlobalInit().Leader().MountNamespace()
- mounts.IncRef()
- }
-
- tg := k.NewThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits)
ctx := args.NewContext(k)
- // Get the root directory from the MountNamespace.
- root := mounts.Root()
- // The call to newFSContext below will take a reference on root, so we
- // don't need to hold this one.
- defer root.DecRef()
-
- // Grab the working directory.
- remainingTraversals := uint(args.MaxSymlinkTraversals)
- wd := root // Default.
- if args.WorkingDirectory != "" {
- var err error
- wd, err = mounts.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
- if err != nil {
- return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ var (
+ opener fsbridge.Lookup
+ fsContext *FSContext
+ mntns *fs.MountNamespace
+ )
+
+ if VFS2Enabled {
+ mntnsVFS2 := args.MountNamespaceVFS2
+ if mntnsVFS2 == nil {
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ mntnsVFS2 = k.GlobalInit().Leader().MountNamespaceVFS2()
}
- defer wd.DecRef()
+ // Get the root directory from the MountNamespace.
+ root := args.MountNamespaceVFS2.Root()
+ // The call to newFSContext below will take a reference on root, so we
+ // don't need to hold this one.
+ defer root.DecRef()
+
+ // Grab the working directory.
+ wd := root // Default.
+ if args.WorkingDirectory != "" {
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: wd,
+ Path: fspath.Parse(args.WorkingDirectory),
+ FollowFinalSymlink: true,
+ }
+ var err error
+ wd, err = k.VFS().GetDentryAt(ctx, args.Credentials, &pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ }
+ defer wd.DecRef()
+ }
+ opener = fsbridge.NewVFSLookup(mntnsVFS2, root, wd)
+ fsContext = NewFSContextVFS2(root, wd, args.Umask)
+
+ } else {
+ mntns = args.MountNamespace
+ if mntns == nil {
+ mntns = k.GlobalInit().Leader().MountNamespace()
+ mntns.IncRef()
+ }
+ // Get the root directory from the MountNamespace.
+ root := mntns.Root()
+ // The call to newFSContext below will take a reference on root, so we
+ // don't need to hold this one.
+ defer root.DecRef()
+
+ // Grab the working directory.
+ remainingTraversals := args.MaxSymlinkTraversals
+ wd := root // Default.
+ if args.WorkingDirectory != "" {
+ var err error
+ wd, err = mntns.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ }
+ defer wd.DecRef()
+ }
+ opener = fsbridge.NewFSLookup(mntns, root, wd)
+ fsContext = newFSContext(root, wd, args.Umask)
}
+ tg := k.NewThreadGroup(mntns, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits)
+
// Check which file to start from.
switch {
case args.Filename != "":
@@ -805,11 +881,9 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
}
// Create a fresh task context.
- remainingTraversals = uint(args.MaxSymlinkTraversals)
+ remainingTraversals := args.MaxSymlinkTraversals
loadArgs := loader.LoadArgs{
- Mounts: mounts,
- Root: root,
- WorkingDirectory: wd,
+ Opener: opener,
RemainingTraversals: &remainingTraversals,
ResolveFinal: true,
Filename: args.Filename,
@@ -834,13 +908,15 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
Kernel: k,
ThreadGroup: tg,
TaskContext: tc,
- FSContext: newFSContext(root, wd, args.Umask),
+ FSContext: fsContext,
FDTable: args.FDTable,
Credentials: args.Credentials,
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores),
UTSNamespace: args.UTSNamespace,
IPCNamespace: args.IPCNamespace,
AbstractSocketNamespace: args.AbstractSocketNamespace,
+ MountNamespaceVFS2: args.MountNamespaceVFS2,
ContainerID: args.ContainerID,
}
t, err := k.tasks.NewTask(config)
@@ -1100,6 +1176,14 @@ func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) {
k.sendExternalSignal(info, context)
}
+// SendExternalSignalThreadGroup injects a signal into an specific ThreadGroup.
+// This function doesn't skip signals like SendExternalSignal does.
+func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.SignalInfo) error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ return tg.SendSignal(info)
+}
+
// SendContainerSignal sends the given signal to all processes inside the
// namespace that match the given container ID.
func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
@@ -1178,10 +1262,9 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace {
return k.rootAbstractSocketNamespace
}
-// NetworkStack returns the network stack. NetworkStack may return nil if no
-// network stack is available.
-func (k *Kernel) NetworkStack() inet.Stack {
- return k.networkStack
+// RootNetworkNamespace returns the root network namespace, always non-nil.
+func (k *Kernel) RootNetworkNamespace() *inet.Namespace {
+ return k.rootNetworkNamespace
}
// GlobalInit returns the thread group with ID 1 in the root PID namespace, or
@@ -1378,6 +1461,20 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
return ctx.k.globalInit.mounts.Root()
}
return nil
+ case vfs.CtxRoot:
+ if ctx.k.globalInit == nil {
+ return vfs.VirtualDentry{}
+ }
+ mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ defer mntns.DecRef()
+ // Root() takes a reference on the root dirent for us.
+ return mntns.Root()
+ case vfs.CtxMountNamespace:
+ if ctx.k.globalInit == nil {
+ return nil
+ }
+ // MountNamespaceVFS2() takes a reference for us.
+ return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
case ktime.CtxRealtimeClock:
@@ -1423,3 +1520,8 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
Registers: t.Arch().StateData().Proto(),
})
}
+
+// VFS returns the virtual filesystem for the kernel.
+func (k *Kernel) VFS() *vfs.VirtualFilesystem {
+ return &k.vfs
+}
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index efebfd872..ded95f532 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -303,26 +303,14 @@ func (t *Task) rseqAddrInterrupt() {
return
}
- buf = t.CopyScratchBuffer(linux.SizeOfRSeqCriticalSection)
- if _, err := t.CopyInBytes(critAddr, buf); err != nil {
+ var cs linux.RSeqCriticalSection
+ if err := cs.CopyIn(t, critAddr); err != nil {
t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
return
}
- // Manually marshal RSeqCriticalSection as this is in the hot path when
- // rseq is enabled. It must be as fast as possible.
- //
- // TODO(b/130243041): Replace with go_marshal.
- cs := linux.RSeqCriticalSection{
- Version: usermem.ByteOrder.Uint32(buf[0:4]),
- Flags: usermem.ByteOrder.Uint32(buf[4:8]),
- Start: usermem.ByteOrder.Uint64(buf[8:16]),
- PostCommitOffset: usermem.ByteOrder.Uint64(buf[16:24]),
- Abort: usermem.ByteOrder.Uint64(buf[24:32]),
- }
-
if cs.Version != 0 {
t.Debugf("Unknown version in %+v", cs)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 981e8c7fe..2cee2e6ed 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -424,6 +424,11 @@ type Task struct {
// abstractSockets is protected by mu.
abstractSockets *AbstractSocketNamespace
+ // mountNamespaceVFS2 is the task's mount namespace.
+ //
+ // It is protected by mu. It is owned by the task goroutine.
+ mountNamespaceVFS2 *vfs.MountNamespace
+
// parentDeathSignal is sent to this task's thread group when its parent exits.
//
// parentDeathSignal is protected by mu.
@@ -481,13 +486,10 @@ type Task struct {
numaPolicy int32
numaNodeMask uint64
- // If netns is true, the task is in a non-root network namespace. Network
- // namespaces aren't currently implemented in full; being in a network
- // namespace simply prevents the task from observing any network devices
- // (including loopback) or using abstract socket addresses (see unix(7)).
+ // netns is the task's network namespace. netns is never nil.
//
- // netns is protected by mu. netns is owned by the task goroutine.
- netns bool
+ // netns is protected by mu.
+ netns *inet.Namespace
// If rseqPreempted is true, before the next call to p.Switch(),
// interrupt rseq critical regions as defined by rseqAddr and
@@ -638,6 +640,11 @@ func (t *Task) Value(key interface{}) interface{} {
return int32(t.ThreadGroup().ID())
case fs.CtxRoot:
return t.fsContext.RootDirectory()
+ case vfs.CtxRoot:
+ return t.fsContext.RootDirectoryVFS2()
+ case vfs.CtxMountNamespace:
+ t.mountNamespaceVFS2.IncRef()
+ return t.mountNamespaceVFS2
case fs.CtxDirentCacheLimiter:
return t.k.DirentCacheLimiter
case inet.CtxStack:
@@ -701,6 +708,14 @@ func (t *Task) SyscallRestartBlock() SyscallRestartBlock {
// Preconditions: The caller must be running on the task goroutine, or t.mu
// must be locked.
func (t *Task) IsChrooted() bool {
+ if VFS2Enabled {
+ realRoot := t.mountNamespaceVFS2.Root()
+ defer realRoot.DecRef()
+ root := t.fsContext.RootDirectoryVFS2()
+ defer root.DecRef()
+ return root != realRoot
+ }
+
realRoot := t.tg.mounts.Root()
defer realRoot.DecRef()
root := t.fsContext.RootDirectory()
@@ -774,6 +789,15 @@ func (t *Task) NewFDFrom(fd int32, file *fs.File, flags FDFlags) (int32, error)
return fds[0], nil
}
+// NewFDFromVFS2 is a convenience wrapper for t.FDTable().NewFDVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) NewFDFromVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ return t.fdTable.NewFDVFS2(t, fd, file, flags)
+}
+
// NewFDAt is a convenience wrapper for t.FDTable().NewFDAt.
//
// This automatically passes the task as the context.
@@ -783,6 +807,15 @@ func (t *Task) NewFDAt(fd int32, file *fs.File, flags FDFlags) error {
return t.fdTable.NewFDAt(t, fd, file, flags)
}
+// NewFDAtVFS2 is a convenience wrapper for t.FDTable().NewFDAtVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDAtVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) error {
+ return t.fdTable.NewFDAtVFS2(t, fd, file, flags)
+}
+
// WithMuLocked executes f with t.mu locked.
func (t *Task) WithMuLocked(f func(*Task)) {
t.mu.Lock()
@@ -796,6 +829,15 @@ func (t *Task) MountNamespace() *fs.MountNamespace {
return t.tg.mounts
}
+// MountNamespaceVFS2 returns t's MountNamespace. A reference is taken on the
+// returned mount namespace.
+func (t *Task) MountNamespaceVFS2() *vfs.MountNamespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.mountNamespaceVFS2.IncRef()
+ return t.mountNamespaceVFS2
+}
+
// AbstractSockets returns t's AbstractSocketNamespace.
func (t *Task) AbstractSockets() *AbstractSocketNamespace {
return t.abstractSockets
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 53d4d211b..78866f280 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -54,8 +55,7 @@ type SharingOptions struct {
NewUserNamespace bool
// If NewNetworkNamespace is true, the task should have an independent
- // network namespace. (Note that network namespaces are not really
- // implemented; see comment on Task.netns for details.)
+ // network namespace.
NewNetworkNamespace bool
// If NewFiles is true, the task should use an independent file descriptor
@@ -199,6 +199,17 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
ipcns = NewIPCNamespace(userns)
}
+ netns := t.NetworkNamespace()
+ if opts.NewNetworkNamespace {
+ netns = inet.NewNamespace(netns)
+ }
+
+ // TODO(b/63601033): Implement CLONE_NEWNS.
+ mntnsVFS2 := t.mountNamespaceVFS2
+ if mntnsVFS2 != nil {
+ mntnsVFS2.IncRef()
+ }
+
tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace)
if err != nil {
return 0, nil, err
@@ -241,7 +252,9 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
rseqAddr := usermem.Addr(0)
rseqSignature := uint32(0)
if opts.NewThreadGroup {
- tg.mounts.IncRef()
+ if tg.mounts != nil {
+ tg.mounts.IncRef()
+ }
sh := t.tg.signalHandlers
if opts.NewSignalHandlers {
sh = sh.Fork()
@@ -260,11 +273,12 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
FDTable: fdTable,
Credentials: creds,
Niceness: t.Niceness(),
- NetworkNamespaced: t.netns,
+ NetworkNamespace: netns,
AllowedCPUMask: t.CPUMask(),
UTSNamespace: utsns,
IPCNamespace: ipcns,
AbstractSocketNamespace: t.abstractSockets,
+ MountNamespaceVFS2: mntnsVFS2,
RSeqAddr: rseqAddr,
RSeqSignature: rseqSignature,
ContainerID: t.ContainerID(),
@@ -274,9 +288,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else {
cfg.InheritParent = t
}
- if opts.NewNetworkNamespace {
- cfg.NetworkNamespaced = true
- }
nt, err := t.tg.pidns.owner.NewTask(cfg)
if err != nil {
if opts.NewThreadGroup {
@@ -473,7 +484,7 @@ func (t *Task) Unshare(opts *SharingOptions) error {
t.mu.Unlock()
return syserror.EPERM
}
- t.netns = true
+ t.netns = inet.NewNamespace(t.netns)
}
if opts.NewUTSNamespace {
if !haveCapSysAdmin {
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 2d6e7733c..2be982684 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -136,7 +136,7 @@ func (t *Task) Stack() *arch.Stack {
func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) {
// If File is not nil, we should load that instead of resolving Filename.
if args.File != nil {
- args.Filename = args.File.MappedName(ctx)
+ args.Filename = args.File.PathnameWithDeleted(ctx)
}
// Prepare a new user address space to load into.
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index 435761e5a..c4ade6e8e 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -269,6 +269,13 @@ func (*runExitMain) execute(t *Task) taskRunState {
t.fsContext.DecRef()
t.fdTable.DecRef()
+ t.mu.Lock()
+ if t.mountNamespaceVFS2 != nil {
+ t.mountNamespaceVFS2.DecRef()
+ t.mountNamespaceVFS2 = nil
+ }
+ t.mu.Unlock()
+
// If this is the last task to exit from the thread group, release the
// thread group's resources.
if lastExiter {
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index 41259210c..eeccaa197 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -32,21 +32,21 @@ const (
// Infof logs an formatted info message by calling log.Infof.
func (t *Task) Infof(fmt string, v ...interface{}) {
if log.IsLogging(log.Info) {
- log.Infof(t.logPrefix.Load().(string)+fmt, v...)
+ log.InfofAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Warningf logs a warning string by calling log.Warningf.
func (t *Task) Warningf(fmt string, v ...interface{}) {
if log.IsLogging(log.Warning) {
- log.Warningf(t.logPrefix.Load().(string)+fmt, v...)
+ log.WarningfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Debugf creates a debug string that includes the task ID.
func (t *Task) Debugf(fmt string, v ...interface{}) {
if log.IsLogging(log.Debug) {
- log.Debugf(t.logPrefix.Load().(string)+fmt, v...)
+ log.DebugfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
@@ -198,18 +198,11 @@ func (t *Task) traceExecEvent(tc *TaskContext) {
if !trace.IsEnabled() {
return
}
- d := tc.MemoryManager.Executable()
- if d == nil {
+ file := tc.MemoryManager.Executable()
+ if file == nil {
trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>")
return
}
- defer d.DecRef()
- root := t.fsContext.RootDirectory()
- if root == nil {
- trace.Logf(t.traceContext, traceCategory, "exec: << no root directory >>")
- return
- }
- defer root.DecRef()
- n, _ := d.FullName(root)
- trace.Logf(t.traceContext, traceCategory, "exec: %s", n)
+ defer file.DecRef()
+ trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t))
}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
index 172a31e1d..f7711232c 100644
--- a/pkg/sentry/kernel/task_net.go
+++ b/pkg/sentry/kernel/task_net.go
@@ -22,14 +22,23 @@ import (
func (t *Task) IsNetworkNamespaced() bool {
t.mu.Lock()
defer t.mu.Unlock()
- return t.netns
+ return !t.netns.IsRoot()
}
// NetworkContext returns the network stack used by the task. NetworkContext
// may return nil if no network stack is available.
+//
+// TODO(gvisor.dev/issue/1833): Migrate callers of this method to
+// NetworkNamespace().
func (t *Task) NetworkContext() inet.Stack {
- if t.IsNetworkNamespaced() {
- return nil
- }
- return t.k.networkStack
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns.Stack()
+}
+
+// NetworkNamespace returns the network namespace observed by the task.
+func (t *Task) NetworkNamespace() *inet.Namespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns
}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index de838beef..a5035bb7f 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -17,10 +17,12 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -64,9 +66,8 @@ type TaskConfig struct {
// Niceness is the niceness of the new task.
Niceness int
- // If NetworkNamespaced is true, the new task should observe a non-root
- // network namespace.
- NetworkNamespaced bool
+ // NetworkNamespace is the network namespace to be used for the new task.
+ NetworkNamespace *inet.Namespace
// AllowedCPUMask contains the cpus that this task can run on.
AllowedCPUMask sched.CPUSet
@@ -80,6 +81,9 @@ type TaskConfig struct {
// AbstractSocketNamespace is the AbstractSocketNamespace of the new task.
AbstractSocketNamespace *AbstractSocketNamespace
+ // MountNamespaceVFS2 is the MountNamespace of the new task.
+ MountNamespaceVFS2 *vfs.MountNamespace
+
// RSeqAddr is a pointer to the the userspace linux.RSeq structure.
RSeqAddr usermem.Addr
@@ -116,28 +120,29 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
parent: cfg.Parent,
children: make(map[*Task]struct{}),
},
- runState: (*runApp)(nil),
- interruptChan: make(chan struct{}, 1),
- signalMask: cfg.SignalMask,
- signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
- tc: *tc,
- fsContext: cfg.FSContext,
- fdTable: cfg.FDTable,
- p: cfg.Kernel.Platform.NewContext(),
- k: cfg.Kernel,
- ptraceTracees: make(map[*Task]struct{}),
- allowedCPUMask: cfg.AllowedCPUMask.Copy(),
- ioUsage: &usage.IO{},
- niceness: cfg.Niceness,
- netns: cfg.NetworkNamespaced,
- utsns: cfg.UTSNamespace,
- ipcns: cfg.IPCNamespace,
- abstractSockets: cfg.AbstractSocketNamespace,
- rseqCPU: -1,
- rseqAddr: cfg.RSeqAddr,
- rseqSignature: cfg.RSeqSignature,
- futexWaiter: futex.NewWaiter(),
- containerID: cfg.ContainerID,
+ runState: (*runApp)(nil),
+ interruptChan: make(chan struct{}, 1),
+ signalMask: cfg.SignalMask,
+ signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
+ tc: *tc,
+ fsContext: cfg.FSContext,
+ fdTable: cfg.FDTable,
+ p: cfg.Kernel.Platform.NewContext(),
+ k: cfg.Kernel,
+ ptraceTracees: make(map[*Task]struct{}),
+ allowedCPUMask: cfg.AllowedCPUMask.Copy(),
+ ioUsage: &usage.IO{},
+ niceness: cfg.Niceness,
+ netns: cfg.NetworkNamespace,
+ utsns: cfg.UTSNamespace,
+ ipcns: cfg.IPCNamespace,
+ abstractSockets: cfg.AbstractSocketNamespace,
+ mountNamespaceVFS2: cfg.MountNamespaceVFS2,
+ rseqCPU: -1,
+ rseqAddr: cfg.RSeqAddr,
+ rseqSignature: cfg.RSeqSignature,
+ futexWaiter: futex.NewWaiter(),
+ containerID: cfg.ContainerID,
}
t.creds.Store(cfg.Credentials)
t.endStopCond.L = &t.tg.signalHandlers.mu
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 768e958d2..268f62e9d 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -256,7 +256,7 @@ type ThreadGroup struct {
tty *TTY
}
-// NewThreadGroup returns a new, empty thread group in PID namespace ns. The
+// NewThreadGroup returns a new, empty thread group in PID namespace pidns. The
// thread group leader will send its parent terminationSignal when it exits.
// The new thread group isn't visible to the system until a task has been
// created inside of it by a successful call to TaskSet.NewTask.
@@ -317,7 +317,9 @@ func (tg *ThreadGroup) release() {
for _, it := range its {
it.DestroyTimer()
}
- tg.mounts.DecRef()
+ if tg.mounts != nil {
+ tg.mounts.DecRef()
+ }
}
// forEachChildThreadGroupLocked indicates over all child ThreadGroups.
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index 23790378a..c6aa65f28 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -33,6 +33,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
@@ -40,6 +41,7 @@ go_library(
"//pkg/sentry/pgalloc",
"//pkg/sentry/uniqueid",
"//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index 122ed05c2..616fafa2c 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -27,7 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -97,11 +97,11 @@ type elfInfo struct {
// accepts from the ELF, and it doesn't parse unnecessary parts of the file.
//
// ctx may be nil if f does not need it.
-func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
+func parseHeader(ctx context.Context, f fsbridge.File) (elfInfo, error) {
// Check ident first; it will tell us the endianness of the rest of the
// structs.
var ident [elf.EI_NIDENT]byte
- _, err := readFull(ctx, f, usermem.BytesIOSequence(ident[:]), 0)
+ _, err := f.ReadFull(ctx, usermem.BytesIOSequence(ident[:]), 0)
if err != nil {
log.Infof("Error reading ELF ident: %v", err)
// The entire ident array always exists.
@@ -137,7 +137,7 @@ func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
var hdr elf.Header64
hdrBuf := make([]byte, header64Size)
- _, err = readFull(ctx, f, usermem.BytesIOSequence(hdrBuf), 0)
+ _, err = f.ReadFull(ctx, usermem.BytesIOSequence(hdrBuf), 0)
if err != nil {
log.Infof("Error reading ELF header: %v", err)
// The entire header always exists.
@@ -187,7 +187,7 @@ func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
}
phdrBuf := make([]byte, totalPhdrSize)
- _, err = readFull(ctx, f, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff))
+ _, err = f.ReadFull(ctx, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff))
if err != nil {
log.Infof("Error reading ELF phdrs: %v", err)
// If phdrs were specified, they should all exist.
@@ -227,7 +227,7 @@ func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
// mapSegment maps a phdr into the Task. offset is the offset to apply to
// phdr.Vaddr.
-func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.ProgHeader, offset usermem.Addr) error {
+func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr *elf.ProgHeader, offset usermem.Addr) error {
// We must make a page-aligned mapping.
adjust := usermem.Addr(phdr.Vaddr).PageOffset()
@@ -395,7 +395,7 @@ type loadedELF struct {
//
// Preconditions:
// * f is an ELF file
-func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) {
+func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) {
first := true
var start, end usermem.Addr
var interpreter string
@@ -431,7 +431,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
}
path := make([]byte, phdr.Filesz)
- _, err := readFull(ctx, f, usermem.BytesIOSequence(path), int64(phdr.Off))
+ _, err := f.ReadFull(ctx, usermem.BytesIOSequence(path), int64(phdr.Off))
if err != nil {
// If an interpreter was specified, it should exist.
ctx.Infof("Error reading PT_INTERP path: %v", err)
@@ -564,7 +564,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
// Preconditions:
// * f is an ELF file
// * f is the first ELF loaded into m
-func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) {
+func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f fsbridge.File) (loadedELF, arch.Context, error) {
info, err := parseHeader(ctx, f)
if err != nil {
ctx.Infof("Failed to parse initial ELF: %v", err)
@@ -602,7 +602,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS
//
// Preconditions:
// * f is an ELF file
-func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, initial loadedELF) (loadedELF, error) {
+func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, initial loadedELF) (loadedELF, error) {
info, err := parseHeader(ctx, f)
if err != nil {
if err == syserror.ENOEXEC {
@@ -649,16 +649,14 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error
// Refresh the traversal limit.
*args.RemainingTraversals = linux.MaxSymlinkTraversals
args.Filename = bin.interpreter
- d, i, err := openPath(ctx, args)
+ intFile, err := openPath(ctx, args)
if err != nil {
ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err)
return loadedELF{}, nil, err
}
- defer i.DecRef()
- // We don't need the Dirent.
- d.DecRef()
+ defer intFile.DecRef()
- interp, err = loadInterpreterELF(ctx, args.MemoryManager, i, bin)
+ interp, err = loadInterpreterELF(ctx, args.MemoryManager, intFile, bin)
if err != nil {
ctx.Infof("Error loading interpreter: %v", err)
return loadedELF{}, nil, err
diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go
index 098a45d36..3886b4d33 100644
--- a/pkg/sentry/loader/interpreter.go
+++ b/pkg/sentry/loader/interpreter.go
@@ -19,7 +19,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -37,9 +37,9 @@ const (
)
// parseInterpreterScript returns the interpreter path and argv.
-func parseInterpreterScript(ctx context.Context, filename string, f *fs.File, argv []string) (newpath string, newargv []string, err error) {
+func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.File, argv []string) (newpath string, newargv []string, err error) {
line := make([]byte, interpMaxLineLength)
- n, err := readFull(ctx, f, usermem.BytesIOSequence(line), 0)
+ n, err := f.ReadFull(ctx, usermem.BytesIOSequence(line), 0)
// Short read is OK.
if err != nil && err != io.ErrUnexpectedEOF {
if err == io.EOF {
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 9a613d6b7..d6675b8f0 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -20,7 +20,6 @@ import (
"fmt"
"io"
"path"
- "strings"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -29,8 +28,10 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -41,16 +42,6 @@ type LoadArgs struct {
// MemoryManager is the memory manager to load the executable into.
MemoryManager *mm.MemoryManager
- // Mounts is the mount namespace in which to look up Filename.
- Mounts *fs.MountNamespace
-
- // Root is the root directory under which to look up Filename.
- Root *fs.Dirent
-
- // WorkingDirectory is the working directory under which to look up
- // Filename.
- WorkingDirectory *fs.Dirent
-
// RemainingTraversals is the maximum number of symlinks to follow to
// resolve Filename. This counter is passed by reference to keep it
// updated throughout the call stack.
@@ -65,7 +56,12 @@ type LoadArgs struct {
// File is an open fs.File object of the executable. If File is not
// nil, then File will be loaded and Filename will be ignored.
- File *fs.File
+ //
+ // The caller is responsible for checking that the user can execute this file.
+ File fsbridge.File
+
+ // Opener is used to open the executable file when 'File' is nil.
+ Opener fsbridge.Lookup
// CloseOnExec indicates that the executable (or one of its parent
// directories) was opened with O_CLOEXEC. If the executable is an
@@ -106,103 +102,32 @@ func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset in
// installed in the Task FDTable. The caller takes ownership of both.
//
// args.Filename must be a readable, executable, regular file.
-func openPath(ctx context.Context, args LoadArgs) (*fs.Dirent, *fs.File, error) {
+func openPath(ctx context.Context, args LoadArgs) (fsbridge.File, error) {
if args.Filename == "" {
ctx.Infof("cannot open empty name")
- return nil, nil, syserror.ENOENT
- }
-
- var d *fs.Dirent
- var err error
- if args.ResolveFinal {
- d, err = args.Mounts.FindInode(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals)
- } else {
- d, err = args.Mounts.FindLink(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals)
- }
- if err != nil {
- return nil, nil, err
- }
- // Defer a DecRef for the sake of failure cases.
- defer d.DecRef()
-
- if !args.ResolveFinal && fs.IsSymlink(d.Inode.StableAttr) {
- return nil, nil, syserror.ELOOP
- }
-
- if err := checkPermission(ctx, d); err != nil {
- return nil, nil, err
- }
-
- // If they claim it's a directory, then make sure.
- //
- // N.B. we reject directories below, but we must first reject
- // non-directories passed as directories.
- if strings.HasSuffix(args.Filename, "/") && !fs.IsDir(d.Inode.StableAttr) {
- return nil, nil, syserror.ENOTDIR
- }
-
- if err := checkIsRegularFile(ctx, d, args.Filename); err != nil {
- return nil, nil, err
- }
-
- f, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
- if err != nil {
- return nil, nil, err
- }
- // Defer a DecRef for the sake of failure cases.
- defer f.DecRef()
-
- if err := checkPread(ctx, f, args.Filename); err != nil {
- return nil, nil, err
- }
-
- d.IncRef()
- f.IncRef()
- return d, f, err
-}
-
-// checkFile performs checks on a file to be executed.
-func checkFile(ctx context.Context, f *fs.File, filename string) error {
- if err := checkPermission(ctx, f.Dirent); err != nil {
- return err
- }
-
- if err := checkIsRegularFile(ctx, f.Dirent, filename); err != nil {
- return err
+ return nil, syserror.ENOENT
}
- return checkPread(ctx, f, filename)
-}
-
-// checkPermission checks whether the file is readable and executable.
-func checkPermission(ctx context.Context, d *fs.Dirent) error {
- perms := fs.PermMask{
- // TODO(gvisor.dev/issue/160): Linux requires only execute
- // permission, not read. However, our backing filesystems may
- // prevent us from reading the file without read permission.
- //
- // Additionally, a task with a non-readable executable has
- // additional constraints on access via ptrace and procfs.
- Read: true,
- Execute: true,
+ // TODO(gvisor.dev/issue/160): Linux requires only execute permission,
+ // not read. However, our backing filesystems may prevent us from reading
+ // the file without read permission. Additionally, a task with a
+ // non-readable executable has additional constraints on access via
+ // ptrace and procfs.
+ opts := vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
}
- return d.Inode.CheckPermission(ctx, perms)
+ return args.Opener.OpenPath(ctx, args.Filename, opts, args.RemainingTraversals, args.ResolveFinal)
}
// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc.
-func checkIsRegularFile(ctx context.Context, d *fs.Dirent, filename string) error {
- attr := d.Inode.StableAttr
- if !fs.IsRegular(attr) {
- ctx.Infof("%s is not regular: %v", filename, attr)
- return syserror.EACCES
+func checkIsRegularFile(ctx context.Context, file fsbridge.File, filename string) error {
+ t, err := file.Type(ctx)
+ if err != nil {
+ return err
}
- return nil
-}
-
-// checkPread checks whether we can read the file at arbitrary offsets.
-func checkPread(ctx context.Context, f *fs.File, filename string) error {
- if !f.Flags().Pread {
- ctx.Infof("%s cannot be read at an offset: %+v", filename, f.Flags())
+ if t != linux.ModeRegular {
+ ctx.Infof("%q is not a regular file: %v", filename, t)
return syserror.EACCES
}
return nil
@@ -224,8 +149,10 @@ const (
maxLoaderAttempts = 6
)
-// loadExecutable loads an executable that is pointed to by args.File. If nil,
-// the path args.Filename is resolved and loaded. If the executable is an
+// loadExecutable loads an executable that is pointed to by args.File. The
+// caller is responsible for checking that the user can execute this file.
+// If nil, the path args.Filename is resolved and loaded (check that the user
+// can execute this file is done here in this case). If the executable is an
// interpreter script rather than an ELF, the binary of the corresponding
// interpreter will be loaded.
//
@@ -234,37 +161,27 @@ const (
// * arch.Context matching the binary arch
// * fs.Dirent of the binary file
// * Possibly updated args.Argv
-func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, *fs.Dirent, []string, error) {
+func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, fsbridge.File, []string, error) {
for i := 0; i < maxLoaderAttempts; i++ {
- var (
- d *fs.Dirent
- err error
- )
if args.File == nil {
- d, args.File, err = openPath(ctx, args)
- // We will return d in the successful case, but defer a DecRef for the
- // sake of intermediate loops and failure cases.
- if d != nil {
- defer d.DecRef()
- }
- if args.File != nil {
- defer args.File.DecRef()
+ var err error
+ args.File, err = openPath(ctx, args)
+ if err != nil {
+ ctx.Infof("Error opening %s: %v", args.Filename, err)
+ return loadedELF{}, nil, nil, nil, err
}
+ // Ensure file is release in case the code loops or errors out.
+ defer args.File.DecRef()
} else {
- d = args.File.Dirent
- d.IncRef()
- defer d.DecRef()
- err = checkFile(ctx, args.File, args.Filename)
- }
- if err != nil {
- ctx.Infof("Error opening %s: %v", args.Filename, err)
- return loadedELF{}, nil, nil, nil, err
+ if err := checkIsRegularFile(ctx, args.File, args.Filename); err != nil {
+ return loadedELF{}, nil, nil, nil, err
+ }
}
// Check the header. Is this an ELF or interpreter script?
var hdr [4]uint8
// N.B. We assume that reading from a regular file cannot block.
- _, err = readFull(ctx, args.File, usermem.BytesIOSequence(hdr[:]), 0)
+ _, err := args.File.ReadFull(ctx, usermem.BytesIOSequence(hdr[:]), 0)
// Allow unexpected EOF, as a valid executable could be only three bytes
// (e.g., #!a).
if err != nil && err != io.ErrUnexpectedEOF {
@@ -281,9 +198,10 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
ctx.Infof("Error loading ELF: %v", err)
return loadedELF{}, nil, nil, nil, err
}
- // An ELF is always terminal. Hold on to d.
- d.IncRef()
- return loaded, ac, d, args.Argv, err
+ // An ELF is always terminal. Hold on to file.
+ args.File.IncRef()
+ return loaded, ac, args.File, args.Argv, err
+
case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)):
if args.CloseOnExec {
return loadedELF{}, nil, nil, nil, syserror.ENOENT
@@ -295,6 +213,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
}
// Refresh the traversal limit for the interpreter.
*args.RemainingTraversals = linux.MaxSymlinkTraversals
+
default:
ctx.Infof("Unknown magic: %v", hdr)
return loadedELF{}, nil, nil, nil, syserror.ENOEXEC
@@ -317,11 +236,11 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
// * Load is called on the Task goroutine.
func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) {
// Load the executable itself.
- loaded, ac, d, newArgv, err := loadExecutable(ctx, args)
+ loaded, ac, file, newArgv, err := loadExecutable(ctx, args)
if err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
}
- defer d.DecRef()
+ defer file.DecRef()
// Load the VDSO.
vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded)
@@ -390,7 +309,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
m.SetEnvvStart(sl.EnvvStart)
m.SetEnvvEnd(sl.EnvvEnd)
m.SetAuxv(auxv)
- m.SetExecutable(d)
+ m.SetExecutable(file)
ac.SetIP(uintptr(loaded.entry))
ac.SetStack(uintptr(stack.Bottom))
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 52f446ed7..161b28c2c 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -69,6 +70,8 @@ type byteReader struct {
var _ fs.FileOperations = (*byteReader)(nil)
// newByteReaderFile creates a fake file to read data from.
+//
+// TODO(gvisor.dev/issue/1623): Convert to VFS2.
func newByteReaderFile(ctx context.Context, data []byte) *fs.File {
// Create a fake inode.
inode := fs.NewInode(
@@ -123,7 +126,7 @@ func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSeq
// * PT_LOAD segments don't extend beyond the end of the file.
//
// ctx may be nil if f does not need it.
-func validateVDSO(ctx context.Context, f *fs.File, size uint64) (elfInfo, error) {
+func validateVDSO(ctx context.Context, f fsbridge.File, size uint64) (elfInfo, error) {
info, err := parseHeader(ctx, f)
if err != nil {
log.Infof("Unable to parse VDSO header: %v", err)
@@ -221,7 +224,7 @@ type VDSO struct {
// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
// param page for updating by the kernel.
func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
- vdsoFile := newByteReaderFile(ctx, vdsoBin)
+ vdsoFile := fsbridge.NewFSFile(newByteReaderFile(ctx, vdsoBin))
// First make sure the VDSO is valid. vdsoFile does not use ctx, so a
// nil context can be passed.
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index e5729ced5..73591dab7 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -105,8 +105,8 @@ go_library(
"//pkg/safecopy",
"//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/fs",
"//pkg/sentry/fs/proc/seqfile",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/futex",
"//pkg/sentry/kernel/shm",
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
index e58a63deb..94d39af60 100644
--- a/pkg/sentry/mm/address_space.go
+++ b/pkg/sentry/mm/address_space.go
@@ -18,7 +18,6 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -42,8 +41,15 @@ func (mm *MemoryManager) AddressSpace() platform.AddressSpace {
func (mm *MemoryManager) Activate() error {
// Fast path: the MemoryManager already has an active
// platform.AddressSpace, and we just need to indicate that we need it too.
- if atomicbitops.IncUnlessZeroInt32(&mm.active) {
- return nil
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 0 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active+1) {
+ return nil
+ }
}
for {
@@ -118,8 +124,15 @@ func (mm *MemoryManager) Activate() error {
func (mm *MemoryManager) Deactivate() {
// Fast path: this is not the last goroutine to deactivate the
// MemoryManager.
- if atomicbitops.DecUnlessOneInt32(&mm.active) {
- return
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 1 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active-1) {
+ return
+ }
}
mm.activeMu.Lock()
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 47b8fbf43..3c263ebaa 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -18,7 +18,6 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/limits"
@@ -229,7 +228,15 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
// IncUsers increments mm's user count and returns true. If the user count is
// already 0, IncUsers does nothing and returns false.
func (mm *MemoryManager) IncUsers() bool {
- return atomicbitops.IncUnlessZeroInt32(&mm.users)
+ for {
+ users := atomic.LoadInt32(&mm.users)
+ if users == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt32(&mm.users, users, users+1) {
+ return true
+ }
+ }
}
// DecUsers decrements mm's user count. If the user count reaches 0, all
diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go
index f550acae0..6a49334f4 100644
--- a/pkg/sentry/mm/metadata.go
+++ b/pkg/sentry/mm/metadata.go
@@ -16,7 +16,7 @@ package mm
import (
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -132,7 +132,7 @@ func (mm *MemoryManager) SetAuxv(auxv arch.Auxv) {
//
// An additional reference will be taken in the case of a non-nil executable,
// which must be released by the caller.
-func (mm *MemoryManager) Executable() *fs.Dirent {
+func (mm *MemoryManager) Executable() fsbridge.File {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
@@ -147,15 +147,15 @@ func (mm *MemoryManager) Executable() *fs.Dirent {
// SetExecutable sets the executable.
//
// This takes a reference on d.
-func (mm *MemoryManager) SetExecutable(d *fs.Dirent) {
+func (mm *MemoryManager) SetExecutable(file fsbridge.File) {
mm.metadataMu.Lock()
// Grab a new reference.
- d.IncRef()
+ file.IncRef()
// Set the executable.
orig := mm.executable
- mm.executable = d
+ mm.executable = file
mm.metadataMu.Unlock()
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 09e582dd3..637383c7a 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -37,7 +37,7 @@ package mm
import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
@@ -215,7 +215,7 @@ type MemoryManager struct {
// is not nil, it holds a reference on the Dirent.
//
// executable is protected by metadataMu.
- executable *fs.Dirent
+ executable fsbridge.File
// dumpability describes if and how this MemoryManager may be dumped to
// userspace.
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 8076c7529..f1afc74dc 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -329,10 +329,12 @@ func (m *machine) Destroy() {
}
// Get gets an available vCPU.
+//
+// This will return with the OS thread locked.
func (m *machine) Get() *vCPU {
+ m.mu.RLock()
runtime.LockOSThread()
tid := procid.Current()
- m.mu.RLock()
// Check for an exact match.
if c := m.vCPUs[tid]; c != nil {
@@ -343,8 +345,22 @@ func (m *machine) Get() *vCPU {
// The happy path failed. We now proceed to acquire an exclusive lock
// (because the vCPU map may change), and scan all available vCPUs.
+ // In this case, we first unlock the OS thread. Otherwise, if mu is
+ // not available, the current system thread will be parked and a new
+ // system thread spawned. We avoid this situation by simply refreshing
+ // tid after relocking the system thread.
m.mu.RUnlock()
+ runtime.UnlockOSThread()
m.mu.Lock()
+ runtime.LockOSThread()
+ tid = procid.Current()
+
+ // Recheck for an exact match.
+ if c := m.vCPUs[tid]; c != nil {
+ c.lock()
+ m.mu.Unlock()
+ return c
+ }
for {
// Scan for an available vCPU.
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 971eed7fa..4f2406ce3 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -7,7 +7,7 @@ go_template(
name = "generic_walker",
srcs = select_arch(
amd64 = ["walker_amd64.go"],
- arm64 = ["walker_amd64.go"],
+ arm64 = ["walker_arm64.go"],
),
opt_types = [
"Visitor",
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 79e16d6e8..4d42d29cb 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
"//pkg/syserror",
+ "//pkg/tcpip",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 00265f15b..8834a1e1a 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -189,7 +190,7 @@ func putUint32(buf []byte, n uint32) []byte {
// putCmsg writes a control message header and as much data as will fit into
// the unused capacity of a buffer.
func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
- space := AlignDown(cap(buf)-len(buf), 4)
+ space := binary.AlignDown(cap(buf)-len(buf), 4)
// We can't write to space that doesn't exist, so if we are going to align
// the available space, we must align down.
@@ -282,19 +283,9 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int
return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
}
-// AlignUp rounds a length up to an alignment. align must be a power of 2.
-func AlignUp(length int, align uint) int {
- return (length + int(align) - 1) & ^(int(align) - 1)
-}
-
-// AlignDown rounds a down to an alignment. align must be a power of 2.
-func AlignDown(length int, align uint) int {
- return length & ^(int(align) - 1)
-}
-
// alignSlice extends a slice's length (up to the capacity) to align it.
func alignSlice(buf []byte, align uint) []byte {
- aligned := AlignUp(len(buf), align)
+ aligned := binary.AlignUp(len(buf), align)
if aligned > cap(buf) {
// Linux allows unaligned data if there isn't room for alignment.
// Since there isn't room for alignment, there isn't room for any
@@ -338,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
}
// PackTClass packs an IPV6_TCLASS socket control message.
-func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
+func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IPV6,
@@ -348,6 +339,22 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
)
}
+// PackIPPacketInfo packs an IP_PKTINFO socket control message.
+func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_PKTINFO,
+ t.Arch().Width(),
+ p,
+ )
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
@@ -372,12 +379,16 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackTClass(t, cmsgs.IP.TClass, buf)
}
+ if cmsgs.IP.HasIPPacketInfo {
+ buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ }
+
return buf
}
// cmsgSpace is equivalent to CMSG_SPACE in Linux.
func cmsgSpace(t *kernel.Task, dataLen int) int {
- return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width())
+ return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
}
// CmsgsSpace returns the number of bytes needed to fit the control messages
@@ -404,6 +415,16 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
return space
}
+// NewIPPacketInfo returns the IPPacketInfo struct.
+func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
+ var p tcpip.IPPacketInfo
+ p.NIC = tcpip.NICID(packetInfo.NIC)
+ copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
+ copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+
+ return p
+}
+
// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
@@ -437,7 +458,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_SOCKET:
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
if len(fds)+numRights > linux.SCM_MAX_FD {
@@ -448,7 +469,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
}
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
case linux.SCM_CREDENTIALS:
if length < linux.SizeOfControlMessageCredentials {
@@ -462,7 +483,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
return socket.ControlMessages{}, err
}
cmsgs.Unix.Credentials = scmCreds
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
default:
// Unknown message type.
@@ -476,7 +497,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
}
cmsgs.IP.HasTOS = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_PKTINFO:
+ if length < linux.SizeOfControlMessageIPPacketInfo {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ cmsgs.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+
+ cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ i += binary.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
@@ -489,7 +522,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
}
cmsgs.IP.HasTClass = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index de76388ac..22f78d2e2 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -289,7 +289,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
@@ -336,6 +336,8 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
switch name {
case linux.IP_TOS, linux.IP_RECVTOS:
optlen = sizeofInt32
+ case linux.IP_PKTINFO:
+ optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
@@ -473,7 +475,14 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
case syscall.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+
+ case syscall.IP_PKTINFO:
+ controlMessages.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
}
+
case syscall.SOL_IPV6:
switch unixCmsg.Header.Type {
case syscall.IPV6_TCLASS:
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index fa2a2cb66..7cd2ce55b 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -5,7 +5,11 @@ package(licenses = ["notice"])
go_library(
name = "netfilter",
srcs = [
+ "extensions.go",
"netfilter.go",
+ "targets.go",
+ "tcp_matcher.go",
+ "udp_matcher.go",
],
# This target depends on netstack and should only be used by epsocket,
# which is allowed to depend on netstack.
@@ -17,6 +21,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/syserr",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
new file mode 100644
index 000000000..b4b244abf
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// TODO(gvisor.dev/issue/170): The following per-matcher params should be
+// supported:
+// - Table name
+// - Match size
+// - User size
+// - Hooks
+// - Proto
+// - Family
+
+// matchMaker knows how to (un)marshal the matcher named name().
+type matchMaker interface {
+ // name is the matcher name as stored in the xt_entry_match struct.
+ name() string
+
+ // marshal converts from an iptables.Matcher to an ABI struct.
+ marshal(matcher iptables.Matcher) []byte
+
+ // unmarshal converts from the ABI matcher struct to an
+ // iptables.Matcher.
+ unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error)
+}
+
+// matchMakers maps the name of supported matchers to the matchMaker that
+// marshals and unmarshals it. It is immutable after package initialization.
+var matchMakers = map[string]matchMaker{}
+
+// registermatchMaker should be called by match extensions to register them
+// with the netfilter package.
+func registerMatchMaker(mm matchMaker) {
+ if _, ok := matchMakers[mm.name()]; ok {
+ panic(fmt.Sprintf("Multiple matches registered with name %q.", mm.name()))
+ }
+ matchMakers[mm.name()] = mm
+}
+
+func marshalMatcher(matcher iptables.Matcher) []byte {
+ matchMaker, ok := matchMakers[matcher.Name()]
+ if !ok {
+ panic(fmt.Sprintf("Unknown matcher of type %T.", matcher))
+ }
+ return matchMaker.marshal(matcher)
+}
+
+// marshalEntryMatch creates a marshalled XTEntryMatch with the given name and
+// data appended at the end.
+func marshalEntryMatch(name string, data []byte) []byte {
+ nflog("marshaling matcher %q", name)
+
+ // We have to pad this struct size to a multiple of 8 bytes.
+ size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
+ matcher := linux.KernelXTEntryMatch{
+ XTEntryMatch: linux.XTEntryMatch{
+ MatchSize: uint16(size),
+ },
+ Data: data,
+ }
+ copy(matcher.Name[:], name)
+
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, matcher)
+ return append(buf, make([]byte, size-len(buf))...)
+}
+
+func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, buf []byte) (iptables.Matcher, error) {
+ matchMaker, ok := matchMakers[match.Name.String()]
+ if !ok {
+ return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
+ }
+ return matchMaker.unmarshal(buf, filter)
+}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 445d08d11..2ec11f6ac 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -17,6 +17,7 @@
package netfilter
import (
+ "errors"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -34,10 +35,6 @@ import (
// shouldn't be reached - an error has occurred if we fall through to one.
const errorTargetName = "ERROR"
-const (
- matcherNameUDP = "udp"
-)
-
// Metadata is used to verify that we are correctly serializing and
// deserializing iptables into structs consumable by the iptables tool. We save
// a metadata struct when the tables are written, and when they are read out we
@@ -53,7 +50,9 @@ type metadata struct {
// nflog logs messages related to the writing and reading of iptables.
func nflog(format string, args ...interface{}) {
- log.Infof("netfilter: "+format, args...)
+ if log.IsLogging(log.Debug) {
+ log.Debugf("netfilter: "+format, args...)
+ }
}
// GetInfo returns information about iptables.
@@ -67,7 +66,8 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT
// Find the appropriate table.
table, err := findTable(stack, info.Name)
if err != nil {
- return linux.IPTGetinfo{}, err
+ nflog("%v", err)
+ return linux.IPTGetinfo{}, syserr.ErrInvalidArgument
}
// Get the hooks that apply to this table.
@@ -94,39 +94,40 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
// Read in the struct and table name.
var userEntries linux.IPTGetEntries
if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
- log.Warningf("netfilter: couldn't copy in entries %q", userEntries.Name)
+ nflog("couldn't copy in entries %q", userEntries.Name)
return linux.KernelIPTGetEntries{}, syserr.FromError(err)
}
// Find the appropriate table.
table, err := findTable(stack, userEntries.Name)
if err != nil {
- log.Warningf("netfilter: couldn't find table %q", userEntries.Name)
- return linux.KernelIPTGetEntries{}, err
+ nflog("%v", err)
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
// Convert netstack's iptables rules to something that the iptables
// tool can understand.
entries, meta, err := convertNetstackToBinary(userEntries.Name.String(), table)
if err != nil {
- return linux.KernelIPTGetEntries{}, err
+ nflog("couldn't read entries: %v", err)
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
if meta != table.Metadata().(metadata) {
panic(fmt.Sprintf("Table %q metadata changed between writing and reading. Was saved as %+v, but is now %+v", userEntries.Name.String(), table.Metadata().(metadata), meta))
}
if binary.Size(entries) > uintptr(outLen) {
- log.Warningf("Insufficient GetEntries output size: %d", uintptr(outLen))
+ nflog("insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
return entries, nil
}
-func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, *syserr.Error) {
+func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, error) {
ipt := stack.IPTables()
table, ok := ipt.Tables[tablename.String()]
if !ok {
- return iptables.Table{}, syserr.ErrInvalidArgument
+ return iptables.Table{}, fmt.Errorf("couldn't find table %q", tablename)
}
return table, nil
}
@@ -154,15 +155,14 @@ func FillDefaultIPTables(stack *stack.Stack) {
// format expected by the iptables tool. Linux stores each table as a binary
// blob that can only be traversed by parsing a bit, reading some offsets,
// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
+func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, error) {
// Return values.
var entries linux.KernelIPTGetEntries
var meta metadata
// The table name has to fit in the struct.
if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
- log.Warningf("Table name %q too long.", tablename)
- return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ return linux.KernelIPTGetEntries{}, metadata{}, fmt.Errorf("table name %q too long.", tablename)
}
copy(entries.Name[:], tablename)
@@ -228,61 +228,27 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern
return entries, meta, nil
}
-func marshalMatcher(matcher iptables.Matcher) []byte {
- switch m := matcher.(type) {
- case *iptables.UDPMatcher:
- return marshalUDPMatcher(m)
- default:
- // TODO(gvisor.dev/issue/170): Support other matchers.
- panic(fmt.Errorf("unknown matcher of type %T", matcher))
- }
-}
-
-func marshalUDPMatcher(matcher *iptables.UDPMatcher) []byte {
- nflog("convert to binary: marshalling UDP matcher: %+v", matcher)
-
- // We have to pad this struct size to a multiple of 8 bytes.
- size := alignUp(linux.SizeOfXTEntryMatch+linux.SizeOfXTUDP, 8)
-
- linuxMatcher := linux.KernelXTEntryMatch{
- XTEntryMatch: linux.XTEntryMatch{
- MatchSize: uint16(size),
- },
- Data: make([]byte, 0, linux.SizeOfXTUDP),
- }
- copy(linuxMatcher.Name[:], matcherNameUDP)
-
- xtudp := linux.XTUDP{
- SourcePortStart: matcher.Data.SourcePortStart,
- SourcePortEnd: matcher.Data.SourcePortEnd,
- DestinationPortStart: matcher.Data.DestinationPortStart,
- DestinationPortEnd: matcher.Data.DestinationPortEnd,
- InverseFlags: matcher.Data.InverseFlags,
- }
- linuxMatcher.Data = binary.Marshal(linuxMatcher.Data, usermem.ByteOrder, xtudp)
-
- buf := make([]byte, 0, size)
- buf = binary.Marshal(buf, usermem.ByteOrder, linuxMatcher)
- buf = append(buf, make([]byte, size-len(buf))...)
- nflog("convert to binary: marshalled UDP matcher into %v", buf)
- return buf[:]
-}
-
func marshalTarget(target iptables.Target) []byte {
- switch target.(type) {
- case iptables.UnconditionalAcceptTarget:
- return marshalStandardTarget(iptables.Accept)
- case iptables.UnconditionalDropTarget:
- return marshalStandardTarget(iptables.Drop)
+ switch tg := target.(type) {
+ case iptables.AcceptTarget:
+ return marshalStandardTarget(iptables.RuleAccept)
+ case iptables.DropTarget:
+ return marshalStandardTarget(iptables.RuleDrop)
case iptables.ErrorTarget:
- return marshalErrorTarget()
+ return marshalErrorTarget(errorTargetName)
+ case iptables.UserChainTarget:
+ return marshalErrorTarget(tg.Name)
+ case iptables.ReturnTarget:
+ return marshalStandardTarget(iptables.RuleReturn)
+ case JumpTarget:
+ return marshalJumpTarget(tg)
default:
panic(fmt.Errorf("unknown target of type %T", target))
}
}
-func marshalStandardTarget(verdict iptables.Verdict) []byte {
- nflog("convert to binary: marshalling standard target with size %d", linux.SizeOfXTStandardTarget)
+func marshalStandardTarget(verdict iptables.RuleVerdict) []byte {
+ nflog("convert to binary: marshalling standard target")
// The target's name will be the empty string.
target := linux.XTStandardTarget{
@@ -296,56 +262,69 @@ func marshalStandardTarget(verdict iptables.Verdict) []byte {
return binary.Marshal(ret, usermem.ByteOrder, target)
}
-func marshalErrorTarget() []byte {
+func marshalErrorTarget(errorName string) []byte {
// This is an error target named error
target := linux.XTErrorTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTErrorTarget,
},
}
- copy(target.Name[:], errorTargetName)
+ copy(target.Name[:], errorName)
copy(target.Target.Name[:], errorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
return binary.Marshal(ret, usermem.ByteOrder, target)
}
+func marshalJumpTarget(jt JumpTarget) []byte {
+ nflog("convert to binary: marshalling jump target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ // Verdict is overloaded by the ABI. When positive, it holds
+ // the jump offset from the start of the table.
+ Verdict: int32(jt.Offset),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
// translateFromStandardVerdict translates verdicts the same way as the iptables
// tool.
-func translateFromStandardVerdict(verdict iptables.Verdict) int32 {
+func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 {
switch verdict {
- case iptables.Accept:
+ case iptables.RuleAccept:
return -linux.NF_ACCEPT - 1
- case iptables.Drop:
+ case iptables.RuleDrop:
return -linux.NF_DROP - 1
- case iptables.Queue:
- return -linux.NF_QUEUE - 1
- case iptables.Return:
+ case iptables.RuleReturn:
return linux.NF_RETURN
- case iptables.Jump:
+ default:
// TODO(gvisor.dev/issue/170): Support Jump.
- panic("Jump isn't supported yet")
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
}
- panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
}
-// translateToStandardVerdict translates from the value in a
+// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an iptables.Verdict.
-func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) {
+func translateToStandardTarget(val int32) (iptables.Target, error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return iptables.Accept, nil
+ return iptables.AcceptTarget{}, nil
case -linux.NF_DROP - 1:
- return iptables.Drop, nil
+ return iptables.DropTarget{}, nil
case -linux.NF_QUEUE - 1:
- log.Warningf("Unsupported iptables verdict QUEUE.")
+ return nil, errors.New("unsupported iptables verdict QUEUE")
case linux.NF_RETURN:
- log.Warningf("Unsupported iptables verdict RETURN.")
+ return iptables.ReturnTarget{}, nil
default:
- log.Warningf("Unknown iptables verdict %d.", val)
+ return nil, fmt.Errorf("unknown iptables verdict %d", val)
}
- return iptables.Invalid, syserr.ErrInvalidArgument
}
// SetEntries sets iptables rules for a single table. See
@@ -353,7 +332,7 @@ func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) {
func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// Get the basic rules data (struct ipt_replace).
if len(optVal) < linux.SizeOfIPTReplace {
- log.Warningf("netfilter.SetEntries: optVal has insufficient size for replace %d", len(optVal))
+ nflog("optVal has insufficient size for replace %d", len(optVal))
return syserr.ErrInvalidArgument
}
var replace linux.IPTReplace
@@ -367,7 +346,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
case iptables.TablenameFilter:
table = iptables.EmptyFilterTable()
default:
- log.Warningf("We don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
+ nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
}
@@ -375,13 +354,14 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// Convert input into a list of rules and their offsets.
var offset uint32
- var offsets []uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
nflog("set entries: processing entry at offset %d", offset)
// Get the struct ipt_entry.
if len(optVal) < linux.SizeOfIPTEntry {
- log.Warningf("netfilter: optVal has insufficient size for entry %d", len(optVal))
+ nflog("optVal has insufficient size for entry %d", len(optVal))
return syserr.ErrInvalidArgument
}
var entry linux.IPTEntry
@@ -391,7 +371,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
optVal = optVal[linux.SizeOfIPTEntry:]
if entry.TargetOffset < linux.SizeOfIPTEntry {
- log.Warningf("netfilter: entry has too-small target offset %d", entry.TargetOffset)
+ nflog("entry has too-small target offset %d", entry.TargetOffset)
return syserr.ErrInvalidArgument
}
@@ -399,7 +379,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// filtering fields.
filter, err := filterFromIPTIP(entry.IP)
if err != nil {
- return err
+ nflog("bad iptip: %v", err)
+ return syserr.ErrInvalidArgument
}
// TODO(gvisor.dev/issue/170): Matchers and targets can specify
@@ -407,25 +388,26 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// Get matchers.
matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry
if len(optVal) < int(matchersSize) {
- log.Warningf("netfilter: entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
+ nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
return syserr.ErrInvalidArgument
}
matchers, err := parseMatchers(filter, optVal[:matchersSize])
if err != nil {
- log.Warningf("netfilter: failed to parse matchers: %v", err)
- return err
+ nflog("failed to parse matchers: %v", err)
+ return syserr.ErrInvalidArgument
}
optVal = optVal[matchersSize:]
// Get the target of the rule.
targetSize := entry.NextOffset - entry.TargetOffset
if len(optVal) < int(targetSize) {
- log.Warningf("netfilter: entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
+ nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
return syserr.ErrInvalidArgument
}
target, err := parseTarget(optVal[:targetSize])
if err != nil {
- return err
+ nflog("failed to parse target: %v", err)
+ return syserr.ErrInvalidArgument
}
optVal = optVal[targetSize:]
@@ -434,11 +416,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
Target: target,
Matchers: matchers,
})
- offsets = append(offsets, offset)
+ offsets[offset] = int(entryIdx)
offset += uint32(entry.NextOffset)
if initialOptValLen-len(optVal) != int(entry.NextOffset) {
- log.Warningf("netfilter: entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return syserr.ErrInvalidArgument
}
}
@@ -447,32 +430,77 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
for hook, _ := range replace.HookEntry {
if table.ValidHooks()&(1<<hook) != 0 {
hk := hookFromLinux(hook)
- for ruleIdx, offset := range offsets {
+ for offset, ruleIdx := range offsets {
if offset == replace.HookEntry[hook] {
table.BuiltinChains[hk] = ruleIdx
}
if offset == replace.Underflow[hook] {
+ if !validUnderflow(table.Rules[ruleIdx]) {
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP")
+ return syserr.ErrInvalidArgument
+ }
table.Underflows[hk] = ruleIdx
}
}
if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset {
- log.Warningf("Hook %v is unset.", hk)
+ nflog("hook %v is unset.", hk)
return syserr.ErrInvalidArgument
}
if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset {
- log.Warningf("Underflow %v is unset.", hk)
+ nflog("underflow %v is unset.", hk)
return syserr.ErrInvalidArgument
}
}
}
+ // Add the user chains.
+ for ruleIdx, rule := range table.Rules {
+ target, ok := rule.Target.(iptables.UserChainTarget)
+ if !ok {
+ continue
+ }
+
+ // We found a user chain. Before inserting it into the table,
+ // check that:
+ // - There's some other rule after it.
+ // - There are no matchers.
+ if ruleIdx == len(table.Rules)-1 {
+ nflog("user chain must have a rule or default policy")
+ return syserr.ErrInvalidArgument
+ }
+ if len(table.Rules[ruleIdx].Matchers) != 0 {
+ nflog("user chain's first node must have no matchers")
+ return syserr.ErrInvalidArgument
+ }
+ table.UserChains[target.Name] = ruleIdx + 1
+ }
+
+ // Set each jump to point to the appropriate rule. Right now they hold byte
+ // offsets.
+ for ruleIdx, rule := range table.Rules {
+ jump, ok := rule.Target.(JumpTarget)
+ if !ok {
+ continue
+ }
+
+ // Find the rule corresponding to the jump rule offset.
+ jumpTo, ok := offsets[jump.Offset]
+ if !ok {
+ nflog("failed to find a rule to jump to")
+ return syserr.ErrInvalidArgument
+ }
+ jump.RuleNum = jumpTo
+ rule.Target = jump
+ table.Rules[ruleIdx] = rule
+ }
+
// TODO(gvisor.dev/issue/170): Support other chains.
// Since we only support modifying the INPUT chain right now, make sure
// all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
if hook != iptables.Input {
- if _, ok := table.Rules[ruleIdx].Target.(iptables.UnconditionalAcceptTarget); !ok {
- log.Warningf("Hook %d is unsupported.", hook)
+ if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok {
+ nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
}
}
@@ -498,7 +526,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
// only the matchers.
-func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, *syserr.Error) {
+func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, error) {
nflog("set entries: parsing matchers of size %d", len(optVal))
var matchers []iptables.Matcher
for len(optVal) > 0 {
@@ -506,8 +534,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma
// Get the XTEntryMatch.
if len(optVal) < linux.SizeOfXTEntryMatch {
- log.Warningf("netfilter: optVal has insufficient size for entry match: %d", len(optVal))
- return nil, syserr.ErrInvalidArgument
+ return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal))
}
var match linux.XTEntryMatch
buf := optVal[:linux.SizeOfXTEntryMatch]
@@ -516,45 +543,18 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma
// Check some invariants.
if match.MatchSize < linux.SizeOfXTEntryMatch {
- log.Warningf("netfilter: match size is too small, must be at least %d", linux.SizeOfXTEntryMatch)
- return nil, syserr.ErrInvalidArgument
+
+ return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch)
}
if len(optVal) < int(match.MatchSize) {
- log.Warningf("netfilter: optVal has insufficient size for match: %d", len(optVal))
- return nil, syserr.ErrInvalidArgument
+ return nil, fmt.Errorf("optVal has insufficient size for match: %d", len(optVal))
}
- buf = optVal[linux.SizeOfXTEntryMatch:match.MatchSize]
- var matcher iptables.Matcher
- var err error
- switch match.Name.String() {
- case matcherNameUDP:
- if len(buf) < linux.SizeOfXTUDP {
- log.Warningf("netfilter: optVal has insufficient size for UDP match: %d", len(optVal))
- return nil, syserr.ErrInvalidArgument
- }
- // For alignment reasons, the match's total size may
- // exceed what's strictly necessary to hold matchData.
- var matchData linux.XTUDP
- binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData)
- log.Infof("parseMatchers: parsed XTUDP: %+v", matchData)
- matcher, err = iptables.NewUDPMatcher(filter, iptables.UDPMatcherParams{
- SourcePortStart: matchData.SourcePortStart,
- SourcePortEnd: matchData.SourcePortEnd,
- DestinationPortStart: matchData.DestinationPortStart,
- DestinationPortEnd: matchData.DestinationPortEnd,
- InverseFlags: matchData.InverseFlags,
- })
- if err != nil {
- log.Warningf("netfilter: failed to create UDP matcher: %v", err)
- return nil, syserr.ErrInvalidArgument
- }
-
- default:
- log.Warningf("netfilter: unsupported matcher with name %q", match.Name.String())
- return nil, syserr.ErrInvalidArgument
+ // Parse the specific matcher.
+ matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize])
+ if err != nil {
+ return nil, fmt.Errorf("failed to create matcher: %v", err)
}
-
matchers = append(matchers, matcher)
// TODO(gvisor.dev/issue/170): Check the revision field.
@@ -562,8 +562,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma
}
if len(optVal) != 0 {
- log.Warningf("netfilter: optVal should be exhausted after parsing matchers")
- return nil, syserr.ErrInvalidArgument
+ return nil, errors.New("optVal should be exhausted after parsing matchers")
}
return matchers, nil
@@ -571,11 +570,10 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma
// parseTarget parses a target from optVal. optVal should contain only the
// target.
-func parseTarget(optVal []byte) (iptables.Target, *syserr.Error) {
+func parseTarget(optVal []byte) (iptables.Target, error) {
nflog("set entries: parsing target of size %d", len(optVal))
if len(optVal) < linux.SizeOfXTEntryTarget {
- log.Warningf("netfilter: optVal has insufficient size for entry target %d", len(optVal))
- return nil, syserr.ErrInvalidArgument
+ return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
}
var target linux.XTEntryTarget
buf := optVal[:linux.SizeOfXTEntryTarget]
@@ -584,32 +582,23 @@ func parseTarget(optVal []byte) (iptables.Target, *syserr.Error) {
case "":
// Standard target.
if len(optVal) != linux.SizeOfXTStandardTarget {
- log.Warningf("netfilter.SetEntries: optVal has wrong size for standard target %d", len(optVal))
- return nil, syserr.ErrInvalidArgument
+ return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal))
}
var standardTarget linux.XTStandardTarget
buf = optVal[:linux.SizeOfXTStandardTarget]
binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
- verdict, err := translateToStandardVerdict(standardTarget.Verdict)
- if err != nil {
- return nil, err
- }
- switch verdict {
- case iptables.Accept:
- return iptables.UnconditionalAcceptTarget{}, nil
- case iptables.Drop:
- return iptables.UnconditionalDropTarget{}, nil
- default:
- log.Warningf("Unknown verdict: %v", verdict)
- return nil, syserr.ErrInvalidArgument
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict)
}
+ // A verdict >= 0 indicates a jump.
+ return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
case errorTargetName:
// Error target.
if len(optVal) != linux.SizeOfXTErrorTarget {
- log.Infof("netfilter.SetEntries: optVal has insufficient size for error target %d", len(optVal))
- return nil, syserr.ErrInvalidArgument
+ return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal))
}
var errorTarget linux.XTErrorTarget
buf = optVal[:linux.SizeOfXTErrorTarget]
@@ -622,24 +611,24 @@ func parseTarget(optVal []byte) (iptables.Target, *syserr.Error) {
// somehow fall through every rule.
// * To mark the start of a user defined chain. These
// rules have an error with the name of the chain.
- switch errorTarget.Name.String() {
+ switch name := errorTarget.Name.String(); name {
case errorTargetName:
+ nflog("set entries: error target")
return iptables.ErrorTarget{}, nil
default:
- log.Infof("Unknown error target %q doesn't exist or isn't supported yet.", errorTarget.Name.String())
- return nil, syserr.ErrInvalidArgument
+ // User defined chain.
+ nflog("set entries: user-defined target %q", name)
+ return iptables.UserChainTarget{Name: name}, nil
}
}
// Unknown target.
- log.Infof("Unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
- return nil, syserr.ErrInvalidArgument
+ return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
}
-func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, *syserr.Error) {
+func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) {
if containsUnsupportedFields(iptip) {
- log.Warningf("netfilter: unsupported fields in struct iptip: %+v", iptip)
- return iptables.IPHeaderFilter{}, syserr.ErrInvalidArgument
+ return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
}
return iptables.IPHeaderFilter{
Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
@@ -662,6 +651,18 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool {
iptip.InverseFlags != 0
}
+func validUnderflow(rule iptables.Rule) bool {
+ if len(rule.Matchers) != 0 {
+ return false
+ }
+ switch rule.Target.(type) {
+ case iptables.AcceptTarget, iptables.DropTarget:
+ return true
+ default:
+ return false
+ }
+}
+
func hookFromLinux(hook int) iptables.Hook {
switch hook {
case linux.NF_INET_PRE_ROUTING:
@@ -677,8 +678,3 @@ func hookFromLinux(hook int) iptables.Hook {
}
panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
-
-// alignUp rounds a length up to an alignment. align must be a power of 2.
-func alignUp(length int, align uint) int {
- return (length + int(align) - 1) & ^(int(align) - 1)
-}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
new file mode 100644
index 000000000..c421b87cf
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+)
+
+// JumpTarget implements iptables.Target.
+type JumpTarget struct {
+ // Offset is the byte offset of the rule to jump to. It is used for
+ // marshaling and unmarshaling.
+ Offset uint32
+
+ // RuleNum is the rule to jump to.
+ RuleNum int
+}
+
+// Action implements iptables.Target.Action.
+func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) {
+ return iptables.RuleJump, jt.RuleNum
+}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
new file mode 100644
index 000000000..f9945e214
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -0,0 +1,143 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameTCP = "tcp"
+
+func init() {
+ registerMatchMaker(tcpMarshaler{})
+}
+
+// tcpMarshaler implements matchMaker for TCP matching.
+type tcpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (tcpMarshaler) name() string {
+ return matcherNameTCP
+}
+
+// marshal implements matchMaker.marshal.
+func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
+ matcher := mr.(*TCPMatcher)
+ xttcp := linux.XTTCP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTTCP)
+ return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, usermem.ByteOrder, xttcp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (tcpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+ if len(buf) < linux.SizeOfXTTCP {
+ return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.XTTCP
+ binary.Unmarshal(buf[:linux.SizeOfXTTCP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTTCP: %+v", matchData)
+
+ if matchData.Option != 0 ||
+ matchData.FlagMask != 0 ||
+ matchData.FlagCompare != 0 ||
+ matchData.InverseFlags != 0 {
+ return nil, fmt.Errorf("unsupported TCP matcher flags set")
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber {
+ return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber)
+ }
+
+ return &TCPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// TCPMatcher matches TCP packets and their headers. It implements Matcher.
+type TCPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*TCPMatcher) Name() string {
+ return matcherNameTCP
+}
+
+// Match implements Matcher.Match.
+func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return false, false
+ }
+
+ // We dont't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
+
+ // Now we need the transport header. However, this may not have been set
+ // yet.
+ // TODO(gvisor.dev/issue/170): Parsing the transport header should
+ // ultimately be moved into the iptables.Check codepath as matchers are
+ // added.
+ var tcpHeader header.TCP
+ if pkt.TransportHeader != nil {
+ tcpHeader = header.TCP(pkt.TransportHeader)
+ } else {
+ // The TCP header hasn't been parsed yet. We have to do it here.
+ if len(pkt.Data.First()) < header.TCPMinimumSize {
+ // There's no valid TCP header here, so we hotdrop the
+ // packet.
+ return false, true
+ }
+ tcpHeader = header.TCP(pkt.Data.First())
+ }
+
+ // Check whether the source and destination ports are within the
+ // matching range.
+ if sourcePort := tcpHeader.SourcePort(); sourcePort < tm.sourcePortStart || tm.sourcePortEnd < sourcePort {
+ return false, false
+ }
+ if destinationPort := tcpHeader.DestinationPort(); destinationPort < tm.destinationPortStart || tm.destinationPortEnd < destinationPort {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/tcpip/iptables/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 496931d7a..86aa11696 100644
--- a/pkg/tcpip/iptables/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -12,56 +12,89 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package iptables
+package netfilter
import (
"fmt"
- "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// TODO(gvisor.dev/issue/170): The following per-matcher params should be
-// supported:
-// - Table name
-// - Match size
-// - User size
-// - Hooks
-// - Proto
-// - Family
+const matcherNameUDP = "udp"
-// UDPMatcher matches UDP packets and their headers. It implements Matcher.
-type UDPMatcher struct {
- Data UDPMatcherParams
+func init() {
+ registerMatchMaker(udpMarshaler{})
}
-// UDPMatcherParams are the parameters used to create a UDPMatcher.
-type UDPMatcherParams struct {
- SourcePortStart uint16
- SourcePortEnd uint16
- DestinationPortStart uint16
- DestinationPortEnd uint16
- InverseFlags uint8
+// udpMarshaler implements matchMaker for UDP matching.
+type udpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (udpMarshaler) name() string {
+ return matcherNameUDP
}
-// NewUDPMatcher returns a new instance of UDPMatcher.
-func NewUDPMatcher(filter IPHeaderFilter, data UDPMatcherParams) (Matcher, error) {
- log.Infof("Adding rule with UDPMatcherParams: %+v", data)
+// marshal implements matchMaker.marshal.
+func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
+ matcher := mr.(*UDPMatcher)
+ xtudp := linux.XTUDP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTUDP)
+ return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, usermem.ByteOrder, xtudp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (udpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+ if len(buf) < linux.SizeOfXTUDP {
+ return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may exceed what's
+ // strictly necessary to hold matchData.
+ var matchData linux.XTUDP
+ binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTUDP: %+v", matchData)
- if data.InverseFlags != 0 {
+ if matchData.InverseFlags != 0 {
return nil, fmt.Errorf("unsupported UDP matcher inverse flags set")
}
if filter.Protocol != header.UDPProtocolNumber {
- return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber)
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
}
- return &UDPMatcher{Data: data}, nil
+ return &UDPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// UDPMatcher matches UDP packets and their headers. It implements Matcher.
+type UDPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*UDPMatcher) Name() string {
+ return matcherNameUDP
}
// Match implements Matcher.Match.
-func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
@@ -98,10 +131,10 @@ func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName str
// Check whether the source and destination ports are within the
// matching range.
- if sourcePort := udpHeader.SourcePort(); sourcePort < um.Data.SourcePortStart || um.Data.SourcePortEnd < sourcePort {
+ if sourcePort := udpHeader.SourcePort(); sourcePort < um.sourcePortStart || um.sourcePortEnd < sourcePort {
return false, false
}
- if destinationPort := udpHeader.DestinationPort(); destinationPort < um.Data.DestinationPortStart || um.Data.DestinationPortEnd < destinationPort {
+ if destinationPort := udpHeader.DestinationPort(); destinationPort < um.destinationPortStart || um.destinationPortEnd < destinationPort {
return false, false
}
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
index 4ea252ccb..0899c61d1 100644
--- a/pkg/sentry/socket/netlink/message.go
+++ b/pkg/sentry/socket/netlink/message.go
@@ -23,18 +23,11 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// alignUp rounds a length up to an alignment.
-//
-// Preconditions: align is a power of two.
-func alignUp(length int, align uint) int {
- return (length + int(align) - 1) &^ (int(align) - 1)
-}
-
// alignPad returns the length of padding required for alignment.
//
// Preconditions: align is a power of two.
func alignPad(length int, align uint) int {
- return alignUp(length, align) - length
+ return binary.AlignUp(length, align) - length
}
// Message contains a complete serialized netlink message.
@@ -138,7 +131,7 @@ func (m *Message) Finalize() []byte {
// Align the message. Note that the message length in the header (set
// above) is the useful length of the message, not the total aligned
// length. See net/netlink/af_netlink.c:__nlmsg_put.
- aligned := alignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
m.putZeros(aligned - len(m.buf))
return m.buf
}
@@ -173,7 +166,7 @@ func (m *Message) PutAttr(atype uint16, v interface{}) {
m.Put(v)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -190,7 +183,7 @@ func (m *Message) PutAttrString(atype uint16, s string) {
m.putZeros(1)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index ed2fbcceb..e187276c5 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1318,6 +1318,22 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
return ib, nil
+ case linux.IPV6_RECVTCLASS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1414,6 +1430,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
return o, nil
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1762,6 +1793,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
linux.IPV6_IPSEC_POLICY,
linux.IPV6_JOIN_ANYCAST,
linux.IPV6_LEAVE_ANYCAST,
+ // TODO(b/148887420): Add support for IPV6_PKTINFO.
linux.IPV6_PKTINFO,
linux.IPV6_ROUTER_ALERT,
linux.IPV6_XFRM_POLICY,
@@ -1787,6 +1819,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+ case linux.IPV6_RECVTCLASS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1949,6 +1989,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ case linux.IP_PKTINFO:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -1964,7 +2014,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_PKTINFO,
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
@@ -2061,7 +2110,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
- linux.IPV6_RECVTCLASS,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
linux.IPV6_TCLASS,
@@ -2395,10 +2443,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
func (s *SocketOperations) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
},
}
}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 5afff2564..5f181f017 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -75,6 +75,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
switch protocol {
case syscall.IPPROTO_ICMP:
return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
case syscall.IPPROTO_UDP:
return header.UDPProtocolNumber, true, nil
case syscall.IPPROTO_TCP:
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 762a946fe..88d5db9fc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"capability.go",
"clone.go",
+ "epoll.go",
"futex.go",
"linux64_amd64.go",
"linux64_arm64.go",
@@ -30,7 +31,6 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
- "//pkg/sentry/socket/control",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/syscalls/linux",
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
new file mode 100644
index 000000000..a6e48b836
--- /dev/null
+++ b/pkg/sentry/strace/epoll.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
+ }
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x ", eventAddr)
+ writeEpollEvent(&sb, e)
+ return sb.String()
+}
+
+func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string {
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x {", eventsAddr)
+ addr := eventsAddr
+ for i := uint64(0); i < numEvents; i++ {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(addr, &e); err != nil {
+ fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err)
+ continue
+ }
+ writeEpollEvent(&sb, e)
+ if uint64(sb.Len()) >= maxBytes {
+ sb.WriteString("...")
+ break
+ }
+ if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok {
+ fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr)
+ continue
+ }
+ }
+ sb.WriteString("}")
+ return sb.String()
+}
+
+func writeEpollEvent(sb *strings.Builder, e linux.EpollEvent) {
+ events := epollEventEvents.Parse(uint64(e.Events))
+ fmt.Fprintf(sb, "{events=%s data=[%#x, %#x]}", events, e.Data[0], e.Data[1])
+}
+
+var epollCtlOps = abi.ValueSet{
+ linux.EPOLL_CTL_ADD: "EPOLL_CTL_ADD",
+ linux.EPOLL_CTL_DEL: "EPOLL_CTL_DEL",
+ linux.EPOLL_CTL_MOD: "EPOLL_CTL_MOD",
+}
+
+var epollEventEvents = abi.FlagSet{
+ {Flag: linux.EPOLLIN, Name: "EPOLLIN"},
+ {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"},
+ {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"},
+ {Flag: linux.EPOLLERR, Name: "EPOLLERR"},
+ {Flag: linux.EPOLLHUP, Name: "EPULLHUP"},
+ {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"},
+ {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"},
+ {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"},
+ {Flag: linux.EPOLLWRBAND, Name: "EPOLLWRBAND"},
+ {Flag: linux.EPOLLMSG, Name: "EPOLLMSG"},
+ {Flag: linux.EPOLLRDHUP, Name: "EPOLLRDHUP"},
+ {Flag: linux.EPOLLEXCLUSIVE, Name: "EPOLLEXCLUSIVE"},
+ {Flag: linux.EPOLLWAKEUP, Name: "EPOLLWAKEUP"},
+ {Flag: linux.EPOLLONESHOT, Name: "EPOLLONESHOT"},
+ {Flag: linux.EPOLLET, Name: "EPOLLET"},
+}
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index a4de545e9..71b92eaee 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -256,8 +256,8 @@ var linuxAMD64 = SyscallMap{
229: makeSyscallInfo("clock_getres", Hex, PostTimespec),
230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
231: makeSyscallInfo("exit_group", Hex),
- 232: makeSyscallInfo("epoll_wait", Hex, Hex, Hex, Hex),
- 233: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 232: makeSyscallInfo("epoll_wait", FD, EpollEvents, Hex, Hex),
+ 233: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
234: makeSyscallInfo("tgkill", Hex, Hex, Signal),
235: makeSyscallInfo("utimes", Path, Timeval),
// 236: vserver (not implemented in the Linux kernel)
@@ -305,7 +305,7 @@ var linuxAMD64 = SyscallMap{
278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
- 281: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 281: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
282: makeSyscallInfo("signalfd", Hex, Hex, Hex),
283: makeSyscallInfo("timerfd_create", Hex, Hex),
284: makeSyscallInfo("eventfd", Hex),
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
index 8bc38545f..bd7361a52 100644
--- a/pkg/sentry/strace/linux64_arm64.go
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -45,8 +45,8 @@ var linuxARM64 = SyscallMap{
18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
19: makeSyscallInfo("eventfd2", Hex, Hex),
20: makeSyscallInfo("epoll_create1", Hex),
- 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
- 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 21: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
+ 22: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
23: makeSyscallInfo("dup", FD),
24: makeSyscallInfo("dup3", FD, FD, Hex),
25: makeSyscallInfo("fcntl", FD, Hex, Hex),
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index f7ff4573e..51e6d81b2 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
@@ -220,13 +219,13 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
- i += control.AlignUp(length, width)
+ i += binary.AlignUp(length, width)
continue
}
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := control.AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
fds := make(linux.ControlMessageRights, numRights)
@@ -295,7 +294,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
default:
panic("unreachable")
}
- i += control.AlignUp(length, width)
+ i += binary.AlignUp(length, width)
}
return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index a796b2396..77655558e 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -141,6 +141,10 @@ func path(t *kernel.Task, addr usermem.Addr) string {
}
func fd(t *kernel.Task, fd int32) string {
+ if kernel.VFS2Enabled {
+ return fdVFS2(t, fd)
+ }
+
root := t.FSContext().RootDirectory()
if root != nil {
defer root.DecRef()
@@ -169,6 +173,30 @@ func fd(t *kernel.Task, fd int32) string {
return fmt.Sprintf("%#x %s", fd, name)
}
+func fdVFS2(t *kernel.Task, fd int32) string {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+
+ vfsObj := root.Mount().Filesystem().VirtualFilesystem()
+ if fd == linux.AT_FDCWD {
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef()
+
+ name, _ := vfsObj.PathnameWithDeleted(t, root, wd)
+ return fmt.Sprintf("AT_FDCWD %s", name)
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ // Cast FD to uint64 to avoid printing negative hex.
+ return fmt.Sprintf("%#x (bad FD)", uint64(fd))
+ }
+ defer file.DecRef()
+
+ name, _ := vfsObj.PathnameWithDeleted(t, root, file.VirtualDentry())
+ return fmt.Sprintf("%#x %s", fd, name)
+}
+
func fdpair(t *kernel.Task, addr usermem.Addr) string {
var fds [2]int32
_, err := t.CopyIn(addr, &fds)
@@ -453,6 +481,12 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
case PollFDs:
output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case EpollCtlOp:
+ output = append(output, epollCtlOps.Parse(uint64(args[arg].Int())))
+ case EpollEvent:
+ output = append(output, epollEvent(t, args[arg].Pointer()))
+ case EpollEvents:
+ output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize)))
case SelectFDSet:
output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
case Oct:
@@ -521,6 +555,8 @@ func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uint
output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer())
case PollFDs:
output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true)
+ case EpollEvents:
+ output[arg] = epollEvents(t, args[arg].Pointer(), uint64(rval), uint64(maximumBlobSize))
case GetSockOptVal:
output[arg] = getSockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Pointer() /* optLen */, maximumBlobSize, rval)
case SetSockOptVal:
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 446d1e0f6..7e69b9279 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -228,6 +228,16 @@ const (
// SockOptLevel is the optname argument in getsockopt(2) and
// setsockopt(2).
SockOptName
+
+ // EpollCtlOp is the op argument to epoll_ctl(2).
+ EpollCtlOp
+
+ // EpollEvent is the event argument in epoll_ctl(2).
+ EpollEvent
+
+ // EpollEvents is an array of struct epoll_event. It is the events
+ // argument in epoll_wait(2)/epoll_pwait(2).
+ EpollEvents
)
// defaultFormat is the syscall argument format to use if the actual format is
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index be16ee686..c7883e68e 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -42,8 +42,6 @@ go_library(
"sys_socket.go",
"sys_splice.go",
"sys_stat.go",
- "sys_stat_amd64.go",
- "sys_stat_arm64.go",
"sys_sync.go",
"sys_sysinfo.go",
"sys_syslog.go",
@@ -74,6 +72,7 @@ go_library(
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/epoll",
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index 5f11b496c..3ab93fbde 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// EpollCreate1 implements the epoll_create1(2) linux syscall.
func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
@@ -83,8 +85,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
mask = waiter.EventMaskFromLinux(e.Events)
- data[0] = e.Fd
- data[1] = e.Data
+ data = e.Data
}
// Perform the requested operations.
@@ -165,3 +166,5 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+
+// LINT.ThenChange(vfs2/epoll.go)
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 421845ebb..c21f14dc0 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -130,6 +130,8 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string
return path, dirPath, nil
}
+// LINT.IfChange
+
func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -575,6 +577,10 @@ func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Ioctl implements linux syscall ioctl(2).
func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -650,6 +656,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/ioctl.go)
+
+// LINT.IfChange
+
// Getcwd implements the linux syscall getcwd(2).
func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -760,6 +770,10 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/fscontext.go)
+
+// LINT.IfChange
+
// Close implements linux syscall close(2).
func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -1094,6 +1108,8 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/fd.go)
+
const (
_FADV_NORMAL = 0
_FADV_RANDOM = 1
@@ -1141,6 +1157,8 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.IfChange
+
func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1421,6 +1439,10 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1480,6 +1502,10 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return n, nil, err
}
+// LINT.ThenChange(vfs2/stat.go)
+
+// LINT.IfChange
+
func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1516,6 +1542,10 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, unlinkAt(t, dirFD, addr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Truncate implements linux syscall truncate(2).
func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -1614,6 +1644,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/setstat.go)
+
// Umask implements linux syscall umask(2).
func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
mask := args[0].ModeT()
@@ -1621,6 +1653,8 @@ func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(mask), nil, nil
}
+// LINT.IfChange
+
// Change ownership of a file.
//
// uid and gid may be -1, in which case they will not be changed.
@@ -1987,6 +2021,10 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true)
}
+// LINT.ThenChange(vfs2/setstat.go)
+
+// LINT.IfChange
+
func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error {
newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */)
if err != nil {
@@ -2042,6 +2080,8 @@ func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
// Fallocate implements linux system call fallocate(2).
func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index f66f4ffde..b126fecc0 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -27,6 +27,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Getdents implements linux syscall getdents(2) for 64bit systems.
func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -244,3 +246,5 @@ func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error {
func (ds *direntSerializer) Written() int {
return ds.written
}
+
+// LINT.ThenChange(vfs2/getdents.go)
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
index 297e920c4..3f7691eae 100644
--- a/pkg/sentry/syscalls/linux/sys_lseek.go
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Lseek implements linux syscall lseek(2).
func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -52,3 +54,5 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
return uintptr(offset), nil, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 9959f6e61..91694d374 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -35,6 +35,8 @@ func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return uintptr(addr), nil, nil
}
+// LINT.IfChange
+
// Mmap implements linux syscall mmap(2).
func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
prot := args[2].Int()
@@ -104,6 +106,8 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(rv), nil, err
}
+// LINT.ThenChange(vfs2/mmap.go)
+
// Munmap implements linux syscall munmap(2).
func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64())
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index 98db32d77..9c6728530 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -135,7 +136,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
// Set the underlying executable.
- t.MemoryManager().SetExecutable(file.Dirent)
+ t.MemoryManager().SetExecutable(fsbridge.NewFSFile(file))
case linux.PR_SET_MM_AUXV,
linux.PR_SET_MM_START_CODE,
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 227692f06..78a2cb750 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskRead contains events that can be triggered on reads.
EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
@@ -388,3 +390,5 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index c841abccb..701b27b4a 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -23,6 +23,26 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uattr.Links,
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: sattr.BlockSize,
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
+}
+
// Stat implements linux syscall stat(2).
func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -112,7 +132,8 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err
if err != nil {
return err
}
- return copyOutStat(t, statAddr, d.Inode.StableAttr, uattr)
+ s := statFromAttrs(t, d.Inode.StableAttr, uattr)
+ return s.CopyOut(t, statAddr)
}
// fstat implements fstat for the given *fs.File.
@@ -121,7 +142,8 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
if err != nil {
return err
}
- return copyOutStat(t, statAddr, f.Dirent.Inode.StableAttr, uattr)
+ s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr)
+ return s.CopyOut(t, statAddr)
}
// Statx implements linux syscall statx(2).
@@ -277,3 +299,5 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
_, err = t.CopyOut(addr, &statfs)
return err
}
+
+// LINT.ThenChange(vfs2/stat.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat_amd64.go b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
deleted file mode 100644
index 75a567bd4..000000000
--- a/pkg/sentry/syscalls/linux/sys_stat_amd64.go
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//+build amd64
-
-package linux
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/usermem"
-)
-
-// copyOutStat copies the attributes (sattr, uattr) to the struct stat at
-// address dst in t's address space. It encodes the stat struct to bytes
-// manually, as stat() is a very common syscall for many applications, and
-// t.CopyObjectOut has noticeable performance impact due to its many slice
-// allocations and use of reflection.
-func copyOutStat(t *kernel.Task, dst usermem.Addr, sattr fs.StableAttr, uattr fs.UnstableAttr) error {
- b := t.CopyScratchBuffer(int(linux.SizeOfStat))[:0]
-
- // Dev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.DeviceID))
- // Ino (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.InodeID))
- // Nlink (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uattr.Links)
- // Mode (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, sattr.Type.LinuxType()|uint32(uattr.Perms.LinuxMode()))
- // UID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
- // GID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()))
- // Padding (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, 0)
- // Rdev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)))
- // Size (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Size))
- // Blksize (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.BlockSize))
- // Blocks (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Usage/512))
-
- // ATime
- atime := uattr.AccessTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Nsec))
-
- // MTime
- mtime := uattr.ModificationTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Nsec))
-
- // CTime
- ctime := uattr.StatusChangeTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Nsec))
-
- _, err := t.CopyOutBytes(dst, b)
- return err
-}
diff --git a/pkg/sentry/syscalls/linux/sys_stat_arm64.go b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
deleted file mode 100644
index 80c98d05c..000000000
--- a/pkg/sentry/syscalls/linux/sys_stat_arm64.go
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//+build arm64
-
-package linux
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/usermem"
-)
-
-// copyOutStat copies the attributes (sattr, uattr) to the struct stat at
-// address dst in t's address space. It encodes the stat struct to bytes
-// manually, as stat() is a very common syscall for many applications, and
-// t.CopyObjectOut has noticeable performance impact due to its many slice
-// allocations and use of reflection.
-func copyOutStat(t *kernel.Task, dst usermem.Addr, sattr fs.StableAttr, uattr fs.UnstableAttr) error {
- b := t.CopyScratchBuffer(int(linux.SizeOfStat))[:0]
-
- // Dev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.DeviceID))
- // Ino (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.InodeID))
- // Mode (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, sattr.Type.LinuxType()|uint32(uattr.Perms.LinuxMode()))
- // Nlink (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Links))
- // UID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
- // GID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()))
- // Rdev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)))
- // Padding (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, 0)
- // Size (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Size))
- // Blksize (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(sattr.BlockSize))
- // Padding (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, 0)
- // Blocks (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Usage/512))
-
- // ATime
- atime := uattr.AccessTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Nsec))
-
- // MTime
- mtime := uattr.ModificationTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Nsec))
-
- // CTime
- ctime := uattr.StatusChangeTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Nsec))
-
- _, err := t.CopyOutBytes(dst, b)
- return err
-}
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
index 3e55235bd..5ad465ae3 100644
--- a/pkg/sentry/syscalls/linux/sys_sync.go
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Sync implements linux system call sync(2).
func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
t.MountNamespace().SyncAll(t)
@@ -135,3 +137,5 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
+
+// LINT.ThenChange(vfs2/sync.go)
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 0c9e2255d..00915fdde 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/loader"
@@ -119,7 +120,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
defer root.DecRef()
var wd *fs.Dirent
- var executable *fs.File
+ var executable fsbridge.File
var closeOnExec bool
if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) {
// Even if the pathname is absolute, we may still need the wd
@@ -136,7 +137,15 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
closeOnExec = fdFlags.CloseOnExec
if atEmptyPath && len(pathname) == 0 {
- executable = f
+ // TODO(gvisor.dev/issue/160): Linux requires only execute permission,
+ // not read. However, our backing filesystems may prevent us from reading
+ // the file without read permission. Additionally, a task with a
+ // non-readable executable has additional constraints on access via
+ // ptrace and procfs.
+ if err := f.Dirent.Inode.CheckPermission(t, fs.PermMask{Read: true, Execute: true}); err != nil {
+ return 0, nil, err
+ }
+ executable = fsbridge.NewFSFile(f)
} else {
wd = f.Dirent
wd.IncRef()
@@ -152,9 +161,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
// Load the new TaskContext.
remainingTraversals := uint(linux.MaxSymlinkTraversals)
loadArgs := loader.LoadArgs{
- Mounts: t.MountNamespace(),
- Root: root,
- WorkingDirectory: wd,
+ Opener: fsbridge.NewFSLookup(t.MountNamespace(), root, wd),
RemainingTraversals: &remainingTraversals,
ResolveFinal: resolveFinal,
Filename: pathname,
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index aba892939..506ee54ce 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskWrite contains events that can be triggered on writes.
//
@@ -358,3 +360,5 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index 9d8140b8a..2de5e3422 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// GetXattr implements linux syscall getxattr(2).
func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return getXattrFromPath(t, args, true)
@@ -418,3 +420,5 @@ func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error {
return d.Inode.RemoveXattr(t, d, name)
}
+
+// LINT.ThenChange(vfs2/xattr.go)
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 6b8a00b6e..f51761e81 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -5,18 +5,44 @@ package(licenses = ["notice"])
go_library(
name = "vfs2",
srcs = [
+ "epoll.go",
+ "epoll_unsafe.go",
+ "execve.go",
+ "fd.go",
+ "filesystem.go",
+ "fscontext.go",
+ "getdents.go",
+ "ioctl.go",
"linux64.go",
"linux64_override_amd64.go",
"linux64_override_arm64.go",
- "sys_read.go",
+ "mmap.go",
+ "path.go",
+ "poll.go",
+ "read_write.go",
+ "setstat.go",
+ "stat.go",
+ "sync.go",
+ "xattr.go",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
+ "//pkg/abi/linux",
+ "//pkg/fspath",
+ "//pkg/gohacks",
"//pkg/sentry/arch",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/memmap",
"//pkg/sentry/syscalls",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
new file mode 100644
index 000000000..d6cb0e79a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -0,0 +1,225 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EpollCreate1 implements Linux syscall epoll_create1(2).
+func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^linux.EPOLL_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCreate implements Linux syscall epoll_create(2).
+func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+
+ // "Since Linux 2.6.8, the size argument is ignored, but must be greater
+ // than zero" - epoll_create(2)
+ if size <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCtl implements Linux syscall epoll_ctl(2).
+func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ op := args[1].Int()
+ fd := args[2].Int()
+ eventAddr := args[3].Pointer()
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+ if epfile == file {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var event linux.EpollEvent
+ switch op {
+ case linux.EPOLL_CTL_ADD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.AddInterest(file, fd, event)
+ case linux.EPOLL_CTL_DEL:
+ return 0, nil, ep.DeleteInterest(file, fd)
+ case linux.EPOLL_CTL_MOD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.ModifyInterest(file, fd, event)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// EpollWait implements Linux syscall epoll_wait(2).
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeout := int(args[3].Int())
+
+ const _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
+ if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
+ return 0, nil, syserror.EINVAL
+ }
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent,
+ // maxEvents), so that the buffer can be allocated on the stack.
+ var (
+ events [16]linux.EpollEvent
+ total int
+ ch chan struct{}
+ haveDeadline bool
+ deadline ktime.Time
+ )
+ for {
+ batchEvents := len(events)
+ if batchEvents > maxEvents {
+ batchEvents = maxEvents
+ }
+ n := ep.ReadEvents(events[:batchEvents])
+ maxEvents -= n
+ if n != 0 {
+ // Copy what we read out.
+ copiedEvents, err := copyOutEvents(t, eventsAddr, events[:n])
+ eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
+ total += copiedEvents
+ if err != nil {
+ if total != 0 {
+ return uintptr(total), nil, nil
+ }
+ return 0, nil, err
+ }
+ // If we've filled the application's event buffer, we're done.
+ if maxEvents == 0 {
+ return uintptr(total), nil, nil
+ }
+ // Loop if we read a full batch, under the expectation that there
+ // may be more events to read.
+ if n == batchEvents {
+ continue
+ }
+ }
+ // We get here if n != batchEvents. If we read any number of events
+ // (just now, or in a previous iteration of this loop), or if timeout
+ // is 0 (such that epoll_wait should be non-blocking), return the
+ // events we've read so far to the application.
+ if total != 0 || timeout == 0 {
+ return uintptr(total), nil, nil
+ }
+ // In the first iteration of this loop, register with the epoll
+ // instance for readability events, but then immediately continue the
+ // loop since we need to retry ReadEvents() before blocking. In all
+ // subsequent iterations, block until events are available, the timeout
+ // expires, or an interrupt arrives.
+ if ch == nil {
+ var w waiter.Entry
+ w, ch = waiter.NewChannelEntry(nil)
+ epfile.EventRegister(&w, waiter.EventIn)
+ defer epfile.EventUnregister(&w)
+ } else {
+ // Set up the timer if a timeout was specified.
+ if timeout > 0 && !haveDeadline {
+ timeoutDur := time.Duration(timeout) * time.Millisecond
+ deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
+ haveDeadline = true
+ }
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ // total must be 0 since otherwise we would have returned
+ // above.
+ return 0, nil, err
+ }
+ }
+ }
+}
+
+// EpollPwait implements Linux syscall epoll_pwait(2).
+func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ return EpollWait(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
new file mode 100644
index 000000000..825f325bf
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "reflect"
+ "runtime"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const sizeofEpollEvent = int(unsafe.Sizeof(linux.EpollEvent{}))
+
+func copyOutEvents(t *kernel.Task, addr usermem.Addr, events []linux.EpollEvent) (int, error) {
+ if len(events) == 0 {
+ return 0, nil
+ }
+ // Cast events to a byte slice for copying.
+ var eventBytes []byte
+ eventBytesHdr := (*reflect.SliceHeader)(unsafe.Pointer(&eventBytes))
+ eventBytesHdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(&events[0])))
+ eventBytesHdr.Len = len(events) * sizeofEpollEvent
+ eventBytesHdr.Cap = len(events) * sizeofEpollEvent
+ copiedBytes, err := t.CopyOutBytes(addr, eventBytes)
+ runtime.KeepAlive(events)
+ copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
+ return copiedEvents, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
new file mode 100644
index 000000000..aef0078a8
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Execve implements linux syscall execve(2).
+func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathnameAddr := args[0].Pointer()
+ argvAddr := args[1].Pointer()
+ envvAddr := args[2].Pointer()
+ return execveat(t, linux.AT_FDCWD, pathnameAddr, argvAddr, envvAddr, 0 /* flags */)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
+ if err != nil {
+ return 0, nil, err
+ }
+ var argv, envv []string
+ if argvAddr != 0 {
+ var err error
+ argv, err = t.CopyInVector(argvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if envvAddr != 0 {
+ var err error
+ envv, err = t.CopyInVector(envvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ var executable fsbridge.File
+ closeOnExec := false
+ if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute {
+ // We must open the executable ourselves since dirfd is used as the
+ // starting point while resolving path, but the task working directory
+ // is used as the starting point while resolving interpreters (Linux:
+ // fs/binfmt_script.c:load_script() => fs/exec.c:open_exec() =>
+ // do_open_execat(fd=AT_FDCWD)), and the loader package is currently
+ // incapable of handling this correctly.
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ start := dirfile.VirtualDentry()
+ start.IncRef()
+ dirfile.DecRef()
+ closeOnExec = dirfileFlags.CloseOnExec
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ })
+ start.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ executable = fsbridge.NewVFSFile(file)
+ }
+
+ // Load the new TaskContext.
+ mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
+ defer mntns.DecRef()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef()
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Opener: fsbridge.NewVFSLookup(mntns, root, wd),
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ Filename: pathname,
+ File: executable,
+ CloseOnExec: closeOnExec,
+ Argv: argv,
+ Envv: envv,
+ Features: t.Arch().FeatureSet(),
+ }
+
+ tc, se := t.Kernel().LoadTaskImage(t, loadArgs)
+ if se != nil {
+ return 0, nil, se.ToError()
+ }
+
+ ctrl, err := t.Execve(tc)
+ return 0, ctrl, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
new file mode 100644
index 000000000..3afcea665
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -0,0 +1,147 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Close implements Linux syscall close(2).
+func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ // Note that Remove provides a reference on the file that we may use to
+ // flush. It is still active until we drop the final reference below
+ // (and other reference-holding operations complete).
+ _, file := t.FDTable().Remove(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.OnClose(t)
+ return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file)
+}
+
+// Dup implements Linux syscall dup(2).
+func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, syserror.EMFILE
+ }
+ return uintptr(newFD), nil, nil
+}
+
+// Dup2 implements Linux syscall dup2(2).
+func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+
+ if oldfd == newfd {
+ // As long as oldfd is valid, dup2() does nothing and returns newfd.
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ file.DecRef()
+ return uintptr(newfd), nil, nil
+ }
+
+ return dup3(t, oldfd, newfd, 0)
+}
+
+// Dup3 implements Linux syscall dup3(2).
+func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+ flags := args[2].Uint()
+
+ if oldfd == newfd {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return dup3(t, oldfd, newfd, flags)
+}
+
+func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^linux.O_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(newfd), nil, nil
+}
+
+// Fcntl implements linux syscall fcntl(2).
+func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ cmd := args[1].Int()
+
+ file, flags := t.FDTable().GetVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
+ minfd := args[2].Int()
+ fd, err := t.NewFDFromVFS2(minfd, file, kernel.FDFlags{
+ CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+ case linux.F_GETFD:
+ return uintptr(flags.ToLinuxFDFlags()), nil, nil
+ case linux.F_SETFD:
+ flags := args[2].Uint()
+ t.FDTable().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: flags&linux.FD_CLOEXEC != 0,
+ })
+ return 0, nil, nil
+ case linux.F_GETFL:
+ return uintptr(file.StatusFlags()), nil, nil
+ case linux.F_SETFL:
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
+ default:
+ // TODO(gvisor.dev/issue/1623): Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
new file mode 100644
index 000000000..fc5ceea4c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -0,0 +1,326 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Link implements Linux syscall link(2).
+func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, linkat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Linkat implements Linux syscall linkat(2).
+func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+ if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) {
+ return syserror.ENOENT
+ }
+
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_FOLLOW != 0))
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop)
+}
+
+// Mkdir implements Linux syscall mkdir(2).
+func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, mkdirat(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Mkdirat implements Linux syscall mkdirat(2).
+func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, mkdirat(t, dirfd, addr, mode)
+}
+
+func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+}
+
+// Mknod implements Linux syscall mknod(2).
+func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ dev := args[2].Uint()
+ return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev)
+}
+
+// Mknodat implements Linux syscall mknodat(2).
+func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ dev := args[3].Uint()
+ return 0, nil, mknodat(t, dirfd, addr, mode, dev)
+}
+
+func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ major, minor := linux.DecodeDeviceID(dev)
+ return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{
+ Mode: linux.FileMode(mode &^ t.FSContext().Umask()),
+ DevMajor: uint32(major),
+ DevMinor: minor,
+ })
+}
+
+// Open implements Linux syscall open(2).
+func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+ mode := args[2].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, flags, mode)
+}
+
+// Openat implements Linux syscall openat(2).
+func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ flags := args[2].Uint()
+ mode := args[3].ModeT()
+ return openat(t, dirfd, addr, flags, mode)
+}
+
+// Creat implements Linux syscall creat(2).
+func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode)
+}
+
+func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.O_NOFOLLOW == 0))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
+ Flags: flags,
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ return uintptr(fd), nil, err
+}
+
+// Rename implements Linux syscall rename(2).
+func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, renameat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Renameat implements Linux syscall renameat(2).
+func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, 0 /* flags */)
+}
+
+// Renameat2 implements Linux syscall renameat2(2).
+func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Uint()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error {
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ // "If oldpath refers to a symbolic link, the link is renamed" - rename(2)
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{
+ Flags: flags,
+ })
+}
+
+// Rmdir implements Linux syscall rmdir(2).
+func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlink implements Linux syscall unlink(2).
+func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlinkat implements Linux syscall unlinkat(2).
+func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if flags&^linux.AT_REMOVEDIR != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if flags&linux.AT_REMOVEDIR != 0 {
+ return 0, nil, rmdirat(t, dirfd, pathAddr)
+ }
+ return 0, nil, unlinkat(t, dirfd, pathAddr)
+}
+
+// Symlink implements Linux syscall symlink(2).
+func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ linkpathAddr := args[1].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, linux.AT_FDCWD, linkpathAddr)
+}
+
+// Symlinkat implements Linux syscall symlinkat(2).
+func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ newdirfd := args[1].Int()
+ linkpathAddr := args[2].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr)
+}
+
+func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error {
+ target, err := t.CopyInString(targetAddr, linux.PATH_MAX)
+ if err != nil {
+ return err
+ }
+ linkpath, err := copyInPath(t, linkpathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, newdirfd, linkpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
new file mode 100644
index 000000000..317409a18
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Getcwd implements Linux syscall getcwd(2).
+func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ size := args[1].SizeT()
+
+ root := t.FSContext().RootDirectoryVFS2()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd)
+ root.DecRef()
+ wd.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Note this is >= because we need a terminator.
+ if uint(len(s)) >= size {
+ return 0, nil, syserror.ERANGE
+ }
+
+ // Construct a byte slice containing a NUL terminator.
+ buf := t.CopyScratchBuffer(len(s) + 1)
+ copy(buf, s)
+ buf[len(buf)-1] = 0
+
+ // Write the pathname slice.
+ n, err := t.CopyOutBytes(addr, buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Chdir implements Linux syscall chdir(2).
+func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Fchdir implements Linux syscall fchdir(2).
+func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Chroot implements Linux syscall chroot(2).
+func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ if !t.HasCapability(linux.CAP_SYS_CHROOT) {
+ return 0, nil, syserror.EPERM
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetRootDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
new file mode 100644
index 000000000..ddc140b65
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Getdents implements Linux syscall getdents(2).
+func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, false /* isGetdents64 */)
+}
+
+// Getdents64 implements Linux syscall getdents64(2).
+func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, true /* isGetdents64 */)
+}
+
+func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ cb := getGetdentsCallback(t, addr, size, isGetdents64)
+ err := file.IterDirents(t, cb)
+ n := size - cb.remaining
+ putGetdentsCallback(cb)
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+type getdentsCallback struct {
+ t *kernel.Task
+ addr usermem.Addr
+ remaining int
+ isGetdents64 bool
+}
+
+var getdentsCallbackPool = sync.Pool{
+ New: func() interface{} {
+ return &getdentsCallback{}
+ },
+}
+
+func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback {
+ cb := getdentsCallbackPool.Get().(*getdentsCallback)
+ *cb = getdentsCallback{
+ t: t,
+ addr: addr,
+ remaining: size,
+ isGetdents64: isGetdents64,
+ }
+ return cb
+}
+
+func putGetdentsCallback(cb *getdentsCallback) {
+ cb.t = nil
+ getdentsCallbackPool.Put(cb)
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
+ var buf []byte
+ if cb.isGetdents64 {
+ // struct linux_dirent64 {
+ // ino64_t d_ino; /* 64-bit inode number */
+ // off64_t d_off; /* 64-bit offset to next structure */
+ // unsigned short d_reclen; /* Size of this dirent */
+ // unsigned char d_type; /* File type */
+ // char d_name[]; /* Filename (null-terminated) */
+ // };
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ buf[18] = dirent.Type
+ copy(buf[19:], dirent.Name)
+ buf[size-1] = 0 // NUL terminator
+ } else {
+ // struct linux_dirent {
+ // unsigned long d_ino; /* Inode number */
+ // unsigned long d_off; /* Offset to next linux_dirent */
+ // unsigned short d_reclen; /* Length of this linux_dirent */
+ // char d_name[]; /* Filename (null-terminated) */
+ // /* length is actually (d_reclen - 2 -
+ // offsetof(struct linux_dirent, d_name)) */
+ // /*
+ // char pad; // Zero padding byte
+ // char d_type; // File type (only since Linux
+ // // 2.6.4); offset is (d_reclen - 1)
+ // */
+ // };
+ if cb.t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
+ }
+ size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ copy(buf[18:], dirent.Name)
+ buf[size-3] = 0 // NUL terminator
+ buf[size-2] = 0 // zero padding byte
+ buf[size-1] = dirent.Type
+ }
+ n, err := cb.t.CopyOutBytes(cb.addr, buf)
+ if err != nil {
+ // Don't report partially-written dirents by advancing cb.addr or
+ // cb.remaining.
+ return err
+ }
+ cb.addr += usermem.Addr(n)
+ cb.remaining -= n
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
new file mode 100644
index 000000000..5a2418da9
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Ioctl implements Linux syscall ioctl(2).
+func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ ret, err := file.Ioctl(t, t.MemoryManager(), args)
+ return ret, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
index c134714ee..7d220bc20 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package vfs2
import (
@@ -22,4 +24,142 @@ import (
// Override syscall table to add syscalls implementations from this package.
func Override(table map[uintptr]kernel.Syscall) {
table[0] = syscalls.Supported("read", Read)
+ table[1] = syscalls.Supported("write", Write)
+ table[2] = syscalls.Supported("open", Open)
+ table[3] = syscalls.Supported("close", Close)
+ table[4] = syscalls.Supported("stat", Stat)
+ table[5] = syscalls.Supported("fstat", Fstat)
+ table[6] = syscalls.Supported("lstat", Lstat)
+ table[7] = syscalls.Supported("poll", Poll)
+ table[8] = syscalls.Supported("lseek", Lseek)
+ table[9] = syscalls.Supported("mmap", Mmap)
+ table[16] = syscalls.Supported("ioctl", Ioctl)
+ table[17] = syscalls.Supported("pread64", Pread64)
+ table[18] = syscalls.Supported("pwrite64", Pwrite64)
+ table[19] = syscalls.Supported("readv", Readv)
+ table[20] = syscalls.Supported("writev", Writev)
+ table[21] = syscalls.Supported("access", Access)
+ delete(table, 22) // pipe
+ table[23] = syscalls.Supported("select", Select)
+ table[32] = syscalls.Supported("dup", Dup)
+ table[33] = syscalls.Supported("dup2", Dup2)
+ delete(table, 40) // sendfile
+ delete(table, 41) // socket
+ delete(table, 42) // connect
+ delete(table, 43) // accept
+ delete(table, 44) // sendto
+ delete(table, 45) // recvfrom
+ delete(table, 46) // sendmsg
+ delete(table, 47) // recvmsg
+ delete(table, 48) // shutdown
+ delete(table, 49) // bind
+ delete(table, 50) // listen
+ delete(table, 51) // getsockname
+ delete(table, 52) // getpeername
+ delete(table, 53) // socketpair
+ delete(table, 54) // setsockopt
+ delete(table, 55) // getsockopt
+ table[59] = syscalls.Supported("execve", Execve)
+ table[72] = syscalls.Supported("fcntl", Fcntl)
+ delete(table, 73) // flock
+ table[74] = syscalls.Supported("fsync", Fsync)
+ table[75] = syscalls.Supported("fdatasync", Fdatasync)
+ table[76] = syscalls.Supported("truncate", Truncate)
+ table[77] = syscalls.Supported("ftruncate", Ftruncate)
+ table[78] = syscalls.Supported("getdents", Getdents)
+ table[79] = syscalls.Supported("getcwd", Getcwd)
+ table[80] = syscalls.Supported("chdir", Chdir)
+ table[81] = syscalls.Supported("fchdir", Fchdir)
+ table[82] = syscalls.Supported("rename", Rename)
+ table[83] = syscalls.Supported("mkdir", Mkdir)
+ table[84] = syscalls.Supported("rmdir", Rmdir)
+ table[85] = syscalls.Supported("creat", Creat)
+ table[86] = syscalls.Supported("link", Link)
+ table[87] = syscalls.Supported("unlink", Unlink)
+ table[88] = syscalls.Supported("symlink", Symlink)
+ table[89] = syscalls.Supported("readlink", Readlink)
+ table[90] = syscalls.Supported("chmod", Chmod)
+ table[91] = syscalls.Supported("fchmod", Fchmod)
+ table[92] = syscalls.Supported("chown", Chown)
+ table[93] = syscalls.Supported("fchown", Fchown)
+ table[94] = syscalls.Supported("lchown", Lchown)
+ table[132] = syscalls.Supported("utime", Utime)
+ table[133] = syscalls.Supported("mknod", Mknod)
+ table[137] = syscalls.Supported("statfs", Statfs)
+ table[138] = syscalls.Supported("fstatfs", Fstatfs)
+ table[161] = syscalls.Supported("chroot", Chroot)
+ table[162] = syscalls.Supported("sync", Sync)
+ delete(table, 165) // mount
+ delete(table, 166) // umount2
+ delete(table, 187) // readahead
+ table[188] = syscalls.Supported("setxattr", Setxattr)
+ table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
+ table[190] = syscalls.Supported("fsetxattr", Fsetxattr)
+ table[191] = syscalls.Supported("getxattr", Getxattr)
+ table[192] = syscalls.Supported("lgetxattr", Lgetxattr)
+ table[193] = syscalls.Supported("fgetxattr", Fgetxattr)
+ table[194] = syscalls.Supported("listxattr", Listxattr)
+ table[195] = syscalls.Supported("llistxattr", Llistxattr)
+ table[196] = syscalls.Supported("flistxattr", Flistxattr)
+ table[197] = syscalls.Supported("removexattr", Removexattr)
+ table[198] = syscalls.Supported("lremovexattr", Lremovexattr)
+ table[199] = syscalls.Supported("fremovexattr", Fremovexattr)
+ delete(table, 206) // io_setup
+ delete(table, 207) // io_destroy
+ delete(table, 208) // io_getevents
+ delete(table, 209) // io_submit
+ delete(table, 210) // io_cancel
+ table[213] = syscalls.Supported("epoll_create", EpollCreate)
+ table[217] = syscalls.Supported("getdents64", Getdents64)
+ delete(table, 221) // fdavise64
+ table[232] = syscalls.Supported("epoll_wait", EpollWait)
+ table[233] = syscalls.Supported("epoll_ctl", EpollCtl)
+ table[235] = syscalls.Supported("utimes", Utimes)
+ delete(table, 253) // inotify_init
+ delete(table, 254) // inotify_add_watch
+ delete(table, 255) // inotify_rm_watch
+ table[257] = syscalls.Supported("openat", Openat)
+ table[258] = syscalls.Supported("mkdirat", Mkdirat)
+ table[259] = syscalls.Supported("mknodat", Mknodat)
+ table[260] = syscalls.Supported("fchownat", Fchownat)
+ table[261] = syscalls.Supported("futimens", Futimens)
+ table[262] = syscalls.Supported("newfstatat", Newfstatat)
+ table[263] = syscalls.Supported("unlinkat", Unlinkat)
+ table[264] = syscalls.Supported("renameat", Renameat)
+ table[265] = syscalls.Supported("linkat", Linkat)
+ table[266] = syscalls.Supported("symlinkat", Symlinkat)
+ table[267] = syscalls.Supported("readlinkat", Readlinkat)
+ table[268] = syscalls.Supported("fchmodat", Fchmodat)
+ table[269] = syscalls.Supported("faccessat", Faccessat)
+ table[270] = syscalls.Supported("pselect", Pselect)
+ table[271] = syscalls.Supported("ppoll", Ppoll)
+ delete(table, 275) // splice
+ delete(table, 276) // tee
+ table[277] = syscalls.Supported("sync_file_range", SyncFileRange)
+ table[280] = syscalls.Supported("utimensat", Utimensat)
+ table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
+ delete(table, 282) // signalfd
+ delete(table, 283) // timerfd_create
+ delete(table, 284) // eventfd
+ delete(table, 285) // fallocate
+ delete(table, 286) // timerfd_settime
+ delete(table, 287) // timerfd_gettime
+ delete(table, 288) // accept4
+ delete(table, 289) // signalfd4
+ delete(table, 290) // eventfd2
+ table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
+ table[292] = syscalls.Supported("dup3", Dup3)
+ delete(table, 293) // pipe2
+ delete(table, 294) // inotify_init1
+ table[295] = syscalls.Supported("preadv", Preadv)
+ table[296] = syscalls.Supported("pwritev", Pwritev)
+ delete(table, 299) // recvmmsg
+ table[306] = syscalls.Supported("syncfs", Syncfs)
+ delete(table, 307) // sendmmsg
+ table[316] = syscalls.Supported("renameat2", Renameat2)
+ delete(table, 319) // memfd_create
+ table[322] = syscalls.Supported("execveat", Execveat)
+ table[327] = syscalls.Supported("preadv2", Preadv2)
+ table[328] = syscalls.Supported("pwritev2", Pwritev2)
+ table[332] = syscalls.Supported("statx", Statx)
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
index 6af5c400f..a6b367468 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package vfs2
import (
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
new file mode 100644
index 000000000..60a43f0a0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mmap implements Linux syscall mmap(2).
+func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ prot := args[2].Int()
+ flags := args[3].Int()
+ fd := args[4].Int()
+ fixed := flags&linux.MAP_FIXED != 0
+ private := flags&linux.MAP_PRIVATE != 0
+ shared := flags&linux.MAP_SHARED != 0
+ anon := flags&linux.MAP_ANONYMOUS != 0
+ map32bit := flags&linux.MAP_32BIT != 0
+
+ // Require exactly one of MAP_PRIVATE and MAP_SHARED.
+ if private == shared {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := memmap.MMapOpts{
+ Length: args[1].Uint64(),
+ Offset: args[5].Uint64(),
+ Addr: args[0].Pointer(),
+ Fixed: fixed,
+ Unmap: fixed,
+ Map32Bit: map32bit,
+ Private: private,
+ Perms: usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ },
+ MaxPerms: usermem.AnyAccess,
+ GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
+ Precommit: linux.MAP_POPULATE&flags != 0,
+ }
+ if linux.MAP_LOCKED&flags != 0 {
+ opts.MLockMode = memmap.MLockEager
+ }
+ defer func() {
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef()
+ }
+ }()
+
+ if !anon {
+ // Convert the passed FD to a file reference.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // mmap unconditionally requires that the FD is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EACCES
+ }
+ // MAP_SHARED requires that the FD be writable for PROT_WRITE.
+ if shared && !file.IsWritable() {
+ opts.MaxPerms.Write = false
+ }
+
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ rv, err := t.MemoryManager().MMap(t, opts)
+ return uintptr(rv), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
new file mode 100644
index 000000000..97da6c647
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) {
+ pathname, err := t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return fspath.Path{}, err
+ }
+ return fspath.Parse(pathname), nil
+}
+
+type taskPathOperation struct {
+ pop vfs.PathOperation
+ haveStartRef bool
+}
+
+func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink) (taskPathOperation, error) {
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ haveStartRef := false
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ root.DecRef()
+ return taskPathOperation{}, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ haveStartRef = true
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ root.DecRef()
+ return taskPathOperation{}, syserror.EBADF
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ haveStartRef = true
+ dirfile.DecRef()
+ }
+ }
+ return taskPathOperation{
+ pop: vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ },
+ haveStartRef: haveStartRef,
+ }, nil
+}
+
+func (tpop *taskPathOperation) Release() {
+ tpop.pop.Root.DecRef()
+ if tpop.haveStartRef {
+ tpop.pop.Start.DecRef()
+ tpop.haveStartRef = false
+ }
+}
+
+type shouldAllowEmptyPath bool
+
+const (
+ disallowEmptyPath shouldAllowEmptyPath = false
+ allowEmptyPath shouldAllowEmptyPath = true
+)
+
+type shouldFollowFinalSymlink bool
+
+const (
+ nofollowFinalSymlink shouldFollowFinalSymlink = false
+ followFinalSymlink shouldFollowFinalSymlink = true
+)
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
new file mode 100644
index 000000000..dbf4882da
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -0,0 +1,584 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fileCap is the maximum allowable files for poll & select. This has no
+// equivalent in Linux; it exists in gVisor since allocation failure in Go is
+// unrecoverable.
+const fileCap = 1024 * 1024
+
+// Masks for "readable", "writable", and "exceptional" events as defined by
+// select(2).
+const (
+ // selectReadEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLIN_SET.
+ selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
+
+ // selectWriteEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLOUT_SET.
+ selectWriteEvents = linux.POLLOUT | linux.POLLERR
+
+ // selectExceptEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLEX_SET.
+ selectExceptEvents = linux.POLLPRI
+)
+
+// pollState tracks the associated file description and waiter of a PollFD.
+type pollState struct {
+ file *vfs.FileDescription
+ waiter waiter.Entry
+}
+
+// initReadiness gets the current ready mask for the file represented by the FD
+// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
+// used to register with the file for event notifications, and a reference to
+// the file is stored in "state".
+func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
+ if pfd.FD < 0 {
+ pfd.REvents = 0
+ return
+ }
+
+ file := t.GetFileVFS2(pfd.FD)
+ if file == nil {
+ pfd.REvents = linux.POLLNVAL
+ return
+ }
+
+ if ch == nil {
+ defer file.DecRef()
+ } else {
+ state.file = file
+ state.waiter, _ = waiter.NewChannelEntry(ch)
+ file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ }
+
+ r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ pfd.REvents = int16(r.ToLinux()) & pfd.Events
+}
+
+// releaseState releases all the pollState in "state".
+func releaseState(state []pollState) {
+ for i := range state {
+ if state[i].file != nil {
+ state[i].file.EventUnregister(&state[i].waiter)
+ state[i].file.DecRef()
+ }
+ }
+}
+
+// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
+// when "timeout" is greater than zero.
+//
+// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
+// positive if interrupted by a signal.
+func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
+ var ch chan struct{}
+ if timeout != 0 {
+ ch = make(chan struct{}, 1)
+ }
+
+ // Register for event notification in the files involved if we may
+ // block (timeout not zero). Once we find a file that has a non-zero
+ // result, we stop registering for events but still go through all files
+ // to get their ready masks.
+ state := make([]pollState, len(pfd))
+ defer releaseState(state)
+ n := uintptr(0)
+ for i := range pfd {
+ initReadiness(t, &pfd[i], &state[i], ch)
+ if pfd[i].REvents != 0 {
+ n++
+ ch = nil
+ }
+ }
+
+ if timeout == 0 {
+ return timeout, n, nil
+ }
+
+ haveTimeout := timeout >= 0
+
+ for n == 0 {
+ var err error
+ // Wait for a notification.
+ timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout)
+ if err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ return timeout, 0, err
+ }
+
+ // We got notified, count how many files are ready. If none,
+ // then this was a spurious notification, and we just go back
+ // to sleep with the remaining timeout.
+ for i := range state {
+ if state[i].file == nil {
+ continue
+ }
+
+ r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
+ rl := int16(r.ToLinux()) & pfd[i].Events
+ if rl != 0 {
+ pfd[i].REvents = rl
+ n++
+ }
+ }
+ }
+
+ return timeout, n, nil
+}
+
+// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
+func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+ if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
+ return nil, syserror.EINVAL
+ }
+
+ pfd := make([]linux.PollFD, nfds)
+ if nfds > 0 {
+ if _, err := t.CopyIn(addr, &pfd); err != nil {
+ return nil, err
+ }
+ }
+
+ return pfd, nil
+}
+
+func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+ pfd, err := copyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return timeout, 0, err
+ }
+
+ // Compatibility warning: Linux adds POLLHUP and POLLERR just before
+ // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
+ // polling, changing event masks here is an application-visible difference.
+ // (Linux also doesn't copy out event masks at all, only revents.)
+ for i := range pfd {
+ pfd[i].Events |= linux.POLLHUP | linux.POLLERR
+ }
+ remainingTimeout, n, err := pollBlock(t, pfd, timeout)
+ err = syserror.ConvertIntr(err, syserror.EINTR)
+
+ // The poll entries are copied out regardless of whether
+ // any are set or not. This aligns with the Linux behavior.
+ if nfds > 0 && err == nil {
+ if _, err := t.CopyOut(addr, pfd); err != nil {
+ return remainingTimeout, 0, err
+ }
+ }
+
+ return remainingTimeout, n, err
+}
+
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+ if nfds < 0 || nfds > fileCap {
+ return 0, syserror.EINVAL
+ }
+
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+
+ // Count how many FDs are actually being requested so that we can build
+ // a PollFD array.
+ fdCount := 0
+ for i := 0; i < nBytes; i++ {
+ v := r[i] | w[i] | e[i]
+ for v != 0 {
+ v &= (v - 1)
+ fdCount++
+ }
+ }
+
+ // Build the PollFD array.
+ pfd := make([]linux.PollFD, 0, fdCount)
+ var fd int32
+ for i := 0; i < nBytes; i++ {
+ rV, wV, eV := r[i], w[i], e[i]
+ v := rV | wV | eV
+ m := byte(1)
+ for j := 0; j < 8; j++ {
+ if (v & m) != 0 {
+ // Make sure the fd is valid and decrement the reference
+ // immediately to ensure we don't leak. Note, another thread
+ // might be about to close fd. This is racy, but that's
+ // OK. Linux is racy in the same way.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ file.DecRef()
+
+ var mask int16
+ if (rV & m) != 0 {
+ mask |= selectReadEvents
+ }
+
+ if (wV & m) != 0 {
+ mask |= selectWriteEvents
+ }
+
+ if (eV & m) != 0 {
+ mask |= selectExceptEvents
+ }
+
+ pfd = append(pfd, linux.PollFD{
+ FD: fd,
+ Events: mask,
+ })
+ }
+
+ fd++
+ m <<= 1
+ }
+ }
+
+ // Do the syscall, then count the number of bits set.
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
+ return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ // r, w, and e are currently event mask bitsets; unset bits corresponding
+ // to events that *didn't* occur.
+ bitSetCount := uintptr(0)
+ for idx := range pfd {
+ events := pfd[idx].REvents
+ i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
+ m := byte(1) << j
+ if r[i]&m != 0 {
+ if (events & selectReadEvents) != 0 {
+ bitSetCount++
+ } else {
+ r[i] &^= m
+ }
+ }
+ if w[i]&m != 0 {
+ if (events & selectWriteEvents) != 0 {
+ bitSetCount++
+ } else {
+ w[i] &^= m
+ }
+ }
+ if e[i]&m != 0 {
+ if (events & selectExceptEvents) != 0 {
+ bitSetCount++
+ } else {
+ e[i] &^= m
+ }
+ }
+ }
+
+ // Copy updated vectors back.
+ if readFDs != 0 {
+ if _, err := t.CopyOut(readFDs, r); err != nil {
+ return 0, err
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyOut(writeFDs, w); err != nil {
+ return 0, err
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ return 0, err
+ }
+ }
+
+ return bitSetCount, nil
+}
+
+// timeoutRemaining returns the amount of time remaining for the specified
+// timeout or 0 if it has elapsed.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
+ now := t.Kernel().MonotonicClock().Now()
+ remaining := timeout - now.Sub(startNs)
+ if remaining < 0 {
+ remaining = 0
+ }
+ return remaining
+}
+
+// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
+ return tsRemaining.CopyOut(t, timespecAddr)
+}
+
+// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
+ return tvRemaining.CopyOut(t, timevalAddr)
+}
+
+// pollRestartBlock encapsulates the state required to restart poll(2) via
+// restart_syscall(2).
+//
+// +stateify savable
+type pollRestartBlock struct {
+ pfdAddr usermem.Addr
+ nfds uint
+ timeout time.Duration
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return poll(t, p.pfdAddr, p.nfds, p.timeout)
+}
+
+func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+ remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ // On an interrupt poll(2) is restarted with the remaining timeout.
+ if err == syserror.EINTR {
+ t.SetSyscallRestartBlock(&pollRestartBlock{
+ pfdAddr: pfdAddr,
+ nfds: nfds,
+ timeout: remainingTimeout,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+ }
+ return n, err
+}
+
+// Poll implements linux syscall poll(2).
+func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timeout := time.Duration(args[2].Int()) * time.Millisecond
+ n, err := poll(t, pfdAddr, nfds, timeout)
+ return n, nil, err
+}
+
+// Ppoll implements linux syscall ppoll(2).
+func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timespecAddr := args[2].Pointer()
+ maskAddr := args[3].Pointer()
+ maskSize := uint(args[4].Uint())
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ _, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // doPoll returns EINTR if interrupted, but ppoll is normally restartable
+ // if interrupted by something other than a signal handled by the
+ // application (i.e. returns ERESTARTNOHAND). However, if
+ // copyOutTimespecRemaining failed, then the restarted ppoll would use the
+ // wrong timeout, so the error should be left as EINTR.
+ //
+ // Note that this means that if err is nil but copyErr is not, copyErr is
+ // ignored. This is consistent with Linux.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Select implements linux syscall select(2).
+func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timevalAddr := args[4].Pointer()
+
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timevalAddr != 0 {
+ var timeval linux.Timeval
+ if err := timeval.CopyIn(t, timevalAddr); err != nil {
+ return 0, nil, err
+ }
+ if timeval.Sec < 0 || timeval.Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(timeval.ToNsecCapped())
+ }
+ startNs := t.Kernel().MonotonicClock().Now()
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Pselect implements linux syscall pselect(2).
+func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+ maskWithSizeAddr := args[5].Pointer()
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskWithSizeAddr != 0 {
+ if t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
+ }
+ var maskStruct sigSetWithSize
+ if err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
+ return 0, nil, err
+ }
+ if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// +marshal
+type sigSetWithSize struct {
+ sigsetAddr uint64
+ sizeofSigset uint64
+}
+
+// copyTimespecInToDuration copies a Timespec from the untrusted app range,
+// validates it and converts it to a Duration.
+//
+// If the Timespec is larger than what can be represented in a Duration, the
+// returned value is the maximum that Duration will allow.
+//
+// If timespecAddr is NULL, the returned value is negative.
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timespecAddr != 0 {
+ var timespec linux.Timespec
+ if err := timespec.CopyIn(t, timespecAddr); err != nil {
+ return 0, err
+ }
+ if !timespec.Valid() {
+ return 0, syserror.EINVAL
+ }
+ timeout = time.Duration(timespec.ToNsecCapped())
+ }
+ return timeout, nil
+}
+
+func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error {
+ if maskAddr == 0 {
+ return nil
+ }
+ if maskSize != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ var mask linux.SignalSet
+ if err := mask.CopyIn(t, maskAddr); err != nil {
+ return err
+ }
+ mask &^= kernel.UnblockableSignals
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
new file mode 100644
index 000000000..35f6308d6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -0,0 +1,511 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ eventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
+ eventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr
+)
+
+// Read implements Linux syscall read(2).
+func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
+}
+
+// Readv implements Linux syscall readv(2).
+func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file)
+}
+
+func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.Read(t, dst, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Read(t, dst, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pread64 implements Linux syscall pread64(2).
+func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file)
+}
+
+// Preadv implements Linux syscall preadv(2).
+func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file)
+}
+
+// Preadv2 implements Linux syscall preadv2(2).
+func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1142)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.ReadOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = read(t, file, dst, opts)
+ } else {
+ n, err = pread(t, file, dst, offset, opts)
+ }
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+}
+
+func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.PRead(t, dst, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PRead(t, dst, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Write implements Linux syscall write(2).
+func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file)
+}
+
+// Writev implements Linux syscall writev(2).
+func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file)
+}
+
+func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.Write(t, src, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Write(t, src, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pwrite64 implements Linux syscall pwrite64(2).
+func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file)
+}
+
+// Pwritev implements Linux syscall pwritev(2).
+func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file)
+}
+
+// Pwritev2 implements Linux syscall pwritev2(2).
+func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1162)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.WriteOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = write(t, file, src, opts)
+ } else {
+ n, err = pwrite(t, file, src, offset, opts)
+ }
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+}
+
+func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.PWrite(t, src, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PWrite(t, src, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Lseek implements Linux syscall lseek(2).
+func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ whence := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newoff, err := file.Seek(t, offset, whence)
+ return uintptr(newoff), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
new file mode 100644
index 000000000..9250659ff
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -0,0 +1,380 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+// Chmod implements Linux syscall chmod(2).
+func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, fchmodat(t, linux.AT_FDCWD, pathAddr, mode)
+}
+
+// Fchmodat implements Linux syscall fchmodat(2).
+func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, fchmodat(t, dirfd, pathAddr, mode)
+}
+
+func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Fchmod implements Linux syscall fchmod(2).
+func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].ModeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Chown implements Linux syscall chown(2).
+func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, 0 /* flags */)
+}
+
+// Lchown implements Linux syscall lchown(2).
+func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Fchownat implements Linux syscall fchownat(2).
+func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ owner := args[2].Int()
+ group := args[3].Int()
+ flags := args[4].Int()
+ return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags)
+}
+
+func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
+}
+
+func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vfs.SetStatOptions) error {
+ userns := t.UserNamespace()
+ if owner != -1 {
+ kuid := userns.MapToKUID(auth.UID(owner))
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_UID
+ opts.Stat.UID = uint32(kuid)
+ }
+ if group != -1 {
+ kgid := userns.MapToKGID(auth.GID(group))
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_GID
+ opts.Stat.GID = uint32(kgid)
+ }
+ return nil
+}
+
+// Fchown implements Linux syscall fchown(2).
+func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ owner := args[1].Int()
+ group := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, file.SetStat(t, opts)
+}
+
+// Truncate implements Linux syscall truncate(2).
+func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Ftruncate implements Linux syscall ftruncate(2).
+func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Utime implements Linux syscall utime(2).
+func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times linux.Utime
+ if err := times.CopyIn(t, timesAddr); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime.Sec = times.Actime
+ opts.Stat.Mtime.Sec = times.Modtime
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimes implements Linux syscall utimes(2).
+func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Usec * 1000),
+ }
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Usec * 1000),
+ }
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimensat implements Linux syscall utimensat(2).
+func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Futimens implements Linux syscall futimens(2).
+func Futimens(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ timesAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.SetStat(t, opts)
+}
+
+func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+ if timesAddr == 0 {
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ return nil
+ }
+ var times [2]linux.Timespec
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return err
+ }
+ if times[0].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_ATIME
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Nsec),
+ }
+ }
+ if times[1].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_MTIME
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Nsec),
+ }
+ }
+ return nil
+}
+
+func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.SetStat() instead of
+ // VirtualFilesystem.SetStatAt(), since the former may be able
+ // to use opened file state to expedite the SetStat.
+ err := dirfile.SetStat(t, *opts)
+ dirfile.DecRef()
+ return err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+ return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ }, opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
new file mode 100644
index 000000000..dca8d7011
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -0,0 +1,346 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Stat implements Linux syscall stat(2).
+func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, 0 /* flags */)
+}
+
+// Lstat implements Linux syscall lstat(2).
+func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Newfstatat implements Linux syscall newfstatat, which backs fstatat(2).
+func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ statAddr := args[2].Pointer()
+ flags := args[3].Int()
+ return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags)
+}
+
+func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for fstatat(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+}
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint64(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int64(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
+
+func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
+ return linux.Timespec{
+ Sec: sxts.Sec,
+ Nsec: int64(sxts.Nsec),
+ }
+}
+
+// Fstat implements Linux syscall fstat(2).
+func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ statAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ statx, err := file.Stat(t, vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return 0, nil, stat.CopyOut(t, statAddr)
+}
+
+// Statx implements Linux syscall statx(2).
+func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+ mask := args[3].Uint()
+ statxAddr := args[4].Pointer()
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: mask,
+ Sync: uint32(flags & linux.AT_STATX_SYNC_TYPE),
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for statx(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+}
+
+func userifyStatx(t *kernel.Task, statx *linux.Statx) {
+ userns := t.UserNamespace()
+ statx.UID = uint32(auth.KUID(statx.UID).In(userns).OrOverflow())
+ statx.GID = uint32(auth.KGID(statx.GID).In(userns).OrOverflow())
+}
+
+// Readlink implements Linux syscall readlink(2).
+func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+ size := args[2].SizeT()
+ return readlinkat(t, linux.AT_FDCWD, pathAddr, bufAddr, size)
+}
+
+// Access implements Linux syscall access(2).
+func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Faccessat implements Linux syscall access(2).
+func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Readlinkat implements Linux syscall mknodat(2).
+func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ bufAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ return readlinkat(t, dirfd, pathAddr, bufAddr, size)
+}
+
+func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
+ if int(size) <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ // "Since Linux 2.6.39, pathname can be an empty string, in which case the
+ // call operates on the symbolic link referred to by dirfd ..." -
+ // readlinkat(2)
+ tpop, err := getTaskPathOperation(t, dirfd, path, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if len(target) > int(size) {
+ target = target[:size]
+ }
+ n, err := t.CopyOutBytes(bufAddr, gohacks.ImmutableBytesFromString(target))
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Statfs implements Linux syscall statfs(2).
+func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
+
+// Fstatfs implements Linux syscall fstatfs(2).
+func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufAddr := args[1].Pointer()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
new file mode 100644
index 000000000..365250b0b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -0,0 +1,87 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements Linux syscall sync(2).
+func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.Kernel().VFS().SyncAllFilesystems(t)
+}
+
+// Syncfs implements Linux syscall syncfs(2).
+func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SyncFS(t)
+}
+
+// Fsync implements Linux syscall fsync(2).
+func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.Sync(t)
+}
+
+// Fdatasync implements Linux syscall fdatasync(2).
+func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of unnecessary metadata.
+ return Fsync(t, args)
+}
+
+// SyncFileRange implements Linux syscall sync_file_range(2).
+func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ nbytes := args[2].Int64()
+ flags := args[3].Uint()
+
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if nbytes < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of
+ // [offset, offset+nbytes).
+ return 0, nil, file.Sync(t)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_read.go b/pkg/sentry/syscalls/linux/vfs2/sys_read.go
deleted file mode 100644
index 7667524c7..000000000
--- a/pkg/sentry/syscalls/linux/vfs2/sys_read.go
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs2
-
-import (
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- // EventMaskRead contains events that can be triggered on reads.
- EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
-)
-
-// Read implements linux syscall read(2). Note that we try to get a buffer that
-// is exactly the size requested because some applications like qemu expect
-// they can do large reads all at once. Bug for bug. Same for other read
-// calls below.
-func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- fd := args[0].Int()
- addr := args[1].Pointer()
- size := args[2].SizeT()
-
- file := t.GetFileVFS2(fd)
- if file == nil {
- return 0, nil, syserror.EBADF
- }
- defer file.DecRef()
-
- // Check that the size is legitimate.
- si := int(size)
- if si < 0 {
- return 0, nil, syserror.EINVAL
- }
-
- // Get the destination of the read.
- dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
- AddressSpaceActive: true,
- })
- if err != nil {
- return 0, nil, err
- }
-
- n, err := read(t, file, dst, vfs.ReadOptions{})
- t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, linux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
-}
-
-func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- n, err := file.Read(t, dst, opts)
- if err != syserror.ErrWouldBlock {
- return n, err
- }
-
- // Register for notifications.
- w, ch := waiter.NewChannelEntry(nil)
- file.EventRegister(&w, EventMaskRead)
-
- total := n
- for {
- // Shorten dst to reflect bytes previously read.
- dst = dst.DropFirst(int(n))
-
- // Issue the request and break out if it completes with anything other than
- // "would block".
- n, err := file.Read(t, dst, opts)
- total += n
- if err != syserror.ErrWouldBlock {
- break
- }
- if err := t.Block(ch); err != nil {
- break
- }
- }
- file.EventUnregister(&w)
-
- return total, err
-}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
new file mode 100644
index 000000000..89e9ff4d7
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -0,0 +1,353 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Listxattr implements Linux syscall listxattr(2).
+func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, followFinalSymlink)
+}
+
+// Llistxattr implements Linux syscall llistxattr(2).
+func Llistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, nofollowFinalSymlink)
+}
+
+func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Flistxattr implements Linux syscall flistxattr(2).
+func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ names, err := file.Listxattr(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Getxattr implements Linux syscall getxattr(2).
+func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, followFinalSymlink)
+}
+
+// Lgetxattr implements Linux syscall lgetxattr(2).
+func Lgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, nofollowFinalSymlink)
+}
+
+func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Fgetxattr implements Linux syscall fgetxattr(2).
+func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := file.Getxattr(t, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Setxattr implements Linux syscall setxattr(2).
+func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, followFinalSymlink)
+}
+
+// Lsetxattr implements Linux syscall lsetxattr(2).
+func Lsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, nofollowFinalSymlink)
+}
+
+func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Fsetxattr implements Linux syscall fsetxattr(2).
+func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Removexattr implements Linux syscall removexattr(2).
+func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, followFinalSymlink)
+}
+
+// Lremovexattr implements Linux syscall lremovexattr(2).
+func Lremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, nofollowFinalSymlink)
+}
+
+func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name)
+}
+
+// Fremovexattr implements Linux syscall fremovexattr(2).
+func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Removexattr(t, name)
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) {
+ if size > linux.XATTR_LIST_MAX {
+ size = linux.XATTR_LIST_MAX
+ }
+ var buf bytes.Buffer
+ for _, name := range names {
+ buf.WriteString(name)
+ buf.WriteByte(0)
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the list.
+ return buf.Len(), nil
+ }
+ if buf.Len() > int(size) {
+ if size >= linux.XATTR_LIST_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(listAddr, buf.Bytes())
+}
+
+func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ return "", syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
+ return "", err
+ }
+ return gohacks.StringFromImmutableBytes(buf), nil
+}
+
+func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ size = linux.XATTR_SIZE_MAX
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the value.
+ return len(value), nil
+ }
+ if len(value) > int(size) {
+ if size >= linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(valueAddr, gohacks.ImmutableBytesFromString(value))
+}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 14b39eb9d..07c8383e6 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -43,6 +43,8 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/gohacks",
+ "//pkg/log",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go
index d97362b9a..82781e6d3 100644
--- a/pkg/sentry/vfs/context.go
+++ b/pkg/sentry/vfs/context.go
@@ -29,9 +29,10 @@ const (
CtxRoot
)
-// MountNamespaceFromContext returns the MountNamespace used by ctx. It does
-// not take a reference on the returned MountNamespace. If ctx is not
-// associated with a MountNamespace, MountNamespaceFromContext returns nil.
+// MountNamespaceFromContext returns the MountNamespace used by ctx. If ctx is
+// not associated with a MountNamespace, MountNamespaceFromContext returns nil.
+//
+// A reference is taken on the returned MountNamespace.
func MountNamespaceFromContext(ctx context.Context) *MountNamespace {
if v := ctx.Value(CtxMountNamespace); v != nil {
return v.(*MountNamespace)
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 486a76475..35b208721 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -71,6 +71,8 @@ import (
// lifetime. Dentry reference counts only indicate the extent to which VFS
// requires Dentries to exist; Filesystems may elect to cache or discard
// Dentries with zero references.
+//
+// +stateify savable
type Dentry struct {
// parent is this Dentry's parent in this Filesystem. If this Dentry is
// independent, parent is nil.
@@ -89,7 +91,7 @@ type Dentry struct {
children map[string]*Dentry
// mu synchronizes disowning and mounting over this Dentry.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
// impl is the DentryImpl associated with this Dentry. impl is immutable.
// This should be the last field in Dentry.
diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go
index 3af2aa58d..bda5576fa 100644
--- a/pkg/sentry/vfs/device.go
+++ b/pkg/sentry/vfs/device.go
@@ -56,6 +56,7 @@ type Device interface {
Open(ctx context.Context, mnt *Mount, d *Dentry, opts OpenOptions) (*FileDescription, error)
}
+// +stateify savable
type registeredDevice struct {
dev Device
opts RegisterDeviceOptions
@@ -63,6 +64,8 @@ type registeredDevice struct {
// RegisterDeviceOptions contains options to
// VirtualFilesystem.RegisterDevice().
+//
+// +stateify savable
type RegisterDeviceOptions struct {
// GroupName is the name shown for this device registration in
// /proc/devices. If GroupName is empty, this registration will not be
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 7c83f9a5a..3da45d744 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -85,8 +85,8 @@ type epollInterest struct {
ready bool
epollInterestEntry
- // userData is the epoll_data_t associated with this epollInterest.
- // userData is protected by epoll.mu.
+ // userData is the struct epoll_event::data associated with this
+ // epollInterest. userData is protected by epoll.mu.
userData [2]int32
}
@@ -157,7 +157,7 @@ func (ep *EpollInstance) Seek(ctx context.Context, offset int64, whence int32) (
// AddInterest implements the semantics of EPOLL_CTL_ADD.
//
// Preconditions: A reference must be held on file.
-func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, mask uint32, userData [2]int32) error {
+func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event linux.EpollEvent) error {
// Check for cyclic polling if necessary.
subep, _ := file.impl.(*EpollInstance)
if subep != nil {
@@ -183,12 +183,12 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, mask uint
}
// Register interest in file.
- mask |= linux.EPOLLERR | linux.EPOLLRDHUP
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP
epi := &epollInterest{
epoll: ep,
key: key,
mask: mask,
- userData: userData,
+ userData: event.Data,
}
ep.interest[key] = epi
wmask := waiter.EventMaskFromLinux(mask)
@@ -202,6 +202,9 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, mask uint
// Add epi to file.epolls so that it is removed when the last
// FileDescription reference is dropped.
file.epollMu.Lock()
+ if file.epolls == nil {
+ file.epolls = make(map[*epollInterest]struct{})
+ }
file.epolls[epi] = struct{}{}
file.epollMu.Unlock()
@@ -236,7 +239,7 @@ func (ep *EpollInstance) mightPollRecursive(ep2 *EpollInstance, remainingRecursi
// ModifyInterest implements the semantics of EPOLL_CTL_MOD.
//
// Preconditions: A reference must be held on file.
-func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, mask uint32, userData [2]int32) error {
+func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event linux.EpollEvent) error {
ep.interestMu.Lock()
defer ep.interestMu.Unlock()
@@ -250,13 +253,13 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, mask u
}
// Update epi for the next call to ep.ReadEvents().
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP
ep.mu.Lock()
epi.mask = mask
- epi.userData = userData
+ epi.userData = event.Data
ep.mu.Unlock()
// Re-register with the new mask.
- mask |= linux.EPOLLERR | linux.EPOLLRDHUP
file.EventUnregister(&epi.waiter)
wmask := waiter.EventMaskFromLinux(mask)
file.EventRegister(&epi.waiter, wmask)
@@ -363,8 +366,7 @@ func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int {
// Report ievents.
events[i] = linux.EpollEvent{
Events: ievents.ToLinux(),
- Fd: epi.userData[0],
- Data: epi.userData[1],
+ Data: epi.userData,
}
i++
if i == len(events) {
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 5bac660c7..9a1ad630c 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -435,11 +435,11 @@ type Dirent struct {
// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
type IterDirentsCallback interface {
- // Handle handles the given iterated Dirent. It returns true if iteration
- // should continue, and false if FileDescriptionImpl.IterDirents should
- // terminate now and restart with the same Dirent the next time it is
- // called.
- Handle(dirent Dirent) bool
+ // Handle handles the given iterated Dirent. If Handle returns a non-nil
+ // error, FileDescriptionImpl.IterDirents must stop iteration and return
+ // the error; the next call to FileDescriptionImpl.IterDirents should
+ // restart with the same Dirent.
+ Handle(dirent Dirent) error
}
// OnClose is called when a file descriptor representing the FileDescription is
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 8fa26418e..3a75d4d62 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -107,7 +107,10 @@ func (fd *testFD) SetStat(ctx context.Context, opts SetStatOptions) error {
func TestGenCountFD(t *testing.T) {
ctx := contexttest.Context(t)
- vfsObj := New() // vfs.New()
+ vfsObj := &VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
fd := newTestFD(vfsObj, linux.O_RDWR, &genCount{})
defer fd.DecRef()
@@ -162,7 +165,10 @@ func TestGenCountFD(t *testing.T) {
func TestWritable(t *testing.T) {
ctx := contexttest.Context(t)
- vfsObj := New() // vfs.New()
+ vfsObj := &VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
fd := newTestFD(vfsObj, linux.O_RDWR, &storeData{data: "init"})
defer fd.DecRef()
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index a06a6caf3..556976d0b 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -29,6 +29,8 @@ import (
// Filesystem methods require that a reference is held.
//
// Filesystem is analogous to Linux's struct super_block.
+//
+// +stateify savable
type Filesystem struct {
// refs is the reference count. refs is accessed using atomic memory
// operations.
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index c58b70728..bb9cada81 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -44,6 +44,7 @@ type GetFilesystemOptions struct {
InternalData interface{}
}
+// +stateify savable
type registeredFilesystemType struct {
fsType FilesystemType
opts RegisterFilesystemTypeOptions
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 1fbb420f9..9912df799 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -38,6 +38,8 @@ import (
//
// Mount is analogous to Linux's struct mount. (gVisor does not distinguish
// between struct mount and struct vfsmount.)
+//
+// +stateify savable
type Mount struct {
// vfs, fs, and root are immutable. References are held on fs and root.
//
@@ -85,6 +87,8 @@ type Mount struct {
// MountNamespace methods require that a reference is held.
//
// MountNamespace is analogous to Linux's struct mnt_namespace.
+//
+// +stateify savable
type MountNamespace struct {
// root is the MountNamespace's root mount. root is immutable.
root *Mount
@@ -114,6 +118,7 @@ type MountNamespace struct {
func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) {
rft := vfs.getFilesystemType(fsTypeName)
if rft == nil {
+ ctx.Warningf("Unknown filesystem: %s", fsTypeName)
return nil, syserror.ENODEV
}
fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
@@ -231,9 +236,12 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
return syserror.EINVAL
}
vfs.mountMu.Lock()
- if mntns := MountNamespaceFromContext(ctx); mntns != nil && mntns != vd.mount.ns {
- vfs.mountMu.Unlock()
- return syserror.EINVAL
+ if mntns := MountNamespaceFromContext(ctx); mntns != nil {
+ defer mntns.DecRef()
+ if mntns != vd.mount.ns {
+ vfs.mountMu.Unlock()
+ return syserror.EINVAL
+ }
}
// TODO(jamieliu): Linux special-cases umount of the caller's root, which
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index bd90d36c4..bc7581698 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -26,6 +26,7 @@ import (
"sync/atomic"
"unsafe"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -64,6 +65,8 @@ func (mnt *Mount) storeKey(vd VirtualDentry) {
// (provided mutation is sufficiently uncommon).
//
// mountTable.Init() must be called on new mountTables before use.
+//
+// +stateify savable
type mountTable struct {
// mountTable is implemented as a seqcount-protected hash table that
// resolves collisions with linear probing, featuring Robin Hood insertion
@@ -75,8 +78,8 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq sync.SeqCount
- seed uint32 // for hashing keys
+ seq sync.SeqCount `state:"nosave"`
+ seed uint32 // for hashing keys
// size holds both length (number of elements) and capacity (number of
// slots): capacity is stored as its base-2 log (referred to as order) in
@@ -89,7 +92,7 @@ type mountTable struct {
// length and cap in separate uint32s) for ~free.
size uint64
- slots unsafe.Pointer // []mountSlot; never nil after Init
+ slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
}
type mountSlot struct {
@@ -158,7 +161,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
loop:
for {
@@ -359,12 +362,3 @@ func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
//go:linkname rand32 runtime.fastrand
func rand32() uint32
-
-// This is copy/pasted from runtime.noescape(), and is needed because arguments
-// apparently escape from all functions defined by linkname.
-//
-//go:nosplit
-func noescape(p unsafe.Pointer) unsafe.Pointer {
- x := uintptr(p)
- return unsafe.Pointer(x ^ 0)
-}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index b7774bf28..6af7fdac1 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -61,7 +61,7 @@ type MountOptions struct {
type OpenOptions struct {
// Flags contains access mode and flags as specified for open(2).
//
- // FilesystemImpls is reponsible for implementing the following flags:
+ // FilesystemImpls are responsible for implementing the following flags:
// O_RDONLY, O_WRONLY, O_RDWR, O_APPEND, O_CREAT, O_DIRECT, O_DSYNC,
// O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_PATH, O_SYNC, O_TMPFILE, and
// O_TRUNC. VFS is responsible for handling O_DIRECTORY, O_LARGEFILE, and
@@ -72,6 +72,11 @@ type OpenOptions struct {
// If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the
// created file.
Mode linux.FileMode
+
+ // FileExec is set when the file is being opened to be executed.
+ // VirtualFilesystem.OpenAt() checks that the caller has execute permissions
+ // on the file, and that the file is a regular file.
+ FileExec bool
}
// ReadOptions contains options to FileDescription.PRead(),
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index f664581f4..8e250998a 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -103,17 +103,22 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
// AccessTypesForOpenFlags returns MayRead|MayWrite in this case.
//
// Use May{Read,Write}FileWithOpenFlags() for these checks instead.
-func AccessTypesForOpenFlags(flags uint32) AccessTypes {
- switch flags & linux.O_ACCMODE {
+func AccessTypesForOpenFlags(opts *OpenOptions) AccessTypes {
+ ats := AccessTypes(0)
+ if opts.FileExec {
+ ats |= MayExec
+ }
+
+ switch opts.Flags & linux.O_ACCMODE {
case linux.O_RDONLY:
- if flags&linux.O_TRUNC != 0 {
- return MayRead | MayWrite
+ if opts.Flags&linux.O_TRUNC != 0 {
+ return ats | MayRead | MayWrite
}
- return MayRead
+ return ats | MayRead
case linux.O_WRONLY:
- return MayWrite
+ return ats | MayWrite
default:
- return MayRead | MayWrite
+ return ats | MayRead | MayWrite
}
}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 8a0b382f6..eb4ebb511 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -228,7 +228,7 @@ func (rp *ResolvingPath) Advance() {
rp.pit = next
} else { // at end of path segment, continue with next one
rp.curPart--
- rp.pit = rp.parts[rp.curPart-1]
+ rp.pit = rp.parts[rp.curPart]
}
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 908c69f91..73f8043be 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -46,11 +46,13 @@ import (
//
// There is no analogue to the VirtualFilesystem type in Linux, as the
// equivalent state in Linux is global.
+//
+// +stateify savable
type VirtualFilesystem struct {
// mountMu serializes mount mutations.
//
// mountMu is analogous to Linux's namespace_sem.
- mountMu sync.Mutex
+ mountMu sync.Mutex `state:"nosave"`
// mounts maps (mount parent, mount point) pairs to mounts. (Since mounts
// are uniquely namespaced, including mount parent in the key correctly
@@ -89,44 +91,42 @@ type VirtualFilesystem struct {
// devices contains all registered Devices. devices is protected by
// devicesMu.
- devicesMu sync.RWMutex
+ devicesMu sync.RWMutex `state:"nosave"`
devices map[devTuple]*registeredDevice
// anonBlockDevMinor contains all allocated anonymous block device minor
// numbers. anonBlockDevMinorNext is a lower bound for the smallest
// unallocated anonymous block device number. anonBlockDevMinorNext and
// anonBlockDevMinor are protected by anonBlockDevMinorMu.
- anonBlockDevMinorMu sync.Mutex
+ anonBlockDevMinorMu sync.Mutex `state:"nosave"`
anonBlockDevMinorNext uint32
anonBlockDevMinor map[uint32]struct{}
// fsTypes contains all registered FilesystemTypes. fsTypes is protected by
// fsTypesMu.
- fsTypesMu sync.RWMutex
+ fsTypesMu sync.RWMutex `state:"nosave"`
fsTypes map[string]*registeredFilesystemType
// filesystems contains all Filesystems. filesystems is protected by
// filesystemsMu.
- filesystemsMu sync.Mutex
+ filesystemsMu sync.Mutex `state:"nosave"`
filesystems map[*Filesystem]struct{}
}
-// New returns a new VirtualFilesystem with no mounts or FilesystemTypes.
-func New() *VirtualFilesystem {
- vfs := &VirtualFilesystem{
- mountpoints: make(map[*Dentry]map[*Mount]struct{}),
- devices: make(map[devTuple]*registeredDevice),
- anonBlockDevMinorNext: 1,
- anonBlockDevMinor: make(map[uint32]struct{}),
- fsTypes: make(map[string]*registeredFilesystemType),
- filesystems: make(map[*Filesystem]struct{}),
- }
+// Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes.
+func (vfs *VirtualFilesystem) Init() error {
+ vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{})
+ vfs.devices = make(map[devTuple]*registeredDevice)
+ vfs.anonBlockDevMinorNext = 1
+ vfs.anonBlockDevMinor = make(map[uint32]struct{})
+ vfs.fsTypes = make(map[string]*registeredFilesystemType)
+ vfs.filesystems = make(map[*Filesystem]struct{})
vfs.mounts.Init()
// Construct vfs.anonMount.
anonfsDevMinor, err := vfs.GetAnonBlockDevMinor()
if err != nil {
- panic(fmt.Sprintf("VirtualFilesystem.GetAnonBlockDevMinor() failed during VirtualFilesystem construction: %v", err))
+ return err
}
anonfs := anonFilesystem{
devMinor: anonfsDevMinor,
@@ -137,8 +137,7 @@ func New() *VirtualFilesystem {
fs: &anonfs.vfsfs,
refs: 1,
}
-
- return vfs
+ return nil
}
// PathOperation specifies the path operated on by a VFS method.
@@ -379,6 +378,22 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
if err == nil {
vfs.putResolvingPath(rp)
+
+ // TODO(gvisor.dev/issue/1193): Move inside fsimpl to avoid another call
+ // to FileDescription.Stat().
+ if opts.FileExec {
+ // Only a regular file can be executed.
+ stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
+ if err != nil {
+ fd.DecRef()
+ return nil, err
+ }
+ if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG {
+ fd.DecRef()
+ return nil, syserror.EACCES
+ }
+ }
+
return fd, nil
}
if !rp.handleError(err) {
@@ -724,6 +739,8 @@ func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
// VirtualDentry methods require that a reference is held on the VirtualDentry.
//
// VirtualDentry is analogous to Linux's struct path.
+//
+// +stateify savable
type VirtualDentry struct {
mount *Mount
dentry *Dentry
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
index 3af447fb9..f59061f37 100644
--- a/pkg/sleep/commit_noasm.go
+++ b/pkg/sleep/commit_noasm.go
@@ -28,15 +28,6 @@ import "sync/atomic"
// It is written in assembly because it is called from g0, so it doesn't have
// a race context.
func commitSleep(g uintptr, waitingG *uintptr) bool {
- for {
- // Check if the wait was aborted.
- if atomic.LoadUintptr(waitingG) == 0 {
- return false
- }
-
- // Try to store the G so that wakers know who to wake.
- if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) {
- return true
- }
- }
+ // Try to store the G so that wakers know who to wake.
+ return atomic.CompareAndSwapUintptr(waitingG, preparingG, g)
}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index acbf0229b..65bfcf778 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -299,20 +299,17 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
}
}
- for {
- // Nothing to do if there isn't a G waiting.
- g := atomic.LoadUintptr(&s.waitingG)
- if g == 0 {
- return
- }
+ // Nothing to do if there isn't a G waiting.
+ if atomic.LoadUintptr(&s.waitingG) == 0 {
+ return
+ }
- // Signal to the sleeper that a waker has been asserted.
- if atomic.CompareAndSwapUintptr(&s.waitingG, g, 0) {
- if g != preparingG {
- // We managed to get a G. Wake it up.
- goready(g, 0)
- }
- }
+ // Signal to the sleeper that a waker has been asserted.
+ switch g := atomic.SwapUintptr(&s.waitingG, 0); g {
+ case 0, preparingG:
+ default:
+ // We managed to get a G. Wake it up.
+ goready(g, 0)
}
}
diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD
new file mode 100644
index 000000000..0500a22cf
--- /dev/null
+++ b/pkg/syncevent/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "syncevent",
+ srcs = [
+ "broadcaster.go",
+ "receiver.go",
+ "source.go",
+ "syncevent.go",
+ "waiter_amd64.s",
+ "waiter_arm64.s",
+ "waiter_asm_unsafe.go",
+ "waiter_noasm_unsafe.go",
+ "waiter_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/atomicbitops",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "syncevent_test",
+ size = "small",
+ srcs = [
+ "broadcaster_test.go",
+ "syncevent_example_test.go",
+ "waiter_test.go",
+ ],
+ library = ":syncevent",
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go
new file mode 100644
index 000000000..4bff59e7d
--- /dev/null
+++ b/pkg/syncevent/broadcaster.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Broadcaster is an implementation of Source that supports any number of
+// subscribed Receivers.
+//
+// The zero value of Broadcaster is valid and has no subscribed Receivers.
+// Broadcaster is not copyable by value.
+//
+// All Broadcaster methods may be called concurrently from multiple goroutines.
+type Broadcaster struct {
+ // Broadcaster is implemented as a hash table where keys are assigned by
+ // the Broadcaster and returned as SubscriptionIDs, making it safe to use
+ // the identity function for hashing. The hash table resolves collisions
+ // using linear probing and features Robin Hood insertion and backward
+ // shift deletion in order to support a relatively high load factor
+ // efficiently, which matters since the cost of Broadcast is linear in the
+ // size of the table.
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // Invariants: len(table) is 0 or a power of 2.
+ table []broadcasterSlot
+
+ // load is the number of entries in table with receiver != nil.
+ load int
+
+ lastID SubscriptionID
+}
+
+type broadcasterSlot struct {
+ // Invariants: If receiver == nil, then filter == NoEvents and id == 0.
+ // Otherwise, id != 0.
+ receiver *Receiver
+ filter Set
+ id SubscriptionID
+}
+
+const (
+ broadcasterMinNonZeroTableSize = 2 // must be a power of 2 > 1
+
+ broadcasterMaxLoadNum = 13
+ broadcasterMaxLoadDen = 16
+)
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID {
+ b.mu.Lock()
+
+ // Assign an ID for this subscription.
+ b.lastID++
+ id := b.lastID
+
+ // Expand the table if over the maximum load factor:
+ //
+ // load / len(b.table) > broadcasterMaxLoadNum / broadcasterMaxLoadDen
+ // load * broadcasterMaxLoadDen > broadcasterMaxLoadNum * len(b.table)
+ b.load++
+ if (b.load * broadcasterMaxLoadDen) > (broadcasterMaxLoadNum * len(b.table)) {
+ // Double the number of slots in the new table.
+ newlen := broadcasterMinNonZeroTableSize
+ if len(b.table) != 0 {
+ newlen = 2 * len(b.table)
+ }
+ if newlen <= cap(b.table) {
+ // Reuse excess capacity in the current table, moving entries not
+ // already in their first-probed positions to better ones.
+ newtable := b.table[:newlen]
+ newmask := uint64(newlen - 1)
+ for i := range b.table {
+ if b.table[i].receiver != nil && uint64(b.table[i].id)&newmask != uint64(i) {
+ entry := b.table[i]
+ b.table[i] = broadcasterSlot{}
+ broadcasterTableInsert(newtable, entry.id, entry.receiver, entry.filter)
+ }
+ }
+ b.table = newtable
+ } else {
+ newtable := make([]broadcasterSlot, newlen)
+ // Copy existing entries to the new table.
+ for i := range b.table {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ }
+ }
+ // Switch to the new table.
+ b.table = newtable
+ }
+ }
+
+ broadcasterTableInsert(b.table, id, r, filter)
+ b.mu.Unlock()
+ return id
+}
+
+// Preconditions: table must not be full. len(table) is a power of 2.
+func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) {
+ entry := broadcasterSlot{
+ receiver: r,
+ filter: filter,
+ id: id,
+ }
+ mask := uint64(len(table) - 1)
+ i := uint64(id) & mask
+ disp := uint64(0)
+ for {
+ if table[i].receiver == nil {
+ table[i] = entry
+ return
+ }
+ // If we've been displaced farther from our first-probed slot than the
+ // element stored in this one, swap elements and switch to inserting
+ // the replaced one. (This is Robin Hood insertion.)
+ slotDisp := (i - uint64(table[i].id)) & mask
+ if disp > slotDisp {
+ table[i], entry = entry, table[i]
+ disp = slotDisp
+ }
+ i = (i + 1) & mask
+ disp++
+ }
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (b *Broadcaster) UnsubscribeEvents(id SubscriptionID) {
+ b.mu.Lock()
+
+ mask := uint64(len(b.table) - 1)
+ i := uint64(id) & mask
+ for {
+ if b.table[i].id == id {
+ // Found the element to remove. Move all subsequent elements
+ // backward until we either find an empty slot, or an element that
+ // is already in its first-probed slot. (This is backward shift
+ // deletion.)
+ for {
+ next := (i + 1) & mask
+ if b.table[next].receiver == nil {
+ break
+ }
+ if uint64(b.table[next].id)&mask == next {
+ break
+ }
+ b.table[i] = b.table[next]
+ i = next
+ }
+ b.table[i] = broadcasterSlot{}
+ break
+ }
+ i = (i + 1) & mask
+ }
+
+ // If a table 1/4 of the current size would still be at or under the
+ // maximum load factor (i.e. the current table size is at least two
+ // expansions bigger than necessary), halve the size of the table to reduce
+ // the cost of Broadcast. Since we are concerned with iteration time and
+ // not memory usage, reuse the existing slice to reduce future allocations
+ // from table re-expansion.
+ b.load--
+ if len(b.table) > broadcasterMinNonZeroTableSize && (b.load*(4*broadcasterMaxLoadDen)) <= (broadcasterMaxLoadNum*len(b.table)) {
+ newlen := len(b.table) / 2
+ newtable := b.table[:newlen]
+ for i := newlen; i < len(b.table); i++ {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ b.table[i] = broadcasterSlot{}
+ }
+ }
+ b.table = newtable
+ }
+
+ b.mu.Unlock()
+}
+
+// Broadcast notifies all Receivers subscribed to the Broadcaster of the subset
+// of events to which they subscribed. The order in which Receivers are
+// notified is unspecified.
+func (b *Broadcaster) Broadcast(events Set) {
+ b.mu.Lock()
+ for i := range b.table {
+ if intersection := events & b.table[i].filter; intersection != 0 {
+ // We don't need to check if broadcasterSlot.receiver is nil, since
+ // if it is then broadcasterSlot.filter is 0.
+ b.table[i].receiver.Notify(intersection)
+ }
+ }
+ b.mu.Unlock()
+}
+
+// FilteredEvents returns the set of events for which Broadcast will notify at
+// least one Receiver, i.e. the union of filters for all subscribed Receivers.
+func (b *Broadcaster) FilteredEvents() Set {
+ var es Set
+ b.mu.Lock()
+ for i := range b.table {
+ es |= b.table[i].filter
+ }
+ b.mu.Unlock()
+ return es
+}
diff --git a/pkg/syncevent/broadcaster_test.go b/pkg/syncevent/broadcaster_test.go
new file mode 100644
index 000000000..e88779e23
--- /dev/null
+++ b/pkg/syncevent/broadcaster_test.go
@@ -0,0 +1,376 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestBroadcasterFilter(t *testing.T) {
+ const numReceivers = 2 * MaxEvents
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1<<(i%MaxEvents))
+ }
+ for ev := 0; ev < MaxEvents; ev++ {
+ br.Broadcast(1 << ev)
+ for i := range ws {
+ want := NoEvents
+ if i%MaxEvents == ev {
+ want = 1 << ev
+ }
+ if got := ws[i].Receiver().PendingAndAckAll(); got != want {
+ t.Errorf("after Broadcast of event %d: waiter %d has pending event set %#x, wanted %#x", ev, i, got, want)
+ }
+ }
+ }
+}
+
+// TestBroadcasterManySubscriptions tests that subscriptions are not lost by
+// table expansion/compaction.
+func TestBroadcasterManySubscriptions(t *testing.T) {
+ const numReceivers = 5000 // arbitrary
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ ids := make([]SubscriptionID, numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Subscribe receiver i.
+ ids[i] = br.SubscribeEvents(ws[i].Receiver(), 1)
+ // Check that receivers [0, i] are subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[j].Pending() != 1 {
+ t.Errorf("receiver %d did not receive an event after subscription of receiver %d", j, i)
+ }
+ ws[j].Ack(1)
+ }
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Unsubscribe receiver unsub[i].
+ br.UnsubscribeEvents(ids[unsub[i]])
+ // Check that receivers [unsub[0], unsub[i]] are not subscribed, and that
+ // receivers (unsub[i], unsub[numReceivers]) are still subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[unsub[j]].Pending() != 0 {
+ t.Errorf("unsub iteration %d: receiver %d received an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ }
+ for j := i + 1; j < numReceivers; j++ {
+ if ws[unsub[j]].Pending() != 1 {
+ t.Errorf("unsub iteration %d: receiver %d did not receive an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ ws[unsub[j]].Ack(1)
+ }
+ }
+}
+
+var (
+ receiverCountsNonZero = []int{1, 4, 16, 64}
+ receiverCountsIncludingZero = append([]int{0}, receiverCountsNonZero...)
+)
+
+// BenchmarkBroadcasterX, BenchmarkMapX, and BenchmarkQueueX benchmark usage
+// pattern X (described in terms of Broadcaster) with Broadcaster, a
+// Mutex-protected map[*Receiver]Set, and waiter.Queue respectively.
+
+// BenchmarkXxxSubscribeUnsubscribe measures the cost of a Subscribe/Unsubscribe
+// cycle.
+
+func BenchmarkBroadcasterSubscribeUnsubscribe(b *testing.B) {
+ var br Broadcaster
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ id := br.SubscribeEvents(w.Receiver(), 1)
+ br.UnsubscribeEvents(id)
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribe(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ m[w.Receiver()] = Set(1)
+ mu.Unlock()
+ mu.Lock()
+ delete(m, w.Receiver())
+ mu.Unlock()
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribe(b *testing.B) {
+ var q waiter.Queue
+ e, _ := waiter.NewChannelEntry(nil)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.EventRegister(&e, 1)
+ q.EventUnregister(&e)
+ }
+}
+
+// BenchmarkXxxSubscribeUnsubscribeBatch is similar to
+// BenchmarkXxxSubscribeUnsubscribe, but subscribes and unsubscribes a large
+// number of Receivers at a time in order to measure the amortized overhead of
+// table expansion/compaction. (Since waiter.Queue is implemented using a
+// linked list, BenchmarkQueueSubscribeUnsubscribe and
+// BenchmarkQueueSubscribeUnsubscribeBatch should produce nearly the same
+// result.)
+
+const numBatchReceivers = 1000
+
+func BenchmarkBroadcasterSubscribeUnsubscribeBatch(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+ ids := make([]SubscriptionID, numBatchReceivers)
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ ids[j] = br.SubscribeEvents(ws[j].Receiver(), 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ br.UnsubscribeEvents(ids[unsub[j]])
+ }
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribeBatch(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ m[ws[j].Receiver()] = Set(1)
+ mu.Unlock()
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ delete(m, ws[unsub[j]].Receiver())
+ mu.Unlock()
+ }
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribeBatch(b *testing.B) {
+ var q waiter.Queue
+ es := make([]waiter.Entry, numBatchReceivers)
+ for i := range es {
+ es[i], _ = waiter.NewChannelEntry(nil)
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventRegister(&es[j], 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventUnregister(&es[unsub[j]])
+ }
+ }
+}
+
+// BenchmarkXxxBroadcastRedundant measures how long it takes to Broadcast
+// already-pending events to multiple Receivers.
+
+func BenchmarkBroadcasterBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+ br.Broadcast(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ for i := 0; i < n; i++ {
+ e, _ := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ }
+ q.Notify(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ }
+ })
+ }
+}
+
+// BenchmarkXxxBroadcastAck measures how long it takes to Broadcast events to
+// multiple Receivers, check that all Receivers have received the event, and
+// clear the event from all Receivers.
+
+func BenchmarkBroadcasterBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ chs := make([]chan struct{}, n)
+ for i := range chs {
+ e, ch := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ chs[i] = ch
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ for _, ch := range chs {
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("channel did not receive event")
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/syncevent/receiver.go b/pkg/syncevent/receiver.go
new file mode 100644
index 000000000..5c86e5400
--- /dev/null
+++ b/pkg/syncevent/receiver.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/atomicbitops"
+)
+
+// Receiver is an event sink that holds pending events and invokes a callback
+// whenever new events become pending. Receiver's methods may be called
+// concurrently from multiple goroutines.
+//
+// Receiver.Init() must be called before first use.
+type Receiver struct {
+ // pending is the set of pending events. pending is accessed using atomic
+ // memory operations.
+ pending uint64
+
+ // cb is notified when new events become pending. cb is immutable after
+ // Init().
+ cb ReceiverCallback
+}
+
+// ReceiverCallback receives callbacks from a Receiver.
+type ReceiverCallback interface {
+ // NotifyPending is called when the corresponding Receiver has new pending
+ // events.
+ //
+ // NotifyPending is called synchronously from Receiver.Notify(), so
+ // implementations must not take locks that may be held by callers of
+ // Receiver.Notify(). NotifyPending may be called concurrently from
+ // multiple goroutines.
+ NotifyPending()
+}
+
+// Init must be called before first use of r.
+func (r *Receiver) Init(cb ReceiverCallback) {
+ r.cb = cb
+}
+
+// Pending returns the set of pending events.
+func (r *Receiver) Pending() Set {
+ return Set(atomic.LoadUint64(&r.pending))
+}
+
+// Notify sets the given events as pending.
+func (r *Receiver) Notify(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already pending.
+ if p&es == es {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-OR because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p|es)) {
+ // If the CAS fails, fall back to atomic-OR.
+ atomicbitops.OrUint64(&r.pending, uint64(es))
+ }
+ r.cb.NotifyPending()
+}
+
+// Ack unsets the given events as pending.
+func (r *Receiver) Ack(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already not pending.
+ if p&es == 0 {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-AND because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p&^es)) {
+ // If the CAS fails, fall back to atomic-AND.
+ atomicbitops.AndUint64(&r.pending, ^uint64(es))
+ }
+}
+
+// PendingAndAckAll unsets all events as pending and returns the set of
+// previously-pending events.
+//
+// PendingAndAckAll should only be used in preference to a call to Pending
+// followed by a conditional call to Ack when the caller expects events to be
+// pending (e.g. after a call to ReceiverCallback.NotifyPending()).
+func (r *Receiver) PendingAndAckAll() Set {
+ return Set(atomic.SwapUint64(&r.pending, 0))
+}
diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go
new file mode 100644
index 000000000..ddffb171a
--- /dev/null
+++ b/pkg/syncevent/source.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+// Source represents an event source.
+type Source interface {
+ // SubscribeEvents causes the Source to notify the given Receiver of the
+ // given subset of events.
+ //
+ // Preconditions: r != nil. The ReceiverCallback for r must not take locks
+ // that are ordered prior to the Source; for example, it cannot call any
+ // Source methods.
+ SubscribeEvents(r *Receiver, filter Set) SubscriptionID
+
+ // UnsubscribeEvents causes the Source to stop notifying the Receiver
+ // subscribed by a previous call to SubscribeEvents that returned the given
+ // SubscriptionID.
+ //
+ // Preconditions: UnsubscribeEvents may be called at most once for any
+ // given SubscriptionID.
+ UnsubscribeEvents(id SubscriptionID)
+}
+
+// SubscriptionID identifies a call to Source.SubscribeEvents.
+type SubscriptionID uint64
+
+// UnsubscribeAndAck is a convenience function that unsubscribes r from the
+// given events from src and also clears them from r.
+func UnsubscribeAndAck(src Source, r *Receiver, filter Set, id SubscriptionID) {
+ src.UnsubscribeEvents(id)
+ r.Ack(filter)
+}
+
+// NoopSource implements Source by never sending events to subscribed
+// Receivers.
+type NoopSource struct{}
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (NoopSource) SubscribeEvents(*Receiver, Set) SubscriptionID {
+ return 0
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (NoopSource) UnsubscribeEvents(SubscriptionID) {
+}
+
+// See Broadcaster for a non-noop implementations of Source.
diff --git a/pkg/syncevent/syncevent.go b/pkg/syncevent/syncevent.go
new file mode 100644
index 000000000..9fb6a06de
--- /dev/null
+++ b/pkg/syncevent/syncevent.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syncevent provides efficient primitives for goroutine
+// synchronization based on event bitmasks.
+package syncevent
+
+// Set is a bitmask where each bit represents a distinct user-defined event.
+// The event package does not treat any bits in Set specially.
+type Set uint64
+
+const (
+ // NoEvents is a Set containing no events.
+ NoEvents = Set(0)
+
+ // AllEvents is a Set containing all possible events.
+ AllEvents = ^Set(0)
+
+ // MaxEvents is the number of distinct events that can be represented by a Set.
+ MaxEvents = 64
+)
diff --git a/pkg/syncevent/syncevent_example_test.go b/pkg/syncevent/syncevent_example_test.go
new file mode 100644
index 000000000..bfb18e2ea
--- /dev/null
+++ b/pkg/syncevent/syncevent_example_test.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+)
+
+func Example_ioReadinessInterrputible() {
+ const (
+ evReady = Set(1 << iota)
+ evInterrupt
+ )
+ errNotReady := fmt.Errorf("not ready for I/O")
+
+ // State of some I/O object.
+ var (
+ br Broadcaster
+ ready uint32
+ )
+ doIO := func() error {
+ if atomic.LoadUint32(&ready) == 0 {
+ return errNotReady
+ }
+ return nil
+ }
+ go func() {
+ // The I/O object eventually becomes ready for I/O.
+ time.Sleep(100 * time.Millisecond)
+ // When it does, it first ensures that future calls to isReady() return
+ // true, then broadcasts the readiness event to Receivers.
+ atomic.StoreUint32(&ready, 1)
+ br.Broadcast(evReady)
+ }()
+
+ // Each user of the I/O object owns a Waiter.
+ var w Waiter
+ w.Init()
+ // The Waiter may be asynchronously interruptible, e.g. for signal
+ // handling in the sentry.
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ w.Receiver().Notify(evInterrupt)
+ }()
+
+ // To use the I/O object:
+ //
+ // Optionally, if the I/O object is likely to be ready, attempt I/O first.
+ err := doIO()
+ if err == nil {
+ // Success, we're done.
+ return /* nil */
+ }
+ if err != errNotReady {
+ // Failure, I/O failed for some reason other than readiness.
+ return /* err */
+ }
+ // Subscribe for readiness events from the I/O object.
+ id := br.SubscribeEvents(w.Receiver(), evReady)
+ // When we are finished blocking, unsubscribe from readiness events and
+ // remove readiness events from the pending event set.
+ defer UnsubscribeAndAck(&br, w.Receiver(), evReady, id)
+ for {
+ // Attempt I/O again. This must be done after the call to SubscribeEvents,
+ // since the I/O object might have become ready between the previous call
+ // to doIO and the call to SubscribeEvents.
+ err = doIO()
+ if err == nil {
+ return /* nil */
+ }
+ if err != errNotReady {
+ return /* err */
+ }
+ // Block until either the I/O object indicates it is ready, or we are
+ // interrupted.
+ events := w.Wait()
+ if events&evInterrupt != 0 {
+ // In the specific case of sentry signal handling, signal delivery
+ // is handled by another system, so we aren't responsible for
+ // acknowledging evInterrupt.
+ return /* errInterrupted */
+ }
+ // Note that, in a concurrent context, the I/O object might become
+ // ready and then not ready again. To handle this:
+ //
+ // - evReady must be acknowledged before calling doIO() again (rather
+ // than after), so that if the I/O object becomes ready *again* after
+ // the call to doIO(), the readiness event is not lost.
+ //
+ // - We must loop instead of just calling doIO() once after receiving
+ // evReady.
+ w.Ack(evReady)
+ }
+}
diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/syncevent/waiter_amd64.s
new file mode 100644
index 000000000..985b56ae5
--- /dev/null
+++ b/pkg/syncevent/waiter_amd64.s
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVQ g+0(FP), DI
+ MOVQ wg+8(FP), SI
+
+ MOVQ $·preparingG(SB), AX
+ LOCK
+ CMPXCHGQ DI, 0(SI)
+
+ SETEQ AX
+ MOVB AX, ret+16(FP)
+
+ RET
+
diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/syncevent/waiter_arm64.s
new file mode 100644
index 000000000..20d7ac23b
--- /dev/null
+++ b/pkg/syncevent/waiter_arm64.s
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVD wg+8(FP), R0
+ MOVD $·preparingG(SB), R1
+ MOVD g+0(FP), R2
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE ok
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ok:
+ CSET EQ, R0
+ MOVB R0, ret+16(FP)
+ RET
+
diff --git a/pkg/fspath/builder_unsafe.go b/pkg/syncevent/waiter_asm_unsafe.go
index 75606808d..0995e9053 100644
--- a/pkg/fspath/builder_unsafe.go
+++ b/pkg/syncevent/waiter_asm_unsafe.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,16 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package fspath
+// +build amd64 arm64
+
+package syncevent
import (
"unsafe"
)
-// String returns the accumulated string. No other methods should be called
-// after String.
-func (b *Builder) String() string {
- bs := b.buf[b.start:]
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
-}
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go
new file mode 100644
index 000000000..1c4b0e39a
--- /dev/null
+++ b/pkg/syncevent/waiter_noasm_unsafe.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// waiterUnlock is called from g0, so when the race detector is enabled,
+// waiterUnlock must be implemented in assembly since no race context is
+// available.
+//
+// +build !race
+// +build !amd64,!arm64
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+// waiterUnlock is the "unlock function" passed to runtime.gopark by
+// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g.
+// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping
+// should be aborted.
+//
+//go:nosplit
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool {
+ // The only way this CAS can fail is if a call to Waiter.NotifyPending()
+ // has replaced *wg with nil, in which case we should not sleep.
+ return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), g)
+}
diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go
new file mode 100644
index 000000000..3c8cbcdd8
--- /dev/null
+++ b/pkg/syncevent/waiter_test.go
@@ -0,0 +1,414 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestWaiterAlreadyPending(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ w.Notify(want)
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterAsyncNotify(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ w.Notify(want)
+ }()
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterWaitFor(t *testing.T) {
+ var w Waiter
+ w.Init()
+ evWaited := Set(1)
+ evOther := Set(2)
+ w.Notify(evOther)
+ notifiedEvent := uint32(0)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ atomic.StoreUint32(&notifiedEvent, 1)
+ w.Notify(evWaited)
+ }()
+ if got, want := w.WaitFor(evWaited), evWaited|evOther; got != want {
+ t.Errorf("Waiter.WaitFor: got %#x, wanted %#x", got, want)
+ }
+ if atomic.LoadUint32(&notifiedEvent) == 0 {
+ t.Errorf("Waiter.WaitFor returned before goroutine notified waited-for event")
+ }
+}
+
+func TestWaiterWaitAndAckAll(t *testing.T) {
+ var w Waiter
+ w.Init()
+ w.Notify(AllEvents)
+ if got := w.WaitAndAckAll(); got != AllEvents {
+ t.Errorf("Waiter.WaitAndAckAll: got %#x, wanted %#x", got, AllEvents)
+ }
+ if got := w.Pending(); got != NoEvents {
+ t.Errorf("Waiter.WaitAndAckAll did not ack all events: got %#x, wanted 0", got)
+ }
+}
+
+// BenchmarkWaiterX, BenchmarkSleeperX, and BenchmarkChannelX benchmark usage
+// pattern X (described in terms of Waiter) with Waiter, sleep.Sleeper, and
+// buffered chan struct{} respectively. When the maximum number of event
+// sources is relevant, we use 3 event sources because this is representative
+// of the kernel.Task.block() use case: an interrupt source, a timeout source,
+// and the actual event source being waited on.
+
+// Event set used by most benchmarks.
+const evBench Set = 1
+
+// BenchmarkXxxNotifyRedundant measures how long it takes to notify a Waiter of
+// an event that is already pending.
+
+func BenchmarkWaiterNotifyRedundant(b *testing.B) {
+ var w Waiter
+ w.Init()
+ w.Notify(evBench)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyRedundant(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+ w.Assert()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ }
+}
+
+func BenchmarkChannelNotifyRedundant(b *testing.B) {
+ ch := make(chan struct{}, 1)
+ ch <- struct{}{}
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }
+}
+
+// BenchmarkXxxNotifyWaitAck measures how long it takes to notify a Waiter an
+// event, return that event using a blocking check, and then unset the event as
+// pending.
+
+func BenchmarkWaiterNotifyWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+
+ // wait + ack
+ <-ch
+ }
+}
+
+// BenchmarkSleeperMultiNotifyWaitAck is equivalent to
+// BenchmarkSleeperNotifyWaitAck, but also includes allocation of a
+// temporary sleep.Waker. This is necessary when multiple goroutines may wait
+// for the same event, since each sleep.Waker can wake only a single
+// sleep.Sleeper.
+//
+// The syncevent package does not require a distinct object for each
+// waiter-waker relationship, so BenchmarkWaiterNotifyWaitAck and
+// BenchmarkWaiterMultiNotifyWaitAck would be identical. The analogous state
+// for channels, runtime.sudog, is inescapably runtime-allocated, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelMultiNotifyWaitAck would
+// also be identical.
+
+func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ // The sleep package doesn't provide sync.Pool allocation of Wakers;
+ // we do for a fairer comparison.
+ wakerPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Waker{}
+ },
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := wakerPool.Get().(*sleep.Waker)
+ s.AddWaker(w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ wakerPool.Put(w)
+ }
+}
+
+// BenchmarkXxxTempNotifyWaitAck is equivalent to NotifyWaitAck, but also
+// includes allocation of a temporary Waiter. This models the case where a
+// goroutine not already associated with a Waiter needs one in order to block.
+//
+// The analogous state for channels is built into runtime.g, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelTempNotifyWaitAck would be
+// identical.
+
+func BenchmarkWaiterTempNotifyWaitAck(b *testing.B) {
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := GetWaiter()
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ PutWaiter(w)
+ }
+}
+
+func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) {
+ // The sleep package doesn't provide sync.Pool allocation of Sleepers;
+ // we do for a fairer comparison.
+ sleeperPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Sleeper{}
+ },
+ }
+ var w sleep.Waker
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s := sleeperPool.Get().(*sleep.Sleeper)
+ s.AddWaker(&w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ sleeperPool.Put(s)
+ }
+}
+
+// BenchmarkXxxNotifyWaitMultiAck is equivalent to NotifyWaitAck, but allows
+// for multiple event sources.
+
+func BenchmarkWaiterNotifyWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ws[0].Assert()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, wanted 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+
+ // wait + clear
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitAck measures how long it takes to wait for an
+// event while another goroutine signals the event. This assumes that a new
+// goroutine doesn't run immediately (i.e. the creator of a new goroutine is
+// allowed to go to sleep before the new goroutine has a chance to run).
+
+func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(1)
+ }()
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Assert()
+ }()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }()
+ <-ch
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitMultiAck is equivalent to NotifyAsyncWaitAck, but
+// allows for multiple event sources.
+
+func BenchmarkWaiterNotifyAsyncWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(evBench)
+ }()
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ ws[0].Assert()
+ }()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, expected 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+ }()
+
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
new file mode 100644
index 000000000..112e0e604
--- /dev/null
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -0,0 +1,206 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.11
+// +build !go1.15
+
+// Check go:linkname function signatures when updating Go version.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
+
+//go:linkname goready runtime.goready
+func goready(g unsafe.Pointer, traceskip int)
+
+const (
+ waitReasonSelect = 9 // Go: src/runtime/runtime2.go
+ traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go
+)
+
+// Waiter allows a goroutine to block on pending events received by a Receiver.
+//
+// Waiter.Init() must be called before first use.
+type Waiter struct {
+ r Receiver
+
+ // g is one of:
+ //
+ // - nil: No goroutine is blocking in Wait.
+ //
+ // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
+ // completed waiterUnlock(). Thus the wait can only be interrupted by
+ // replacing the value of g with nil (the G may not be in state Gwaiting
+ // yet, so we can't call goready.)
+ //
+ // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the
+ // goroutine blocked in Wait, which can only be woken by calling goready.
+ g unsafe.Pointer `state:"zerovalue"`
+}
+
+// Sentinel object for Waiter.g.
+var preparingG struct{}
+
+// Init must be called before first use of w.
+func (w *Waiter) Init() {
+ w.r.Init(w)
+}
+
+// Receiver returns the Receiver that receives events that unblock calls to
+// w.Wait().
+func (w *Waiter) Receiver() *Receiver {
+ return &w.r
+}
+
+// Pending returns the set of pending events.
+func (w *Waiter) Pending() Set {
+ return w.r.Pending()
+}
+
+// Wait blocks until at least one event is pending, then returns the set of
+// pending events. It does not affect the set of pending events; callers must
+// call w.Ack() to do so, or use w.WaitAndAck() instead.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) Wait() Set {
+ return w.WaitFor(AllEvents)
+}
+
+// WaitFor blocks until at least one event in es is pending, then returns the
+// set of pending events (including those not in es). It does not affect the
+// set of pending events; callers must call w.Ack() to do so.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitFor(es Set) Set {
+ for {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending.
+ if p := w.r.Pending(); p&es != NoEvents {
+ return p
+ }
+
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if p := w.r.Pending(); p&es != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ }
+}
+
+// Ack marks the given events as not pending.
+func (w *Waiter) Ack(es Set) {
+ w.r.Ack(es)
+}
+
+// WaitAndAckAll blocks until at least one event is pending, then marks all
+// events as not pending and returns the set of previously-pending events.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitAndAckAll() Set {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending. Call Pending() first since, in the common case that events are
+ // not yet pending, this skips an atomic swap on w.r.pending.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+
+ for {
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+
+ // Check for pending events. We call PendingAndAckAll() directly now since
+ // we only expect to be woken after events become pending.
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+}
+
+// Notify marks the given events as pending, possibly unblocking concurrent
+// calls to w.Wait() or w.WaitFor().
+func (w *Waiter) Notify(es Set) {
+ w.r.Notify(es)
+}
+
+// NotifyPending implements ReceiverCallback.NotifyPending. Users of Waiter
+// should not call NotifyPending.
+func (w *Waiter) NotifyPending() {
+ // Optimization: Skip the atomic swap on w.g if there is no sleeping
+ // goroutine. NotifyPending is called after w.r.Pending() is updated, so
+ // concurrent and future calls to w.Wait() will observe pending events and
+ // abort sleeping.
+ if atomic.LoadPointer(&w.g) == nil {
+ return
+ }
+ // Wake a sleeping G, or prevent a G that is preparing to sleep from doing
+ // so. Swap is needed here to ensure that only one call to NotifyPending
+ // calls goready.
+ if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) {
+ goready(g, 0)
+ }
+}
+
+var waiterPool = sync.Pool{
+ New: func() interface{} {
+ w := &Waiter{}
+ w.Init()
+ return w
+ },
+}
+
+// GetWaiter returns an unused Waiter. PutWaiter should be called to release
+// the Waiter once it is no longer needed.
+//
+// Where possible, users should prefer to associate each goroutine that calls
+// Waiter.Wait() with a distinct pre-allocated Waiter to avoid allocation of
+// Waiters in hot paths.
+func GetWaiter() *Waiter {
+ return waiterPool.Get().(*Waiter)
+}
+
+// PutWaiter releases an unused Waiter previously returned by GetWaiter.
+func PutWaiter(w *Waiter) {
+ waiterPool.Put(w)
+}
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 2269f6237..4b5a0fca6 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -29,6 +29,7 @@ var (
EACCES = error(syscall.EACCES)
EAGAIN = error(syscall.EAGAIN)
EBADF = error(syscall.EBADF)
+ EBADFD = error(syscall.EBADFD)
EBUSY = error(syscall.EBUSY)
ECHILD = error(syscall.ECHILD)
ECONNREFUSED = error(syscall.ECONNREFUSED)
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index ea0a0409a..3c552988a 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
+ _, err := r.CreateEndpoint(&wq)
if err != nil {
t.Fatalf("r.CreateEndpoint() = %v", err)
}
- defer ep.Close()
- r.Complete(false)
-
- c := NewTCPConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
+ // Endpoint will be closed in deferred s.Close (above).
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
stop = func() {
c1.Close()
c2.Close()
+ s.Close()
+ s.Wait()
}
if err := l.Close(); err != nil {
@@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 150310c11..17e94c562 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) {
vv.views = append(vv.views, vv2.views...)
vv.size += vv2.size
}
+
+// AppendView appends the given view into this vectorised view.
+func (vv *VectorisedView) AppendView(v View) {
+ vv.views = append(vv.views, v)
+ vv.size += len(v)
+}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 4d6ae0871..c6c160dfc 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -161,6 +161,20 @@ func FragmentFlags(flags uint8) NetworkChecker {
}
}
+// ReceiveTClass creates a checker that checks the TCLASS field in
+// ControlMessages.
+func ReceiveTClass(want uint32) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTClass {
+ t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
+ }
+ if got := cm.TClass; got != want {
+ t.Fatalf("got cm.TClass = %d, want %d", got, want)
+ }
+ }
+}
+
// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
func ReceiveTOS(want uint8) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
index bab26580b..d1b73cfdf 100644
--- a/pkg/tcpip/iptables/BUILD
+++ b/pkg/tcpip/iptables/BUILD
@@ -8,7 +8,6 @@ go_library(
"iptables.go",
"targets.go",
"types.go",
- "udp_matcher.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index 1b9485bbd..dbaccbb36 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -52,10 +52,10 @@ func DefaultTables() IPTables {
Tables: map[string]Table{
TablenameNat: Table{
Rules: []Rule{
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
BuiltinChains: map[Hook]int{
@@ -74,8 +74,8 @@ func DefaultTables() IPTables {
},
TablenameMangle: Table{
Rules: []Rule{
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
BuiltinChains: map[Hook]int{
@@ -90,9 +90,9 @@ func DefaultTables() IPTables {
},
TablenameFilter: Table{
Rules: []Rule{
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
BuiltinChains: map[Hook]int{
@@ -135,27 +135,53 @@ func EmptyFilterTable() Table {
}
}
+// A chainVerdict is what a table decides should be done with a packet.
+type chainVerdict int
+
+const (
+ // chainAccept indicates the packet should continue through netstack.
+ chainAccept chainVerdict = iota
+
+ // chainAccept indicates the packet should be dropped.
+ chainDrop
+
+ // chainReturn indicates the packet should return to the calling chain
+ // or the underflow rule of a builtin chain.
+ chainReturn
+)
+
// Check runs pkt through the rules for hook. It returns true when the packet
// should continue traversing the network stack and false when it should be
// dropped.
//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
- // TODO(gvisor.dev/issue/170): A lot of this is uncomplicated because
- // we're missing features. Jumps, the call stack, etc. aren't checked
- // for yet because we're yet to support them.
-
// Go through each table containing the hook.
for _, tablename := range it.Priorities[hook] {
- switch verdict := it.checkTable(hook, pkt, tablename); verdict {
+ table := it.Tables[tablename]
+ ruleIdx := table.BuiltinChains[hook]
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict {
// If the table returns Accept, move on to the next table.
- case Accept:
+ case chainAccept:
continue
// The Drop verdict is final.
- case Drop:
+ case chainDrop:
return false
- case Stolen, Queue, Repeat, None, Jump, Return, Continue:
- panic(fmt.Sprintf("Unimplemented verdict %v.", verdict))
+ case chainReturn:
+ // Any Return from a built-in chain means we have to
+ // call the underflow.
+ underflow := table.Rules[table.Underflows[hook]]
+ switch v, _ := underflow.Target.Action(pkt); v {
+ case RuleAccept:
+ continue
+ case RuleDrop:
+ return false
+ case RuleJump, RuleReturn:
+ panic("Underflows should only return RuleAccept or RuleDrop.")
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", v))
+ }
+
default:
panic(fmt.Sprintf("Unknown verdict %v.", verdict))
}
@@ -166,36 +192,59 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
}
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) Verdict {
+func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
- table := it.Tables[tablename]
- for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ {
- switch verdict := it.checkRule(hook, pkt, table, ruleIdx); verdict {
- // In either of these cases, this table is done with the packet.
- case Accept, Drop:
- return verdict
- // Continue traversing the rules of the table.
- case Continue:
- continue
- case Stolen, Queue, Repeat, None, Jump, Return:
- panic(fmt.Sprintf("Unimplemented verdict %v.", verdict))
+ for ruleIdx < len(table.Rules) {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict {
+ case RuleAccept:
+ return chainAccept
+
+ case RuleDrop:
+ return chainDrop
+
+ case RuleReturn:
+ return chainReturn
+
+ case RuleJump:
+ // "Jumping" to the next rule just means we're
+ // continuing on down the list.
+ if jumpTo == ruleIdx+1 {
+ ruleIdx++
+ continue
+ }
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict {
+ case chainAccept:
+ return chainAccept
+ case chainDrop:
+ return chainDrop
+ case chainReturn:
+ ruleIdx++
+ continue
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
+ }
+
default:
- panic(fmt.Sprintf("Unknown verdict %v.", verdict))
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
}
+
}
- panic(fmt.Sprintf("Traversed past the entire list of iptables rules in table %q.", tablename))
+ // We got through the entire table without a decision. Default to DROP
+ // for safety.
+ return chainDrop
}
// Precondition: pk.NetworkHeader is set.
-func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) Verdict {
+func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// First check whether the packet matches the IP header filter.
// TODO(gvisor.dev/issue/170): Support other fields of the filter.
if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() {
- return Continue
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
}
// Go through each rule matcher. If they all match, run
@@ -203,14 +252,14 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru
for _, matcher := range rule.Matchers {
matches, hotdrop := matcher.Match(hook, pkt, "")
if hotdrop {
- return Drop
+ return RuleDrop, 0
}
if !matches {
- return Continue
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
}
}
// All the matchers matched, so run the target.
- verdict, _ := rule.Target.Action(pkt)
- return verdict
+ return rule.Target.Action(pkt)
}
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
index 4dd281371..81a2e39a2 100644
--- a/pkg/tcpip/iptables/targets.go
+++ b/pkg/tcpip/iptables/targets.go
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// This file contains various Targets.
-
package iptables
import (
@@ -21,20 +19,20 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-// UnconditionalAcceptTarget accepts all packets.
-type UnconditionalAcceptTarget struct{}
+// AcceptTarget accepts packets.
+type AcceptTarget struct{}
// Action implements Target.Action.
-func (UnconditionalAcceptTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) {
- return Accept, ""
+func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
+ return RuleAccept, 0
}
-// UnconditionalDropTarget denies all packets.
-type UnconditionalDropTarget struct{}
+// DropTarget drops packets.
+type DropTarget struct{}
// Action implements Target.Action.
-func (UnconditionalDropTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) {
- return Drop, ""
+func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
+ return RuleDrop, 0
}
// ErrorTarget logs an error and drops the packet. It represents a target that
@@ -42,7 +40,26 @@ func (UnconditionalDropTarget) Action(packet tcpip.PacketBuffer) (Verdict, strin
type ErrorTarget struct{}
// Action implements Target.Action.
-func (ErrorTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) {
- log.Warningf("ErrorTarget triggered.")
- return Drop, ""
+func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
+ log.Debugf("ErrorTarget triggered.")
+ return RuleDrop, 0
+}
+
+// UserChainTarget marks a rule as the beginning of a user chain.
+type UserChainTarget struct {
+ Name string
+}
+
+// Action implements Target.Action.
+func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) {
+ panic("UserChainTarget should never be called.")
+}
+
+// ReturnTarget returns from the current chain. If the chain is a built-in, the
+// hook's underflow should be called.
+type ReturnTarget struct{}
+
+// Action implements Target.Action.
+func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) {
+ return RuleReturn, 0
}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 2ea8994ae..7d032fd23 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -56,44 +56,21 @@ const (
NumHooks
)
-// A Verdict is returned by a rule's target to indicate how traversal of rules
-// should (or should not) continue.
-type Verdict int
+// A RuleVerdict is what a rule decides should be done with a packet.
+type RuleVerdict int
const (
- // Invalid indicates an unkonwn or erroneous verdict.
- Invalid Verdict = iota
+ // RuleAccept indicates the packet should continue through netstack.
+ RuleAccept RuleVerdict = iota
- // Accept indicates the packet should continue traversing netstack as
- // normal.
- Accept
+ // RuleDrop indicates the packet should be dropped.
+ RuleDrop
- // Drop inicates the packet should be dropped, stopping traversing
- // netstack.
- Drop
+ // RuleJump indicates the packet should jump to another chain.
+ RuleJump
- // Stolen indicates the packet was co-opted by the target and should
- // stop traversing netstack.
- Stolen
-
- // Queue indicates the packet should be queued for userspace processing.
- Queue
-
- // Repeat indicates the packet should re-traverse the chains for the
- // current hook.
- Repeat
-
- // None indicates no verdict was reached.
- None
-
- // Jump indicates a jump to another chain.
- Jump
-
- // Continue indicates that traversal should continue at the next rule.
- Continue
-
- // Return indicates that traversal should return to the calling chain.
- Return
+ // RuleReturn indicates the packet should return to the previous chain.
+ RuleReturn
)
// IPTables holds all the tables for a netstack.
@@ -171,6 +148,9 @@ type IPHeaderFilter struct {
// A Matcher is the interface for matching packets.
type Matcher interface {
+ // Name returns the name of the Matcher.
+ Name() string
+
// Match returns whether the packet matches and whether the packet
// should be "hotdropped", i.e. dropped immediately. This is usually
// used for suspicious packets.
@@ -183,6 +163,6 @@ type Matcher interface {
type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
- // Jump, it also returns the name of the chain to jump to.
- Action(packet tcpip.PacketBuffer) (Verdict, string)
+ // Jump, it also returns the index of the rule to jump to.
+ Action(packet tcpip.PacketBuffer) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 3974c464e..b8b93e78e 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = ["channel.go"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 78d447acd..5944ba190 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -20,6 +20,7 @@ package channel
import (
"context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -33,6 +34,118 @@ type PacketInfo struct {
Route stack.Route
}
+// Notification is the interface for receiving notification from the packet
+// queue.
+type Notification interface {
+ // WriteNotify will be called when a write happens to the queue.
+ WriteNotify()
+}
+
+// NotificationHandle is an opaque handle to the registered notification target.
+// It can be used to unregister the notification when no longer interested.
+//
+// +stateify savable
+type NotificationHandle struct {
+ n Notification
+}
+
+type queue struct {
+ // mu protects fields below.
+ mu sync.RWMutex
+ // c is the outbound packet channel. Sending to c should hold mu.
+ c chan PacketInfo
+ numWrite int
+ numRead int
+ notify []*NotificationHandle
+}
+
+func (q *queue) Close() {
+ close(q.c)
+}
+
+func (q *queue) Read() (PacketInfo, bool) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ select {
+ case p := <-q.c:
+ q.numRead++
+ return p, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ // We have to receive from channel without holding the lock, since it can
+ // block indefinitely. This will cause a window that numWrite - numRead
+ // produces a larger number, but won't go to negative. numWrite >= numRead
+ // still holds.
+ select {
+ case pkt := <-q.c:
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.numRead++
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) Write(p PacketInfo) bool {
+ wrote := false
+
+ // It's important to make sure nobody can see numWrite until we increment it,
+ // so numWrite >= numRead holds.
+ q.mu.Lock()
+ select {
+ case q.c <- p:
+ wrote = true
+ q.numWrite++
+ default:
+ }
+ notify := q.notify
+ q.mu.Unlock()
+
+ if wrote {
+ // Send notification outside of lock.
+ for _, h := range notify {
+ h.n.WriteNotify()
+ }
+ }
+ return wrote
+}
+
+func (q *queue) Num() int {
+ q.mu.RLock()
+ defer q.mu.RUnlock()
+ n := q.numWrite - q.numRead
+ if n < 0 {
+ panic("numWrite < numRead")
+ }
+ return n
+}
+
+func (q *queue) AddNotify(notify Notification) *NotificationHandle {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ h := &NotificationHandle{n: notify}
+ q.notify = append(q.notify, h)
+ return h
+}
+
+func (q *queue) RemoveNotify(handle *NotificationHandle) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ // Make a copy, since we reads the array outside of lock when notifying.
+ notify := make([]*NotificationHandle, 0, len(q.notify))
+ for _, h := range q.notify {
+ if h != handle {
+ notify = append(notify, h)
+ }
+ }
+ q.notify = notify
+}
+
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
@@ -41,14 +154,16 @@ type Endpoint struct {
linkAddr tcpip.LinkAddress
LinkEPCapabilities stack.LinkEndpointCapabilities
- // c is where outbound packets are queued.
- c chan PacketInfo
+ // Outbound packet queue.
+ q *queue
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- c: make(chan PacketInfo, size),
+ q: &queue{
+ c: make(chan PacketInfo, size),
+ },
mtu: mtu,
linkAddr: linkAddr,
}
@@ -57,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
// Close closes e. Further packet injections will panic. Reads continue to
// succeed until all packets are read.
func (e *Endpoint) Close() {
- close(e.c)
+ e.q.Close()
}
-// Read does non-blocking read for one packet from the outbound packet queue.
+// Read does non-blocking read one packet from the outbound packet queue.
func (e *Endpoint) Read() (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- default:
- return PacketInfo{}, false
- }
+ return e.q.Read()
}
// ReadContext does blocking read for one packet from the outbound packet queue.
// It can be cancelled by ctx, and in this case, it returns false.
func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- case <-ctx.Done():
- return PacketInfo{}, false
- }
+ return e.q.ReadContext(ctx)
}
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
- select {
- case <-e.c:
- c++
- default:
+ if _, ok := e.Read(); !ok {
return c
}
+ c++
}
}
+// NumQueued returns the number of packet queued for outbound.
+func (e *Endpoint) NumQueued() int {
+ return e.q.Num()
+}
+
// InjectInbound injects an inbound packet.
func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
@@ -155,10 +263,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
Route: route,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
@@ -171,7 +276,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
route.Release()
payloadView := pkts[0].Data.ToView()
n := 0
-packetLoop:
for _, pkt := range pkts {
off := pkt.DataOffset
size := pkt.DataSize
@@ -185,12 +289,10 @@ packetLoop:
Route: route,
}
- select {
- case e.c <- p:
- n++
- default:
- break packetLoop
+ if !e.q.Write(p) {
+ break
}
+ n++
}
return n, nil
@@ -204,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
GSO: nil,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
+
+// AddNotify adds a notification target for receiving event about outgoing
+// packets.
+func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
+ return e.q.AddNotify(notify)
+}
+
+// RemoveNotify removes handle from the list of notification targets.
+func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
+ e.q.RemoveNotify(handle)
+}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index e5096ea38..e0db6cf54 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -4,6 +4,22 @@ package(licenses = ["notice"])
go_library(
name = "tun",
- srcs = ["tun_unsafe.go"],
+ srcs = [
+ "device.go",
+ "protocol.go",
+ "tun_unsafe.go",
+ ],
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/refs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
)
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
new file mode 100644
index 000000000..6ff47a742
--- /dev/null
+++ b/pkg/tcpip/link/tun/device.go
@@ -0,0 +1,352 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // drivers/net/tun.c:tun_net_init()
+ defaultDevMtu = 1500
+
+ // Queue length for outbound packet, arriving at fd side for read. Overflow
+ // causes packet drops. gVisor implementation-specific.
+ defaultDevOutQueueLen = 1024
+)
+
+var zeroMAC [6]byte
+
+// Device is an opened /dev/net/tun device.
+//
+// +stateify savable
+type Device struct {
+ waiter.Queue
+
+ mu sync.RWMutex `state:"nosave"`
+ endpoint *tunEndpoint
+ notifyHandle *channel.NotificationHandle
+ flags uint16
+}
+
+// beforeSave is invoked by stateify.
+func (d *Device) beforeSave() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ // TODO(b/110961832): Restore the device to stack. At this moment, the stack
+ // is not savable.
+ if d.endpoint != nil {
+ panic("/dev/net/tun does not support save/restore when a device is associated with it.")
+ }
+}
+
+// Release implements fs.FileOperations.Release.
+func (d *Device) Release() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Decrease refcount if there is an endpoint associated with this file.
+ if d.endpoint != nil {
+ d.endpoint.RemoveNotify(d.notifyHandle)
+ d.endpoint.DecRef()
+ d.endpoint = nil
+ }
+}
+
+// SetIff services TUNSETIFF ioctl(2) request.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.endpoint != nil {
+ return syserror.EINVAL
+ }
+
+ // Input validations.
+ isTun := flags&linux.IFF_TUN != 0
+ isTap := flags&linux.IFF_TAP != 0
+ supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
+ if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
+ return syserror.EINVAL
+ }
+
+ prefix := "tun"
+ if isTap {
+ prefix = "tap"
+ }
+
+ endpoint, err := attachOrCreateNIC(s, name, prefix)
+ if err != nil {
+ return syserror.EINVAL
+ }
+
+ d.endpoint = endpoint
+ d.notifyHandle = d.endpoint.AddNotify(d)
+ d.flags = flags
+ return nil
+}
+
+func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) {
+ for {
+ // 1. Try to attach to an existing NIC.
+ if name != "" {
+ if nic, found := s.GetNICByName(name); found {
+ endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if !ok {
+ // Not a NIC created by tun device.
+ return nil, syserror.EOPNOTSUPP
+ }
+ if !endpoint.TryIncRef() {
+ // Race detected: NIC got deleted in between.
+ continue
+ }
+ return endpoint, nil
+ }
+ }
+
+ // 2. Creating a new NIC.
+ id := tcpip.NICID(s.UniqueID())
+ endpoint := &tunEndpoint{
+ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
+ stack: s,
+ nicID: id,
+ name: name,
+ }
+ if endpoint.name == "" {
+ endpoint.name = fmt.Sprintf("%s%d", prefix, id)
+ }
+ err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
+ Name: endpoint.name,
+ })
+ switch err {
+ case nil:
+ return endpoint, nil
+ case tcpip.ErrDuplicateNICID:
+ // Race detected: A NIC has been created in between.
+ continue
+ default:
+ return nil, syserror.EINVAL
+ }
+ }
+}
+
+// Write inject one inbound packet to the network interface.
+func (d *Device) Write(data []byte) (int64, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return 0, syserror.EBADFD
+ }
+ if !endpoint.IsAttached() {
+ return 0, syserror.EIO
+ }
+
+ dataLen := int64(len(data))
+
+ // Packet information.
+ var pktInfoHdr PacketInfoHeader
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ if len(data) < PacketInfoHeaderSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize])
+ data = data[PacketInfoHeaderSize:]
+ }
+
+ // Ethernet header (TAP only).
+ var ethHdr header.Ethernet
+ if d.hasFlags(linux.IFF_TAP) {
+ if len(data) < header.EthernetMinimumSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ ethHdr = header.Ethernet(data[:header.EthernetMinimumSize])
+ data = data[header.EthernetMinimumSize:]
+ }
+
+ // Try to determine network protocol number, default zero.
+ var protocol tcpip.NetworkProtocolNumber
+ switch {
+ case pktInfoHdr != nil:
+ protocol = pktInfoHdr.Protocol()
+ case ethHdr != nil:
+ protocol = ethHdr.Type()
+ }
+
+ // Try to determine remote link address, default zero.
+ var remote tcpip.LinkAddress
+ switch {
+ case ethHdr != nil:
+ remote = ethHdr.SourceAddress()
+ default:
+ remote = tcpip.LinkAddress(zeroMAC[:])
+ }
+
+ pkt := tcpip.PacketBuffer{
+ Data: buffer.View(data).ToVectorisedView(),
+ }
+ if ethHdr != nil {
+ pkt.LinkHeader = buffer.View(ethHdr)
+ }
+ endpoint.InjectLinkAddr(protocol, remote, pkt)
+ return dataLen, nil
+}
+
+// Read reads one outgoing packet from the network interface.
+func (d *Device) Read() ([]byte, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return nil, syserror.EBADFD
+ }
+
+ for {
+ info, ok := endpoint.Read()
+ if !ok {
+ return nil, syserror.ErrWouldBlock
+ }
+
+ v, ok := d.encodePkt(&info)
+ if !ok {
+ // Ignore unsupported packet.
+ continue
+ }
+ return v, nil
+ }
+}
+
+// encodePkt encodes packet for fd side.
+func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
+ var vv buffer.VectorisedView
+
+ // Packet information.
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ hdr := make(PacketInfoHeader, PacketInfoHeaderSize)
+ hdr.Encode(&PacketInfoFields{
+ Protocol: info.Proto,
+ })
+ vv.AppendView(buffer.View(hdr))
+ }
+
+ // If the packet does not already have link layer header, and the route
+ // does not exist, we can't compute it. This is possibly a raw packet, tun
+ // device doesn't support this at the moment.
+ if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" {
+ return nil, false
+ }
+
+ // Ethernet header (TAP only).
+ if d.hasFlags(linux.IFF_TAP) {
+ // Add ethernet header if not provided.
+ if info.Pkt.LinkHeader == nil {
+ hdr := &header.EthernetFields{
+ SrcAddr: info.Route.LocalLinkAddress,
+ DstAddr: info.Route.RemoteLinkAddress,
+ Type: info.Proto,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = d.endpoint.LinkAddress()
+ }
+
+ eth := make(header.Ethernet, header.EthernetMinimumSize)
+ eth.Encode(hdr)
+ vv.AppendView(buffer.View(eth))
+ } else {
+ vv.AppendView(info.Pkt.LinkHeader)
+ }
+ }
+
+ // Append upper headers.
+ vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):]))
+ // Append data payload.
+ vv.Append(info.Pkt.Data)
+
+ return vv.ToView(), true
+}
+
+// Name returns the name of the attached network interface. Empty string if
+// unattached.
+func (d *Device) Name() string {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ if d.endpoint != nil {
+ return d.endpoint.name
+ }
+ return ""
+}
+
+// Flags returns the flags set for d. Zero value if unset.
+func (d *Device) Flags() uint16 {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ return d.flags
+}
+
+func (d *Device) hasFlags(flags uint16) bool {
+ return d.flags&flags == flags
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn != 0 {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint != nil && endpoint.NumQueued() == 0 {
+ mask &= ^waiter.EventIn
+ }
+ }
+ return mask & (waiter.EventIn | waiter.EventOut)
+}
+
+// WriteNotify implements channel.Notification.WriteNotify.
+func (d *Device) WriteNotify() {
+ d.Notify(waiter.EventIn)
+}
+
+// tunEndpoint is the link endpoint for the NIC created by the tun device.
+//
+// It is ref-counted as multiple opening files can attach to the same NIC.
+// The last owner is responsible for deleting the NIC.
+type tunEndpoint struct {
+ *channel.Endpoint
+
+ refs.AtomicRefCount
+
+ stack *stack.Stack
+ nicID tcpip.NICID
+ name string
+}
+
+// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
+func (e *tunEndpoint) DecRef() {
+ e.DecRefWithDestructor(func() {
+ e.stack.RemoveNIC(e.nicID)
+ })
+}
diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go
new file mode 100644
index 000000000..89d9d91a9
--- /dev/null
+++ b/pkg/tcpip/link/tun/protocol.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // PacketInfoHeaderSize is the size of the packet information header.
+ PacketInfoHeaderSize = 4
+
+ offsetFlags = 0
+ offsetProtocol = 2
+)
+
+// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is
+// not set.
+type PacketInfoFields struct {
+ Flags uint16
+ Protocol tcpip.NetworkProtocolNumber
+}
+
+// PacketInfoHeader is the wire representation of the packet information sent if
+// IFF_NO_PI flag is not set.
+type PacketInfoHeader []byte
+
+// Encode encodes f into h.
+func (h PacketInfoHeader) Encode(f *PacketInfoFields) {
+ binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags)
+ binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol))
+}
+
+// Flags returns the flag field in h.
+func (h PacketInfoHeader) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[offsetFlags:])
+}
+
+// Protocol returns the protocol field in h.
+func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:]))
+}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 4da13c5df..e9fcc89a8 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
}, nil
}
-// LinkAddressProtocol implements stack.LinkAddressResolver.
+// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
-// LinkAddressRequest implements stack.LinkAddressResolver.
+// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
RemoteLinkAddress: broadcastMAC,
@@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
}
-// ResolveStaticAddress implements stack.LinkAddressResolver.
+// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return broadcastMAC, true
@@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
return tcpip.LinkAddress([]byte(nil)), false
}
-// SetOption implements NetworkProtocol.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.NetworkProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements NetworkProtocol.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.NetworkProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
// NewProtocol returns an ARP network protocol.
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 6597e6781..4f1742938 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 180a480fd..9aef5234b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 045409bda..f651871ce 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -1148,22 +1148,27 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
return true
}
-// cleanupHostOnlyState cleans up any state that is only useful for hosts.
+// cleanupState cleans up ndp's state.
//
-// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a
-// host to a router. This function will invalidate all discovered on-link
-// prefixes, discovered routers, and auto-generated addresses as routers do not
-// normally process Router Advertisements to discover default routers and
-// on-link prefixes, and auto-generate addresses via SLAAC.
+// If hostOnly is true, then only host-specific state will be cleaned up.
+//
+// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
+// transitioning from a host to a router. This function will invalidate all
+// discovered on-link prefixes, discovered routers, and auto-generated
+// addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address will not be
+// invalidated as routers are also expected to generate a link-local address.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupHostOnlyState() {
+func (ndp *ndpState) cleanupState(hostOnly bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
linkLocalAddrs := 0
for addr := range ndp.autoGenAddresses {
// RFC 4862 section 5 states that routers are also expected to generate a
- // link-local address so we do not invalidate them.
- if linkLocalSubnet.Contains(addr) {
+ // link-local address so we do not invalidate them if we are cleaning up
+ // host-only state.
+ if hostOnly && linkLocalSubnet.Contains(addr) {
linkLocalAddrs++
continue
}
@@ -1230,7 +1235,7 @@ func (ndp *ndpState) startSolicitingRouters() {
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
pkt := header.ICMPv6(hdr.Prepend(payloadSize))
pkt.SetType(header.ICMPv6RouterSolicit)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 1f6f77439..6e9306d09 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -267,6 +267,17 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s
}
}
+// channelLinkWithHeaderLength is a channel.Endpoint with a configurable
+// header length.
+type channelLinkWithHeaderLength struct {
+ *channel.Endpoint
+ headerLength uint16
+}
+
+func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
+ return l.headerLength
+}
+
// Check e to make sure that the event is for addr on nic with ID 1, and the
// resolved flag set to resolved with the specified err.
func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
@@ -323,21 +334,46 @@ func TestDADDisabled(t *testing.T) {
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+// This tests also validates the NDP NS packet that is transmitted.
func TestDADResolve(t *testing.T) {
const nicID = 1
tests := []struct {
name string
+ linkHeaderLen uint16
dupAddrDetectTransmits uint8
retransTimer time.Duration
expectedRetransmitTimer time.Duration
}{
- {"1:1s:1s", 1, time.Second, time.Second},
- {"2:1s:1s", 2, time.Second, time.Second},
- {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
+ {
+ name: "1:1s:1s",
+ dupAddrDetectTransmits: 1,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "2:1s:1s",
+ linkHeaderLen: 1,
+ dupAddrDetectTransmits: 2,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "1:2s:2s",
+ linkHeaderLen: 2,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 2 * time.Second,
+ expectedRetransmitTimer: 2 * time.Second,
+ },
// 0s is an invalid RetransmitTimer timer and will be fixed to
// the default RetransmitTimer value of 1s.
- {"1:0s:1s", 1, 0, time.Second},
+ {
+ name: "1:0s:1s",
+ linkHeaderLen: 3,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 0,
+ expectedRetransmitTimer: time.Second,
+ },
}
for _, test := range tests {
@@ -356,10 +392,13 @@ func TestDADResolve(t *testing.T) {
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
- e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
+ if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -445,6 +484,10 @@ func TestDADResolve(t *testing.T) {
checker.NDPNSTargetAddress(addr1),
checker.NDPNSOptions(nil),
))
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
})
}
@@ -592,70 +635,94 @@ func TestDADFail(t *testing.T) {
}
}
-// TestDADStop tests to make sure that the DAD process stops when an address is
-// removed.
func TestDADStop(t *testing.T) {
const nicID = 1
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- ndpConfigs := stack.NDPConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 2,
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
+ tests := []struct {
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests to make sure that DAD stops when an address is removed.
+ {
+ name: "Remove address",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err)
+ }
+ },
+ },
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ // Tests to make sure that DAD stops when the NIC is disabled.
+ {
+ name: "Disable NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
- // Remove the address. This should stop DAD.
- if err := s.RemoveAddress(nicID, addr1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
- }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
- // Wait for DAD to fail (since the address was removed during DAD).
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ test.stopFn(t, s)
+
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
- // Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
- t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Errorf("got NeighborSolicit = %d, want <= 1", got)
+ }
+ })
}
}
@@ -2886,17 +2953,16 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
}
-// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers
-// and prefixes, and non-linklocal auto-generated addresses are invalidated when
-// a NIC becomes a router.
-func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
+// TestCleanupNDPState tests that all discovered routers and prefixes, and
+// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestCleanupNDPState(t *testing.T) {
t.Parallel()
const (
- lifetimeSeconds = 5
- maxEvents = 4
- nicID1 = 1
- nicID2 = 2
+ lifetimeSeconds = 5
+ maxRouterAndPrefixEvents = 4
+ nicID1 = 1
+ nicID2 = 2
)
prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2912,254 +2978,308 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
PrefixLen: 64,
}
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents),
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
+ tests := []struct {
+ name string
+ cleanupFn func(t *testing.T, s *stack.Stack)
+ keepAutoGenLinkLocal bool
+ maxAutoGenAddrEvents int
+ }{
+ // A NIC should still keep its auto-generated link-local address when
+ // becoming a router.
+ {
+ name: "Forwarding Enable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
+ keepAutoGenLinkLocal: true,
+ maxAutoGenAddrEvents: 4,
},
- NDPDisp: &ndpDisp,
- })
- expectRouterEvent := func() (bool, ndpRouterEvent) {
- select {
- case e := <-ndpDisp.routerC:
- return true, e
- default:
- }
+ // A NIC should cleanup all NDP state when it is disabled.
+ {
+ name: "NIC Disable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- return false, ndpRouterEvent{}
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ },
}
- expectPrefixEvent := func() (bool, ndpPrefixEvent) {
- select {
- case e := <-ndpDisp.prefixC:
- return true, e
- default:
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
- return false, ndpPrefixEvent{}
- }
+ expectRouterEvent := func() (bool, ndpRouterEvent) {
+ select {
+ case e := <-ndpDisp.routerC:
+ return true, e
+ default:
+ }
- expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
- select {
- case e := <-ndpDisp.autoGenAddrC:
- return true, e
- default:
- }
+ return false, ndpRouterEvent{}
+ }
- return false, ndpAutoGenAddrEvent{}
- }
+ expectPrefixEvent := func() (bool, ndpPrefixEvent) {
+ select {
+ case e := <-ndpDisp.prefixC:
+ return true, e
+ default:
+ }
- e1 := channel.New(0, 1280, linkAddr1)
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
- }
- // We have other tests that make sure we receive the *correct* events
- // on normal discovery of routers/prefixes, and auto-generated
- // addresses. Here we just make sure we get an event and let other tests
- // handle the correctness check.
- expectAutoGenAddrEvent()
+ return false, ndpPrefixEvent{}
+ }
- e2 := channel.New(0, 1280, linkAddr2)
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
- }
- expectAutoGenAddrEvent()
+ expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ return true, e
+ default:
+ }
- // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
- // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
- // llAddr4) to discover multiple routers and prefixes, and auto-gen
- // multiple addresses.
+ return false, ndpAutoGenAddrEvent{}
+ }
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
- }
+ e1 := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
+ }
+ // We have other tests that make sure we receive the *correct* events
+ // on normal discovery of routers/prefixes, and auto-generated
+ // addresses. Here we just make sure we get an event and let other tests
+ // handle the correctness check.
+ expectAutoGenAddrEvent()
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
- }
+ e2 := channel.New(0, 1280, linkAddr2)
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ }
+ expectAutoGenAddrEvent()
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
- }
+ // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
+ // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
+ // llAddr4) to discover multiple routers and prefixes, and auto-gen
+ // multiple addresses.
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
+ }
- // We should have the auto-generated addresses added.
- nicinfo := s.NICInfo()
- nic1Addrs := nicinfo[nicID1].ProtocolAddresses
- nic2Addrs := nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
+ }
- // We can't proceed any further if we already failed the test (missing
- // some discovery/auto-generated address events or addresses).
- if t.Failed() {
- t.FailNow()
- }
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
+ }
- s.SetForwarding(true)
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
+ }
- // Collect invalidation events after becoming a router
- gotRouterEvents := make(map[ndpRouterEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectRouterEvent()
- if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotRouterEvents[e]++
- }
- gotPrefixEvents := make(map[ndpPrefixEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectPrefixEvent()
- if !ok {
- t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotPrefixEvents[e]++
- }
- gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectAutoGenAddrEvent()
- if !ok {
- t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotAutoGenAddrEvents[e]++
- }
+ // We should have the auto-generated addresses added.
+ nicinfo := s.NICInfo()
+ nic1Addrs := nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs := nicinfo[nicID2].ProtocolAddresses
+ if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
- // No need to proceed any further if we already failed the test (missing
- // some invalidation events).
- if t.Failed() {
- t.FailNow()
- }
+ // We can't proceed any further if we already failed the test (missing
+ // some discovery/auto-generated address events or addresses).
+ if t.Failed() {
+ t.FailNow()
+ }
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
- }
- expectedPrefixEvents := map[ndpPrefixEvent]int{
- {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
- t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
- }
- expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
- {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
- }
- if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
- t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
- }
+ test.cleanupFn(t, s)
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ // Collect invalidation events after having NDP state cleaned up.
+ gotRouterEvents := make(map[ndpRouterEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectRouterEvent()
+ if !ok {
+ t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotRouterEvents[e]++
+ }
+ gotPrefixEvents := make(map[ndpPrefixEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectPrefixEvent()
+ if !ok {
+ t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotPrefixEvents[e]++
+ }
+ gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
+ for i := 0; i < test.maxAutoGenAddrEvents; i++ {
+ ok, e := expectAutoGenAddrEvent()
+ if !ok {
+ t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i)
+ break
+ }
+ gotAutoGenAddrEvents[e]++
+ }
- // Should not get any more events (invalidation timers should have been
- // cancelled when we transitioned into a router).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
- select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
- default:
- }
- select {
- case <-ndpDisp.prefixC:
- t.Error("unexpected prefix event")
- default:
- }
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Error("unexpected auto-generated address event")
- default:
+ // No need to proceed any further if we already failed the test (missing
+ // some invalidation events).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ expectedRouterEvents := map[ndpRouterEvent]int{
+ {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
+ t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ }
+ expectedPrefixEvents := map[ndpPrefixEvent]int{
+ {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
+ t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
+ }
+ expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
+ {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
+ }
+
+ if !test.keepAutoGenLinkLocal {
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1
+ }
+
+ if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
+ t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
+ }
+
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ }
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+
+ // Should not get any more events (invalidation timers should have been
+ // cancelled when the NDP state was cleaned up).
+ time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ select {
+ case <-ndpDisp.routerC:
+ t.Error("unexpected router event")
+ default:
+ }
+ select {
+ case <-ndpDisp.prefixC:
+ t.Error("unexpected prefix event")
+ default:
+ }
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Error("unexpected auto-generated address event")
+ default:
+ }
+ })
}
}
@@ -3259,8 +3379,11 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
func TestRouterSolicitation(t *testing.T) {
t.Parallel()
+ const nicID = 1
+
tests := []struct {
name string
+ linkHeaderLen uint16
maxRtrSolicit uint8
rtrSolicitInt time.Duration
effectiveRtrSolicitInt time.Duration
@@ -3277,6 +3400,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS with delay",
+ linkHeaderLen: 1,
maxRtrSolicit: 2,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3285,6 +3409,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Single RS without delay",
+ linkHeaderLen: 2,
maxRtrSolicit: 1,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3293,6 +3418,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS without delay and invalid zero interval",
+ linkHeaderLen: 3,
maxRtrSolicit: 2,
rtrSolicitInt: 0,
effectiveRtrSolicitInt: 4 * time.Second,
@@ -3330,8 +3456,11 @@ func TestRouterSolicitation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
- e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
waitForPkt := func(timeout time.Duration) {
t.Helper()
ctx, _ := context.WithTimeout(context.Background(), timeout)
@@ -3357,6 +3486,10 @@ func TestRouterSolicitation(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPRS(),
)
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
@@ -3373,8 +3506,8 @@ func TestRouterSolicitation(t *testing.T) {
MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
// Make sure each RS got sent at the right
@@ -3406,77 +3539,130 @@ func TestRouterSolicitation(t *testing.T) {
})
}
-// TestStopStartSolicitingRouters tests that when forwarding is enabled or
-// disabled, router solicitations are stopped or started, respecitively.
func TestStopStartSolicitingRouters(t *testing.T) {
t.Parallel()
+ const nicID = 1
const interval = 500 * time.Millisecond
const delay = time.Second
const maxRtrSolicitations = 3
- e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
+ tests := []struct {
+ name string
+ startFn func(t *testing.T, s *stack.Stack)
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests that when forwarding is enabled or disabled, router solicitations
+ // are stopped or started, respectively.
+ {
+ name: "Forwarding enabled and disabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(false)
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
- // Enable forwarding which should stop router solicitations.
- s.SetForwarding(true)
- ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- // A single RS may have been sent before forwarding was enabled.
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok = e.ReadContext(ctx); ok {
- t.Fatal("Should not have sent more than one RS message")
- }
- }
+ // Tests that when a NIC is enabled or disabled, router solicitations
+ // are started or stopped, respectively.
+ {
+ name: "NIC disabled and enabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Enabling forwarding again should do nothing.
- s.SetForwarding(true)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
- }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Disable forwarding which should start router solicitations.
- s.SetForwarding(false)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- // Disabling forwarding again should do nothing.
- s.SetForwarding(false)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Stop soliciting routers.
+ test.stopFn(t, s)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ // A single RS may have been sent before forwarding was enabled.
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok = e.ReadContext(ctx); ok {
+ t.Fatal("should not have sent more than one RS message")
+ }
+ }
+
+ // Stopping router solicitations after it has already been stopped should
+ // do nothing.
+ test.stopFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
+ }
+
+ // Start soliciting routers.
+ test.startFn(t, s)
+ waitForPkt(delay + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ }
+
+ // Starting router solicitations after it has already completed should do
+ // nothing.
+ test.startFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after finishing router solicitations")
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 78d451cca..46d3a6646 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -27,6 +27,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+var ipv4BroadcastAddr = tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 8 * header.IPv4AddressSize,
+ },
+}
+
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
@@ -132,11 +140,77 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
}
+ nic.linkEP.Attach(nic)
+
return nic
}
-// enable enables the NIC. enable will attach the link to its LinkEndpoint and
-// join the IPv6 All-Nodes Multicast address (ff02::1).
+// enabled returns true if n is enabled.
+func (n *NIC) enabled() bool {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ return enabled
+}
+
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if !enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if !n.mu.enabled {
+ return nil
+ }
+
+ // TODO(b/147015577): Should Routes that are currently bound to n be
+ // invalidated? Currently, Routes will continue to work when a NIC is enabled
+ // again, and applications may not know that the underlying NIC was ever
+ // disabled.
+
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ n.mu.ndp.stopSolicitingRouters()
+ n.mu.ndp.cleanupState(false /* hostOnly */)
+
+ // Stop DAD for all the unicast IPv6 endpoints that are in the
+ // permanentTentative state.
+ for _, r := range n.mu.endpoints {
+ if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
+ n.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+ }
+
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ // The address may have already been removed.
+ if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ n.mu.enabled = false
+ return nil
+}
+
+// enable enables n.
+//
+// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast
+// address (ff02::1), start DAD for permanent addresses, and start soliciting
+// routers if the stack is not operating as a router. If the stack is also
+// configured to auto-generate a link-local address, one will be generated.
func (n *NIC) enable() *tcpip.Error {
n.mu.RLock()
enabled := n.mu.enabled
@@ -154,14 +228,9 @@ func (n *NIC) enable() *tcpip.Error {
n.mu.enabled = true
- n.attachLinkEndpoint()
-
// Create an endpoint to receive broadcast packets on this interface.
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
- }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
+ if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
}
@@ -183,6 +252,14 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
// state.
//
@@ -200,10 +277,6 @@ func (n *NIC) enable() *tcpip.Error {
}
}
- if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
- return err
- }
-
// Do not auto-generate an IPv6 link-local address for loopback devices.
if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
// The valid and preferred lifetime is infinite for the auto-generated
@@ -225,6 +298,33 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+// remove detaches NIC from the link endpoint, and marks existing referenced
+// network endpoints expired. This guarantees no packets between this NIC and
+// the network stack.
+func (n *NIC) remove() *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ // Remove permanent and permanentTentative addresses, so no packet goes out.
+ var errs []*tcpip.Error
+ for nid, ref := range n.mu.endpoints {
+ switch ref.getKind() {
+ case permanentTentative, permanent:
+ if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
+ errs = append(errs, err)
+ }
+ }
+ }
+ if len(errs) > 0 {
+ return errs[0]
+ }
+
+ return nil
+}
+
// becomeIPv6Router transitions n into an IPv6 router.
//
// When transitioning into an IPv6 router, host-only state (NDP discovered
@@ -234,7 +334,7 @@ func (n *NIC) becomeIPv6Router() {
n.mu.Lock()
defer n.mu.Unlock()
- n.mu.ndp.cleanupHostOnlyState()
+ n.mu.ndp.cleanupState(true /* hostOnly */)
n.mu.ndp.stopSolicitingRouters()
}
@@ -249,12 +349,6 @@ func (n *NIC) becomeIPv6Host() {
n.mu.ndp.startSolicitingRouters()
}
-// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
-// to start delivering packets.
-func (n *NIC) attachLinkEndpoint() {
- n.linkEP.Attach(n)
-}
-
// setPromiscuousMode enables or disables promiscuous mode.
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
@@ -712,6 +806,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
case permanentExpired, temporary:
continue
}
+
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -1009,6 +1104,15 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
return nil
}
+// isInGroup returns true if n has joined the multicast group addr.
+func (n *NIC) isInGroup(addr tcpip.Address) bool {
+ n.mu.RLock()
+ joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
+ n.mu.RUnlock()
+
+ return joins != 0
+}
+
func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
@@ -1215,11 +1319,21 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+// Name returns the name of n.
+func (n *NIC) Name() string {
+ return n.name
+}
+
// Stack returns the instance of the Stack that owns this NIC.
func (n *NIC) Stack() *Stack {
return n.stack
}
+// LinkEndpoint returns the link endpoint of n.
+func (n *NIC) LinkEndpoint() LinkEndpoint {
+ return n.linkEP
+}
+
// isAddrTentative returns true if addr is tentative on n.
//
// Note that if addr is not associated with n, then this function will return
@@ -1406,7 +1520,7 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
//
// r's NIC must be read locked.
func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- return r.getKind() != permanentExpired || r.nic.mu.spoofing
+ return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ec91f60dd..f9fd8f18f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -74,10 +74,11 @@ type TransportEndpoint interface {
// HandleControlPacket takes ownership of pkt.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
- // Close puts the endpoint in a closed state and frees all resources
- // associated with it. This cleanup may happen asynchronously. Wait can
- // be used to block on this asynchronous cleanup.
- Close()
+ // Abort initiates an expedited endpoint teardown. It puts the endpoint
+ // in a closed state and frees all resources associated with it. This
+ // cleanup may happen asynchronously. Wait can be used to block on this
+ // asynchronous cleanup.
+ Abort()
// Wait waits for any worker goroutines owned by the endpoint to stop.
//
@@ -160,6 +161,13 @@ type TransportProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -277,7 +285,7 @@ type NetworkProtocol interface {
// DefaultPrefixLen returns the protocol's default prefix length.
DefaultPrefixLen() int
- // ParsePorts returns the source and destination addresses stored in a
+ // ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
@@ -293,6 +301,13 @@ type NetworkProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// NetworkDispatcher contains the methods used by the network stack to deliver
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index b793f1d74..ebb6c5e3b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -881,6 +881,8 @@ type NICOptions struct {
// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
// NICOptions. See the documentation on type NICOptions for details on how
// NICs can be configured.
+//
+// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -890,8 +892,16 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, opts.Name, ep, opts.Context)
+ // Make sure name is unique, unless unnamed.
+ if opts.Name != "" {
+ for _, n := range s.nics {
+ if n.Name() == opts.Name {
+ return tcpip.ErrDuplicateNICID
+ }
+ }
+ }
+ n := newNIC(s, id, opts.Name, ep, opts.Context)
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -901,34 +911,88 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
-// `LinkEndpoint.Attach` to start delivering packets to it.
+// LinkEndpoint.Attach to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
+// GetNICByName gets the NIC specified by name.
+func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, nic := range s.nics {
+ if nic.Name() == name {
+ return nic, true
+ }
+ }
+ return nil, false
+}
+
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
+ nic, ok := s.nics[id]
+ if !ok {
return tcpip.ErrUnknownNICID
}
return nic.enable()
}
+// DisableNIC disables the given NIC.
+func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.disable()
+}
+
// CheckNIC checks if a NIC is usable.
func (s *Stack) CheckNIC(id tcpip.NICID) bool {
s.mu.RLock()
+ defer s.mu.RUnlock()
+
nic, ok := s.nics[id]
- s.mu.RUnlock()
- if ok {
- return nic.linkEP.IsAttached()
+ if !ok {
+ return false
+ }
+
+ return nic.enabled()
+}
+
+// RemoveNIC removes NIC and all related routes from the network stack.
+func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ delete(s.nics, id)
+
+ // Remove routes in-place. n tracks the number of routes written.
+ n := 0
+ for i, r := range s.routeTable {
+ if r.NIC != id {
+ // Keep this route.
+ if i > n {
+ s.routeTable[n] = r
+ }
+ n++
+ }
}
- return false
+ s.routeTable = s.routeTable[:n]
+
+ return nic.remove()
}
// NICAddressRanges returns a map of NICIDs to their associated subnets.
@@ -980,7 +1044,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
- Running: nic.linkEP.IsAttached(),
+ Running: nic.enabled(),
Promiscuous: nic.isPromiscuousMode(),
Loopback: nic.isLoopback(),
}
@@ -1142,7 +1206,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
- if nic, ok := s.nics[id]; ok {
+ if nic, ok := s.nics[id]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
@@ -1152,7 +1216,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
- if nic, ok := s.nics[route.NIC]; ok {
+ if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
@@ -1382,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
// Endpoints created or modified during this call may not get closed.
func (s *Stack) Close() {
for _, e := range s.RegisteredEndpoints() {
- e.Close()
+ e.Abort()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Close()
+ }
+ for _, p := range s.networkProtocols {
+ p.Close()
}
}
@@ -1400,6 +1470,12 @@ func (s *Stack) Wait() {
for _, e := range s.CleanupEndpoints() {
e.Wait()
}
+ for _, p := range s.transportProtocols {
+ p.proto.Wait()
+ }
+ for _, p := range s.networkProtocols {
+ p.Wait()
+ }
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1605,6 +1681,18 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
return tcpip.ErrUnknownNICID
}
+// IsInGroup returns true if the NIC with ID nicID has joined the multicast
+// group multicastAddr.
+func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.isInGroup(multicastAddr), nil
+ }
+ return false, tcpip.ErrUnknownNICID
+}
+
// IPTables returns the stack's iptables.
func (s *Stack) IPTables() iptables.IPTables {
s.tablesMu.RLock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 24133e6f2..edf6bec52 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -234,10 +235,33 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements TransportProtocol.Close.
+func (*fakeNetworkProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeNetworkProtocol) Wait() {}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
+// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
+// that LinkEndpoint.Attach was called.
+type linkEPWithMockedAttach struct {
+ stack.LinkEndpoint
+ attached bool
+}
+
+// Attach implements stack.LinkEndpoint.Attach.
+func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
+ l.LinkEndpoint.Attach(d)
+ l.attached = true
+}
+
+func (l *linkEPWithMockedAttach) isAttached() bool {
+ return l.attached
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
@@ -509,6 +533,296 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr
}
}
+// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to
+// a NetworkDispatcher when the NIC is created.
+func TestAttachToLinkEndpointImmediately(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ nicOpts stack.NICOptions
+ }{
+ {
+ name: "Create enabled NIC",
+ nicOpts: stack.NICOptions{Disabled: false},
+ },
+ {
+ name: "Create disabled NIC",
+ nicOpts: stack.NICOptions{Disabled: true},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+
+ if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
+ }
+ if !e.isAttached() {
+ t.Fatalf("link endpoint not attached to a network disatcher")
+ }
+ })
+ }
+}
+
+func TestDisableUnknownNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ }
+}
+
+func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := loopback.New()
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ checkNIC := func(enabled bool) {
+ t.Helper()
+
+ allNICInfo := s.NICInfo()
+ nicInfo, ok := allNICInfo[nicID]
+ if !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ } else if nicInfo.Flags.Running != enabled {
+ t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled)
+ }
+
+ if got := s.CheckNIC(nicID); got != enabled {
+ t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled)
+ }
+ }
+
+ // NIC should initially report itself as disabled.
+ checkNIC(false)
+
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(true)
+
+ // If the NIC is not reporting a correct enabled status, we cannot trust the
+ // next check so end the test here.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(false)
+}
+
+func TestRoutesWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Disabling NIC1 should result in no routes to odd addresses. Routes to even
+ // addresses should continue to be available as NIC2 is still enabled.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ nic1Dst := tcpip.Address("\x05")
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ nic2Dst := tcpip.Address("\x06")
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Disabling NIC2 should result in no routes to even addresses. No route
+ // should be available to any address as routes to odd addresses were made
+ // unavailable by disabling NIC1 above.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ // Enabling NIC1 should make routes to odd addresses available again. Routes
+ // to even addresses should continue to be unavailable as NIC2 is still
+ // disabled.
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+}
+
+func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ nic1Dst := tcpip.Address("\x05")
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
+
+ nic2Dst := tcpip.Address("\x06")
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
+
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC1 should fail.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC2 should fail.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ // Writes with Routes that use the re-enabled NIC1 should succeed.
+ // TODO(b/147015577): Should we instead completely invalidate all Routes that
+ // were bound to a disabled NIC at some point?
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+}
+
func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
@@ -1792,6 +2106,91 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
verifyAddresses(t, expectedAddresses, gotAddresses)
}
+func TestCreateNICWithOptions(t *testing.T) {
+ type callArgsAndExpect struct {
+ nicID tcpip.NICID
+ opts stack.NICOptions
+ err *tcpip.Error
+ }
+
+ tests := []struct {
+ desc string
+ calls []callArgsAndExpect
+ }{
+ {
+ desc: "DuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth1"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth2"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "DuplicateName",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "lo"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{Name: "lo"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "Unnamed",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ },
+ },
+ {
+ desc: "UnnamedDuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ for _, call := range test.calls {
+ if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want {
+ t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want)
+ }
+ }
+ })
+ }
+}
+
func TestNICStats(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
@@ -2088,13 +2487,29 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
e := channel.New(0, 1280, test.linkAddr)
s := stack.New(opts)
- nicOpts := stack.NICOptions{Name: test.nicName}
+ nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true}
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
}
- var expectedMainAddr tcpip.AddressWithPrefix
+ // A new disabled NIC should not have any address, even if auto generation
+ // was enabled.
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should attempt auto-generation of a link-local
+ // address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ var expectedMainAddr tcpip.AddressWithPrefix
if test.shouldGen {
expectedMainAddr = tcpip.AddressWithPrefix{
Address: test.expectedAddr,
@@ -2524,6 +2939,111 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
}
+func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should add the IPv4 broadcast address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 1 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
+ }
+ want := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
+ if allNICAddrs[0] != want {
+ t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
+ }
+
+ // Disabling the NIC should remove the IPv4 broadcast address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+}
+
+func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
+ // not been enabled yet.
+ isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+}
+
// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
// was disabled have DAD performed on them when the NIC is enabled.
func TestDoDADWhenNICEnabled(t *testing.T) {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index d686e6eb8..778c0a4d6 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// Close implements stack.TransportEndpoint.Close.
-func (ep *multiPortEndpoint) Close() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Close()
- }
-}
-
-// Wait implements stack.TransportEndpoint.Wait.
-func (ep *multiPortEndpoint) Wait() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Wait()
- }
-}
-
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 869c69a6d..5d1da2f8b 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
+func (f *fakeTransportEndpoint) Abort() {
+ f.Close()
+}
+
func (f *fakeTransportEndpoint) Close() {
f.route.Release()
}
@@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
-func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Abort implements TransportProtocol.Abort.
+func (*fakeTransportProtocol) Abort() {}
+
+// Close implements tcpip.Endpoint.Close.
+func (*fakeTransportProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeTransportProtocol) Wait() {}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 0e944712f..3dc5d87d6 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -323,11 +323,17 @@ type ControlMessages struct {
// TOS is the IPv4 type of service of the associated packet.
TOS uint8
- // HasTClass indicates whether Tclass is valid/set.
+ // HasTClass indicates whether TClass is valid/set.
HasTClass bool
- // Tclass is the IPv6 traffic class of the associated packet.
- TClass int32
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo IPPacketInfo
}
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
@@ -335,9 +341,15 @@ type ControlMessages struct {
// networking stack.
type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources
- // associated with it.
+ // associated with it. Close initiates the teardown process, the
+ // Endpoint may not be fully closed when Close returns.
Close()
+ // Abort initiates an expedited endpoint teardown. As compared to
+ // Close, Abort prioritizes closing the Endpoint quickly over cleanly.
+ // Abort is best effort; implementing Abort with Close is acceptable.
+ Abort()
+
// Read reads data from the endpoint and optionally returns the sender.
//
// This method does not block if there is no data pending. It will also
@@ -496,13 +508,25 @@ type WriteOptions struct {
type SockOptBool int
const (
+ // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the
+ // IPV6_TCLASS ancillary message is passed with incoming packets.
+ ReceiveTClassOption SockOptBool = iota
+
// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
// ancillary message is passed with incoming packets.
- ReceiveTOSOption SockOptBool = iota
+ ReceiveTOSOption
// V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
V6OnlyOption
+
+ // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify
+ // if more inforamtion is provided with incoming packets such
+ // as interface index and address.
+ ReceiveIPPacketInfoOption
+
+ // TODO(b/146901447): convert existing bool socket options to be handled via
+ // Get/SetSockOptBool
)
// SockOptInt represents socket options which values have the int type.
@@ -685,6 +709,20 @@ type IPv4TOSOption uint8
// for all subsequent outgoing IPv6 packets from the endpoint.
type IPv6TrafficClassOption uint8
+// IPPacketInfo is the message struture for IP_PKTINFO.
+//
+// +stateify savable
+type IPPacketInfo struct {
+ // NIC is the ID of the NIC to be used.
+ NIC NICID
+
+ // LocalAddr is the local address.
+ LocalAddr Address
+
+ // DestinationAddr is the destination address.
+ DestinationAddr Address
+}
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 48764b978..2f98a996f 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -25,6 +25,8 @@ import (
)
// StdClock implements Clock with the time package.
+//
+// +stateify savable
type StdClock struct{}
var _ Clock = (*StdClock)(nil)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 42afb3f5b..426da1ee6 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 9ce500e80..113d92901 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index fc5bc69fa..5722815e9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
return ep, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (ep *endpoint) Close() {
ep.mu.Lock()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ee9c4c58b..2ef5fac76 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (e *endpoint) Close() {
e.mu.Lock()
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 08afb7c17..13e383ffc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.Close()
+ // Wake up any waiters. This is strictly not required normally
+ // as a socket that was never accepted can't really have any
+ // registered waiters except when stack.Wait() is called which
+ // waits for all registered endpoints to stop and expects an
+ // EventHUp.
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
@@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Unlock()
// Notify waiters that the endpoint is shutdown.
- e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
s := sleep.Sleeper{}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5c5397823..7730e6445 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.snd.updateMaxPayloadSize(mtu, count)
}
- if n&notifyReset != 0 {
+ if n&notifyReset != 0 || n&notifyAbort != 0 {
return tcpip.ErrConnectionAborted
}
@@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 {
+ if n&notifyClose != 0 || n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index e18012ac0..d792b07d6 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -68,17 +68,28 @@ func (q *epQueue) empty() bool {
type processor struct {
epQ epQueue
newEndpointWaker sleep.Waker
+ closeWaker sleep.Waker
id int
+ wg sync.WaitGroup
}
func newProcessor(id int) *processor {
p := &processor{
id: id,
}
+ p.wg.Add(1)
go p.handleSegments()
return p
}
+func (p *processor) close() {
+ p.closeWaker.Assert()
+}
+
+func (p *processor) wait() {
+ p.wg.Wait()
+}
+
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
@@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) {
func (p *processor) handleSegments() {
const newEndpointWaker = 1
+ const closeWaker = 2
s := sleep.Sleeper{}
s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ s.AddWaker(&p.closeWaker, closeWaker)
defer s.Done()
for {
- s.Fetch(true)
+ id, ok := s.Fetch(true)
+ if ok && id == closeWaker {
+ p.wg.Done()
+ return
+ }
for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
if ep.segmentQueue.empty() {
continue
@@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher {
}
}
+func (d *dispatcher) close() {
+ for _, p := range d.processors {
+ p.close()
+ }
+}
+
+func (d *dispatcher) wait() {
+ for _, p := range d.processors {
+ p.wait()
+ }
+}
+
func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index f2be0e651..f1ad19dac 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -121,6 +121,8 @@ const (
notifyDrain
notifyReset
notifyResetByPeer
+ // notifyAbort is a request for an expedited teardown.
+ notifyAbort
notifyKeepaliveChanged
notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a
@@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
}
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ // The abort notification is not processed synchronously, so no
+ // synchronization is needed.
+ //
+ // If the endpoint becomes connected after this check, we still close
+ // the endpoint. This worst case results in a slower abort.
+ //
+ // If the endpoint disconnected after the check, nothing needs to be
+ // done, so sending a notification which will potentially be ignored is
+ // fine.
+ if e.EndpointState().connected() {
+ e.notifyProtocolGoroutine(notifyAbort)
+ return
+ }
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources associated
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
@@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() {
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e)
- if !e.workerRunning {
+ switch e.EndpointState() {
+ // Sockets in StateSynRecv state(passive connections) are closed when
+ // the handshake fails or if the listening socket is closed while
+ // handshake was in progress. In such cases the handshake goroutine
+ // is already gone by the time Close is called and we need to cleanup
+ // here.
+ case StateInitial, StateBound, StateSynRecv:
e.cleanupLocked()
- } else {
+ e.setEndpointState(StateClose)
+ case StateError, StateClose:
+ // do nothing.
+ default:
e.workerCleanup = true
e.notifyProtocolGoroutine(notifyClose)
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 958c06fa7..73098d904 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -194,7 +194,7 @@ func replyWithReset(s *segment) {
sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
-// SetOption implements TransportProtocol.SetOption.
+// SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
case SACKEnabled:
@@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
}
-// Option implements TransportProtocol.Option.
+// Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
@@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements stack.TransportProtocol.Close.
+func (p *protocol) Close() {
+ p.dispatcher.close()
+}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (p *protocol) Wait() {
+ p.dispatcher.wait()
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index cc118c993..5b2b16afa 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -543,8 +543,9 @@ func TestCurrentConnectedIncrement(t *testing.T) {
),
)
- // Wait for the TIME-WAIT state to transition to CLOSED.
- time.Sleep(1 * time.Second)
+ // Wait for a little more than the TIME-WAIT duration for the socket to
+ // transition to CLOSED state.
+ time.Sleep(1200 * time.Millisecond)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index c9cbed8f4..1c6a600b8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -29,9 +29,11 @@ import (
type udpPacket struct {
udpPacketEntry
senderAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- tos uint8
+ // tos stores either the receiveTOS or receiveTClass value.
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -118,6 +120,13 @@ type endpoint struct {
// as ancillary data to ControlMessages on Read.
receiveTOS bool
+ // receiveTClass determines if the incoming IPv6 TClass header field is
+ // passed as ancillary data to ControlMessages on Read.
+ receiveTClass bool
+
+ // receiveIPPacketInfo determines if the packet info is returned by Read.
+ receiveIPPacketInfo bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -177,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -254,11 +268,22 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
}
e.mu.RLock()
receiveTOS := e.receiveTOS
+ receiveTClass := e.receiveTClass
+ receiveIPPacketInfo := e.receiveIPPacketInfo
e.mu.RUnlock()
if receiveTOS {
cm.HasTOS = true
cm.TOS = p.tos
}
+ if receiveTClass {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
+ if receiveIPPacketInfo {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
return p.data.ToView(), cm, nil
}
@@ -480,6 +505,17 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrNotSupported
+ }
+
+ e.mu.Lock()
+ e.receiveTClass = v
+ e.mu.Unlock()
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -495,6 +531,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
}
e.v6only = v
+ return nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.Lock()
+ e.receiveIPPacketInfo = v
+ e.mu.Unlock()
+ return nil
}
return nil
@@ -692,6 +735,17 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrNotSupported
+ }
+
+ e.mu.RLock()
+ v := e.receiveTClass
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -703,6 +757,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.RLock()
+ v := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ return v, nil
}
return false, tcpip.ErrUnknownProtocolOption
@@ -1247,6 +1307,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
switch r.NetProto {
case header.IPv4ProtocolNumber:
packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.RemoteAddress
+ packet.packetInfo.NIC = r.NICID()
+ case header.IPv6ProtocolNumber:
+ packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
packet.timestamp = e.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 259c3072a..8df089d22 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index f0ff3fe71..34b7c2360 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -409,6 +409,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
@@ -1336,7 +1337,7 @@ func TestSetTTL(t *testing.T) {
}
}
-func TestTOSV4(t *testing.T) {
+func TestSetTOS(t *testing.T) {
for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1347,23 +1348,23 @@ func TestTOSV4(t *testing.T) {
const tos = testTOS
var v tcpip.IPv4TOSOption
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
}
if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err)
}
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
if want := tcpip.IPv4TOSOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
}
testWrite(c, flow, checker.TOS(tos, 0))
@@ -1371,7 +1372,7 @@ func TestTOSV4(t *testing.T) {
}
}
-func TestTOSV6(t *testing.T) {
+func TestSetTClass(t *testing.T) {
for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1379,71 +1380,92 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = testTOS
+ const tClass = testTOS
var v tcpip.IPv6TrafficClassOption
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt failed: %s", err)
+ if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil {
+ c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err)
}
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
- if want := tcpip.IPv6TrafficClassOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if want := tcpip.IPv6TrafficClassOption(tClass); v != want {
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
}
- testWrite(c, flow, checker.TOS(tos, 0))
+ // The header getter for TClass is called TOS, so use that checker.
+ testWrite(c, flow, checker.TOS(tClass, 0))
})
}
}
-func TestReceiveTOSV4(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, broadcast} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
+func TestReceiveTosTClass(t *testing.T) {
+ testCases := []struct {
+ name string
+ getReceiveOption tcpip.SockOptBool
+ tests []testFlow
+ }{
+ {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
+ {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ }
+ for _, testCase := range testCases {
+ for _, flow := range testCase.tests {
+ t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- c.createEndpointForFlow(flow)
+ c.createEndpointForFlow(flow)
+ option := testCase.getReceiveOption
+ name := testCase.name
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
- }
- // Test for expected default value.
- if v != false {
- c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false)
- }
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
+ }
- want := true
- if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil {
- c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err)
- }
+ want := true
+ if err := c.ep.SetSockOptBool(option, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
+ }
- got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
- }
- if got != want {
- c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want)
- }
+ got, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ }
- // Verify that the correct received TOS is handed through as
- // ancillary data to the ControlMessages struct.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
- testRead(c, flow, checker.ReceiveTOS(testTOS))
- })
+ if got != want {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
+ }
+
+ // Verify that the correct received TOS or TClass is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+ switch option {
+ case tcpip.ReceiveTClassOption:
+ testRead(c, flow, checker.ReceiveTClass(testTOS))
+ case tcpip.ReceiveTOSOption:
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ default:
+ t.Fatalf("unknown test variant: %s", name)
+ }
+ })
+ }
}
}
diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD
index ff8b9e91a..6c9ada9c7 100644
--- a/pkg/usermem/BUILD
+++ b/pkg/usermem/BUILD
@@ -25,7 +25,6 @@ go_library(
"bytes_io_unsafe.go",
"usermem.go",
"usermem_arm64.go",
- "usermem_unsafe.go",
"usermem_x86.go",
],
visibility = ["//:sandbox"],
@@ -33,6 +32,7 @@ go_library(
"//pkg/atomicbitops",
"//pkg/binary",
"//pkg/context",
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/safemem",
"//pkg/syserror",
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 71fd4e155..d2f4403b0 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -251,7 +252,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
}
end, ok := addr.AddLength(uint64(readlen))
if !ok {
- return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
+ return gohacks.StringFromImmutableBytes(buf[:done]), syserror.EFAULT
}
// Shorten the read to avoid crossing page boundaries, since faulting
// in a page unnecessarily is expensive. This also ensures that partial
@@ -272,16 +273,16 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
// Look for the terminating zero byte, which may have occurred before
// hitting err.
if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
+ return gohacks.StringFromImmutableBytes(buf[:done+i]), nil
}
done += n
if err != nil {
- return stringFromImmutableBytes(buf[:done]), err
+ return gohacks.StringFromImmutableBytes(buf[:done]), err
}
addr = end
}
- return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG
+ return gohacks.StringFromImmutableBytes(buf), syserror.ENAMETOOLONG
}
// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The