summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/elf.go14
-rw-r--r--pkg/abi/linux/file.go2
-rw-r--r--pkg/abi/linux/netlink.go6
-rw-r--r--pkg/abi/linux/netlink_route.go136
-rw-r--r--pkg/abi/linux/socket.go21
-rw-r--r--pkg/cpuid/cpuid.go203
-rw-r--r--pkg/cpuid/cpuid_test.go24
-rw-r--r--pkg/eventchannel/BUILD13
-rw-r--r--pkg/eventchannel/event.go64
-rw-r--r--pkg/eventchannel/event_test.go146
-rw-r--r--pkg/eventchannel/rate.go54
-rw-r--r--pkg/fdnotifier/BUILD5
-rw-r--r--pkg/fdnotifier/fdnotifier.go3
-rw-r--r--pkg/fdnotifier/poll_unsafe.go8
-rw-r--r--pkg/flipcall/BUILD5
-rw-r--r--pkg/flipcall/ctrl_futex.go146
-rw-r--r--pkg/flipcall/endpoint_unsafe.go238
-rw-r--r--pkg/flipcall/flipcall.go219
-rw-r--r--pkg/flipcall/flipcall_example_test.go20
-rw-r--r--pkg/flipcall/flipcall_test.go299
-rw-r--r--pkg/flipcall/flipcall_unsafe.go69
-rw-r--r--pkg/flipcall/futex_linux.go103
-rw-r--r--pkg/flipcall/io.go113
-rw-r--r--pkg/flipcall/packet_window_allocator.go6
-rw-r--r--pkg/refs/refcounter.go22
-rw-r--r--pkg/seccomp/seccomp_test_victim.go2
-rw-r--r--pkg/sentry/control/proc.go17
-rw-r--r--pkg/sentry/fs/attr.go44
-rw-r--r--pkg/sentry/fs/context.go24
-rw-r--r--pkg/sentry/fs/ext/BUILD54
-rw-r--r--pkg/sentry/fs/ext/ext.go97
-rw-r--r--pkg/sentry/fs/ext/ext_test.go407
-rw-r--r--pkg/sentry/fs/ext/filesystem.go137
-rw-r--r--pkg/sentry/fs/ext/inode.go209
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go2
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go111
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached_test.go10
-rw-r--r--pkg/sentry/fs/gofer/fs.go22
-rw-r--r--pkg/sentry/fs/gofer/session.go27
-rw-r--r--pkg/sentry/fs/host/inode.go6
-rw-r--r--pkg/sentry/fs/host/socket.go8
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go18
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go9
-rw-r--r--pkg/sentry/fs/mounts.go16
-rw-r--r--pkg/sentry/fs/proc/net.go4
-rw-r--r--pkg/sentry/fs/ramfs/dir.go23
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go6
-rw-r--r--pkg/sentry/fs/tty/BUILD1
-rw-r--r--pkg/sentry/fs/tty/dir.go3
-rw-r--r--pkg/sentry/fs/tty/master.go17
-rw-r--r--pkg/sentry/fs/tty/slave.go13
-rw-r--r--pkg/sentry/fs/tty/terminal.go92
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD86
-rw-r--r--pkg/sentry/fsimpl/ext/README.md117
-rw-r--r--pkg/sentry/fsimpl/ext/assets/README.md (renamed from pkg/sentry/fs/ext/assets/README.md)0
-rw-r--r--pkg/sentry/fsimpl/ext/assets/bigfile.txt (renamed from pkg/sentry/fs/ext/assets/bigfile.txt)0
-rw-r--r--pkg/sentry/fsimpl/ext/assets/file.txt (renamed from pkg/sentry/fs/ext/assets/file.txt)0
l---------pkg/sentry/fsimpl/ext/assets/symlink.txt (renamed from pkg/sentry/fs/ext/assets/symlink.txt)0
-rw-r--r--pkg/sentry/fsimpl/ext/assets/tiny.ext2 (renamed from pkg/sentry/fs/ext/assets/tiny.ext2)bin65536 -> 65536 bytes
-rw-r--r--pkg/sentry/fsimpl/ext/assets/tiny.ext3 (renamed from pkg/sentry/fs/ext/assets/tiny.ext3)bin65536 -> 65536 bytes
-rw-r--r--pkg/sentry/fsimpl/ext/assets/tiny.ext4 (renamed from pkg/sentry/fs/ext/assets/tiny.ext4)bin65536 -> 65536 bytes
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/BUILD16
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go193
-rwxr-xr-xpkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh72
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go200
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go157
-rw-r--r--pkg/sentry/fsimpl/ext/dentry.go (renamed from pkg/sentry/fs/ext/dentry.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go308
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD (renamed from pkg/sentry/fs/ext/disklayout/BUILD)2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group.go (renamed from pkg/sentry/fs/ext/disklayout/block_group.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_32.go (renamed from pkg/sentry/fs/ext/disklayout/block_group_32.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_64.go (renamed from pkg/sentry/fs/ext/disklayout/block_group_64.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_test.go (renamed from pkg/sentry/fs/ext/disklayout/block_group_test.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent.go (renamed from pkg/sentry/fs/ext/disklayout/dirent.go)3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_new.go (renamed from pkg/sentry/fs/ext/disklayout/dirent_new.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_old.go (renamed from pkg/sentry/fs/ext/disklayout/dirent_old.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_test.go (renamed from pkg/sentry/fs/ext/disklayout/dirent_test.go)6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/disklayout.go (renamed from pkg/sentry/fs/ext/disklayout/disklayout.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go (renamed from pkg/sentry/fs/ext/disklayout/extent.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go (renamed from pkg/sentry/fs/ext/disklayout/extent_test.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode.go (renamed from pkg/sentry/fs/ext/disklayout/inode.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_new.go (renamed from pkg/sentry/fs/ext/disklayout/inode_new.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_old.go (renamed from pkg/sentry/fs/ext/disklayout/inode_old.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_test.go (renamed from pkg/sentry/fs/ext/disklayout/inode_test.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock.go (renamed from pkg/sentry/fs/ext/disklayout/superblock.go)2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_32.go (renamed from pkg/sentry/fs/ext/disklayout/superblock_32.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_64.go (renamed from pkg/sentry/fs/ext/disklayout/superblock_64.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_old.go (renamed from pkg/sentry/fs/ext/disklayout/superblock_old.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_test.go (renamed from pkg/sentry/fs/ext/disklayout/superblock_test.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/test_utils.go (renamed from pkg/sentry/fs/ext/disklayout/test_utils.go)0
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go135
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go917
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go237
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go (renamed from pkg/sentry/fs/ext/extent_test.go)136
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go86
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go443
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go219
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go159
-rw-r--r--pkg/sentry/fsimpl/ext/symlink.go111
-rw-r--r--pkg/sentry/fsimpl/ext/utils.go (renamed from pkg/sentry/fs/ext/utils.go)35
-rw-r--r--pkg/sentry/fsimpl/memfs/BUILD4
-rw-r--r--pkg/sentry/fsimpl/memfs/directory.go55
-rw-r--r--pkg/sentry/fsimpl/memfs/filesystem.go68
-rw-r--r--pkg/sentry/fsimpl/memfs/memfs.go93
-rw-r--r--pkg/sentry/fsimpl/memfs/regular_file.go7
-rw-r--r--pkg/sentry/fsimpl/memfs/symlink.go4
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD49
-rw-r--r--pkg/sentry/fsimpl/proc/filesystems.go25
-rw-r--r--pkg/sentry/fsimpl/proc/loadavg.go40
-rw-r--r--pkg/sentry/fsimpl/proc/meminfo.go77
-rw-r--r--pkg/sentry/fsimpl/proc/mounts.go33
-rw-r--r--pkg/sentry/fsimpl/proc/net.go338
-rw-r--r--pkg/sentry/fsimpl/proc/net_test.go78
-rw-r--r--pkg/sentry/fsimpl/proc/proc.go16
-rw-r--r--pkg/sentry/fsimpl/proc/stat.go127
-rw-r--r--pkg/sentry/fsimpl/proc/sys.go51
-rw-r--r--pkg/sentry/fsimpl/proc/task.go261
-rw-r--r--pkg/sentry/fsimpl/proc/version.go68
-rw-r--r--pkg/sentry/inet/inet.go52
-rw-r--r--pkg/sentry/inet/test_stack.go10
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/kernel.go157
-rw-r--r--pkg/sentry/kernel/sessions.go12
-rw-r--r--pkg/sentry/kernel/task_block.go12
-rw-r--r--pkg/sentry/kernel/task_context.go11
-rw-r--r--pkg/sentry/kernel/task_start.go3
-rw-r--r--pkg/sentry/kernel/thread_group.go179
-rw-r--r--pkg/sentry/kernel/tty.go28
-rw-r--r--pkg/sentry/loader/elf.go2
-rw-r--r--pkg/sentry/loader/loader.go107
-rw-r--r--pkg/sentry/mm/procfs.go92
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go13
-rw-r--r--pkg/sentry/platform/ptrace/BUILD6
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go2
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_amd64.go (renamed from pkg/flipcall/endpoint_futex.go)34
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_arm64.go30
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go46
-rw-r--r--pkg/sentry/platform/ptrace/stub_arm64.s106
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go15
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go24
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go126
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go (renamed from pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go)3
-rw-r--r--pkg/sentry/safemem/io.go55
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/epsocket/BUILD2
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go136
-rw-r--r--pkg/sentry/socket/epsocket/stack.go57
-rw-r--r--pkg/sentry/socket/hostinet/socket.go37
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go10
-rw-r--r--pkg/sentry/socket/hostinet/stack.go93
-rw-r--r--pkg/sentry/socket/netfilter/BUILD24
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go286
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go62
-rw-r--r--pkg/sentry/socket/netlink/socket.go45
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/BUILD3
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/notifier.go3
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go23
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go31
-rw-r--r--pkg/sentry/socket/socket.go40
-rw-r--r--pkg/sentry/socket/unix/io.go4
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go2
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go2
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go36
-rw-r--r--pkg/sentry/socket/unix/unix.go20
-rw-r--r--pkg/sentry/state/BUILD1
-rw-r--r--pkg/sentry/state/state.go5
-rw-r--r--pkg/sentry/strace/socket.go2
-rw-r--r--pkg/sentry/syscalls/linux/error.go4
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go19
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go24
-rw-r--r--pkg/sentry/syscalls/linux/sys_mount.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go27
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go3
-rw-r--r--pkg/sentry/vfs/BUILD10
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go122
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go141
-rw-r--r--pkg/sentry/vfs/testutil.go139
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go44
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go65
-rw-r--r--pkg/tcpip/header/ipv4.go9
-rw-r--r--pkg/tcpip/header/ipv6.go23
-rw-r--r--pkg/tcpip/iptables/BUILD5
-rw-r--r--pkg/tcpip/iptables/iptables.go4
-rw-r--r--pkg/tcpip/iptables/types.go19
-rw-r--r--pkg/tcpip/link/fdbased/BUILD4
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go179
-rw-r--r--pkg/tcpip/link/fdbased/mmap_amd64.go194
-rw-r--r--pkg/tcpip/link/fdbased/mmap_stub.go23
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go (renamed from pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go)2
-rw-r--r--pkg/tcpip/link/rawfile/BUILD4
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s7
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s42
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go31
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_unsafe.go8
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go (renamed from pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go)10
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go6
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go5
-rw-r--r--pkg/tcpip/network/arp/arp.go5
-rw-r--r--pkg/tcpip/network/arp/arp_test.go63
-rw-r--r--pkg/tcpip/network/ip_test.go6
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go1
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go28
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go6
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go33
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD1
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go5
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go9
-rw-r--r--pkg/tcpip/stack/BUILD24
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go253
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go79
-rw-r--r--pkg/tcpip/stack/nic.go306
-rw-r--r--pkg/tcpip/stack/route.go10
-rw-r--r--pkg/tcpip/stack/stack.go78
-rw-r--r--pkg/tcpip/stack/stack_test.go734
-rw-r--r--pkg/tcpip/stack/transport_test.go61
-rw-r--r--pkg/tcpip/tcpip.go101
-rw-r--r--pkg/tcpip/tcpip_test.go32
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go22
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go11
-rw-r--r--pkg/tcpip/transport/raw/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go26
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go16
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go89
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go86
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go161
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/snd.go41
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go60
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go1104
-rw-r--r--pkg/unet/BUILD2
-rw-r--r--pkg/unet/unet_unsafe.go19
244 files changed, 11525 insertions, 3597 deletions
diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go
index fb1c679d2..40f0459a0 100644
--- a/pkg/abi/linux/elf.go
+++ b/pkg/abi/linux/elf.go
@@ -89,3 +89,17 @@ const (
// AT_SYSINFO_EHDR is the address of the VDSO.
AT_SYSINFO_EHDR = 33
)
+
+// ELF ET_CORE and ptrace GETREGSET/SETREGSET register set types.
+//
+// See include/uapi/linux/elf.h.
+const (
+ // NT_PRSTATUS is for general purpose register.
+ NT_PRSTATUS = 0x1
+
+ // NT_PRFPREG is for float point register.
+ NT_PRFPREG = 0x2
+
+ // NT_X86_XSTATE is for x86 extended state using xsave.
+ NT_X86_XSTATE = 0x202
+)
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index 615e72646..7d742871a 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -178,6 +178,8 @@ const (
// Values for preadv2/pwritev2.
const (
+ // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is
+ // accepted as a valid flag argument for preadv2/pwritev2.
RWF_HIPRI = 0x00000001
RWF_DSYNC = 0x00000002
RWF_SYNC = 0x00000004
diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go
index e8b6544b4..0ba086c76 100644
--- a/pkg/abi/linux/netlink.go
+++ b/pkg/abi/linux/netlink.go
@@ -122,3 +122,9 @@ const (
NETLINK_EXT_ACK = 11
NETLINK_DUMP_STRICT_CHK = 12
)
+
+// NetlinkErrorMessage is struct nlmsgerr, from uapi/linux/netlink.h.
+type NetlinkErrorMessage struct {
+ Error int32
+ Header NetlinkMessageHeader
+}
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
index dd698e2bc..152f6b144 100644
--- a/pkg/abi/linux/netlink_route.go
+++ b/pkg/abi/linux/netlink_route.go
@@ -189,3 +189,139 @@ const (
const (
ARPHRD_LOOPBACK = 772
)
+
+// RouteMessage struct rtmsg, from uapi/linux/rtnetlink.h.
+type RouteMessage struct {
+ Family uint8
+ DstLen uint8
+ SrcLen uint8
+ TOS uint8
+
+ Table uint8
+ Protocol uint8
+ Scope uint8
+ Type uint8
+
+ Flags uint32
+}
+
+// Route types, from uapi/linux/rtnetlink.h.
+const (
+ // RTN_UNSPEC represents an unspecified route type.
+ RTN_UNSPEC = 0
+
+ // RTN_UNICAST represents a unicast route.
+ RTN_UNICAST = 1
+
+ // RTN_LOCAL represents a route that is accepted locally.
+ RTN_LOCAL = 2
+
+ // RTN_BROADCAST represents a broadcast route (Traffic is accepted locally
+ // as broadcast, and sent as broadcast).
+ RTN_BROADCAST = 3
+
+ // RTN_ANYCAST represents a anycast route (Traffic is accepted locally as
+ // broadcast but sent as unicast).
+ RTN_ANYCAST = 6
+
+ // RTN_MULTICAST represents a multicast route.
+ RTN_MULTICAST = 5
+
+ // RTN_BLACKHOLE represents a route where all traffic is dropped.
+ RTN_BLACKHOLE = 6
+
+ // RTN_UNREACHABLE represents a route where the destination is unreachable.
+ RTN_UNREACHABLE = 7
+
+ RTN_PROHIBIT = 8
+ RTN_THROW = 9
+ RTN_NAT = 10
+ RTN_XRESOLVE = 11
+)
+
+// Route protocols/origins, from uapi/linux/rtnetlink.h.
+const (
+ RTPROT_UNSPEC = 0
+ RTPROT_REDIRECT = 1
+ RTPROT_KERNEL = 2
+ RTPROT_BOOT = 3
+ RTPROT_STATIC = 4
+ RTPROT_GATED = 8
+ RTPROT_RA = 9
+ RTPROT_MRT = 10
+ RTPROT_ZEBRA = 11
+ RTPROT_BIRD = 12
+ RTPROT_DNROUTED = 13
+ RTPROT_XORP = 14
+ RTPROT_NTK = 15
+ RTPROT_DHCP = 16
+ RTPROT_MROUTED = 17
+ RTPROT_BABEL = 42
+ RTPROT_BGP = 186
+ RTPROT_ISIS = 187
+ RTPROT_OSPF = 188
+ RTPROT_RIP = 189
+ RTPROT_EIGRP = 192
+)
+
+// Route scopes, from uapi/linux/rtnetlink.h.
+const (
+ RT_SCOPE_UNIVERSE = 0
+ RT_SCOPE_SITE = 200
+ RT_SCOPE_LINK = 253
+ RT_SCOPE_HOST = 254
+ RT_SCOPE_NOWHERE = 255
+)
+
+// Route flags, from uapi/linux/rtnetlink.h.
+const (
+ RTM_F_NOTIFY = 0x100
+ RTM_F_CLONED = 0x200
+ RTM_F_EQUALIZE = 0x400
+ RTM_F_PREFIX = 0x800
+ RTM_F_LOOKUP_TABLE = 0x1000
+ RTM_F_FIB_MATCH = 0x2000
+)
+
+// Route tables, from uapi/linux/rtnetlink.h.
+const (
+ RT_TABLE_UNSPEC = 0
+ RT_TABLE_COMPAT = 252
+ RT_TABLE_DEFAULT = 253
+ RT_TABLE_MAIN = 254
+ RT_TABLE_LOCAL = 255
+)
+
+// Route attributes, from uapi/linux/rtnetlink.h.
+const (
+ RTA_UNSPEC = 0
+ RTA_DST = 1
+ RTA_SRC = 2
+ RTA_IIF = 3
+ RTA_OIF = 4
+ RTA_GATEWAY = 5
+ RTA_PRIORITY = 6
+ RTA_PREFSRC = 7
+ RTA_METRICS = 8
+ RTA_MULTIPATH = 9
+ RTA_PROTOINFO = 10
+ RTA_FLOW = 11
+ RTA_CACHEINFO = 12
+ RTA_SESSION = 13
+ RTA_MP_ALGO = 14
+ RTA_TABLE = 15
+ RTA_MARK = 16
+ RTA_MFC_STATS = 17
+ RTA_VIA = 18
+ RTA_NEWDST = 19
+ RTA_PREF = 20
+ RTA_ENCAP_TYPE = 21
+ RTA_ENCAP = 22
+ RTA_EXPIRES = 23
+ RTA_PAD = 24
+ RTA_UID = 25
+ RTA_TTL_PROPAGATE = 26
+ RTA_IP_PROTO = 27
+ RTA_SPORT = 28
+ RTA_DPORT = 29
+)
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 6d22002c4..d5b731390 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -267,6 +267,20 @@ type SockAddrUnix struct {
Path [UnixPathMax]int8
}
+// SockAddr represents a union of valid socket address types. This is logically
+// equivalent to struct sockaddr. SockAddr ensures that a well-defined set of
+// types can be used as socket addresses.
+type SockAddr interface {
+ // implementsSockAddr exists purely to allow a type to indicate that they
+ // implement this interface. This method is a no-op and shouldn't be called.
+ implementsSockAddr()
+}
+
+func (s *SockAddrInet) implementsSockAddr() {}
+func (s *SockAddrInet6) implementsSockAddr() {}
+func (s *SockAddrUnix) implementsSockAddr() {}
+func (s *SockAddrNetlink) implementsSockAddr() {}
+
// Linger is struct linger, from include/linux/socket.h.
type Linger struct {
OnOff int32
@@ -278,7 +292,10 @@ const SizeOfLinger = 8
// TCPInfo is a collection of TCP statistics.
//
-// From uapi/linux/tcp.h.
+// From uapi/linux/tcp.h. Newer versions of Linux continue to add new fields to
+// the end of this struct or within existing unusued space, so its size grows
+// over time. The current iteration is based on linux v4.17. New versions are
+// always backwards compatible.
type TCPInfo struct {
State uint8
CaState uint8
@@ -352,7 +369,7 @@ type TCPInfo struct {
}
// SizeOfTCPInfo is the binary size of a TCPInfo struct.
-const SizeOfTCPInfo = 104
+var SizeOfTCPInfo = int(binary.Size(TCPInfo{}))
// Control message types, from linux/socket.h.
const (
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index 3fabaf445..5d61dc2ff 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -418,6 +418,73 @@ var x86FeatureParseOnlyStrings = map[Feature]string{
X86FeaturePREFETCHWT1: "prefetchwt1",
}
+// intelCacheDescriptors describe the caches and TLBs on the system. They are
+// returned in the registers for eax=2. Intel only.
+type intelCacheDescriptor uint8
+
+// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol.
+// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors".
+const (
+ intelNullDescriptor intelCacheDescriptor = 0
+ intelNoTLBDescriptor intelCacheDescriptor = 0xfe
+ intelNoCacheDescriptor intelCacheDescriptor = 0xff
+
+ // Most descriptors omitted for brevity as they are currently unused.
+)
+
+// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4.
+type CacheType uint8
+
+const (
+ // cacheNull indicates that there are no more entries.
+ cacheNull CacheType = iota
+
+ // CacheData is a data cache.
+ CacheData
+
+ // CacheInstruction is an instruction cache.
+ CacheInstruction
+
+ // CacheUnified is a unified instruction and data cache.
+ CacheUnified
+)
+
+// Cache describes the parameters of a single cache on the system.
+//
+// +stateify savable
+type Cache struct {
+ // Level is the hierarchical level of this cache (L1, L2, etc).
+ Level uint32
+
+ // Type is the type of cache.
+ Type CacheType
+
+ // FullyAssociative indicates that entries may be placed in any block.
+ FullyAssociative bool
+
+ // Partitions is the number of physical partitions in the cache.
+ Partitions uint32
+
+ // Ways is the number of ways of associativity in the cache.
+ Ways uint32
+
+ // Sets is the number of sets in the cache.
+ Sets uint32
+
+ // InvalidateHierarchical indicates that WBINVD/INVD from threads
+ // sharing this cache acts upon lower level caches for threads sharing
+ // this cache.
+ InvalidateHierarchical bool
+
+ // Inclusive indicates that this cache is inclusive of lower cache
+ // levels.
+ Inclusive bool
+
+ // DirectMapped indicates that this cache is directly mapped from
+ // address, rather than using a hash function.
+ DirectMapped bool
+}
+
// Just a way to wrap cpuid function numbers.
type cpuidFunction uint32
@@ -494,7 +561,7 @@ func (f Feature) flagString(cpuinfoOnly bool) string {
return ""
}
-// FeatureSet is a set of Features for a cpu.
+// FeatureSet is a set of Features for a CPU.
//
// +stateify savable
type FeatureSet struct {
@@ -521,6 +588,15 @@ type FeatureSet struct {
// SteppingID is part of the processor signature.
SteppingID uint8
+
+ // Caches describes the caches on the CPU.
+ Caches []Cache
+
+ // CacheLine is the size of a cache line in bytes.
+ //
+ // All caches use the same line size. This is not enforced in the CPUID
+ // encoding, but is true on all known x86 processors.
+ CacheLine uint32
}
// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
@@ -557,22 +633,27 @@ func (fs FeatureSet) CPUInfo(cpu uint) string {
fmt.Fprintln(&b, "wp\t\t: yes")
fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true))
fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
- fmt.Fprintf(&b, "clflush size\t: %d\n", 64)
- fmt.Fprintf(&b, "cache_alignment\t: %d\n", 64)
+ fmt.Fprintf(&b, "clflush size\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(&b, "cache_alignment\t: %d\n", fs.CacheLine)
fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
fmt.Fprintln(&b, "power management:") // This is always here, but can be blank.
fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline.
return b.String()
}
+const (
+ amdVendorID = "AuthenticAMD"
+ intelVendorID = "GenuineIntel"
+)
+
// AMD returns true if fs describes an AMD CPU.
func (fs *FeatureSet) AMD() bool {
- return fs.VendorID == "AuthenticAMD"
+ return fs.VendorID == amdVendorID
}
// Intel returns true if fs describes an Intel CPU.
func (fs *FeatureSet) Intel() bool {
- return fs.VendorID == "GenuineIntel"
+ return fs.VendorID == intelVendorID
}
// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
@@ -589,9 +670,18 @@ func (e ErrIncompatible) Error() string {
// CheckHostCompatible returns nil if fs is a subset of the host feature set.
func (fs *FeatureSet) CheckHostCompatible() error {
hfs := HostFeatureSet()
+
if diff := fs.Subtract(hfs); diff != nil {
return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
}
+
+ // The size of a cache line must match, as it is critical to correctly
+ // utilizing CLFLUSH. Other cache properties are allowed to change, as
+ // they are not important to correctness.
+ if fs.CacheLine != hfs.CacheLine {
+ return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)}
+ }
+
return nil
}
@@ -732,14 +822,6 @@ func (fs *FeatureSet) HasFeature(feature Feature) bool {
return fs.Set[feature]
}
-// IsSubset returns true if the FeatureSet is a subset of the FeatureSet passed in.
-// This is useful if you want to see if a FeatureSet is compatible with another
-// FeatureSet, since you can only run with a given FeatureSet if it's a subset of
-// the host's.
-func (fs *FeatureSet) IsSubset(other *FeatureSet) bool {
- return fs.Subtract(other) == nil
-}
-
// Subtract returns the features present in fs that are not present in other.
// If all features in fs are present in other, Subtract returns nil.
func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
@@ -755,17 +837,6 @@ func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
return
}
-// TakeFeatureIntersection will set the features in `fs` to the intersection of
-// the features in `fs` and `other` (effectively clearing any feature bits on
-// `fs` that are not also set in `other`).
-func (fs *FeatureSet) TakeFeatureIntersection(other *FeatureSet) {
- for f := range fs.Set {
- if !other.Set[f] {
- delete(fs.Set, f)
- }
- }
-}
-
// EmulateID emulates a cpuid instruction based on the feature set.
func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
switch cpuidFunction(origAx) {
@@ -773,9 +844,8 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
bx, dx, cx = fs.vendorIDRegs()
case featureInfo:
- // clflush line size (ebx bits[15:8]) hardcoded as 8. This
- // means cache lines of size 64 bytes.
- bx = 8 << 8
+ // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported.
+ bx = (fs.CacheLine / 8) << 8
cx = fs.blockMask(block(0))
dx = fs.blockMask(block(1))
ax = fs.signature()
@@ -789,10 +859,46 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
// will always return 01H. Software should ignore this value
// and not interpret it as an informational descriptor." - SDM
//
- // We do not support exposing cache information, but we do set
- // this fixed field because some language runtimes (dlang) get
- // confused by ax = 0 and will loop infinitely.
- ax = 1
+ // We only support reporting cache parameters via
+ // intelDeterministicCacheParams; report as much here.
+ //
+ // We do not support exposing TLB information at all.
+ ax = 1 | (uint32(intelNoCacheDescriptor) << 8)
+ case intelDeterministicCacheParams:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // cx is the index of the cache to describe.
+ if int(origCx) >= len(fs.Caches) {
+ return uint32(cacheNull), 0, 0, 0
+ }
+ c := fs.Caches[origCx]
+
+ ax = uint32(c.Type)
+ ax |= c.Level << 5
+ ax |= 1 << 8 // Always claim the cache is "self-initializing".
+ if c.FullyAssociative {
+ ax |= 1 << 9
+ }
+ // Processor topology not supported.
+
+ bx = fs.CacheLine - 1
+ bx |= (c.Partitions - 1) << 12
+ bx |= (c.Ways - 1) << 22
+
+ cx = c.Sets - 1
+
+ if !c.InvalidateHierarchical {
+ dx |= 1
+ }
+ if c.Inclusive {
+ dx |= 1 << 1
+ }
+ if !c.DirectMapped {
+ dx |= 1 << 2
+ }
case xSaveInfo:
if !fs.UseXsave() {
return 0, 0, 0, 0
@@ -845,10 +951,41 @@ func HostFeatureSet() *FeatureSet {
vendorID := vendorIDFromRegs(bx, cx, dx)
// eax=1 gets basic features in ecx:edx.
- ax, _, cx, dx := HostID(1, 0)
+ ax, bx, cx, dx := HostID(1, 0)
featureBlock0 := cx
featureBlock1 := dx
ef, em, pt, f, m, sid := signatureSplit(ax)
+ cacheLine := 8 * (bx >> 8) & 0xff
+
+ // eax=4, ecx=i gets details about cache index i. Only supported on Intel.
+ var caches []Cache
+ if vendorID == intelVendorID {
+ // ecx selects the cache index until a null type is returned.
+ for i := uint32(0); ; i++ {
+ ax, bx, cx, dx := HostID(4, i)
+ t := CacheType(ax & 0xf)
+ if t == cacheNull {
+ break
+ }
+
+ lineSize := (bx & 0xfff) + 1
+ if lineSize != cacheLine {
+ panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine))
+ }
+
+ caches = append(caches, Cache{
+ Type: t,
+ Level: (ax >> 5) & 0x7,
+ FullyAssociative: ((ax >> 9) & 1) == 1,
+ Partitions: ((bx >> 12) & 0x3ff) + 1,
+ Ways: ((bx >> 22) & 0x3ff) + 1,
+ Sets: cx + 1,
+ InvalidateHierarchical: (dx & 1) == 0,
+ Inclusive: ((dx >> 1) & 1) == 1,
+ DirectMapped: ((dx >> 2) & 1) == 0,
+ })
+ }
+ }
// eax=7, ecx=0 gets extended features in ecx:ebx.
_, bx, cx, _ = HostID(7, 0)
@@ -883,6 +1020,8 @@ func HostFeatureSet() *FeatureSet {
Family: f,
Model: m,
SteppingID: sid,
+ CacheLine: cacheLine,
+ Caches: caches,
}
}
diff --git a/pkg/cpuid/cpuid_test.go b/pkg/cpuid/cpuid_test.go
index 6ae14d2da..a707ebb55 100644
--- a/pkg/cpuid/cpuid_test.go
+++ b/pkg/cpuid/cpuid_test.go
@@ -57,24 +57,13 @@ var justFPUandPAE = &FeatureSet{
X86FeaturePAE: true,
}}
-func TestIsSubset(t *testing.T) {
- if !justFPU.IsSubset(justFPUandPAE) {
- t.Errorf("Got %v is not subset of %v, want IsSubset being true", justFPU, justFPUandPAE)
+func TestSubtract(t *testing.T) {
+ if diff := justFPU.Subtract(justFPUandPAE); diff != nil {
+ t.Errorf("Got %v is not subset of %v, want diff (%v) to be nil", justFPU, justFPUandPAE, diff)
}
- if justFPUandPAE.IsSubset(justFPU) {
- t.Errorf("Got %v is a subset of %v, want IsSubset being false", justFPU, justFPUandPAE)
- }
-}
-
-func TestTakeFeatureIntersection(t *testing.T) {
- testFeatures := HostFeatureSet()
- testFeatures.TakeFeatureIntersection(justFPU)
- if !testFeatures.IsSubset(justFPU) {
- t.Errorf("Got more features than expected after intersecting host features with justFPU: %v, want %v", testFeatures.Set, justFPU.Set)
- }
- if !testFeatures.HasFeature(X86FeatureFPU) {
- t.Errorf("Got no features in testFeatures after intersecting, want %v", X86FeatureFPU)
+ if justFPUandPAE.Subtract(justFPU) == nil {
+ t.Errorf("Got %v is a subset of %v, want diff to be nil", justFPU, justFPUandPAE)
}
}
@@ -83,7 +72,7 @@ func TestTakeFeatureIntersection(t *testing.T) {
// if HostFeatureSet gives back junk bits.
func TestHostFeatureSet(t *testing.T) {
hostFeatures := HostFeatureSet()
- if !justFPUandPAE.IsSubset(hostFeatures) {
+ if justFPUandPAE.Subtract(hostFeatures) != nil {
t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures)
}
}
@@ -175,6 +164,7 @@ func TestEmulateIDBasicFeatures(t *testing.T) {
testFeatures := newEmptyFeatureSet()
testFeatures.Add(X86FeatureCLFSH)
testFeatures.Add(X86FeatureAVX)
+ testFeatures.CacheLine = 64
ax, bx, cx, dx := testFeatures.EmulateID(1, 0)
ECXAVXBit := uint32(1 << uint(X86FeatureAVX))
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index 4c336ea84..9961baaa9 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
package(licenses = ["notice"])
@@ -7,6 +7,7 @@ go_library(
name = "eventchannel",
srcs = [
"event.go",
+ "rate.go",
],
importpath = "gvisor.dev/gvisor/pkg/eventchannel",
visibility = ["//:sandbox"],
@@ -16,6 +17,7 @@ go_library(
"//pkg/unet",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_golang_protobuf//ptypes:go_default_library_gen",
+ "@org_golang_x_time//rate:go_default_library",
],
)
@@ -30,3 +32,12 @@ go_proto_library(
proto = ":eventchannel_proto",
visibility = ["//:sandbox"],
)
+
+go_test(
+ name = "eventchannel_test",
+ srcs = ["event_test.go"],
+ embed = [":eventchannel"],
+ deps = [
+ "@com_github_golang_protobuf//proto:go_default_library",
+ ],
+)
diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go
index f6d26532b..d37ad0428 100644
--- a/pkg/eventchannel/event.go
+++ b/pkg/eventchannel/event.go
@@ -43,18 +43,36 @@ type Emitter interface {
Close() error
}
-var (
- mu sync.Mutex
- emitters = make(map[Emitter]struct{})
-)
+// DefaultEmitter is the default emitter. Calls to Emit and AddEmitter are sent
+// to this Emitter.
+var DefaultEmitter = &multiEmitter{}
-// Emit emits a message using all added emitters.
+// Emit is a helper method that calls DefaultEmitter.Emit.
func Emit(msg proto.Message) error {
- mu.Lock()
- defer mu.Unlock()
+ _, err := DefaultEmitter.Emit(msg)
+ return err
+}
+
+// AddEmitter is a helper method that calls DefaultEmitter.AddEmitter.
+func AddEmitter(e Emitter) {
+ DefaultEmitter.AddEmitter(e)
+}
+
+// multiEmitter is an Emitter that forwards messages to multiple Emitters.
+type multiEmitter struct {
+ // mu protects emitters.
+ mu sync.Mutex
+ // emitters is initialized lazily in AddEmitter.
+ emitters map[Emitter]struct{}
+}
+
+// Emit emits a message using all added emitters.
+func (me *multiEmitter) Emit(msg proto.Message) (bool, error) {
+ me.mu.Lock()
+ defer me.mu.Unlock()
var err error
- for e := range emitters {
+ for e := range me.emitters {
hangup, eerr := e.Emit(msg)
if eerr != nil {
if err == nil {
@@ -68,18 +86,36 @@ func Emit(msg proto.Message) error {
}
if hangup {
log.Infof("Hangup on eventchannel emitter %v.", e)
- delete(emitters, e)
+ delete(me.emitters, e)
}
}
- return err
+ return false, err
}
// AddEmitter adds a new emitter.
-func AddEmitter(e Emitter) {
- mu.Lock()
- defer mu.Unlock()
- emitters[e] = struct{}{}
+func (me *multiEmitter) AddEmitter(e Emitter) {
+ me.mu.Lock()
+ defer me.mu.Unlock()
+ if me.emitters == nil {
+ me.emitters = make(map[Emitter]struct{})
+ }
+ me.emitters[e] = struct{}{}
+}
+
+// Close closes all emitters. If any Close call errors, it returns the first
+// one encountered.
+func (me *multiEmitter) Close() error {
+ me.mu.Lock()
+ defer me.mu.Unlock()
+ var err error
+ for e := range me.emitters {
+ if eerr := e.Close(); err == nil && eerr != nil {
+ err = eerr
+ }
+ delete(me.emitters, e)
+ }
+ return err
}
func marshal(msg proto.Message) ([]byte, error) {
diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go
new file mode 100644
index 000000000..3649097d6
--- /dev/null
+++ b/pkg/eventchannel/event_test.go
@@ -0,0 +1,146 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventchannel
+
+import (
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/golang/protobuf/proto"
+)
+
+// testEmitter is an emitter that can be used in tests. It records all events
+// emitted, and whether it has been closed.
+type testEmitter struct {
+ // mu protects all fields below.
+ mu sync.Mutex
+
+ // events contains all emitted events.
+ events []proto.Message
+
+ // closed records whether Close() was called.
+ closed bool
+}
+
+// Emit implements Emitter.Emit.
+func (te *testEmitter) Emit(msg proto.Message) (bool, error) {
+ te.mu.Lock()
+ defer te.mu.Unlock()
+ te.events = append(te.events, msg)
+ return false, nil
+}
+
+// Close implements Emitter.Close.
+func (te *testEmitter) Close() error {
+ te.mu.Lock()
+ defer te.mu.Unlock()
+ if te.closed {
+ return fmt.Errorf("closed called twice")
+ }
+ te.closed = true
+ return nil
+}
+
+// testMessage implements proto.Message for testing.
+type testMessage struct {
+ proto.Message
+
+ // name is the name of the message, used by tests to compare messages.
+ name string
+}
+
+func TestMultiEmitter(t *testing.T) {
+ // Create three testEmitters, tied together in a multiEmitter.
+ me := &multiEmitter{}
+ var emitters []*testEmitter
+ for i := 0; i < 3; i++ {
+ te := &testEmitter{}
+ emitters = append(emitters, te)
+ me.AddEmitter(te)
+ }
+
+ // Emit three messages to multiEmitter.
+ names := []string{"foo", "bar", "baz"}
+ for _, name := range names {
+ m := testMessage{name: name}
+ if _, err := me.Emit(m); err != nil {
+ t.Fatal("me.Emit(%v) failed: %v", m, err)
+ }
+ }
+
+ // All three emitters should have all three events.
+ for _, te := range emitters {
+ if got, want := len(te.events), len(names); got != want {
+ t.Fatalf("emitter got %d events, want %d", got, want)
+ }
+ for i, name := range names {
+ if got := te.events[i].(testMessage).name; got != name {
+ t.Errorf("emitter got message with name %q, want %q", got, name)
+ }
+ }
+ }
+
+ // Close multiEmitter.
+ if err := me.Close(); err != nil {
+ t.Fatal("me.Close() failed: %v", err)
+ }
+
+ // All testEmitters should be closed.
+ for _, te := range emitters {
+ if !te.closed {
+ t.Errorf("te.closed got false, want true")
+ }
+ }
+}
+
+func TestRateLimitedEmitter(t *testing.T) {
+ // Create a RateLimittedEmitter that wraps a testEmitter.
+ te := &testEmitter{}
+ max := float64(5) // events per second
+ burst := 10 // events
+ rle := RateLimitedEmitterFrom(te, max, burst)
+
+ // Send 50 messages in one shot.
+ for i := 0; i < 50; i++ {
+ if _, err := rle.Emit(testMessage{}); err != nil {
+ t.Fatalf("rle.Emit failed: %v", err)
+ }
+ }
+
+ // We should have received only 10 messages.
+ if got, want := len(te.events), 10; got != want {
+ t.Errorf("got %d events, want %d", got, want)
+ }
+
+ // Sleep for a second and then send another 50.
+ time.Sleep(1 * time.Second)
+ for i := 0; i < 50; i++ {
+ if _, err := rle.Emit(testMessage{}); err != nil {
+ t.Fatalf("rle.Emit failed: %v", err)
+ }
+ }
+
+ // We should have at least 5 more message, plus maybe a few more if the
+ // test ran slowly.
+ got, wantAtLeast, wantAtMost := len(te.events), 15, 20
+ if got < wantAtLeast {
+ t.Errorf("got %d events, want at least %d", got, wantAtLeast)
+ }
+ if got > wantAtMost {
+ t.Errorf("got %d events, want at most %d", got, wantAtMost)
+ }
+}
diff --git a/pkg/eventchannel/rate.go b/pkg/eventchannel/rate.go
new file mode 100644
index 000000000..179226c92
--- /dev/null
+++ b/pkg/eventchannel/rate.go
@@ -0,0 +1,54 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventchannel
+
+import (
+ "github.com/golang/protobuf/proto"
+ "golang.org/x/time/rate"
+)
+
+// rateLimitedEmitter wraps an emitter and limits events to the given limits.
+// Events that would exceed the limit are discarded.
+type rateLimitedEmitter struct {
+ inner Emitter
+ limiter *rate.Limiter
+}
+
+// RateLimitedEmitterFrom creates a new event channel emitter that wraps the
+// existing emitter and enforces rate limits. The limits are imposed via a
+// token bucket, with `maxRate` events per second, with burst size of `burst`
+// events. See the golang.org/x/time/rate package and
+// https://en.wikipedia.org/wiki/Token_bucket for more information about token
+// buckets generally.
+func RateLimitedEmitterFrom(inner Emitter, maxRate float64, burst int) Emitter {
+ return &rateLimitedEmitter{
+ inner: inner,
+ limiter: rate.NewLimiter(rate.Limit(maxRate), burst),
+ }
+}
+
+// Emit implements EventEmitter.Emit.
+func (rle *rateLimitedEmitter) Emit(msg proto.Message) (bool, error) {
+ if !rle.limiter.Allow() {
+ // Drop event.
+ return false, nil
+ }
+ return rle.inner.Emit(msg)
+}
+
+// Close implements EventEmitter.Close.
+func (rle *rateLimitedEmitter) Close() error {
+ return rle.inner.Close()
+}
diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD
index d0552c06e..aca2d8a82 100644
--- a/pkg/fdnotifier/BUILD
+++ b/pkg/fdnotifier/BUILD
@@ -10,5 +10,8 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/fdnotifier",
visibility = ["//:sandbox"],
- deps = ["//pkg/waiter"],
+ deps = [
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
)
diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go
index 58529f99f..f4aae1953 100644
--- a/pkg/fdnotifier/fdnotifier.go
+++ b/pkg/fdnotifier/fdnotifier.go
@@ -25,6 +25,7 @@ import (
"sync"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -72,7 +73,7 @@ func (n *notifier) waitFD(fd int32, fi *fdInfo, mask waiter.EventMask) error {
}
e := syscall.EpollEvent{
- Events: mask.ToLinux() | -syscall.EPOLLET,
+ Events: mask.ToLinux() | unix.EPOLLET,
Fd: fd,
}
diff --git a/pkg/fdnotifier/poll_unsafe.go b/pkg/fdnotifier/poll_unsafe.go
index ab8857b5e..4225b04dd 100644
--- a/pkg/fdnotifier/poll_unsafe.go
+++ b/pkg/fdnotifier/poll_unsafe.go
@@ -35,8 +35,14 @@ func NonBlockingPoll(fd int32, mask waiter.EventMask) waiter.EventMask {
events: int16(mask.ToLinux()),
}
+ ts := syscall.Timespec{
+ Sec: 0,
+ Nsec: 0,
+ }
+
for {
- n, _, err := syscall.RawSyscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(&e)), 1, 0)
+ n, _, err := syscall.RawSyscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(&e)), 1,
+ uintptr(unsafe.Pointer(&ts)), 0, 0, 0)
// Interrupted by signal, try again.
if err == syscall.EINTR {
continue
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index 7126fc45f..bd1d614b6 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -5,10 +5,11 @@ package(licenses = ["notice"])
go_library(
name = "flipcall",
srcs = [
- "endpoint_futex.go",
- "endpoint_unsafe.go",
+ "ctrl_futex.go",
"flipcall.go",
+ "flipcall_unsafe.go",
"futex_linux.go",
+ "io.go",
"packet_window_allocator.go",
],
importpath = "gvisor.dev/gvisor/pkg/flipcall",
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
new file mode 100644
index 000000000..865b6f640
--- /dev/null
+++ b/pkg/flipcall/ctrl_futex.go
@@ -0,0 +1,146 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+type endpointControlImpl struct {
+ state int32
+}
+
+// Bits in endpointControlImpl.state.
+const (
+ epsBlocked = 1 << iota
+ epsShutdown
+)
+
+func (ep *Endpoint) ctrlInit(opts ...EndpointOption) error {
+ if len(opts) != 0 {
+ return fmt.Errorf("unknown EndpointOption: %T", opts[0])
+ }
+ return nil
+}
+
+type ctrlHandshakeRequest struct{}
+
+type ctrlHandshakeResponse struct{}
+
+func (ep *Endpoint) ctrlConnect() error {
+ if err := ep.enterFutexWait(); err != nil {
+ return err
+ }
+ _, err := ep.futexConnect(&ctrlHandshakeRequest{})
+ ep.exitFutexWait()
+ return err
+}
+
+func (ep *Endpoint) ctrlWaitFirst() error {
+ if err := ep.enterFutexWait(); err != nil {
+ return err
+ }
+ defer ep.exitFutexWait()
+
+ // Wait for the handshake request.
+ if err := ep.futexSwitchFromPeer(); err != nil {
+ return err
+ }
+
+ // Read the handshake request.
+ reqLen := atomic.LoadUint32(ep.dataLen())
+ if reqLen > ep.dataCap {
+ return fmt.Errorf("invalid handshake request length %d (maximum %d)", reqLen, ep.dataCap)
+ }
+ var req ctrlHandshakeRequest
+ if err := json.NewDecoder(ep.NewReader(reqLen)).Decode(&req); err != nil {
+ return fmt.Errorf("error reading handshake request: %v", err)
+ }
+
+ // Write the handshake response.
+ w := ep.NewWriter()
+ if err := json.NewEncoder(w).Encode(ctrlHandshakeResponse{}); err != nil {
+ return fmt.Errorf("error writing handshake response: %v", err)
+ }
+ *ep.dataLen() = w.Len()
+
+ // Return control to the client.
+ if err := ep.futexSwitchToPeer(); err != nil {
+ return err
+ }
+
+ // Wait for the first non-handshake message.
+ return ep.futexSwitchFromPeer()
+}
+
+func (ep *Endpoint) ctrlRoundTrip() error {
+ if err := ep.futexSwitchToPeer(); err != nil {
+ return err
+ }
+ if err := ep.enterFutexWait(); err != nil {
+ return err
+ }
+ err := ep.futexSwitchFromPeer()
+ ep.exitFutexWait()
+ return err
+}
+
+func (ep *Endpoint) ctrlWakeLast() error {
+ return ep.futexSwitchToPeer()
+}
+
+func (ep *Endpoint) enterFutexWait() error {
+ switch eps := atomic.AddInt32(&ep.ctrl.state, epsBlocked); eps {
+ case epsBlocked:
+ return nil
+ case epsBlocked | epsShutdown:
+ atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
+ return shutdownError{}
+ default:
+ // Most likely due to ep.enterFutexWait() being called concurrently
+ // from multiple goroutines.
+ panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state before flipcall.Endpoint.enterFutexWait(): %v", eps-epsBlocked))
+ }
+}
+
+func (ep *Endpoint) exitFutexWait() {
+ atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
+}
+
+func (ep *Endpoint) ctrlShutdown() {
+ // Set epsShutdown to ensure that future calls to ep.enterFutexWait() fail.
+ if atomic.AddInt32(&ep.ctrl.state, epsShutdown)&epsBlocked != 0 {
+ // Wake the blocked thread. This must loop because it's possible that
+ // FUTEX_WAKE occurs after the waiter sets epsBlocked, but before it
+ // blocks in FUTEX_WAIT.
+ for {
+ // Wake MaxInt32 threads to prevent a broken or malicious peer from
+ // swallowing our wakeup by FUTEX_WAITing from multiple threads.
+ if err := ep.futexWakeConnState(math.MaxInt32); err != nil {
+ log.Warningf("failed to FUTEX_WAKE Endpoints: %v", err)
+ break
+ }
+ yieldThread()
+ if atomic.LoadInt32(&ep.ctrl.state)&epsBlocked == 0 {
+ break
+ }
+ }
+ }
+}
diff --git a/pkg/flipcall/endpoint_unsafe.go b/pkg/flipcall/endpoint_unsafe.go
deleted file mode 100644
index 8319955e0..000000000
--- a/pkg/flipcall/endpoint_unsafe.go
+++ /dev/null
@@ -1,238 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package flipcall
-
-import (
- "fmt"
- "math"
- "reflect"
- "sync/atomic"
- "syscall"
- "unsafe"
-)
-
-// An Endpoint provides the ability to synchronously transfer data and control
-// to a connected peer Endpoint, which may be in another process.
-//
-// Since the Endpoint control transfer model is synchronous, at any given time
-// one Endpoint "has control" (designated the *active* Endpoint), and the other
-// is "waiting for control" (designated the *inactive* Endpoint). Users of the
-// flipcall package arbitrarily designate one Endpoint as initially-active, and
-// the other as initially-inactive; in a client/server protocol, the client
-// Endpoint is usually initially-active (able to send a request) and the server
-// Endpoint is usually initially-inactive (waiting for a request). The
-// initially-active Endpoint writes data to be sent to Endpoint.Data(), and
-// then synchronously transfers control to the inactive Endpoint by calling
-// Endpoint.SendRecv(), becoming the inactive Endpoint in the process. The
-// initially-inactive Endpoint waits for control by calling
-// Endpoint.RecvFirst(); receiving control causes it to become the active
-// Endpoint. After this, the protocol is symmetric: the active Endpoint reads
-// data sent by the peer by reading from Endpoint.Data(), writes data to be
-// sent to the peer into Endpoint.Data(), and then calls Endpoint.SendRecv() to
-// exchange roles with the peer, which blocks until the peer has done the same.
-type Endpoint struct {
- // shutdown is non-zero if Endpoint.Shutdown() has been called. shutdown is
- // accessed using atomic memory operations.
- shutdown uint32
-
- // dataCap is the size of the datagram part of the packet window in bytes.
- // dataCap is immutable.
- dataCap uint32
-
- // packet is the beginning of the packet window. packet is immutable.
- packet unsafe.Pointer
-
- ctrl endpointControlState
-}
-
-// Init must be called on zero-value Endpoints before first use. If it
-// succeeds, Destroy() must be called once the Endpoint is no longer in use.
-//
-// ctrlMode specifies how connected Endpoints will exchange control. Both
-// connected Endpoints must specify the same value for ctrlMode.
-//
-// pwd represents the packet window used to exchange data with the peer
-// Endpoint. FD may differ between Endpoints if they are in different
-// processes, but must represent the same file. The packet window must
-// initially be filled with zero bytes.
-func (ep *Endpoint) Init(ctrlMode ControlMode, pwd PacketWindowDescriptor) error {
- if pwd.Length < pageSize {
- return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize)
- }
- if pwd.Length > math.MaxUint32 {
- return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32)
- }
- m, _, e := syscall.Syscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
- if e != 0 {
- return fmt.Errorf("failed to mmap packet window: %v", e)
- }
- ep.dataCap = uint32(pwd.Length) - uint32(packetHeaderBytes)
- ep.packet = (unsafe.Pointer)(m)
- if err := ep.initControlState(ctrlMode); err != nil {
- ep.unmapPacket()
- return err
- }
- return nil
-}
-
-// NewEndpoint is a convenience function that returns an initialized Endpoint
-// allocated on the heap.
-func NewEndpoint(ctrlMode ControlMode, pwd PacketWindowDescriptor) (*Endpoint, error) {
- var ep Endpoint
- if err := ep.Init(ctrlMode, pwd); err != nil {
- return nil, err
- }
- return &ep, nil
-}
-
-func (ep *Endpoint) unmapPacket() {
- syscall.Syscall(syscall.SYS_MUNMAP, uintptr(ep.packet), uintptr(ep.dataCap)+packetHeaderBytes, 0)
- ep.dataCap = 0
- ep.packet = nil
-}
-
-// Destroy releases resources owned by ep. No other Endpoint methods may be
-// called after Destroy.
-func (ep *Endpoint) Destroy() {
- ep.unmapPacket()
-}
-
-// Packets consist of an 8-byte header followed by an arbitrarily-sized
-// datagram. The header consists of:
-//
-// - A 4-byte native-endian sequence number, which is incremented by the active
-// Endpoint after it finishes writing to the packet window. The sequence number
-// is needed to handle spurious wakeups.
-//
-// - A 4-byte native-endian datagram length in bytes.
-const (
- sizeofUint32 = unsafe.Sizeof(uint32(0))
- packetHeaderBytes = 2 * sizeofUint32
-)
-
-func (ep *Endpoint) seq() *uint32 {
- return (*uint32)(ep.packet)
-}
-
-func (ep *Endpoint) dataLen() *uint32 {
- return (*uint32)((unsafe.Pointer)(uintptr(ep.packet) + sizeofUint32))
-}
-
-// DataCap returns the maximum datagram size supported by ep in bytes.
-func (ep *Endpoint) DataCap() uint32 {
- return ep.dataCap
-}
-
-func (ep *Endpoint) data() unsafe.Pointer {
- return unsafe.Pointer(uintptr(ep.packet) + packetHeaderBytes)
-}
-
-// Data returns the datagram part of ep's packet window as a byte slice.
-//
-// Note that the packet window is shared with the potentially-untrusted peer
-// Endpoint, which may concurrently mutate the contents of the packet window.
-// Thus:
-//
-// - Readers must not assume that two reads of the same byte in Data() will
-// return the same result. In other words, readers should read any given byte
-// in Data() at most once.
-//
-// - Writers must not assume that they will read back the same data that they
-// have written. In other words, writers should avoid reading from Data() at
-// all.
-func (ep *Endpoint) Data() []byte {
- var bs []byte
- bsReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&bs))
- bsReflect.Data = uintptr(ep.data())
- bsReflect.Len = int(ep.DataCap())
- bsReflect.Cap = bsReflect.Len
- return bs
-}
-
-// SendRecv transfers control to the peer Endpoint, causing its call to
-// Endpoint.SendRecv() or Endpoint.RecvFirst() to return with the given
-// datagram length, then blocks until the peer Endpoint calls
-// Endpoint.SendRecv() or Endpoint.SendLast().
-//
-// Preconditions: No previous call to ep.SendRecv() or ep.RecvFirst() has
-// returned an error. ep.SendLast() has never been called.
-func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) {
- dataCap := ep.DataCap()
- if dataLen > dataCap {
- return 0, fmt.Errorf("can't send packet with datagram length %d (maximum %d)", dataLen, dataCap)
- }
- atomic.StoreUint32(ep.dataLen(), dataLen)
- if err := ep.doRoundTrip(); err != nil {
- return 0, err
- }
- recvDataLen := atomic.LoadUint32(ep.dataLen())
- if recvDataLen > dataCap {
- return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, dataCap)
- }
- return recvDataLen, nil
-}
-
-// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then
-// returns the datagram length specified by that call.
-//
-// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never
-// been called.
-func (ep *Endpoint) RecvFirst() (uint32, error) {
- if err := ep.doWaitFirst(); err != nil {
- return 0, err
- }
- recvDataLen := atomic.LoadUint32(ep.dataLen())
- if dataCap := ep.DataCap(); recvDataLen > dataCap {
- return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, dataCap)
- }
- return recvDataLen, nil
-}
-
-// SendLast causes the peer Endpoint's call to Endpoint.SendRecv() or
-// Endpoint.RecvFirst() to return with the given datagram length.
-//
-// Preconditions: No previous call to ep.SendRecv() or ep.RecvFirst() has
-// returned an error. ep.SendLast() has never been called.
-func (ep *Endpoint) SendLast(dataLen uint32) error {
- dataCap := ep.DataCap()
- if dataLen > dataCap {
- return fmt.Errorf("can't send packet with datagram length %d (maximum %d)", dataLen, dataCap)
- }
- atomic.StoreUint32(ep.dataLen(), dataLen)
- if err := ep.doNotifyLast(); err != nil {
- return err
- }
- return nil
-}
-
-// Shutdown causes concurrent and future calls to ep.SendRecv(),
-// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not
-// wait for concurrent calls to return.
-func (ep *Endpoint) Shutdown() {
- if atomic.SwapUint32(&ep.shutdown, 1) == 0 {
- ep.interruptForShutdown()
- }
-}
-
-func (ep *Endpoint) isShutdown() bool {
- return atomic.LoadUint32(&ep.shutdown) != 0
-}
-
-type endpointShutdownError struct{}
-
-// Error implements error.Error.
-func (endpointShutdownError) Error() string {
- return "Endpoint.Shutdown() has been called"
-}
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 79a1e418a..5c9212c33 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -13,20 +13,217 @@
// limitations under the License.
// Package flipcall implements a protocol providing Fast Local Interprocess
-// Procedure Calls.
+// Procedure Calls between mutually-distrusting processes.
package flipcall
-// ControlMode defines how control is exchanged across a connection.
-type ControlMode uint8
+import (
+ "fmt"
+ "math"
+ "sync/atomic"
+ "syscall"
+)
-const (
- // ControlModeInvalid is invalid, and exists so that ControlMode fields in
- // structs must be explicitly initialized.
- ControlModeInvalid ControlMode = iota
+// An Endpoint provides the ability to synchronously transfer data and control
+// to a connected peer Endpoint, which may be in another process.
+//
+// Since the Endpoint control transfer model is synchronous, at any given time
+// one Endpoint "has control" (designated the active Endpoint), and the other
+// is "waiting for control" (designated the inactive Endpoint). Users of the
+// flipcall package designate one Endpoint as the client, which is initially
+// active, and the other as the server, which is initially inactive. See
+// flipcall_example_test.go for usage.
+type Endpoint struct {
+ // packet is a pointer to the beginning of the packet window. (Since this
+ // is a raw OS memory mapping and not a Go object, it does not need to be
+ // represented as an unsafe.Pointer.) packet is immutable.
+ packet uintptr
+
+ // dataCap is the size of the datagram part of the packet window in bytes.
+ // dataCap is immutable.
+ dataCap uint32
+
+ // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
+ // Endpoint has acknowledged shutdown initiated by the peer. shutdown is
+ // accessed using atomic memory operations.
+ shutdown uint32
+
+ // activeState is csClientActive if this is a client Endpoint and
+ // csServerActive if this is a server Endpoint.
+ activeState uint32
+
+ // inactiveState is csServerActive if this is a client Endpoint and
+ // csClientActive if this is a server Endpoint.
+ inactiveState uint32
+
+ ctrl endpointControlImpl
+}
+
+// Init must be called on zero-value Endpoints before first use. If it
+// succeeds, ep.Destroy() must be called once the Endpoint is no longer in use.
+//
+// pwd represents the packet window used to exchange data with the peer
+// Endpoint. FD may differ between Endpoints if they are in different
+// processes, but must represent the same file. The packet window must
+// initially be filled with zero bytes.
+func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) error {
+ if pwd.Length < pageSize {
+ return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize)
+ }
+ if pwd.Length > math.MaxUint32 {
+ return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32)
+ }
+ m, _, e := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ if e != 0 {
+ return fmt.Errorf("failed to mmap packet window: %v", e)
+ }
+ ep.packet = m
+ ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes)
+ // These will be overwritten by ep.Connect() for client Endpoints.
+ ep.activeState = csServerActive
+ ep.inactiveState = csClientActive
+ if err := ep.ctrlInit(opts...); err != nil {
+ ep.unmapPacket()
+ return err
+ }
+ return nil
+}
+
+// NewEndpoint is a convenience function that returns an initialized Endpoint
+// allocated on the heap.
+func NewEndpoint(pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
+ var ep Endpoint
+ if err := ep.Init(pwd, opts...); err != nil {
+ return nil, err
+ }
+ return &ep, nil
+}
+
+// An EndpointOption configures an Endpoint.
+type EndpointOption interface {
+ isEndpointOption()
+}
+
+// Destroy releases resources owned by ep. No other Endpoint methods may be
+// called after Destroy.
+func (ep *Endpoint) Destroy() {
+ ep.unmapPacket()
+}
+
+func (ep *Endpoint) unmapPacket() {
+ syscall.RawSyscall(syscall.SYS_MUNMAP, ep.packet, uintptr(ep.dataCap)+PacketHeaderBytes, 0)
+ ep.packet = 0
+}
- // ControlModeFutex uses shared futex operations on packet control words.
- ControlModeFutex
+// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(),
+// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not
+// wait for concurrent calls to return. The effect of Shutdown on the peer
+// Endpoint is unspecified. Successive calls to Shutdown have no effect.
+//
+// Shutdown is the only Endpoint method that may be called concurrently with
+// other methods on the same Endpoint.
+func (ep *Endpoint) Shutdown() {
+ if atomic.SwapUint32(&ep.shutdown, 1) != 0 {
+ // ep.Shutdown() has previously been called.
+ return
+ }
+ ep.ctrlShutdown()
+}
+
+// isShutdownLocally returns true if ep.Shutdown() has been called.
+func (ep *Endpoint) isShutdownLocally() bool {
+ return atomic.LoadUint32(&ep.shutdown) != 0
+}
+
+type shutdownError struct{}
+
+// Error implements error.Error.
+func (shutdownError) Error() string {
+ return "flipcall connection shutdown"
+}
- // controlModeCount is the number of ControlModes in this list.
- controlModeCount
+// DataCap returns the maximum datagram size supported by ep. Equivalently,
+// DataCap returns len(ep.Data()).
+func (ep *Endpoint) DataCap() uint32 {
+ return ep.dataCap
+}
+
+// Connection state.
+const (
+ // The client is, by definition, initially active, so this must be 0.
+ csClientActive = 0
+ csServerActive = 1
)
+
+// Connect designates ep as a client Endpoint and blocks until the peer
+// Endpoint has called Endpoint.RecvFirst().
+//
+// Preconditions: ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and
+// ep.SendLast() have never been called.
+func (ep *Endpoint) Connect() error {
+ ep.activeState = csClientActive
+ ep.inactiveState = csServerActive
+ return ep.ctrlConnect()
+}
+
+// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then
+// returns the datagram length specified by that call.
+//
+// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never
+// been called.
+func (ep *Endpoint) RecvFirst() (uint32, error) {
+ if err := ep.ctrlWaitFirst(); err != nil {
+ return 0, err
+ }
+ recvDataLen := atomic.LoadUint32(ep.dataLen())
+ if recvDataLen > ep.dataCap {
+ return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
+ }
+ return recvDataLen, nil
+}
+
+// SendRecv transfers control to the peer Endpoint, causing its call to
+// Endpoint.SendRecv() or Endpoint.RecvFirst() to return with the given
+// datagram length, then blocks until the peer Endpoint calls
+// Endpoint.SendRecv() or Endpoint.SendLast().
+//
+// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or
+// ep.RecvFirst() has returned an error. ep.SendLast() has never been called.
+// If ep is a client Endpoint, ep.Connect() has previously been called and
+// returned nil.
+func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) {
+ if dataLen > ep.dataCap {
+ panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap))
+ }
+ // This store can safely be non-atomic: Under correct operation we should
+ // be the only thread writing ep.dataLen(), and ep.ctrlRoundTrip() will
+ // synchronize with the receiver. We will not read from ep.dataLen() until
+ // after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then
+ // they can only shoot themselves in the foot.
+ *ep.dataLen() = dataLen
+ if err := ep.ctrlRoundTrip(); err != nil {
+ return 0, err
+ }
+ recvDataLen := atomic.LoadUint32(ep.dataLen())
+ if recvDataLen > ep.dataCap {
+ return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
+ }
+ return recvDataLen, nil
+}
+
+// SendLast causes the peer Endpoint's call to Endpoint.SendRecv() or
+// Endpoint.RecvFirst() to return with the given datagram length.
+//
+// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or
+// ep.RecvFirst() has returned an error. ep.SendLast() has never been called.
+// If ep is a client Endpoint, ep.Connect() has previously been called and
+// returned nil.
+func (ep *Endpoint) SendLast(dataLen uint32) error {
+ if dataLen > ep.dataCap {
+ panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap))
+ }
+ *ep.dataLen() = dataLen
+ if err := ep.ctrlWakeLast(); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go
index 572a1f119..edb6a8bef 100644
--- a/pkg/flipcall/flipcall_example_test.go
+++ b/pkg/flipcall/flipcall_example_test.go
@@ -17,6 +17,7 @@ package flipcall
import (
"bytes"
"fmt"
+ "sync"
)
func Example() {
@@ -36,20 +37,21 @@ func Example() {
if err != nil {
panic(err)
}
- clientEP, err := NewEndpoint(ControlModeFutex, pwd)
- if err != nil {
+ var clientEP Endpoint
+ if err := clientEP.Init(pwd); err != nil {
panic(err)
}
defer clientEP.Destroy()
- serverEP, err := NewEndpoint(ControlModeFutex, pwd)
- if err != nil {
+ var serverEP Endpoint
+ if err := serverEP.Init(pwd); err != nil {
panic(err)
}
defer serverEP.Destroy()
- serverDone := make(chan struct{})
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
go func() {
- defer func() { serverDone <- struct{}{} }()
+ defer serverRun.Done()
i := 0
var buf bytes.Buffer
// wait for first request
@@ -76,9 +78,13 @@ func Example() {
}()
defer func() {
serverEP.Shutdown()
- <-serverDone
+ serverRun.Wait()
}()
+ // establish connection as client
+ if err := clientEP.Connect(); err != nil {
+ panic(err)
+ }
var buf bytes.Buffer
for i := 0; i < count; i++ {
// write request
diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go
index 20d3002f0..da9d736ab 100644
--- a/pkg/flipcall/flipcall_test.go
+++ b/pkg/flipcall/flipcall_test.go
@@ -15,197 +15,240 @@
package flipcall
import (
+ "runtime"
+ "sync"
"testing"
"time"
)
var testPacketWindowSize = pageSize
-func testSendRecv(t *testing.T, ctrlMode ControlMode) {
- pwa, err := NewPacketWindowAllocator()
- if err != nil {
- t.Fatalf("failed to create PacketWindowAllocator: %v", err)
+type testConnection struct {
+ pwa PacketWindowAllocator
+ clientEP Endpoint
+ serverEP Endpoint
+}
+
+func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []EndpointOption) *testConnection {
+ c := &testConnection{}
+ if err := c.pwa.Init(); err != nil {
+ tb.Fatalf("failed to create PacketWindowAllocator: %v", err)
}
- defer pwa.Destroy()
- pwd, err := pwa.Allocate(testPacketWindowSize)
+ pwd, err := c.pwa.Allocate(testPacketWindowSize)
if err != nil {
- t.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
+ c.pwa.Destroy()
+ tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
}
-
- sendEP, err := NewEndpoint(ctrlMode, pwd)
- if err != nil {
- t.Fatalf("failed to create Endpoint: %v", err)
+ if err := c.clientEP.Init(pwd, clientOpts...); err != nil {
+ c.pwa.Destroy()
+ tb.Fatalf("failed to create client Endpoint: %v", err)
}
- defer sendEP.Destroy()
- recvEP, err := NewEndpoint(ctrlMode, pwd)
- if err != nil {
- t.Fatalf("failed to create Endpoint: %v", err)
+ if err := c.serverEP.Init(pwd, serverOpts...); err != nil {
+ c.pwa.Destroy()
+ c.clientEP.Destroy()
+ tb.Fatalf("failed to create server Endpoint: %v", err)
}
- defer recvEP.Destroy()
+ return c
+}
- otherThreadDone := make(chan struct{})
+func newTestConnection(tb testing.TB) *testConnection {
+ return newTestConnectionWithOptions(tb, nil, nil)
+}
+
+func (c *testConnection) destroy() {
+ c.pwa.Destroy()
+ c.clientEP.Destroy()
+ c.serverEP.Destroy()
+}
+
+func testSendRecv(t *testing.T, c *testConnection) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
go func() {
- defer func() { otherThreadDone <- struct{}{} }()
- t.Logf("initially-inactive Endpoint waiting for packet 1")
- if _, err := recvEP.RecvFirst(); err != nil {
- t.Fatalf("initially-inactive Endpoint.RecvFirst() failed: %v", err)
+ defer serverRun.Done()
+ t.Logf("server Endpoint waiting for packet 1")
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
}
- t.Logf("initially-inactive Endpoint got packet 1, sending packet 2 and waiting for packet 3")
- if _, err := recvEP.SendRecv(0); err != nil {
- t.Fatalf("initially-inactive Endpoint.SendRecv() failed: %v", err)
+ t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3")
+ if _, err := c.serverEP.SendRecv(0); err != nil {
+ t.Fatalf("server Endpoint.SendRecv() failed: %v", err)
}
- t.Logf("initially-inactive Endpoint got packet 3")
+ t.Logf("server Endpoint got packet 3")
}()
defer func() {
- t.Logf("waiting for initially-inactive Endpoint goroutine to complete")
- <-otherThreadDone
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
}()
- t.Logf("initially-active Endpoint sending packet 1 and waiting for packet 2")
- if _, err := sendEP.SendRecv(0); err != nil {
- t.Fatalf("initially-active Endpoint.SendRecv() failed: %v", err)
+ t.Logf("client Endpoint establishing connection")
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
- t.Logf("initially-active Endpoint got packet 2, sending packet 3")
- if err := sendEP.SendLast(0); err != nil {
- t.Fatalf("initially-active Endpoint.SendLast() failed: %v", err)
+ t.Logf("client Endpoint sending packet 1 and waiting for packet 2")
+ if _, err := c.clientEP.SendRecv(0); err != nil {
+ t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
+ t.Logf("client Endpoint got packet 2, sending packet 3")
+ if err := c.clientEP.SendLast(0); err != nil {
+ t.Fatalf("client Endpoint.SendLast() failed: %v", err)
+ }
+ t.Logf("waiting for server goroutine to complete")
+ serverRun.Wait()
}
-func TestFutexSendRecv(t *testing.T) {
- testSendRecv(t, ControlModeFutex)
+func TestSendRecv(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testSendRecv(t, c)
}
-func testRecvFirstShutdown(t *testing.T, ctrlMode ControlMode) {
- pwa, err := NewPacketWindowAllocator()
- if err != nil {
- t.Fatalf("failed to create PacketWindowAllocator: %v", err)
- }
- defer pwa.Destroy()
- pwd, err := pwa.Allocate(testPacketWindowSize)
- if err != nil {
- t.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
- }
+func testShutdownConnect(t *testing.T, c *testConnection) {
+ var clientRun sync.WaitGroup
+ clientRun.Add(1)
+ go func() {
+ defer clientRun.Done()
+ if err := c.clientEP.Connect(); err == nil {
+ t.Errorf("client Endpoint.Connect() succeeded unexpectedly")
+ }
+ }()
+ time.Sleep(time.Second) // to allow c.clientEP.Connect() to block
+ c.clientEP.Shutdown()
+ clientRun.Wait()
+}
- ep, err := NewEndpoint(ctrlMode, pwd)
- if err != nil {
- t.Fatalf("failed to create Endpoint: %v", err)
- }
- defer ep.Destroy()
+func TestShutdownConnect(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownConnect(t, c)
+}
- otherThreadDone := make(chan struct{})
+func testShutdownRecvFirstBeforeConnect(t *testing.T, c *testConnection) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
go func() {
- defer func() { otherThreadDone <- struct{}{} }()
- _, err := ep.RecvFirst()
+ defer serverRun.Done()
+ _, err := c.serverEP.RecvFirst()
if err == nil {
- t.Errorf("Endpoint.RecvFirst() succeeded unexpectedly")
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
-
- time.Sleep(time.Second) // to ensure ep.RecvFirst() has blocked
- ep.Shutdown()
- <-otherThreadDone
+ time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block
+ c.serverEP.Shutdown()
+ serverRun.Wait()
}
-func TestFutexRecvFirstShutdown(t *testing.T) {
- testRecvFirstShutdown(t, ControlModeFutex)
+func TestShutdownRecvFirstBeforeConnect(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownRecvFirstBeforeConnect(t, c)
}
-func testSendRecvShutdown(t *testing.T, ctrlMode ControlMode) {
- pwa, err := NewPacketWindowAllocator()
- if err != nil {
- t.Fatalf("failed to create PacketWindowAllocator: %v", err)
- }
- defer pwa.Destroy()
- pwd, err := pwa.Allocate(testPacketWindowSize)
- if err != nil {
- t.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
- }
-
- sendEP, err := NewEndpoint(ctrlMode, pwd)
- if err != nil {
- t.Fatalf("failed to create Endpoint: %v", err)
- }
- defer sendEP.Destroy()
- recvEP, err := NewEndpoint(ctrlMode, pwd)
- if err != nil {
- t.Fatalf("failed to create Endpoint: %v", err)
- }
- defer recvEP.Destroy()
-
- otherThreadDone := make(chan struct{})
+func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
go func() {
- defer func() { otherThreadDone <- struct{}{} }()
- if _, err := recvEP.RecvFirst(); err != nil {
- t.Fatalf("initially-inactive Endpoint.RecvFirst() failed: %v", err)
- }
- if _, err := recvEP.SendRecv(0); err == nil {
- t.Errorf("initially-inactive Endpoint.SendRecv() succeeded unexpectedly")
+ defer serverRun.Done()
+ if _, err := c.serverEP.RecvFirst(); err == nil {
+ t.Fatalf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
-
- if _, err := sendEP.SendRecv(0); err != nil {
- t.Fatalf("initially-active Endpoint.SendRecv() failed: %v", err)
+ defer func() {
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
- time.Sleep(time.Second) // to ensure recvEP.SendRecv() has blocked
- recvEP.Shutdown()
- <-otherThreadDone
+ c.serverEP.Shutdown()
+ serverRun.Wait()
}
-func TestFutexSendRecvShutdown(t *testing.T) {
- testSendRecvShutdown(t, ControlModeFutex)
+func TestShutdownRecvFirstAfterConnect(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownRecvFirstAfterConnect(t, c)
}
-func benchmarkSendRecv(b *testing.B, ctrlMode ControlMode) {
- pwa, err := NewPacketWindowAllocator()
- if err != nil {
- b.Fatalf("failed to create PacketWindowAllocator: %v", err)
+func testShutdownSendRecv(t *testing.T, c *testConnection) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ }
+ if _, err := c.serverEP.SendRecv(0); err == nil {
+ t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly")
+ }
+ }()
+ defer func() {
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
- defer pwa.Destroy()
- pwd, err := pwa.Allocate(testPacketWindowSize)
- if err != nil {
- b.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
+ if _, err := c.clientEP.SendRecv(0); err != nil {
+ t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
+ time.Sleep(time.Second) // to allow serverEP.SendRecv() to block
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+}
- sendEP, err := NewEndpoint(ctrlMode, pwd)
- if err != nil {
- b.Fatalf("failed to create Endpoint: %v", err)
- }
- defer sendEP.Destroy()
- recvEP, err := NewEndpoint(ctrlMode, pwd)
- if err != nil {
- b.Fatalf("failed to create Endpoint: %v", err)
- }
- defer recvEP.Destroy()
+func TestShutdownSendRecv(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownSendRecv(t, c)
+}
- otherThreadDone := make(chan struct{})
+func benchmarkSendRecv(b *testing.B, c *testConnection) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
go func() {
- defer func() { otherThreadDone <- struct{}{} }()
+ defer serverRun.Done()
if b.N == 0 {
return
}
- if _, err := recvEP.RecvFirst(); err != nil {
- b.Fatalf("initially-inactive Endpoint.RecvFirst() failed: %v", err)
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ b.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
}
for i := 1; i < b.N; i++ {
- if _, err := recvEP.SendRecv(0); err != nil {
- b.Fatalf("initially-inactive Endpoint.SendRecv() failed: %v", err)
+ if _, err := c.serverEP.SendRecv(0); err != nil {
+ b.Fatalf("server Endpoint.SendRecv() failed: %v", err)
}
}
- if err := recvEP.SendLast(0); err != nil {
- b.Fatalf("initially-inactive Endpoint.SendLast() failed: %v", err)
+ if err := c.serverEP.SendLast(0); err != nil {
+ b.Fatalf("server Endpoint.SendLast() failed: %v", err)
}
}()
- defer func() { <-otherThreadDone }()
+ defer func() {
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+ if err := c.clientEP.Connect(); err != nil {
+ b.Fatalf("client Endpoint.Connect() failed: %v", err)
+ }
+ runtime.GC()
b.ResetTimer()
for i := 0; i < b.N; i++ {
- if _, err := sendEP.SendRecv(0); err != nil {
- b.Fatalf("initially-active Endpoint.SendRecv() failed: %v", err)
+ if _, err := c.clientEP.SendRecv(0); err != nil {
+ b.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
}
b.StopTimer()
}
-func BenchmarkFutexSendRecv(b *testing.B) {
- benchmarkSendRecv(b, ControlModeFutex)
+func BenchmarkSendRecv(b *testing.B) {
+ c := newTestConnection(b)
+ defer c.destroy()
+ benchmarkSendRecv(b, c)
}
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
new file mode 100644
index 000000000..7c8977893
--- /dev/null
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -0,0 +1,69 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Packets consist of an 8-byte header followed by an arbitrarily-sized
+// datagram. The header consists of:
+//
+// - A 4-byte native-endian connection state.
+//
+// - A 4-byte native-endian datagram length in bytes.
+const (
+ sizeofUint32 = unsafe.Sizeof(uint32(0))
+
+ // PacketHeaderBytes is the size of a flipcall packet header in bytes. The
+ // maximum datagram size supported by a flipcall connection is equal to the
+ // length of the packet window minus PacketHeaderBytes.
+ //
+ // PacketHeaderBytes is exported to support its use in constant
+ // expressions. Non-constant expressions may prefer to use
+ // PacketWindowLengthForDataCap().
+ PacketHeaderBytes = 2 * sizeofUint32
+)
+
+func (ep *Endpoint) connState() *uint32 {
+ return (*uint32)((unsafe.Pointer)(ep.packet))
+}
+
+func (ep *Endpoint) dataLen() *uint32 {
+ return (*uint32)((unsafe.Pointer)(ep.packet + sizeofUint32))
+}
+
+// Data returns the datagram part of ep's packet window as a byte slice.
+//
+// Note that the packet window is shared with the potentially-untrusted peer
+// Endpoint, which may concurrently mutate the contents of the packet window.
+// Thus:
+//
+// - Readers must not assume that two reads of the same byte in Data() will
+// return the same result. In other words, readers should read any given byte
+// in Data() at most once.
+//
+// - Writers must not assume that they will read back the same data that they
+// have written. In other words, writers should avoid reading from Data() at
+// all.
+func (ep *Endpoint) Data() []byte {
+ var bs []byte
+ bsReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&bs))
+ bsReflect.Data = ep.packet + PacketHeaderBytes
+ bsReflect.Len = int(ep.dataCap)
+ bsReflect.Cap = int(ep.dataCap)
+ return bs
+}
diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go
index 3f592ad16..e7dd812b3 100644
--- a/pkg/flipcall/futex_linux.go
+++ b/pkg/flipcall/futex_linux.go
@@ -17,78 +17,95 @@
package flipcall
import (
+ "encoding/json"
"fmt"
- "math"
+ "runtime"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/log"
)
-func (ep *Endpoint) doFutexRoundTrip() error {
- ourSeq, err := ep.doFutexNotifySeq()
- if err != nil {
- return err
+func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeResponse, error) {
+ var resp ctrlHandshakeResponse
+
+ // Write the handshake request.
+ w := ep.NewWriter()
+ if err := json.NewEncoder(w).Encode(req); err != nil {
+ return resp, fmt.Errorf("error writing handshake request: %v", err)
}
- return ep.doFutexWaitSeq(ourSeq)
-}
+ *ep.dataLen() = w.Len()
-func (ep *Endpoint) doFutexWaitFirst() error {
- return ep.doFutexWaitSeq(0)
-}
+ // Exchange control with the server.
+ if err := ep.futexSwitchToPeer(); err != nil {
+ return resp, err
+ }
+ if err := ep.futexSwitchFromPeer(); err != nil {
+ return resp, err
+ }
-func (ep *Endpoint) doFutexNotifyLast() error {
- _, err := ep.doFutexNotifySeq()
- return err
+ // Read the handshake response.
+ respLen := atomic.LoadUint32(ep.dataLen())
+ if respLen > ep.dataCap {
+ return resp, fmt.Errorf("invalid handshake response length %d (maximum %d)", respLen, ep.dataCap)
+ }
+ if err := json.NewDecoder(ep.NewReader(respLen)).Decode(&resp); err != nil {
+ return resp, fmt.Errorf("error reading handshake response: %v", err)
+ }
+
+ return resp, nil
}
-func (ep *Endpoint) doFutexNotifySeq() (uint32, error) {
- ourSeq := atomic.AddUint32(ep.seq(), 1)
- if err := ep.futexWake(1); err != nil {
- return ourSeq, fmt.Errorf("failed to FUTEX_WAKE peer Endpoint: %v", err)
+func (ep *Endpoint) futexSwitchToPeer() error {
+ // Update connection state to indicate that the peer should be active.
+ if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) {
+ return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", atomic.LoadUint32(ep.connState()))
}
- return ourSeq, nil
+
+ // Wake the peer's Endpoint.futexSwitchFromPeer().
+ if err := ep.futexWakeConnState(1); err != nil {
+ return fmt.Errorf("failed to FUTEX_WAKE peer Endpoint: %v", err)
+ }
+ return nil
}
-func (ep *Endpoint) doFutexWaitSeq(prevSeq uint32) error {
- nextSeq := prevSeq + 1
+func (ep *Endpoint) futexSwitchFromPeer() error {
for {
- if ep.isShutdown() {
- return endpointShutdownError{}
- }
- if err := ep.futexWait(prevSeq); err != nil {
- return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
- }
- seq := atomic.LoadUint32(ep.seq())
- if seq == nextSeq {
+ switch cs := atomic.LoadUint32(ep.connState()); cs {
+ case ep.activeState:
return nil
+ case ep.inactiveState:
+ // Continue to FUTEX_WAIT.
+ default:
+ return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs)
}
- if seq != prevSeq {
- return fmt.Errorf("invalid packet sequence number %d (expected %d or %d)", seq, prevSeq, nextSeq)
+ if ep.isShutdownLocally() {
+ return shutdownError{}
+ }
+ if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
+ return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
}
}
}
-func (ep *Endpoint) doFutexInterruptForShutdown() {
- // Wake MaxInt32 threads to prevent a malicious or broken peer from
- // swallowing our wakeup by FUTEX_WAITing from multiple threads.
- if err := ep.futexWake(math.MaxInt32); err != nil {
- log.Warningf("failed to FUTEX_WAKE Endpoint: %v", err)
- }
-}
-
-func (ep *Endpoint) futexWake(numThreads int32) error {
- if _, _, e := syscall.RawSyscall(syscall.SYS_FUTEX, uintptr(ep.packet), linux.FUTEX_WAKE, uintptr(numThreads)); e != 0 {
+func (ep *Endpoint) futexWakeConnState(numThreads int32) error {
+ if _, _, e := syscall.RawSyscall(syscall.SYS_FUTEX, ep.packet, linux.FUTEX_WAKE, uintptr(numThreads)); e != 0 {
return e
}
return nil
}
-func (ep *Endpoint) futexWait(seq uint32) error {
- _, _, e := syscall.Syscall6(syscall.SYS_FUTEX, uintptr(ep.packet), linux.FUTEX_WAIT, uintptr(seq), 0, 0, 0)
+func (ep *Endpoint) futexWaitConnState(curState uint32) error {
+ _, _, e := syscall.Syscall6(syscall.SYS_FUTEX, ep.packet, linux.FUTEX_WAIT, uintptr(curState), 0, 0, 0)
if e != 0 && e != syscall.EAGAIN && e != syscall.EINTR {
return e
}
return nil
}
+
+func yieldThread() {
+ syscall.Syscall(syscall.SYS_SCHED_YIELD, 0, 0, 0)
+ // The thread we're trying to yield to may be waiting for a Go runtime P.
+ // runtime.Gosched() will hand off ours if necessary.
+ runtime.Gosched()
+}
diff --git a/pkg/flipcall/io.go b/pkg/flipcall/io.go
new file mode 100644
index 000000000..85e40b932
--- /dev/null
+++ b/pkg/flipcall/io.go
@@ -0,0 +1,113 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "fmt"
+ "io"
+)
+
+// DatagramReader implements io.Reader by reading a datagram from an Endpoint's
+// packet window. Its use is optional; users that can use Endpoint.Data() more
+// efficiently are advised to do so.
+type DatagramReader struct {
+ ep *Endpoint
+ off uint32
+ end uint32
+}
+
+// Init must be called on zero-value DatagramReaders before first use.
+//
+// Preconditions: dataLen is 0, or was returned by a previous call to
+// ep.RecvFirst() or ep.SendRecv().
+func (r *DatagramReader) Init(ep *Endpoint, dataLen uint32) {
+ r.ep = ep
+ r.Reset(dataLen)
+}
+
+// Reset causes r to begin reading a new datagram of the given length from the
+// associated Endpoint.
+//
+// Preconditions: dataLen is 0, or was returned by a previous call to the
+// associated Endpoint's RecvFirst() or SendRecv() methods.
+func (r *DatagramReader) Reset(dataLen uint32) {
+ if dataLen > r.ep.dataCap {
+ panic(fmt.Sprintf("invalid dataLen (%d) > ep.dataCap (%d)", dataLen, r.ep.dataCap))
+ }
+ r.off = 0
+ r.end = dataLen
+}
+
+// NewReader is a convenience function that returns an initialized
+// DatagramReader allocated on the heap.
+//
+// Preconditions: dataLen was returned by a previous call to ep.RecvFirst() or
+// ep.SendRecv().
+func (ep *Endpoint) NewReader(dataLen uint32) *DatagramReader {
+ r := &DatagramReader{}
+ r.Init(ep, dataLen)
+ return r
+}
+
+// Read implements io.Reader.Read.
+func (r *DatagramReader) Read(dst []byte) (int, error) {
+ n := copy(dst, r.ep.Data()[r.off:r.end])
+ r.off += uint32(n)
+ if r.off == r.end {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// DatagramWriter implements io.Writer by writing a datagram to an Endpoint's
+// packet window. Its use is optional; users that can use Endpoint.Data() more
+// efficiently are advised to do so.
+type DatagramWriter struct {
+ ep *Endpoint
+ off uint32
+}
+
+// Init must be called on zero-value DatagramWriters before first use.
+func (w *DatagramWriter) Init(ep *Endpoint) {
+ w.ep = ep
+}
+
+// Reset causes w to begin writing a new datagram to the associated Endpoint.
+func (w *DatagramWriter) Reset() {
+ w.off = 0
+}
+
+// NewWriter is a convenience function that returns an initialized
+// DatagramWriter allocated on the heap.
+func (ep *Endpoint) NewWriter() *DatagramWriter {
+ w := &DatagramWriter{}
+ w.Init(ep)
+ return w
+}
+
+// Write implements io.Writer.Write.
+func (w *DatagramWriter) Write(src []byte) (int, error) {
+ n := copy(w.ep.Data()[w.off:w.ep.dataCap], src)
+ w.off += uint32(n)
+ if n != len(src) {
+ return n, fmt.Errorf("datagram would exceed maximum size of %d bytes", w.ep.dataCap)
+ }
+ return n, nil
+}
+
+// Len returns the length of the written datagram.
+func (w *DatagramWriter) Len() uint32 {
+ return w.off
+}
diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window_allocator.go
index 7b455b24d..ccb918fab 100644
--- a/pkg/flipcall/packet_window_allocator.go
+++ b/pkg/flipcall/packet_window_allocator.go
@@ -34,10 +34,10 @@ func init() {
// This is depended on by roundUpToPage().
panic(fmt.Sprintf("system page size (%d) is not a power of 2", pageSize))
}
- if uintptr(pageSize) < packetHeaderBytes {
+ if uintptr(pageSize) < PacketHeaderBytes {
// This is required since Endpoint.Init() imposes a minimum packet
// window size of 1 page.
- panic(fmt.Sprintf("system page size (%d) is less than packet header size (%d)", pageSize, packetHeaderBytes))
+ panic(fmt.Sprintf("system page size (%d) is less than packet header size (%d)", pageSize, PacketHeaderBytes))
}
}
@@ -59,7 +59,7 @@ type PacketWindowDescriptor struct {
// PacketWindowLengthForDataCap returns the minimum packet window size required
// to accommodate datagrams of the given size in bytes.
func PacketWindowLengthForDataCap(dataCap uint32) int {
- return roundUpToPage(int(dataCap) + int(packetHeaderBytes))
+ return roundUpToPage(int(dataCap) + int(PacketHeaderBytes))
}
func roundUpToPage(x int) int {
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 8c3e3d5ab..ad69e0757 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -215,8 +215,8 @@ type AtomicRefCount struct {
type LeakMode uint32
const (
- // uninitializedLeakChecking indicates that the leak checker has not yet been initialized.
- uninitializedLeakChecking LeakMode = iota
+ // UninitializedLeakChecking indicates that the leak checker has not yet been initialized.
+ UninitializedLeakChecking LeakMode = iota
// NoLeakChecking indicates that no effort should be made to check for
// leaks.
@@ -231,20 +231,6 @@ const (
LeaksLogTraces
)
-// String returns LeakMode's string representation.
-func (l LeakMode) String() string {
- switch l {
- case NoLeakChecking:
- return "disabled"
- case LeaksLogWarning:
- return "log-names"
- case LeaksLogTraces:
- return "log-traces"
- default:
- panic(fmt.Sprintf("Invalid leakmode: %d", l))
- }
-}
-
// leakMode stores the current mode for the reference leak checker.
//
// Values must be one of the LeakMode values.
@@ -332,13 +318,15 @@ func (r *AtomicRefCount) finalize() {
switch LeakMode(atomic.LoadUint32(&leakMode)) {
case NoLeakChecking:
return
- case uninitializedLeakChecking:
+ case UninitializedLeakChecking:
note = "(Leak checker uninitialized): "
}
if n := r.ReadRefs(); n != 0 {
msg := fmt.Sprintf("%sAtomicRefCount %p owned by %q garbage collected with ref count of %d (want 0)", note, r, r.name, n)
if len(r.stack) != 0 {
msg += ":\nCaller:\n" + formatStack(r.stack)
+ } else {
+ msg += " (enable trace logging to debug)"
}
log.Warningf(msg)
}
diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go
index 62ae1fd9f..48413f1fb 100644
--- a/pkg/seccomp/seccomp_test_victim.go
+++ b/pkg/seccomp/seccomp_test_victim.go
@@ -70,7 +70,7 @@ func main() {
syscall.SYS_NANOSLEEP: {},
syscall.SYS_NEWFSTATAT: {},
syscall.SYS_OPEN: {},
- syscall.SYS_POLL: {},
+ syscall.SYS_PPOLL: {},
syscall.SYS_PREAD64: {},
syscall.SYS_PSELECT6: {},
syscall.SYS_PWRITE64: {},
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index 3f9772b87..c35faeb4c 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -56,15 +56,10 @@ type ExecArgs struct {
// MountNamespace is the mount namespace to execute the new process in.
// A reference on MountNamespace must be held for the lifetime of the
- // ExecArgs. If MountNamespace is nil, it will default to the kernel's
- // root MountNamespace.
+ // ExecArgs. If MountNamespace is nil, it will default to the init
+ // process's MountNamespace.
MountNamespace *fs.MountNamespace
- // Root defines the root directory for the new process. A reference on
- // Root must be held for the lifetime of the ExecArgs. If Root is nil,
- // it will default to the VFS root.
- Root *fs.Dirent
-
// WorkingDirectory defines the working directory for the new process.
WorkingDirectory string `json:"wd"`
@@ -155,7 +150,6 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
Envv: args.Envv,
WorkingDirectory: args.WorkingDirectory,
MountNamespace: args.MountNamespace,
- Root: args.Root,
Credentials: creds,
FDTable: fdTable,
Umask: 0022,
@@ -167,11 +161,6 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
ContainerID: args.ContainerID,
PIDNamespace: args.PIDNamespace,
}
- if initArgs.Root != nil {
- // initArgs must hold a reference on Root, which will be
- // donated to the new process in CreateProcess.
- initArgs.Root.IncRef()
- }
if initArgs.MountNamespace != nil {
// initArgs must hold a reference on MountNamespace, which will
// be donated to the new process in CreateProcess.
@@ -184,7 +173,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
paths := fs.GetPath(initArgs.Envv)
mns := initArgs.MountNamespace
if mns == nil {
- mns = proc.Kernel.RootMountNamespace()
+ mns = proc.Kernel.GlobalInit().Leader().MountNamespace()
}
f, err := mns.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
if err != nil {
diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go
index 9fc6a5bc2..4f3d6410e 100644
--- a/pkg/sentry/fs/attr.go
+++ b/pkg/sentry/fs/attr.go
@@ -111,6 +111,50 @@ func (n InodeType) LinuxType() uint32 {
}
}
+// ToDirentType converts an InodeType to a linux dirent type field.
+func ToDirentType(nodeType InodeType) uint8 {
+ switch nodeType {
+ case RegularFile, SpecialFile:
+ return linux.DT_REG
+ case Symlink:
+ return linux.DT_LNK
+ case Directory, SpecialDirectory:
+ return linux.DT_DIR
+ case Pipe:
+ return linux.DT_FIFO
+ case CharacterDevice:
+ return linux.DT_CHR
+ case BlockDevice:
+ return linux.DT_BLK
+ case Socket:
+ return linux.DT_SOCK
+ default:
+ return linux.DT_UNKNOWN
+ }
+}
+
+// ToInodeType coverts a linux file type to InodeType.
+func ToInodeType(linuxFileType linux.FileMode) InodeType {
+ switch linuxFileType {
+ case linux.ModeRegular:
+ return RegularFile
+ case linux.ModeDirectory:
+ return Directory
+ case linux.ModeSymlink:
+ return Symlink
+ case linux.ModeNamedPipe:
+ return Pipe
+ case linux.ModeCharacterDevice:
+ return CharacterDevice
+ case linux.ModeBlockDevice:
+ return BlockDevice
+ case linux.ModeSocket:
+ return Socket
+ default:
+ panic(fmt.Sprintf("unknown file mode: %d", linuxFileType))
+ }
+}
+
// StableAttr contains Inode attributes that will be stable throughout the
// lifetime of the Inode.
//
diff --git a/pkg/sentry/fs/context.go b/pkg/sentry/fs/context.go
index 51b4c7ee1..dd427de5d 100644
--- a/pkg/sentry/fs/context.go
+++ b/pkg/sentry/fs/context.go
@@ -112,3 +112,27 @@ func DirentCacheLimiterFromContext(ctx context.Context) *DirentCacheLimiter {
}
return nil
}
+
+type rootContext struct {
+ context.Context
+ root *Dirent
+}
+
+// WithRoot returns a copy of ctx with the given root.
+func WithRoot(ctx context.Context, root *Dirent) context.Context {
+ return &rootContext{
+ Context: ctx,
+ root: root,
+ }
+}
+
+// Value implements Context.Value.
+func (rc rootContext) Value(key interface{}) interface{} {
+ switch key {
+ case CtxRoot:
+ rc.root.IncRef()
+ return rc.root
+ default:
+ return rc.Context.Value(key)
+ }
+}
diff --git a/pkg/sentry/fs/ext/BUILD b/pkg/sentry/fs/ext/BUILD
deleted file mode 100644
index 2c15875f5..000000000
--- a/pkg/sentry/fs/ext/BUILD
+++ /dev/null
@@ -1,54 +0,0 @@
-package(licenses = ["notice"])
-
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-
-go_library(
- name = "ext",
- srcs = [
- "dentry.go",
- "ext.go",
- "filesystem.go",
- "inode.go",
- "utils.go",
- ],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ext",
- visibility = ["//pkg/sentry:internal"],
- deps = [
- "//pkg/abi/linux",
- "//pkg/binary",
- "//pkg/sentry/context",
- "//pkg/sentry/fs/ext/disklayout",
- "//pkg/sentry/kernel/auth",
- "//pkg/sentry/vfs",
- "//pkg/syserror",
- ],
-)
-
-go_test(
- name = "ext_test",
- size = "small",
- srcs = [
- "ext_test.go",
- "extent_test.go",
- ],
- data = [
- "//pkg/sentry/fs/ext:assets/bigfile.txt",
- "//pkg/sentry/fs/ext:assets/file.txt",
- "//pkg/sentry/fs/ext:assets/tiny.ext2",
- "//pkg/sentry/fs/ext:assets/tiny.ext3",
- "//pkg/sentry/fs/ext:assets/tiny.ext4",
- ],
- embed = [":ext"],
- deps = [
- "//pkg/abi/linux",
- "//pkg/binary",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
- "//pkg/sentry/fs/ext/disklayout",
- "//pkg/sentry/kernel/auth",
- "//pkg/sentry/vfs",
- "//runsc/test/testutil",
- "@com_github_google_go-cmp//cmp:go_default_library",
- "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
- ],
-)
diff --git a/pkg/sentry/fs/ext/ext.go b/pkg/sentry/fs/ext/ext.go
deleted file mode 100644
index 10e235fb1..000000000
--- a/pkg/sentry/fs/ext/ext.go
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package ext implements readonly ext(2/3/4) filesystems.
-package ext
-
-import (
- "errors"
- "fmt"
- "io"
- "os"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// filesystemType implements vfs.FilesystemType.
-type filesystemType struct{}
-
-// Compiles only if filesystemType implements vfs.FilesystemType.
-var _ vfs.FilesystemType = (*filesystemType)(nil)
-
-// getDeviceFd returns the read seeker to the underlying device.
-// Currently there are two ways of mounting an ext(2/3/4) fs:
-// 1. Specify a mount with our internal special MountType in the OCI spec.
-// 2. Expose the device to the container and mount it from application layer.
-func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReadSeeker, error) {
- if opts.InternalData == nil {
- // User mount call.
- // TODO(b/134676337): Open the device specified by `source` and return that.
- panic("unimplemented")
- }
-
- // NewFilesystem call originated from within the sentry.
- fd, ok := opts.InternalData.(uintptr)
- if !ok {
- return nil, errors.New("internal data for ext fs must be a uintptr containing the file descriptor to device")
- }
-
- // We do not close this file because that would close the underlying device
- // file descriptor (which is required for reading the fs from disk).
- // TODO(b/134676337): Use pkg/fd instead.
- deviceFile := os.NewFile(fd, source)
- if deviceFile == nil {
- return nil, fmt.Errorf("ext4 device file descriptor is not valid: %d", fd)
- }
-
- return deviceFile, nil
-}
-
-// NewFilesystem implements vfs.FilesystemType.NewFilesystem.
-func (fstype filesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
- dev, err := getDeviceFd(source, opts)
- if err != nil {
- return nil, nil, err
- }
-
- fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)}
- fs.vfsfs.Init(&fs)
- fs.sb, err = readSuperBlock(dev)
- if err != nil {
- return nil, nil, err
- }
-
- if fs.sb.Magic() != linux.EXT_SUPER_MAGIC {
- // mount(2) specifies that EINVAL should be returned if the superblock is
- // invalid.
- return nil, nil, syserror.EINVAL
- }
-
- fs.bgs, err = readBlockGroups(dev, fs.sb)
- if err != nil {
- return nil, nil, err
- }
-
- rootInode, err := fs.getOrCreateInode(disklayout.RootDirInode)
- if err != nil {
- return nil, nil, err
- }
-
- return &fs.vfsfs, &newDentry(rootInode).vfsd, nil
-}
diff --git a/pkg/sentry/fs/ext/ext_test.go b/pkg/sentry/fs/ext/ext_test.go
deleted file mode 100644
index ee7f7907c..000000000
--- a/pkg/sentry/fs/ext/ext_test.go
+++ /dev/null
@@ -1,407 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ext
-
-import (
- "fmt"
- "os"
- "path"
- "testing"
-
- "github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-
- "gvisor.dev/gvisor/runsc/test/testutil"
-)
-
-const (
- assetsDir = "pkg/sentry/fs/ext/assets"
-)
-
-var (
- ext2ImagePath = path.Join(assetsDir, "tiny.ext2")
- ext3ImagePath = path.Join(assetsDir, "tiny.ext3")
- ext4ImagePath = path.Join(assetsDir, "tiny.ext4")
-)
-
-func beginning(_ uint64) uint64 {
- return 0
-}
-
-func middle(i uint64) uint64 {
- return i / 2
-}
-
-func end(i uint64) uint64 {
- return i
-}
-
-// setUp opens imagePath as an ext Filesystem and returns all necessary
-// elements required to run tests. If error is non-nil, it also returns a tear
-// down function which must be called after the test is run for clean up.
-func setUp(t *testing.T, imagePath string) (context.Context, *vfs.Filesystem, *vfs.Dentry, func(), error) {
- localImagePath, err := testutil.FindFile(imagePath)
- if err != nil {
- return nil, nil, nil, nil, fmt.Errorf("failed to open local image at path %s: %v", imagePath, err)
- }
-
- f, err := os.Open(localImagePath)
- if err != nil {
- return nil, nil, nil, nil, err
- }
-
- // Mount the ext4 fs and retrieve the inode structure for the file.
- mockCtx := contexttest.Context(t)
- fs, d, err := filesystemType{}.NewFilesystem(mockCtx, nil, localImagePath, vfs.NewFilesystemOptions{InternalData: f.Fd()})
- if err != nil {
- f.Close()
- return nil, nil, nil, nil, err
- }
-
- tearDown := func() {
- if err := f.Close(); err != nil {
- t.Fatalf("tearDown failed: %v", err)
- }
- }
- return mockCtx, fs, d, tearDown, nil
-}
-
-// TestRootDir tests that the root directory inode is correctly initialized and
-// returned from setUp.
-func TestRootDir(t *testing.T) {
- type inodeProps struct {
- Mode linux.FileMode
- UID auth.KUID
- GID auth.KGID
- Size uint64
- InodeSize uint16
- Links uint16
- Flags disklayout.InodeFlags
- }
-
- type rootDirTest struct {
- name string
- image string
- wantInode inodeProps
- }
-
- tests := []rootDirTest{
- {
- name: "ext4 root dir",
- image: ext4ImagePath,
- wantInode: inodeProps{
- Mode: linux.ModeDirectory | 0755,
- Size: 0x400,
- InodeSize: 0x80,
- Links: 3,
- Flags: disklayout.InodeFlags{Extents: true},
- },
- },
- {
- name: "ext3 root dir",
- image: ext3ImagePath,
- wantInode: inodeProps{
- Mode: linux.ModeDirectory | 0755,
- Size: 0x400,
- InodeSize: 0x80,
- Links: 3,
- },
- },
- {
- name: "ext2 root dir",
- image: ext2ImagePath,
- wantInode: inodeProps{
- Mode: linux.ModeDirectory | 0755,
- Size: 0x400,
- InodeSize: 0x80,
- Links: 3,
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- _, _, vfsd, tearDown, err := setUp(t, test.image)
- if err != nil {
- t.Fatalf("setUp failed: %v", err)
- }
- defer tearDown()
-
- d, ok := vfsd.Impl().(*dentry)
- if !ok {
- t.Fatalf("ext dentry of incorrect type: %T", vfsd.Impl())
- }
-
- // Offload inode contents into local structs for comparison.
- gotInode := inodeProps{
- Mode: d.inode.diskInode.Mode(),
- UID: d.inode.diskInode.UID(),
- GID: d.inode.diskInode.GID(),
- Size: d.inode.diskInode.Size(),
- InodeSize: d.inode.diskInode.InodeSize(),
- Links: d.inode.diskInode.LinksCount(),
- Flags: d.inode.diskInode.Flags(),
- }
-
- if diff := cmp.Diff(gotInode, test.wantInode); diff != "" {
- t.Errorf("inode mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
-
-// TestFilesystemInit tests that the filesystem superblock and block group
-// descriptors are correctly read in and initialized.
-func TestFilesystemInit(t *testing.T) {
- // sb only contains the immutable properties of the superblock.
- type sb struct {
- InodesCount uint32
- BlocksCount uint64
- MaxMountCount uint16
- FirstDataBlock uint32
- BlockSize uint64
- BlocksPerGroup uint32
- ClusterSize uint64
- ClustersPerGroup uint32
- InodeSize uint16
- InodesPerGroup uint32
- BgDescSize uint16
- Magic uint16
- Revision disklayout.SbRevision
- CompatFeatures disklayout.CompatFeatures
- IncompatFeatures disklayout.IncompatFeatures
- RoCompatFeatures disklayout.RoCompatFeatures
- }
-
- // bg only contains the immutable properties of the block group descriptor.
- type bg struct {
- InodeTable uint64
- BlockBitmap uint64
- InodeBitmap uint64
- ExclusionBitmap uint64
- Flags disklayout.BGFlags
- }
-
- type fsInitTest struct {
- name string
- image string
- wantSb sb
- wantBgs []bg
- }
-
- tests := []fsInitTest{
- {
- name: "ext4 filesystem init",
- image: ext4ImagePath,
- wantSb: sb{
- InodesCount: 0x10,
- BlocksCount: 0x40,
- MaxMountCount: 0xffff,
- FirstDataBlock: 0x1,
- BlockSize: 0x400,
- BlocksPerGroup: 0x2000,
- ClusterSize: 0x400,
- ClustersPerGroup: 0x2000,
- InodeSize: 0x80,
- InodesPerGroup: 0x10,
- BgDescSize: 0x40,
- Magic: linux.EXT_SUPER_MAGIC,
- Revision: disklayout.DynamicRev,
- CompatFeatures: disklayout.CompatFeatures{
- ExtAttr: true,
- ResizeInode: true,
- DirIndex: true,
- },
- IncompatFeatures: disklayout.IncompatFeatures{
- DirentFileType: true,
- Extents: true,
- Is64Bit: true,
- FlexBg: true,
- },
- RoCompatFeatures: disklayout.RoCompatFeatures{
- Sparse: true,
- LargeFile: true,
- HugeFile: true,
- DirNlink: true,
- ExtraIsize: true,
- MetadataCsum: true,
- },
- },
- wantBgs: []bg{
- {
- InodeTable: 0x23,
- BlockBitmap: 0x3,
- InodeBitmap: 0x13,
- Flags: disklayout.BGFlags{
- InodeZeroed: true,
- },
- },
- },
- },
- {
- name: "ext3 filesystem init",
- image: ext3ImagePath,
- wantSb: sb{
- InodesCount: 0x10,
- BlocksCount: 0x40,
- MaxMountCount: 0xffff,
- FirstDataBlock: 0x1,
- BlockSize: 0x400,
- BlocksPerGroup: 0x2000,
- ClusterSize: 0x400,
- ClustersPerGroup: 0x2000,
- InodeSize: 0x80,
- InodesPerGroup: 0x10,
- BgDescSize: 0x20,
- Magic: linux.EXT_SUPER_MAGIC,
- Revision: disklayout.DynamicRev,
- CompatFeatures: disklayout.CompatFeatures{
- ExtAttr: true,
- ResizeInode: true,
- DirIndex: true,
- },
- IncompatFeatures: disklayout.IncompatFeatures{
- DirentFileType: true,
- },
- RoCompatFeatures: disklayout.RoCompatFeatures{
- Sparse: true,
- LargeFile: true,
- },
- },
- wantBgs: []bg{
- {
- InodeTable: 0x5,
- BlockBitmap: 0x3,
- InodeBitmap: 0x4,
- Flags: disklayout.BGFlags{
- InodeZeroed: true,
- },
- },
- },
- },
- {
- name: "ext2 filesystem init",
- image: ext2ImagePath,
- wantSb: sb{
- InodesCount: 0x10,
- BlocksCount: 0x40,
- MaxMountCount: 0xffff,
- FirstDataBlock: 0x1,
- BlockSize: 0x400,
- BlocksPerGroup: 0x2000,
- ClusterSize: 0x400,
- ClustersPerGroup: 0x2000,
- InodeSize: 0x80,
- InodesPerGroup: 0x10,
- BgDescSize: 0x20,
- Magic: linux.EXT_SUPER_MAGIC,
- Revision: disklayout.DynamicRev,
- CompatFeatures: disklayout.CompatFeatures{
- ExtAttr: true,
- ResizeInode: true,
- DirIndex: true,
- },
- IncompatFeatures: disklayout.IncompatFeatures{
- DirentFileType: true,
- },
- RoCompatFeatures: disklayout.RoCompatFeatures{
- Sparse: true,
- LargeFile: true,
- },
- },
- wantBgs: []bg{
- {
- InodeTable: 0x5,
- BlockBitmap: 0x3,
- InodeBitmap: 0x4,
- Flags: disklayout.BGFlags{
- InodeZeroed: true,
- },
- },
- },
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- _, vfsfs, _, tearDown, err := setUp(t, test.image)
- if err != nil {
- t.Fatalf("setUp failed: %v", err)
- }
- defer tearDown()
-
- fs, ok := vfsfs.Impl().(*filesystem)
- if !ok {
- t.Fatalf("ext filesystem of incorrect type: %T", vfsfs.Impl())
- }
-
- // Offload superblock and block group descriptors contents into
- // local structs for comparison.
- totalFreeInodes := uint32(0)
- totalFreeBlocks := uint64(0)
- gotSb := sb{
- InodesCount: fs.sb.InodesCount(),
- BlocksCount: fs.sb.BlocksCount(),
- MaxMountCount: fs.sb.MaxMountCount(),
- FirstDataBlock: fs.sb.FirstDataBlock(),
- BlockSize: fs.sb.BlockSize(),
- BlocksPerGroup: fs.sb.BlocksPerGroup(),
- ClusterSize: fs.sb.ClusterSize(),
- ClustersPerGroup: fs.sb.ClustersPerGroup(),
- InodeSize: fs.sb.InodeSize(),
- InodesPerGroup: fs.sb.InodesPerGroup(),
- BgDescSize: fs.sb.BgDescSize(),
- Magic: fs.sb.Magic(),
- Revision: fs.sb.Revision(),
- CompatFeatures: fs.sb.CompatibleFeatures(),
- IncompatFeatures: fs.sb.IncompatibleFeatures(),
- RoCompatFeatures: fs.sb.ReadOnlyCompatibleFeatures(),
- }
- gotNumBgs := len(fs.bgs)
- gotBgs := make([]bg, gotNumBgs)
- for i := 0; i < gotNumBgs; i++ {
- gotBgs[i].InodeTable = fs.bgs[i].InodeTable()
- gotBgs[i].BlockBitmap = fs.bgs[i].BlockBitmap()
- gotBgs[i].InodeBitmap = fs.bgs[i].InodeBitmap()
- gotBgs[i].ExclusionBitmap = fs.bgs[i].ExclusionBitmap()
- gotBgs[i].Flags = fs.bgs[i].Flags()
-
- totalFreeInodes += fs.bgs[i].FreeInodesCount()
- totalFreeBlocks += uint64(fs.bgs[i].FreeBlocksCount())
- }
-
- if diff := cmp.Diff(gotSb, test.wantSb); diff != "" {
- t.Errorf("superblock mismatch (-want +got):\n%s", diff)
- }
-
- if diff := cmp.Diff(gotBgs, test.wantBgs); diff != "" {
- t.Errorf("block group descriptors mismatch (-want +got):\n%s", diff)
- }
-
- if diff := cmp.Diff(totalFreeInodes, fs.sb.FreeInodesCount()); diff != "" {
- t.Errorf("total free inodes mismatch (-want +got):\n%s", diff)
- }
-
- if diff := cmp.Diff(totalFreeBlocks, fs.sb.FreeBlocksCount()); diff != "" {
- t.Errorf("total free blocks mismatch (-want +got):\n%s", diff)
- }
- })
- }
-}
diff --git a/pkg/sentry/fs/ext/filesystem.go b/pkg/sentry/fs/ext/filesystem.go
deleted file mode 100644
index 7150e75a5..000000000
--- a/pkg/sentry/fs/ext/filesystem.go
+++ /dev/null
@@ -1,137 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ext
-
-import (
- "io"
- "sync"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// filesystem implements vfs.FilesystemImpl.
-type filesystem struct {
- // TODO(b/134676337): Remove when all methods have been implemented.
- vfs.FilesystemImpl
-
- vfsfs vfs.Filesystem
-
- // mu serializes changes to the Dentry tree and the usage of the read seeker.
- mu sync.Mutex
-
- // dev is the ReadSeeker for the underlying fs device. It is protected by mu.
- //
- // The ext filesystems aim to maximize locality, i.e. place all the data
- // blocks of a file close together. On a spinning disk, locality reduces the
- // amount of movement of the head hence speeding up IO operations. On an SSD
- // there are no moving parts but locality increases the size of each transer
- // request. Hence, having mutual exclusion on the read seeker while reading a
- // file *should* help in achieving the intended performance gains.
- //
- // Note: This synchronization was not coupled with the ReadSeeker itself
- // because we want to synchronize across read/seek operations for the
- // performance gains mentioned above. Helps enforcing one-file-at-a-time IO.
- dev io.ReadSeeker
-
- // inodeCache maps absolute inode numbers to the corresponding Inode struct.
- // Inodes should be removed from this once their reference count hits 0.
- //
- // Protected by mu because every addition and removal from this corresponds to
- // a change in the dentry tree.
- inodeCache map[uint32]*inode
-
- // sb represents the filesystem superblock. Immutable after initialization.
- sb disklayout.SuperBlock
-
- // bgs represents all the block group descriptors for the filesystem.
- // Immutable after initialization.
- bgs []disklayout.BlockGroup
-}
-
-// Compiles only if filesystem implements vfs.FilesystemImpl.
-var _ vfs.FilesystemImpl = (*filesystem)(nil)
-
-// getOrCreateInode gets the inode corresponding to the inode number passed in.
-// It creates a new one with the given inode number if one does not exist.
-//
-// Preconditions: must be holding fs.mu.
-func (fs *filesystem) getOrCreateInode(inodeNum uint32) (*inode, error) {
- if in, ok := fs.inodeCache[inodeNum]; ok {
- return in, nil
- }
-
- in, err := newInode(fs.dev, fs.sb, fs.bgs, inodeNum)
- if err != nil {
- return nil, err
- }
-
- fs.inodeCache[inodeNum] = in
- return in, nil
-}
-
-// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
-}
-
-// Sync implements vfs.FilesystemImpl.Sync.
-func (fs *filesystem) Sync(ctx context.Context) error {
- // This is a readonly filesystem for now.
- return nil
-}
-
-// The vfs.FilesystemImpl functions below return EROFS because their respective
-// man pages say that EROFS must be returned if the path resolves to a file on
-// a read-only filesystem.
-
-// TODO(b/134676337): Implement path traversal and return EROFS only if the
-// path resolves to a Dentry within ext fs.
-
-// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
-func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
- return syserror.EROFS
-}
-
-// MknodAt implements vfs.FilesystemImpl.MknodAt.
-func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
- return syserror.EROFS
-}
-
-// RenameAt implements vfs.FilesystemImpl.RenameAt.
-func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
- return syserror.EROFS
-}
-
-// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
-func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- return syserror.EROFS
-}
-
-// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
-func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
- return syserror.EROFS
-}
-
-// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
-func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
- return syserror.EROFS
-}
-
-// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
-func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- return syserror.EROFS
-}
diff --git a/pkg/sentry/fs/ext/inode.go b/pkg/sentry/fs/ext/inode.go
deleted file mode 100644
index df1ea0bda..000000000
--- a/pkg/sentry/fs/ext/inode.go
+++ /dev/null
@@ -1,209 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ext
-
-import (
- "io"
- "sync/atomic"
-
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// inode represents an ext inode.
-type inode struct {
- // refs is a reference count. refs is accessed using atomic memory operations.
- refs int64
-
- // inodeNum is the inode number of this inode on disk. This is used to
- // identify inodes within the ext filesystem.
- inodeNum uint32
-
- // diskInode gives us access to the inode struct on disk. Immutable.
- diskInode disklayout.Inode
-
- // root is the root extent node. This lives in the 60 byte diskInode.Blocks().
- // Immutable. Nil if the inode does not use extents.
- root *disklayout.ExtentNode
-}
-
-// incRef increments the inode ref count.
-func (in *inode) incRef() {
- atomic.AddInt64(&in.refs, 1)
-}
-
-// tryIncRef tries to increment the ref count. Returns true if successful.
-func (in *inode) tryIncRef() bool {
- for {
- refs := atomic.LoadInt64(&in.refs)
- if refs == 0 {
- return false
- }
- if atomic.CompareAndSwapInt64(&in.refs, refs, refs+1) {
- return true
- }
- }
-}
-
-// decRef decrements the inode ref count and releases the inode resources if
-// the ref count hits 0.
-//
-// Preconditions: Must have locked fs.mu.
-func (in *inode) decRef(fs *filesystem) {
- if refs := atomic.AddInt64(&in.refs, -1); refs == 0 {
- delete(fs.inodeCache, in.inodeNum)
- } else if refs < 0 {
- panic("ext.inode.decRef() called without holding a reference")
- }
-}
-
-// newInode is the inode constructor. Reads the inode off disk. Identifies
-// inodes based on the absolute inode number on disk.
-//
-// Preconditions: Must hold the mutex of the filesystem containing dev.
-func newInode(dev io.ReadSeeker, sb disklayout.SuperBlock, bgs []disklayout.BlockGroup, inodeNum uint32) (*inode, error) {
- if inodeNum == 0 {
- panic("inode number 0 on ext filesystems is not possible")
- }
-
- in := &inode{refs: 1, inodeNum: inodeNum}
- inodeRecordSize := sb.InodeSize()
- if inodeRecordSize == disklayout.OldInodeSize {
- in.diskInode = &disklayout.InodeOld{}
- } else {
- in.diskInode = &disklayout.InodeNew{}
- }
-
- // Calculate where the inode is actually placed.
- inodesPerGrp := sb.InodesPerGroup()
- blkSize := sb.BlockSize()
- inodeTableOff := bgs[getBGNum(inodeNum, inodesPerGrp)].InodeTable() * blkSize
- inodeOff := inodeTableOff + uint64(uint32(inodeRecordSize)*getBGOff(inodeNum, inodesPerGrp))
-
- // Read it from disk and figure out which type of inode this is.
- if err := readFromDisk(dev, int64(inodeOff), in.diskInode); err != nil {
- return nil, err
- }
-
- if in.diskInode.Flags().Extents {
- in.buildExtTree(dev, blkSize)
- }
-
- return in, nil
-}
-
-// getBGNum returns the block group number that a given inode belongs to.
-func getBGNum(inodeNum uint32, inodesPerGrp uint32) uint32 {
- return (inodeNum - 1) / inodesPerGrp
-}
-
-// getBGOff returns the offset at which the given inode lives in the block
-// group's inode table, i.e. the index of the inode in the inode table.
-func getBGOff(inodeNum uint32, inodesPerGrp uint32) uint32 {
- return (inodeNum - 1) % inodesPerGrp
-}
-
-// buildExtTree builds the extent tree by reading it from disk by doing
-// running a simple DFS. It first reads the root node from the inode struct in
-// memory. Then it recursively builds the rest of the tree by reading it off
-// disk.
-//
-// Preconditions:
-// - Must hold the mutex of the filesystem containing dev.
-// - Inode flag InExtents must be set.
-func (in *inode) buildExtTree(dev io.ReadSeeker, blkSize uint64) error {
- rootNodeData := in.diskInode.Data()
-
- var rootHeader disklayout.ExtentHeader
- binary.Unmarshal(rootNodeData[:disklayout.ExtentStructsSize], binary.LittleEndian, &rootHeader)
-
- // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
- if rootHeader.NumEntries > 4 {
- // read(2) specifies that EINVAL should be returned if the file is unsuitable
- // for reading.
- return syserror.EINVAL
- }
-
- rootEntries := make([]disklayout.ExtentEntryPair, rootHeader.NumEntries)
- for i, off := uint16(0), disklayout.ExtentStructsSize; i < rootHeader.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
- var curEntry disklayout.ExtentEntry
- if rootHeader.Height == 0 {
- // Leaf node.
- curEntry = &disklayout.Extent{}
- } else {
- // Internal node.
- curEntry = &disklayout.ExtentIdx{}
- }
- binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentStructsSize], binary.LittleEndian, curEntry)
- rootEntries[i].Entry = curEntry
- }
-
- // If this node is internal, perform DFS.
- if rootHeader.Height > 0 {
- for i := uint16(0); i < rootHeader.NumEntries; i++ {
- var err error
- if rootEntries[i].Node, err = buildExtTreeFromDisk(dev, rootEntries[i].Entry, blkSize); err != nil {
- return err
- }
- }
- }
-
- in.root = &disklayout.ExtentNode{rootHeader, rootEntries}
- return nil
-}
-
-// buildExtTreeFromDisk reads the extent tree nodes from disk and recursively
-// builds the tree. Performs a simple DFS. It returns the ExtentNode pointed to
-// by the ExtentEntry.
-//
-// Preconditions: Must hold the mutex of the filesystem containing dev.
-func buildExtTreeFromDisk(dev io.ReadSeeker, entry disklayout.ExtentEntry, blkSize uint64) (*disklayout.ExtentNode, error) {
- var header disklayout.ExtentHeader
- off := entry.PhysicalBlock() * blkSize
- if err := readFromDisk(dev, int64(off), &header); err != nil {
- return nil, err
- }
-
- entries := make([]disklayout.ExtentEntryPair, header.NumEntries)
- for i, off := uint16(0), off+disklayout.ExtentStructsSize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
- var curEntry disklayout.ExtentEntry
- if header.Height == 0 {
- // Leaf node.
- curEntry = &disklayout.Extent{}
- } else {
- // Internal node.
- curEntry = &disklayout.ExtentIdx{}
- }
-
- if err := readFromDisk(dev, int64(off), curEntry); err != nil {
- return nil, err
- }
- entries[i].Entry = curEntry
- }
-
- // If this node is internal, perform DFS.
- if header.Height > 0 {
- for i := uint16(0); i < header.NumEntries; i++ {
- var err error
- entries[i].Node, err = buildExtTreeFromDisk(dev, entries[i].Entry, blkSize)
- if err != nil {
- return nil, err
- }
- }
- }
-
- return &disklayout.ExtentNode{header, entries}, nil
-}
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
index 5a0a67eab..669ffcb75 100644
--- a/pkg/sentry/fs/fdpipe/pipe.go
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -87,7 +87,7 @@ func (p *pipeOperations) init() error {
log.Warningf("pipe: cannot stat fd %d: %v", p.file.FD(), err)
return syscall.EINVAL
}
- if s.Mode&syscall.S_IFIFO != syscall.S_IFIFO {
+ if (s.Mode & syscall.S_IFMT) != syscall.S_IFIFO {
log.Warningf("pipe: cannot load fd %d as pipe, file type: %o", p.file.FD(), s.Mode)
return syscall.EINVAL
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index ed62049a9..20cb9a367 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -66,10 +66,8 @@ type CachingInodeOperations struct {
// mfp is used to allocate memory that caches backingFile's contents.
mfp pgalloc.MemoryFileProvider
- // forcePageCache indicates the sentry page cache should be used regardless
- // of whether the platform supports host mapped I/O or not. This must not be
- // modified after inode creation.
- forcePageCache bool
+ // opts contains options. opts is immutable.
+ opts CachingInodeOperationsOptions
attrMu sync.Mutex `state:"nosave"`
@@ -116,6 +114,20 @@ type CachingInodeOperations struct {
refs frameRefSet
}
+// CachingInodeOperationsOptions configures a CachingInodeOperations.
+//
+// +stateify savable
+type CachingInodeOperationsOptions struct {
+ // If ForcePageCache is true, use the sentry page cache even if a host file
+ // descriptor is available.
+ ForcePageCache bool
+
+ // If LimitHostFDTranslation is true, apply maxFillRange() constraints to
+ // host file descriptor mappings returned by
+ // CachingInodeOperations.Translate().
+ LimitHostFDTranslation bool
+}
+
// CachedFileObject is a file that may require caching.
type CachedFileObject interface {
// ReadToBlocksAt reads up to dsts.NumBytes() bytes from the file to dsts,
@@ -159,7 +171,7 @@ type CachedFileObject interface {
// NewCachingInodeOperations returns a new CachingInodeOperations backed by
// a CachedFileObject and its initial unstable attributes.
-func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, forcePageCache bool) *CachingInodeOperations {
+func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, opts CachingInodeOperationsOptions) *CachingInodeOperations {
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
if mfp == nil {
panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
@@ -167,7 +179,7 @@ func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject
return &CachingInodeOperations{
backingFile: backingFile,
mfp: mfp,
- forcePageCache: forcePageCache,
+ opts: opts,
attr: uattr,
hostFileMapper: NewHostFileMapper(),
}
@@ -568,21 +580,30 @@ type inodeReadWriter struct {
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ mem := rw.c.mfp.MemoryFile()
+ fillCache := !rw.c.useHostPageCache() && mem.ShouldCacheEvictable()
+
// Hot path. Avoid defers.
- rw.c.dataMu.RLock()
+ var unlock func()
+ if fillCache {
+ rw.c.dataMu.Lock()
+ unlock = rw.c.dataMu.Unlock
+ } else {
+ rw.c.dataMu.RLock()
+ unlock = rw.c.dataMu.RUnlock
+ }
// Compute the range to read.
if rw.offset >= rw.c.attr.Size {
- rw.c.dataMu.RUnlock()
+ unlock()
return 0, io.EOF
}
end := fs.ReadEndOffset(rw.offset, int64(dsts.NumBytes()), rw.c.attr.Size)
if end == rw.offset { // dsts.NumBytes() == 0?
- rw.c.dataMu.RUnlock()
+ unlock()
return 0, nil
}
- mem := rw.c.mfp.MemoryFile()
var done uint64
seg, gap := rw.c.cache.Find(uint64(rw.offset))
for rw.offset < end {
@@ -592,7 +613,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
// Get internal mappings from the cache.
ims, err := mem.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
if err != nil {
- rw.c.dataMu.RUnlock()
+ unlock()
return done, err
}
@@ -602,7 +623,7 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
rw.offset += int64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.c.dataMu.RUnlock()
+ unlock()
return done, err
}
@@ -610,27 +631,48 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
seg, gap = seg.NextNonEmpty()
case gap.Ok():
- // Read directly from the backing file.
- gapmr := gap.Range().Intersect(mr)
- dst := dsts.TakeFirst64(gapmr.Length())
- n, err := rw.c.backingFile.ReadToBlocksAt(rw.ctx, dst, gapmr.Start)
- done += n
- rw.offset += int64(n)
- dsts = dsts.DropFirst64(n)
- // Partial reads are fine. But we must stop reading.
- if n != dst.NumBytes() || err != nil {
- rw.c.dataMu.RUnlock()
- return done, err
+ gapMR := gap.Range().Intersect(mr)
+ if fillCache {
+ // Read into the cache, then re-enter the loop to read from the
+ // cache.
+ reqMR := memmap.MappableRange{
+ Start: uint64(usermem.Addr(gapMR.Start).RoundDown()),
+ End: fs.OffsetPageEnd(int64(gapMR.End)),
+ }
+ optMR := gap.Range()
+ err := rw.c.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mem, usage.PageCache, rw.c.backingFile.ReadToBlocksAt)
+ mem.MarkEvictable(rw.c, pgalloc.EvictableRange{optMR.Start, optMR.End})
+ seg, gap = rw.c.cache.Find(uint64(rw.offset))
+ if !seg.Ok() {
+ unlock()
+ return done, err
+ }
+ // err might have occurred in part of gap.Range() outside
+ // gapMR. Forget about it for now; if the error matters and
+ // persists, we'll run into it again in a later iteration of
+ // this loop.
+ } else {
+ // Read directly from the backing file.
+ dst := dsts.TakeFirst64(gapMR.Length())
+ n, err := rw.c.backingFile.ReadToBlocksAt(rw.ctx, dst, gapMR.Start)
+ done += n
+ rw.offset += int64(n)
+ dsts = dsts.DropFirst64(n)
+ // Partial reads are fine. But we must stop reading.
+ if n != dst.NumBytes() || err != nil {
+ unlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), FileRangeGapIterator{}
}
- // Continue.
- seg, gap = gap.NextSegment(), FileRangeGapIterator{}
-
default:
break
}
}
- rw.c.dataMu.RUnlock()
+ unlock()
return done, nil
}
@@ -700,7 +742,10 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
seg, gap = seg.NextNonEmpty()
case gap.Ok() && gap.Start() < mr.End:
- // Write directly to the backing file.
+ // Write directly to the backing file. At present, we never fill
+ // the cache when writing, since doing so can convert small writes
+ // into inefficient read-modify-write cycles, and we have no
+ // mechanism for detecting or avoiding this.
gapmr := gap.Range().Intersect(mr)
src := srcs.TakeFirst64(gapmr.Length())
n, err := rw.c.backingFile.WriteFromBlocksAt(rw.ctx, src, gapmr.Start)
@@ -730,7 +775,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
// and memory mappings, and false if c.cache may contain data cached from
// c.backingFile.
func (c *CachingInodeOperations) useHostPageCache() bool {
- return !c.forcePageCache && c.backingFile.FD() >= 0
+ return !c.opts.ForcePageCache && c.backingFile.FD() >= 0
}
// AddMapping implements memmap.Mappable.AddMapping.
@@ -802,11 +847,15 @@ func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.Mapp
func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
// Hot path. Avoid defer.
if c.useHostPageCache() {
+ mr := optional
+ if c.opts.LimitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
return []memmap.Translation{
{
- Source: optional,
+ Source: mr,
File: c,
- Offset: optional.Start,
+ Offset: mr.Start,
Perms: usermem.AnyAccess,
},
}, nil
diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go
index dc19255ed..eb5730c35 100644
--- a/pkg/sentry/fs/fsutil/inode_cached_test.go
+++ b/pkg/sentry/fs/fsutil/inode_cached_test.go
@@ -61,7 +61,7 @@ func TestSetPermissions(t *testing.T) {
uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{
Perms: fs.FilePermsFromMode(0444),
})
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
perms := fs.FilePermsFromMode(0777)
@@ -150,7 +150,7 @@ func TestSetTimestamps(t *testing.T) {
ModificationTime: epoch,
StatusChangeTime: epoch,
}
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil {
@@ -188,7 +188,7 @@ func TestTruncate(t *testing.T) {
uattr := fs.UnstableAttr{
Size: 0,
}
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
if err := iops.Truncate(ctx, nil, uattr.Size); err != nil {
@@ -280,7 +280,7 @@ func TestRead(t *testing.T) {
uattr := fs.UnstableAttr{
Size: int64(len(buf)),
}
- iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
defer iops.Release()
// Expect the cache to be initially empty.
@@ -336,7 +336,7 @@ func TestWrite(t *testing.T) {
uattr := fs.UnstableAttr{
Size: int64(len(buf)),
}
- iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
defer iops.Release()
// Expect the cache to be initially empty.
diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go
index 69999dc28..8f8ab5d29 100644
--- a/pkg/sentry/fs/gofer/fs.go
+++ b/pkg/sentry/fs/gofer/fs.go
@@ -54,6 +54,10 @@ const (
// sandbox using files backed by the gofer. If set to false, unix sockets
// cannot be bound to gofer files without an overlay on top.
privateUnixSocketKey = "privateunixsocket"
+
+ // If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to
+ // true.
+ limitHostFDTranslationKey = "limit_host_fd_translation"
)
// defaultAname is the default attach name.
@@ -134,12 +138,13 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
// opts are parsed 9p mount options.
type opts struct {
- fd int
- aname string
- policy cachePolicy
- msize uint32
- version string
- privateunixsocket bool
+ fd int
+ aname string
+ policy cachePolicy
+ msize uint32
+ version string
+ privateunixsocket bool
+ limitHostFDTranslation bool
}
// options parses mount(2) data into structured options.
@@ -237,6 +242,11 @@ func options(data string) (opts, error) {
delete(options, privateUnixSocketKey)
}
+ if _, ok := options[limitHostFDTranslationKey]; ok {
+ o.limitHostFDTranslation = true
+ delete(options, limitHostFDTranslationKey)
+ }
+
// Fail to attach if the caller wanted us to do something that we
// don't support.
if len(options) > 0 {
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 69d08a627..50da865c1 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -117,6 +117,11 @@ type session struct {
// Flags provided to the mount.
superBlockFlags fs.MountSourceFlags `state:"wait"`
+ // limitHostFDTranslation is the value used for
+ // CachingInodeOperationsOptions.LimitHostFDTranslation for all
+ // CachingInodeOperations created by the session.
+ limitHostFDTranslation bool
+
// connID is a unique identifier for the session connection.
connID string `state:"wait"`
@@ -218,8 +223,11 @@ func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p
uattr := unstable(ctx, valid, attr, s.mounter, s.client)
return sattr, &inodeOperations{
- fileState: fileState,
- cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, s.superBlockFlags.ForcePageCache),
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: s.superBlockFlags.ForcePageCache,
+ LimitHostFDTranslation: s.limitHostFDTranslation,
+ }),
}
}
@@ -242,13 +250,14 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
// Construct the session.
s := session{
- connID: dev,
- msize: o.msize,
- version: o.version,
- cachePolicy: o.policy,
- aname: o.aname,
- superBlockFlags: superBlockFlags,
- mounter: mounter,
+ connID: dev,
+ msize: o.msize,
+ version: o.version,
+ cachePolicy: o.policy,
+ aname: o.aname,
+ superBlockFlags: superBlockFlags,
+ limitHostFDTranslation: o.limitHostFDTranslation,
+ mounter: mounter,
}
s.EnableLeakCheck("gofer.session")
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 679d8321a..894ab01f0 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -200,8 +200,10 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool,
// Build the fs.InodeOperations.
uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s)
iops := &inodeOperations{
- fileState: fileState,
- cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, msrc.Flags.ForcePageCache),
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: msrc.Flags.ForcePageCache,
+ }),
}
// Return the fs.Inode.
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 44c4ee5f2..2392787cb 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -65,7 +65,7 @@ type ConnectedEndpoint struct {
// GetSockOpt and message splitting/rejection in SendMsg, but do not
// prevent lots of small messages from filling the real send buffer
// size on the host.
- sndbuf int `state:"nosave"`
+ sndbuf int64 `state:"nosave"`
// mu protects the fields below.
mu sync.RWMutex `state:"nosave"`
@@ -107,7 +107,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error {
}
c.stype = linux.SockType(stype)
- c.sndbuf = sndbuf
+ c.sndbuf = int64(sndbuf)
return nil
}
@@ -202,7 +202,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
}
// Send implements transport.ConnectedEndpoint.Send.
-func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) {
+func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -279,7 +279,7 @@ func (c *ConnectedEndpoint) EventUpdate() {
}
// Recv implements transport.Receiver.Recv.
-func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
index 05d7c79ad..af6955675 100644
--- a/pkg/sentry/fs/host/socket_iovec.go
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -55,19 +55,19 @@ func copyFromMulti(dst []byte, src [][]byte) {
//
// If intermediate != nil, iovecs references intermediate rather than bufs and
// the caller must copy to/from bufs as necessary.
-func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovecs []syscall.Iovec, intermediate []byte, err error) {
+func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovecs []syscall.Iovec, intermediate []byte, err error) {
var iovsRequired int
for _, b := range bufs {
- length += uintptr(len(b))
+ length += int64(len(b))
if len(b) > 0 {
iovsRequired++
}
}
stopLen := length
- if length > uintptr(maxlen) {
+ if length > maxlen {
if truncate {
- stopLen = uintptr(maxlen)
+ stopLen = maxlen
err = syserror.EAGAIN
} else {
return 0, nil, nil, syserror.EMSGSIZE
@@ -85,7 +85,7 @@ func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovec
}}, b, err
}
- var total uintptr
+ var total int64
iovecs = make([]syscall.Iovec, 0, iovsRequired)
for i := range bufs {
l := len(bufs[i])
@@ -93,9 +93,9 @@ func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovec
continue
}
- stop := l
- if total+uintptr(stop) > stopLen {
- stop = int(stopLen - total)
+ stop := int64(l)
+ if total+stop > stopLen {
+ stop = stopLen - total
}
iovecs = append(iovecs, syscall.Iovec{
@@ -103,7 +103,7 @@ func buildIovec(bufs [][]byte, maxlen int, truncate bool) (length uintptr, iovec
Len: uint64(stop),
})
- total += uintptr(stop)
+ total += stop
if total >= stopLen {
break
}
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index e57be0506..f3bbed7ea 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -23,7 +23,7 @@ import (
//
// If the total length of bufs is > maxlen, fdReadVec will do a partial read
// and err will indicate why the message was truncated.
-func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (readLen uintptr, msgLen uintptr, controlLen uint64, controlTrunc bool, err error) {
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (readLen int64, msgLen int64, controlLen uint64, controlTrunc bool, err error) {
flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
if peek {
flags |= syscall.MSG_PEEK
@@ -48,11 +48,12 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re
msg.Iovlen = uint64(len(iovecs))
}
- n, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
+ rawN, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
if e != 0 {
// N.B. prioritize the syscall error over the buildIovec error.
return 0, 0, 0, false, e
}
+ n := int64(rawN)
// Copy data back to bufs.
if intermediate != nil {
@@ -72,7 +73,7 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int) (re
//
// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
// partial write and err will indicate why the message was truncated.
-func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uintptr, error) {
+func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int64, error) {
length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
if err != nil && len(iovecs) == 0 {
// No partial write to do, return error immediately.
@@ -96,5 +97,5 @@ func fdWriteVec(fd int, bufs [][]byte, maxlen int, truncate bool) (uintptr, uint
return 0, length, e
}
- return n, length, err
+ return int64(n), length, err
}
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 693ffc760..ac0398bd9 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -171,8 +171,6 @@ type MountNamespace struct {
// NewMountNamespace returns a new MountNamespace, with the provided node at the
// root, and the given cache size. A root must always be provided.
func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error) {
- creds := auth.CredentialsFromContext(ctx)
-
// Set the root dirent and id on the root mount. The reference returned from
// NewDirent will be donated to the MountNamespace constructed below.
d := NewDirent(ctx, root, "/")
@@ -181,6 +179,7 @@ func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error
d: newRootMount(1, d),
}
+ creds := auth.CredentialsFromContext(ctx)
mns := MountNamespace{
userns: creds.UserNamespace,
root: d,
@@ -219,6 +218,13 @@ func (mns *MountNamespace) flushMountSourceRefsLocked() {
}
}
+ if mns.root == nil {
+ // No root? This MountSource must have already been destroyed.
+ // This can happen when a Save is triggered while a process is
+ // exiting. There is nothing to flush.
+ return
+ }
+
// Flush root's MountSource references.
mns.root.Inode.MountSource.FlushDirentRefs()
}
@@ -249,6 +255,10 @@ func (mns *MountNamespace) destroy() {
// Drop reference on the root.
mns.root.DecRef()
+ // Ensure that root cannot be accessed via this MountNamespace any
+ // more.
+ mns.root = nil
+
// Wait for asynchronous work (queued by dropping Dirent references
// above) to complete before destroying this MountNamespace.
AsyncBarrier()
@@ -678,7 +688,7 @@ func (mns *MountNamespace) ResolveExecutablePath(ctx context.Context, wd, name s
return "", syserror.ENOENT
}
-// GetPath returns the PATH as a slice of strings given the environemnt
+// GetPath returns the PATH as a slice of strings given the environment
// variables.
func GetPath(env []string) []string {
const prefix = "PATH="
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 6b839685b..9adb23608 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -348,7 +348,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
// Field: local_adddress.
var localAddr linux.SockAddrInet
if local, _, err := sops.GetSockName(t); err == nil {
- localAddr = local.(linux.SockAddrInet)
+ localAddr = *local.(*linux.SockAddrInet)
}
binary.LittleEndian.PutUint16(portBuf, localAddr.Port)
fmt.Fprintf(&buf, "%08X:%04X ",
@@ -358,7 +358,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
// Field: rem_address.
var remoteAddr linux.SockAddrInet
if remote, _, err := sops.GetPeerName(t); err == nil {
- remoteAddr = remote.(linux.SockAddrInet)
+ remoteAddr = *remote.(*linux.SockAddrInet)
}
binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port)
fmt.Fprintf(&buf, "%08X:%04X ",
diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go
index f3e984c24..78e082b8e 100644
--- a/pkg/sentry/fs/ramfs/dir.go
+++ b/pkg/sentry/fs/ramfs/dir.go
@@ -53,7 +53,6 @@ type Dir struct {
fsutil.InodeGenericChecker `state:"nosave"`
fsutil.InodeIsDirAllocate `state:"nosave"`
fsutil.InodeIsDirTruncate `state:"nosave"`
- fsutil.InodeNoopRelease `state:"nosave"`
fsutil.InodeNoopWriteOut `state:"nosave"`
fsutil.InodeNotMappable `state:"nosave"`
fsutil.InodeNotSocket `state:"nosave"`
@@ -84,7 +83,8 @@ type Dir struct {
var _ fs.InodeOperations = (*Dir)(nil)
-// NewDir returns a new Dir with the given contents and attributes.
+// NewDir returns a new Dir with the given contents and attributes. A reference
+// on each fs.Inode in the `contents` map will be donated to this Dir.
func NewDir(ctx context.Context, contents map[string]*fs.Inode, owner fs.FileOwner, perms fs.FilePermissions) *Dir {
d := &Dir{
InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, perms, linux.RAMFS_MAGIC),
@@ -138,7 +138,7 @@ func (d *Dir) addChildLocked(ctx context.Context, name string, inode *fs.Inode)
d.NotifyModificationAndStatusChange(ctx)
}
-// AddChild adds a child to this dir.
+// AddChild adds a child to this dir, inheriting its reference.
func (d *Dir) AddChild(ctx context.Context, name string, inode *fs.Inode) {
d.mu.Lock()
defer d.mu.Unlock()
@@ -172,7 +172,9 @@ func (d *Dir) Children() ([]string, map[string]fs.DentAttr) {
return namesCopy, entriesCopy
}
-// removeChildLocked attempts to remove an entry from this directory.
+// removeChildLocked attempts to remove an entry from this directory. It
+// returns the removed fs.Inode along with its reference, which callers are
+// responsible for decrementing.
func (d *Dir) removeChildLocked(ctx context.Context, name string) (*fs.Inode, error) {
inode, ok := d.children[name]
if !ok {
@@ -253,7 +255,8 @@ func (d *Dir) RemoveDirectory(ctx context.Context, _ *fs.Inode, name string) err
return nil
}
-// Lookup loads an inode at p into a Dirent.
+// Lookup loads an inode at p into a Dirent. It returns the fs.Dirent along
+// with a reference.
func (d *Dir) Lookup(ctx context.Context, _ *fs.Inode, p string) (*fs.Dirent, error) {
if len(p) > linux.NAME_MAX {
return nil, syserror.ENAMETOOLONG
@@ -408,6 +411,16 @@ func (*Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, ol
return Rename(ctx, oldParent.InodeOperations, oldName, newParent.InodeOperations, newName, replacement)
}
+// Release implements fs.InodeOperation.Release.
+func (d *Dir) Release(_ context.Context) {
+ // Drop references on all children.
+ d.mu.Lock()
+ for _, i := range d.children {
+ i.DecRef()
+ }
+ d.mu.Unlock()
+}
+
// dirFileOperations implements fs.FileOperations for a ramfs directory.
//
// +stateify savable
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 0f4497cd6..159fb7c08 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -56,7 +56,6 @@ func rename(ctx context.Context, oldParent *fs.Inode, oldName string, newParent
type Dir struct {
fsutil.InodeGenericChecker `state:"nosave"`
fsutil.InodeIsDirTruncate `state:"nosave"`
- fsutil.InodeNoopRelease `state:"nosave"`
fsutil.InodeNoopWriteOut `state:"nosave"`
fsutil.InodeNotMappable `state:"nosave"`
fsutil.InodeNotSocket `state:"nosave"`
@@ -252,6 +251,11 @@ func (d *Dir) Allocate(ctx context.Context, node *fs.Inode, offset, length int64
return d.ramfsDir.Allocate(ctx, node, offset, length)
}
+// Release implements fs.InodeOperations.Release.
+func (d *Dir) Release(ctx context.Context) {
+ d.ramfsDir.Release(ctx)
+}
+
// Symlink is a symlink.
//
// +stateify savable
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 5e9327aec..291164986 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -23,6 +23,7 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/safemem",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index 1d128532b..2f639c823 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -129,6 +129,9 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode {
// Release implements fs.InodeOperations.Release.
func (d *dirInodeOperations) Release(ctx context.Context) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
d.master.DecRef()
if len(d.slaves) != 0 {
panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d))
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index 92ec1ca18..19b7557d5 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -172,6 +172,19 @@ func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io userme
return 0, mf.t.ld.windowSize(ctx, io, args)
case linux.TIOCSWINSZ:
return 0, mf.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */)
default:
maybeEmitUnimplementedEvent(ctx, cmd)
return 0, syserror.ENOTTY
@@ -185,8 +198,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
linux.TCSETS,
linux.TCSETSW,
linux.TCSETSF,
- linux.TIOCGPGRP,
- linux.TIOCSPGRP,
linux.TIOCGWINSZ,
linux.TIOCSWINSZ,
linux.TIOCSETD,
@@ -200,8 +211,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
linux.TIOCEXCL,
linux.TIOCNXCL,
linux.TIOCGEXCL,
- linux.TIOCNOTTY,
- linux.TIOCSCTTY,
linux.TIOCGSID,
linux.TIOCGETD,
linux.TIOCVHANGUP,
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index e30266404..944c4ada1 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -152,9 +152,16 @@ func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem
case linux.TIOCSCTTY:
// Make the given terminal the controlling terminal of the
// calling process.
- // TODO(b/129283598): Implement once we have support for job
- // control.
- return 0, nil
+ return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */)
default:
maybeEmitUnimplementedEvent(ctx, cmd)
return 0, syserror.ENOTTY
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
index b7cecb2ed..ff8138820 100644
--- a/pkg/sentry/fs/tty/terminal.go
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -17,7 +17,10 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
// Terminal is a pseudoterminal.
@@ -26,23 +29,100 @@ import (
type Terminal struct {
refs.AtomicRefCount
- // n is the terminal index.
+ // n is the terminal index. It is immutable.
n uint32
- // d is the containing directory.
+ // d is the containing directory. It is immutable.
d *dirInodeOperations
- // ld is the line discipline of the terminal.
+ // ld is the line discipline of the terminal. It is immutable.
ld *lineDiscipline
+
+ // masterKTTY contains the controlling process of the master end of
+ // this terminal. This field is immutable.
+ masterKTTY *kernel.TTY
+
+ // slaveKTTY contains the controlling process of the slave end of this
+ // terminal. This field is immutable.
+ slaveKTTY *kernel.TTY
}
func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal {
termios := linux.DefaultSlaveTermios
t := Terminal{
- d: d,
- n: n,
- ld: newLineDiscipline(termios),
+ d: d,
+ n: n,
+ ld: newLineDiscipline(termios),
+ masterKTTY: &kernel.TTY{},
+ slaveKTTY: &kernel.TTY{},
}
t.EnableLeakCheck("tty.Terminal")
return &t
}
+
+// setControllingTTY makes tm the controlling terminal of the calling thread
+// group.
+func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int())
+}
+
+// releaseControllingTTY removes tm as the controlling terminal of the calling
+// thread group.
+func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("releaseControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster))
+}
+
+// foregroundProcessGroup gets the process group ID of tm's foreground process.
+func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("foregroundProcessGroup must be called from a task context")
+ }
+
+ ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster))
+ if err != nil {
+ return 0, err
+ }
+
+ // Write it out to *arg.
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// foregroundProcessGroup sets tm's foreground process.
+func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setForegroundProcessGroup must be called from a task context")
+ }
+
+ // Read in the process group ID.
+ var pgid int32
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid))
+ return uintptr(ret), err
+}
+
+func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
+ if isMaster {
+ return tm.masterKTTY
+ }
+ return tm.slaveKTTY
+}
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
new file mode 100644
index 000000000..a41101339
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -0,0 +1,86 @@
+package(licenses = ["notice"])
+
+load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+go_template_instance(
+ name = "dirent_list",
+ out = "dirent_list.go",
+ package = "ext",
+ prefix = "dirent",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*dirent",
+ "Linker": "*dirent",
+ },
+)
+
+go_library(
+ name = "ext",
+ srcs = [
+ "block_map_file.go",
+ "dentry.go",
+ "directory.go",
+ "dirent_list.go",
+ "ext.go",
+ "extent_file.go",
+ "file_description.go",
+ "filesystem.go",
+ "inode.go",
+ "regular_file.go",
+ "symlink.go",
+ "utils.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/fd",
+ "//pkg/log",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fsimpl/ext/disklayout",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/safemem",
+ "//pkg/sentry/syscalls/linux",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "ext_test",
+ size = "small",
+ srcs = [
+ "block_map_test.go",
+ "ext_test.go",
+ "extent_test.go",
+ ],
+ data = [
+ "//pkg/sentry/fsimpl/ext:assets/bigfile.txt",
+ "//pkg/sentry/fsimpl/ext:assets/file.txt",
+ "//pkg/sentry/fsimpl/ext:assets/tiny.ext2",
+ "//pkg/sentry/fsimpl/ext:assets/tiny.ext3",
+ "//pkg/sentry/fsimpl/ext:assets/tiny.ext4",
+ ],
+ embed = [":ext"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/context",
+ "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/fsimpl/ext/disklayout",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//runsc/test/testutil",
+ "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/ext/README.md b/pkg/sentry/fsimpl/ext/README.md
new file mode 100644
index 000000000..af00cfda8
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/README.md
@@ -0,0 +1,117 @@
+## EXT(2/3/4) File System
+
+This is a filesystem driver which supports ext2, ext3 and ext4 filesystems.
+Linux has specialized drivers for each variant but none which supports all. This
+library takes advantage of ext's backward compatibility and understands the
+internal organization of on-disk structures to support all variants.
+
+This driver implementation diverges from the Linux implementations in being more
+forgiving about versioning. For instance, if a filesystem contains both extent
+based inodes and classical block map based inodes, this driver will not complain
+and interpret them both correctly. While in Linux this would be an issue. This
+blurs the line between the three ext fs variants.
+
+Ext2 is considered deprecated as of Red Hat Enterprise Linux 7, and ext3 has
+been superseded by ext4 by large performance gains. Thus it is recommended to
+upgrade older filesystem images to ext4 using e2fsprogs for better performance.
+
+### Read Only
+
+This driver currently only allows read only operations. A lot of the design
+decisions are based on this feature. There are plans to implement write (the
+process for which is documented in the future work section).
+
+### Performance
+
+One of the biggest wins about this driver is that it directly talks to the
+underlying block device (or whatever persistent storage is being used), instead
+of making expensive RPCs to a gofer.
+
+Another advantage is that ext fs supports fast concurrent reads. Currently the
+device is represented using a `io.ReaderAt` which allows for concurrent reads.
+All reads are directly passed to the device driver which intelligently serves
+the read requests in the optimal order. There is no congestion due to locking
+while reading in the filesystem level.
+
+Reads are optimized further in the way file data is transferred over to user
+memory. Ext fs directly copies over file data from disk into user memory with no
+additional allocations on the way. We can only get faster by preloading file
+data into memory (see future work section).
+
+The internal structures used to represent files, inodes and file descriptors use
+a lot of inheritance. With the level of indirection that an interface adds with
+an internal pointer, it can quickly fragment a structure across memory. As this
+runs along side a full blown kernel (which is memory intensive), having a
+fragmented struct might hurt performance. Hence these internal structures,
+though interfaced, are tightly packed in memory using the same inheritance
+pattern that pkg/sentry/vfs uses. The pkg/sentry/fsimpl/ext/disklayout package
+makes an execption to this pattern for reasons documented in the package.
+
+### Security
+
+This driver also intends to help sandbox the container better by reducing the
+surface of the host kernel that the application touches. It prevents the
+application from exploiting vulnerabilities in the host filesystem driver. All
+`io.ReaderAt.ReadAt()` calls are translated to `pread(2)` which are directly
+passed to the device driver in the kernel. Hence this reduces the surface for
+attack.
+
+The application can not affect any host filesystems other than the one passed
+via block device by the user.
+
+### Future Work
+
+#### Write
+
+To support write operations we would need to modify the block device underneath.
+Currently, the driver does not modify the device at all, not even for updating
+the access times for reads. Modifying the filesystem incorrectly can corrupt it
+and render it unreadable for other correct ext(x) drivers. Hence caution must be
+maintained while modifying metadata structures.
+
+Ext4 specifically is built for performance and has added a lot of complexity as
+to how metadata structures are modified. For instance, files that are organized
+via an extent tree which must be balanced and file data blocks must be placed in
+the same extent as much as possible to increase locality. Such properties must
+be maintained while modifying the tree.
+
+Ext filesystems boast a lot about locality, which plays a big role in them being
+performant. The block allocation algorithm in Linux does a good job in keeping
+related data together. This behavior must be maintained as much as possible,
+else we might end up degrading the filesystem performance over time.
+
+Ext4 also supports a wide variety of features which are specialized for varying
+use cases. Implementing all of them can get difficult very quickly.
+
+Ext(x) checksums all its metadata structures to check for corruption, so
+modification of any metadata struct must correspond with re-checksumming the
+struct. Linux filesystem drivers also order on-disk updates intelligently to not
+corrupt the filesystem and also remain performant. The in-memory metadata
+structures must be kept in sync with what is on disk.
+
+There is also replication of some important structures across the filesystem.
+All replicas must be updated when their original copy is updated. There is also
+provisioning for snapshotting which must be kept in mind, although it should not
+affect this implementation unless we allow users to create filesystem snapshots.
+
+Ext4 also introduced journaling (jbd2). The journal must be updated
+appropriately.
+
+#### Performance
+
+To improve performance we should implement a buffer cache, and optionally, read
+ahead for small files. While doing so we must also keep in mind the memory usage
+and have a reasonable cap on how much file data we want to hold in memory.
+
+#### Features
+
+Our current implementation will work with most ext4 filesystems for readonly
+purposed. However, the following features are not supported yet:
+
+- Journal
+- Snapshotting
+- Extended Attributes
+- Hash Tree Directories
+- Meta Block Groups
+- Multiple Mount Protection
+- Bigalloc
diff --git a/pkg/sentry/fs/ext/assets/README.md b/pkg/sentry/fsimpl/ext/assets/README.md
index 6f1e81b3a..6f1e81b3a 100644
--- a/pkg/sentry/fs/ext/assets/README.md
+++ b/pkg/sentry/fsimpl/ext/assets/README.md
diff --git a/pkg/sentry/fs/ext/assets/bigfile.txt b/pkg/sentry/fsimpl/ext/assets/bigfile.txt
index 3857cf516..3857cf516 100644
--- a/pkg/sentry/fs/ext/assets/bigfile.txt
+++ b/pkg/sentry/fsimpl/ext/assets/bigfile.txt
diff --git a/pkg/sentry/fs/ext/assets/file.txt b/pkg/sentry/fsimpl/ext/assets/file.txt
index 980a0d5f1..980a0d5f1 100644
--- a/pkg/sentry/fs/ext/assets/file.txt
+++ b/pkg/sentry/fsimpl/ext/assets/file.txt
diff --git a/pkg/sentry/fs/ext/assets/symlink.txt b/pkg/sentry/fsimpl/ext/assets/symlink.txt
index 4c330738c..4c330738c 120000
--- a/pkg/sentry/fs/ext/assets/symlink.txt
+++ b/pkg/sentry/fsimpl/ext/assets/symlink.txt
diff --git a/pkg/sentry/fs/ext/assets/tiny.ext2 b/pkg/sentry/fsimpl/ext/assets/tiny.ext2
index 381ade9bf..381ade9bf 100644
--- a/pkg/sentry/fs/ext/assets/tiny.ext2
+++ b/pkg/sentry/fsimpl/ext/assets/tiny.ext2
Binary files differ
diff --git a/pkg/sentry/fs/ext/assets/tiny.ext3 b/pkg/sentry/fsimpl/ext/assets/tiny.ext3
index 0e97a324c..0e97a324c 100644
--- a/pkg/sentry/fs/ext/assets/tiny.ext3
+++ b/pkg/sentry/fsimpl/ext/assets/tiny.ext3
Binary files differ
diff --git a/pkg/sentry/fs/ext/assets/tiny.ext4 b/pkg/sentry/fsimpl/ext/assets/tiny.ext4
index a6859736d..a6859736d 100644
--- a/pkg/sentry/fs/ext/assets/tiny.ext4
+++ b/pkg/sentry/fsimpl/ext/assets/tiny.ext4
Binary files differ
diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD
new file mode 100644
index 000000000..9fddb4c4c
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD
@@ -0,0 +1,16 @@
+load("//tools/go_stateify:defs.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+go_test(
+ name = "benchmark_test",
+ size = "small",
+ srcs = ["benchmark_test.go"],
+ deps = [
+ "//pkg/sentry/context",
+ "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/fsimpl/ext",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
new file mode 100644
index 000000000..10a8083a0
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -0,0 +1,193 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// These benchmarks emulate memfs benchmarks. Ext4 images must be created
+// before this benchmark is run using the `make_deep_ext4.sh` script at
+// /tmp/image-{depth}.ext4 for all the depths tested below.
+package benchmark_test
+
+import (
+ "fmt"
+ "os"
+ "runtime"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+var depths = []int{1, 2, 3, 8, 64, 100}
+
+const filename = "file.txt"
+
+// setUp opens imagePath as an ext Filesystem and returns all necessary
+// elements required to run tests. If error is nil, it also returns a tear
+// down function which must be called after the test is run for clean up.
+func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesystem, *vfs.VirtualDentry, func(), error) {
+ f, err := os.Open(imagePath)
+ if err != nil {
+ return nil, nil, nil, nil, err
+ }
+
+ ctx := contexttest.Context(b)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())})
+ if err != nil {
+ f.Close()
+ return nil, nil, nil, nil, err
+ }
+
+ root := mntns.Root()
+
+ tearDown := func() {
+ root.DecRef()
+
+ if err := f.Close(); err != nil {
+ b.Fatalf("tearDown failed: %v", err)
+ }
+ }
+ return ctx, vfsObj, &root, tearDown, nil
+}
+
+// mount mounts extfs at the path operation passed. Returns a tear down
+// function which must be called after the test is run for clean up.
+func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vfs.PathOperation) func() {
+ b.Helper()
+
+ f, err := os.Open(imagePath)
+ if err != nil {
+ b.Fatalf("could not open image at %s: %v", imagePath, err)
+ }
+
+ ctx := contexttest.Context(b)
+ creds := auth.CredentialsFromContext(ctx)
+
+ if err := vfsfs.NewMount(ctx, creds, imagePath, pop, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}); err != nil {
+ b.Fatalf("failed to mount tmpfs submount: %v", err)
+ }
+ return func() {
+ if err := f.Close(); err != nil {
+ b.Fatalf("tearDown failed: %v", err)
+ }
+ }
+}
+
+// BenchmarkVFS2Ext4fsStat emulates BenchmarkVFS2MemfsStat.
+func BenchmarkVFS2Ext4fsStat(b *testing.B) {
+ for _, depth := range depths {
+ b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
+ ctx, vfsfs, root, tearDown, err := setUp(b, fmt.Sprintf("/tmp/image-%d.ext4", depth))
+ if err != nil {
+ b.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ creds := auth.CredentialsFromContext(ctx)
+ var filePathBuilder strings.Builder
+ filePathBuilder.WriteByte('/')
+ for i := 1; i <= depth; i++ {
+ filePathBuilder.WriteString(fmt.Sprintf("%d", i))
+ filePathBuilder.WriteByte('/')
+ }
+ filePathBuilder.WriteString(filename)
+ filePath := filePathBuilder.String()
+
+ runtime.GC()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: *root,
+ Start: *root,
+ Pathname: filePath,
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ b.Fatalf("stat(%q) failed: %v", filePath, err)
+ }
+ // Sanity check.
+ if stat.Size > 0 {
+ b.Fatalf("got wrong file size (%d)", stat.Size)
+ }
+ }
+ })
+ }
+}
+
+// BenchmarkVFS2ExtfsMountStat emulates BenchmarkVFS2MemfsMountStat.
+func BenchmarkVFS2ExtfsMountStat(b *testing.B) {
+ for _, depth := range depths {
+ b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
+ // Create root extfs with depth 1 so we can mount extfs again at /1/.
+ ctx, vfsfs, root, tearDown, err := setUp(b, fmt.Sprintf("/tmp/image-%d.ext4", 1))
+ if err != nil {
+ b.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ creds := auth.CredentialsFromContext(ctx)
+ mountPointName := "/1/"
+ pop := vfs.PathOperation{
+ Root: *root,
+ Start: *root,
+ Pathname: mountPointName,
+ }
+
+ // Save the mount point for later use.
+ mountPoint, err := vfsfs.GetDentryAt(ctx, creds, &pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ b.Fatalf("failed to walk to mount point: %v", err)
+ }
+ defer mountPoint.DecRef()
+
+ // Create extfs submount.
+ mountTearDown := mount(b, fmt.Sprintf("/tmp/image-%d.ext4", depth), vfsfs, &pop)
+ defer mountTearDown()
+
+ var filePathBuilder strings.Builder
+ filePathBuilder.WriteString(mountPointName)
+ for i := 1; i <= depth; i++ {
+ filePathBuilder.WriteString(fmt.Sprintf("%d", i))
+ filePathBuilder.WriteByte('/')
+ }
+ filePathBuilder.WriteString(filename)
+ filePath := filePathBuilder.String()
+
+ runtime.GC()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: *root,
+ Start: *root,
+ Pathname: filePath,
+ FollowFinalSymlink: true,
+ }, &vfs.StatOptions{})
+ if err != nil {
+ b.Fatalf("stat(%q) failed: %v", filePath, err)
+ }
+ // Sanity check. touch(1) always creates files of size 0 (empty).
+ if stat.Size > 0 {
+ b.Fatalf("got wrong file size (%d)", stat.Size)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh b/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh
new file mode 100755
index 000000000..d0910da1f
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/benchmark/make_deep_ext4.sh
@@ -0,0 +1,72 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script creates an ext4 image with $1 depth of directories and a file in
+# the inner most directory. The created file is at path /1/2/.../depth/file.txt.
+# The ext4 image is written to $2. The image is temporarily mounted at
+# /tmp/mountpoint. This script must be run with sudo privileges.
+
+# Usage:
+# sudo bash make_deep_ext4.sh {depth} {output path}
+
+# Check positional arguments.
+if [ "$#" -ne 2 ]; then
+ echo "Usage: sudo bash make_deep_ext4.sh {depth} {output path}"
+ exit 1
+fi
+
+# Make sure depth is a non-negative number.
+if ! [[ "$1" =~ ^[0-9]+$ ]]; then
+ echo "Depth must be a non-negative number."
+ exit 1
+fi
+
+# Create a 1 MB filesystem image at the requested output path.
+rm -f $2
+fallocate -l 1M $2
+if [ $? -ne 0 ]; then
+ echo "fallocate failed"
+ exit $?
+fi
+
+# Convert that blank into an ext4 image.
+mkfs.ext4 -j $2
+if [ $? -ne 0 ]; then
+ echo "mkfs.ext4 failed"
+ exit $?
+fi
+
+# Mount the image.
+MOUNTPOINT=/tmp/mountpoint
+mkdir -p $MOUNTPOINT
+mount -o loop $2 $MOUNTPOINT
+if [ $? -ne 0 ]; then
+ echo "mount failed"
+ exit $?
+fi
+
+# Create nested directories and the file.
+if [ "$1" -eq 0 ]; then
+ FILEPATH=$MOUNTPOINT/file.txt
+else
+ FILEPATH=$MOUNTPOINT/$(seq -s '/' 1 $1)/file.txt
+fi
+mkdir -p $(dirname $FILEPATH) || exit
+touch $FILEPATH
+
+# Clean up.
+umount $MOUNTPOINT
+rm -rf $MOUNTPOINT
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
new file mode 100644
index 000000000..cea89bcd9
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -0,0 +1,200 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "io"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ // numDirectBlks is the number of direct blocks in ext block map inodes.
+ numDirectBlks = 12
+)
+
+// blockMapFile is a type of regular file which uses direct/indirect block
+// addressing to store file data. This was deprecated in ext4.
+type blockMapFile struct {
+ regFile regularFile
+
+ // directBlks are the direct blocks numbers. The physical blocks pointed by
+ // these holds file data. Contains file blocks 0 to 11.
+ directBlks [numDirectBlks]uint32
+
+ // indirectBlk is the physical block which contains (blkSize/4) direct block
+ // numbers (as uint32 integers).
+ indirectBlk uint32
+
+ // doubleIndirectBlk is the physical block which contains (blkSize/4) indirect
+ // block numbers (as uint32 integers).
+ doubleIndirectBlk uint32
+
+ // tripleIndirectBlk is the physical block which contains (blkSize/4) doubly
+ // indirect block numbers (as uint32 integers).
+ tripleIndirectBlk uint32
+
+ // coverage at (i)th index indicates the amount of file data a node at
+ // height (i) covers. Height 0 is the direct block.
+ coverage [4]uint64
+}
+
+// Compiles only if blockMapFile implements io.ReaderAt.
+var _ io.ReaderAt = (*blockMapFile)(nil)
+
+// newBlockMapFile is the blockMapFile constructor. It initializes the file to
+// physical blocks map with (at most) the first 12 (direct) blocks.
+func newBlockMapFile(regFile regularFile) (*blockMapFile, error) {
+ file := &blockMapFile{regFile: regFile}
+ file.regFile.impl = file
+
+ for i := uint(0); i < 4; i++ {
+ file.coverage[i] = getCoverage(regFile.inode.blkSize, i)
+ }
+
+ blkMap := regFile.inode.diskInode.Data()
+ binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks)
+ binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk)
+ binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk)
+ binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk)
+ return file, nil
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) {
+ if len(dst) == 0 {
+ return 0, nil
+ }
+
+ if off < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ offset := uint64(off)
+ size := f.regFile.inode.diskInode.Size()
+ if offset >= size {
+ return 0, io.EOF
+ }
+
+ // dirBlksEnd is the file offset until which direct blocks cover file data.
+ // Direct blocks cover 0 <= file offset < dirBlksEnd.
+ dirBlksEnd := numDirectBlks * f.coverage[0]
+
+ // indirBlkEnd is the file offset until which the indirect block covers file
+ // data. The indirect block covers dirBlksEnd <= file offset < indirBlkEnd.
+ indirBlkEnd := dirBlksEnd + f.coverage[1]
+
+ // doubIndirBlkEnd is the file offset until which the double indirect block
+ // covers file data. The double indirect block covers the range
+ // indirBlkEnd <= file offset < doubIndirBlkEnd.
+ doubIndirBlkEnd := indirBlkEnd + f.coverage[2]
+
+ read := 0
+ toRead := len(dst)
+ if uint64(toRead)+offset > size {
+ toRead = int(size - offset)
+ }
+ for read < toRead {
+ var err error
+ var curR int
+
+ // Figure out which block to delegate the read to.
+ switch {
+ case offset < dirBlksEnd:
+ // Direct block.
+ curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:])
+ case offset < indirBlkEnd:
+ // Indirect block.
+ curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:])
+ case offset < doubIndirBlkEnd:
+ // Doubly indirect block.
+ curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:])
+ default:
+ // Triply indirect block.
+ curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:])
+ }
+
+ read += curR
+ offset += uint64(curR)
+ if err != nil {
+ return read, err
+ }
+ }
+
+ if read < len(dst) {
+ return read, io.EOF
+ }
+ return read, nil
+}
+
+// read is the recursive step of the ReadAt function. It relies on knowing the
+// current node's location on disk (curPhyBlk) and its height in the block map
+// tree. A height of 0 shows that the current node is actually holding file
+// data. relFileOff tells the offset from which we need to start to reading
+// under the current node. It is completely relative to the current node.
+func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, dst []byte) (int, error) {
+ curPhyBlkOff := int64(curPhyBlk) * int64(f.regFile.inode.blkSize)
+ if height == 0 {
+ toRead := int(f.regFile.inode.blkSize - relFileOff)
+ if len(dst) < toRead {
+ toRead = len(dst)
+ }
+
+ n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff))
+ if n < toRead {
+ return n, syserror.EIO
+ }
+ return n, nil
+ }
+
+ childCov := f.coverage[height-1]
+ startIdx := relFileOff / childCov
+ endIdx := f.regFile.inode.blkSize / 4 // This is exclusive.
+ wantEndIdx := (relFileOff + uint64(len(dst))) / childCov
+ wantEndIdx++ // Make this exclusive.
+ if wantEndIdx < endIdx {
+ endIdx = wantEndIdx
+ }
+
+ read := 0
+ curChildOff := relFileOff % childCov
+ for i := startIdx; i < endIdx; i++ {
+ var childPhyBlk uint32
+ err := readFromDisk(f.regFile.inode.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
+ if err != nil {
+ return read, err
+ }
+
+ n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:])
+ read += n
+ if err != nil {
+ return read, err
+ }
+
+ curChildOff = 0
+ }
+
+ return read, nil
+}
+
+// getCoverage returns the number of bytes a node at the given height covers.
+// Height 0 is the file data block itself. Height 1 is the indirect block.
+//
+// Formula: blkSize * ((blkSize / 4)^height)
+func getCoverage(blkSize uint64, height uint) uint64 {
+ return blkSize * uint64(math.Pow(float64(blkSize/4), float64(height)))
+}
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
new file mode 100644
index 000000000..213aa3919
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -0,0 +1,157 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "bytes"
+ "math/rand"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+)
+
+// These consts are for mocking the block map tree.
+const (
+ mockBMBlkSize = uint32(16)
+ mockBMDiskSize = 2500
+)
+
+// TestBlockMapReader stress tests block map reader functionality. It performs
+// random length reads from all possible positions in the block map structure.
+func TestBlockMapReader(t *testing.T) {
+ mockBMFile, want := blockMapSetUp(t)
+ n := len(want)
+
+ for from := 0; from < n; from++ {
+ got := make([]byte, n-from)
+
+ if read, err := mockBMFile.ReadAt(got, int64(from)); err != nil {
+ t.Fatalf("file read operation from offset %d to %d only read %d bytes: %v", from, n, read, err)
+ }
+
+ if diff := cmp.Diff(got, want[from:]); diff != "" {
+ t.Fatalf("file data from offset %d to %d mismatched (-want +got):\n%s", from, n, diff)
+ }
+ }
+}
+
+// blkNumGen is a number generator which gives block numbers for building the
+// block map file on disk. It gives unique numbers in a random order which
+// facilitates in creating an extremely fragmented filesystem.
+type blkNumGen struct {
+ nums []uint32
+}
+
+// newBlkNumGen is the blkNumGen constructor.
+func newBlkNumGen() *blkNumGen {
+ blkNums := &blkNumGen{}
+ lim := mockBMDiskSize / mockBMBlkSize
+ blkNums.nums = make([]uint32, lim)
+ for i := uint32(0); i < lim; i++ {
+ blkNums.nums[i] = i
+ }
+
+ rand.Shuffle(int(lim), func(i, j int) {
+ blkNums.nums[i], blkNums.nums[j] = blkNums.nums[j], blkNums.nums[i]
+ })
+ return blkNums
+}
+
+// next returns the next random block number.
+func (n *blkNumGen) next() uint32 {
+ ret := n.nums[0]
+ n.nums = n.nums[1:]
+ return ret
+}
+
+// blockMapSetUp creates a mock disk and a block map file. It initializes the
+// block map file with 12 direct block, 1 indirect block, 1 double indirect
+// block and 1 triple indirect block (basically fill it till the rim). It
+// initializes the disk to reflect the inode. Also returns the file data that
+// the inode covers and that is written to disk.
+func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
+ mockDisk := make([]byte, mockBMDiskSize)
+ regFile := regularFile{
+ inode: inode{
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: getMockBMFileFize(),
+ },
+ },
+ dev: bytes.NewReader(mockDisk),
+ blkSize: uint64(mockBMBlkSize),
+ },
+ }
+
+ var fileData []byte
+ blkNums := newBlkNumGen()
+ var data []byte
+
+ // Write the direct blocks.
+ for i := 0; i < numDirectBlks; i++ {
+ curBlkNum := blkNums.next()
+ data = binary.Marshal(data, binary.LittleEndian, curBlkNum)
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...)
+ }
+
+ // Write to indirect block.
+ indirectBlk := blkNums.next()
+ data = binary.Marshal(data, binary.LittleEndian, indirectBlk)
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...)
+
+ // Write to indirect block.
+ doublyIndirectBlk := blkNums.next()
+ data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk)
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...)
+
+ // Write to indirect block.
+ triplyIndirectBlk := blkNums.next()
+ data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk)
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...)
+
+ copy(regFile.inode.diskInode.Data(), data)
+
+ mockFile, err := newBlockMapFile(regFile)
+ if err != nil {
+ t.Fatalf("newBlockMapFile failed: %v", err)
+ }
+ return mockFile, fileData
+}
+
+// writeFileDataToBlock writes random bytes to the block on disk.
+func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkNumGen) []byte {
+ if height == 0 {
+ start := blkNum * mockBMBlkSize
+ end := start + mockBMBlkSize
+ rand.Read(disk[start:end])
+ return disk[start:end]
+ }
+
+ var fileData []byte
+ for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 {
+ curBlkNum := blkNums.next()
+ copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum))
+ fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...)
+ }
+ return fileData
+}
+
+// getMockBMFileFize gets the size of the mock block map file which is used for
+// testing.
+func getMockBMFileFize() uint32 {
+ return uint32(numDirectBlks*getCoverage(uint64(mockBMBlkSize), 0) + getCoverage(uint64(mockBMBlkSize), 1) + getCoverage(uint64(mockBMBlkSize), 2) + getCoverage(uint64(mockBMBlkSize), 3))
+}
diff --git a/pkg/sentry/fs/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go
index 054fb42b6..054fb42b6 100644
--- a/pkg/sentry/fs/ext/dentry.go
+++ b/pkg/sentry/fsimpl/ext/dentry.go
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
new file mode 100644
index 000000000..b51f3e18d
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -0,0 +1,308 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// directory represents a directory inode. It holds the childList in memory.
+type directory struct {
+ inode inode
+
+ // mu serializes the changes to childList.
+ // Lock Order (outermost locks must be taken first):
+ // directory.mu
+ // filesystem.mu
+ mu sync.Mutex
+
+ // childList is a list containing (1) child dirents and (2) fake dirents
+ // (with diskDirent == nil) that represent the iteration position of
+ // directoryFDs. childList is used to support directoryFD.IterDirents()
+ // efficiently. childList is protected by mu.
+ childList direntList
+
+ // childMap maps the child's filename to the dirent structure stored in
+ // childList. This adds some data replication but helps in faster path
+ // traversal. For consistency, key == childMap[key].diskDirent.FileName().
+ // Immutable.
+ childMap map[string]*dirent
+}
+
+// newDirectroy is the directory constructor.
+func newDirectroy(inode inode, newDirent bool) (*directory, error) {
+ file := &directory{inode: inode, childMap: make(map[string]*dirent)}
+ file.inode.impl = file
+
+ // Initialize childList by reading dirents from the underlying file.
+ if inode.diskInode.Flags().Index {
+ // TODO(b/134676337): Support hash tree directories. Currently only the '.'
+ // and '..' entries are read in.
+
+ // Users cannot navigate this hash tree directory yet.
+ log.Warningf("hash tree directory being used which is unsupported")
+ return file, nil
+ }
+
+ // The dirents are organized in a linear array in the file data.
+ // Extract the file data and decode the dirents.
+ regFile, err := newRegularFile(inode)
+ if err != nil {
+ return nil, err
+ }
+
+ // buf is used as scratch space for reading in dirents from disk and
+ // unmarshalling them into dirent structs.
+ buf := make([]byte, disklayout.DirentSize)
+ size := inode.diskInode.Size()
+ for off, inc := uint64(0), uint64(0); off < size; off += inc {
+ toRead := size - off
+ if toRead > disklayout.DirentSize {
+ toRead = disklayout.DirentSize
+ }
+ if n, err := regFile.impl.ReadAt(buf[:toRead], int64(off)); uint64(n) < toRead {
+ return nil, err
+ }
+
+ var curDirent dirent
+ if newDirent {
+ curDirent.diskDirent = &disklayout.DirentNew{}
+ } else {
+ curDirent.diskDirent = &disklayout.DirentOld{}
+ }
+ binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent)
+
+ if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 {
+ // Inode number and name length fields being set to 0 is used to indicate
+ // an unused dirent.
+ file.childList.PushBack(&curDirent)
+ file.childMap[curDirent.diskDirent.FileName()] = &curDirent
+ }
+
+ // The next dirent is placed exactly after this dirent record on disk.
+ inc = uint64(curDirent.diskDirent.RecordSize())
+ }
+
+ return file, nil
+}
+
+func (i *inode) isDir() bool {
+ _, ok := i.impl.(*directory)
+ return ok
+}
+
+// dirent is the directory.childList node.
+type dirent struct {
+ diskDirent disklayout.Dirent
+
+ // direntEntry links dirents into their parent directory.childList.
+ direntEntry
+}
+
+// directoryFD represents a directory file description. It implements
+// vfs.FileDescriptionImpl.
+type directoryFD struct {
+ fileDescription
+ vfs.DirectoryFileDescriptionDefaultImpl
+
+ // Protected by directory.mu.
+ iter *dirent
+ off int64
+}
+
+// Compiles only if directoryFD implements vfs.FileDescriptionImpl.
+var _ vfs.FileDescriptionImpl = (*directoryFD)(nil)
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *directoryFD) Release() {
+ if fd.iter == nil {
+ return
+ }
+
+ dir := fd.inode().impl.(*directory)
+ dir.mu.Lock()
+ dir.childList.Remove(fd.iter)
+ dir.mu.Unlock()
+ fd.iter = nil
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ extfs := fd.filesystem()
+ dir := fd.inode().impl.(*directory)
+
+ dir.mu.Lock()
+ defer dir.mu.Unlock()
+
+ // Ensure that fd.iter exists and is not linked into dir.childList.
+ var child *dirent
+ if fd.iter == nil {
+ // Start iteration at the beginning of dir.
+ child = dir.childList.Front()
+ fd.iter = &dirent{}
+ } else {
+ // Continue iteration from where we left off.
+ child = fd.iter.Next()
+ dir.childList.Remove(fd.iter)
+ }
+ for ; child != nil; child = child.Next() {
+ // Skip other directoryFD iterators.
+ if child.diskDirent != nil {
+ childType, ok := child.diskDirent.FileType()
+ if !ok {
+ // We will need to read the inode off disk. Do not increment
+ // ref count here because this inode is not being added to the
+ // dentry tree.
+ extfs.mu.Lock()
+ childInode, err := extfs.getOrCreateInodeLocked(child.diskDirent.Inode())
+ extfs.mu.Unlock()
+ if err != nil {
+ // Usage of the file description after the error is
+ // undefined. This implementation would continue reading
+ // from the next dirent.
+ fd.off++
+ dir.childList.InsertAfter(child, fd.iter)
+ return err
+ }
+ childType = fs.ToInodeType(childInode.diskInode.Mode().FileType())
+ }
+
+ if !cb.Handle(vfs.Dirent{
+ Name: child.diskDirent.FileName(),
+ Type: fs.ToDirentType(childType),
+ Ino: uint64(child.diskDirent.Inode()),
+ Off: fd.off,
+ }) {
+ dir.childList.InsertBefore(child, fd.iter)
+ return nil
+ }
+ fd.off++
+ }
+ }
+ dir.childList.PushBack(fd.iter)
+ return nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ if whence != linux.SEEK_SET && whence != linux.SEEK_CUR {
+ return 0, syserror.EINVAL
+ }
+
+ dir := fd.inode().impl.(*directory)
+
+ dir.mu.Lock()
+ defer dir.mu.Unlock()
+
+ // Find resulting offset.
+ if whence == linux.SEEK_CUR {
+ offset += fd.off
+ }
+
+ if offset < 0 {
+ // lseek(2) specifies that EINVAL should be returned if the resulting offset
+ // is negative.
+ return 0, syserror.EINVAL
+ }
+
+ n := int64(len(dir.childMap))
+ realWantOff := offset
+ if realWantOff > n {
+ realWantOff = n
+ }
+ realCurOff := fd.off
+ if realCurOff > n {
+ realCurOff = n
+ }
+
+ // Ensure that fd.iter exists and is linked into dir.childList so we can
+ // intelligently seek from the optimal position.
+ if fd.iter == nil {
+ fd.iter = &dirent{}
+ dir.childList.PushFront(fd.iter)
+ }
+
+ // Guess that iterating from the current position is optimal.
+ child := fd.iter
+ diff := realWantOff - realCurOff // Shows direction and magnitude of travel.
+
+ // See if starting from the beginning or end is better.
+ abDiff := diff
+ if diff < 0 {
+ abDiff = -diff
+ }
+ if abDiff > realWantOff {
+ // Starting from the beginning is best.
+ child = dir.childList.Front()
+ diff = realWantOff
+ } else if abDiff > (n - realWantOff) {
+ // Starting from the end is best.
+ child = dir.childList.Back()
+ // (n - 1) because the last non-nil dirent represents the (n-1)th offset.
+ diff = realWantOff - (n - 1)
+ }
+
+ for child != nil {
+ // Skip other directoryFD iterators.
+ if child.diskDirent != nil {
+ if diff == 0 {
+ if child != fd.iter {
+ dir.childList.Remove(fd.iter)
+ dir.childList.InsertBefore(child, fd.iter)
+ }
+
+ fd.off = offset
+ return offset, nil
+ }
+
+ if diff < 0 {
+ diff++
+ child = child.Prev()
+ } else {
+ diff--
+ child = child.Next()
+ }
+ continue
+ }
+
+ if diff < 0 {
+ child = child.Prev()
+ } else {
+ child = child.Next()
+ }
+ }
+
+ // Reaching here indicates that the offset is beyond the end of the childList.
+ dir.childList.Remove(fd.iter)
+ dir.childList.PushBack(fd.iter)
+ fd.off = offset
+ return offset, nil
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+ // mmap(2) specifies that EACCESS should be returned for non-regular file fds.
+ return syserror.EACCES
+}
diff --git a/pkg/sentry/fs/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index dde15110d..907d35b7e 100644
--- a/pkg/sentry/fs/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -22,7 +22,7 @@ go_library(
"superblock_old.go",
"test_utils.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout",
+ importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/fs/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
index ad6f4fef8..ad6f4fef8 100644
--- a/pkg/sentry/fs/ext/disklayout/block_group.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
diff --git a/pkg/sentry/fs/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
index 3e16c76db..3e16c76db 100644
--- a/pkg/sentry/fs/ext/disklayout/block_group_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
diff --git a/pkg/sentry/fs/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
index 9a809197a..9a809197a 100644
--- a/pkg/sentry/fs/ext/disklayout/block_group_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
diff --git a/pkg/sentry/fs/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
index 0ef4294c0..0ef4294c0 100644
--- a/pkg/sentry/fs/ext/disklayout/block_group_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
diff --git a/pkg/sentry/fs/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
index 685bf57b8..417b6cf65 100644
--- a/pkg/sentry/fs/ext/disklayout/dirent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
@@ -21,6 +21,9 @@ import (
const (
// MaxFileName is the maximum length of an ext fs file's name.
MaxFileName = 255
+
+ // DirentSize is the size of ext dirent structures.
+ DirentSize = 263
)
var (
diff --git a/pkg/sentry/fs/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
index 29ae4a5c2..29ae4a5c2 100644
--- a/pkg/sentry/fs/ext/disklayout/dirent_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
diff --git a/pkg/sentry/fs/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
index 6fff12a6e..6fff12a6e 100644
--- a/pkg/sentry/fs/ext/disklayout/dirent_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
diff --git a/pkg/sentry/fs/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
index cc6dff2c9..934919f8a 100644
--- a/pkg/sentry/fs/ext/disklayout/dirent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
@@ -21,8 +21,6 @@ import (
// TestDirentSize tests that the dirent structs are of the correct
// size.
func TestDirentSize(t *testing.T) {
- want := uintptr(263)
-
- assertSize(t, DirentOld{}, want)
- assertSize(t, DirentNew{}, want)
+ assertSize(t, DirentOld{}, uintptr(DirentSize))
+ assertSize(t, DirentNew{}, uintptr(DirentSize))
}
diff --git a/pkg/sentry/fs/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
index bdf4e2132..bdf4e2132 100644
--- a/pkg/sentry/fs/ext/disklayout/disklayout.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
diff --git a/pkg/sentry/fs/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
index 567523d32..567523d32 100644
--- a/pkg/sentry/fs/ext/disklayout/extent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
diff --git a/pkg/sentry/fs/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
index b0fad9b71..b0fad9b71 100644
--- a/pkg/sentry/fs/ext/disklayout/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
diff --git a/pkg/sentry/fs/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go
index 88ae913f5..88ae913f5 100644
--- a/pkg/sentry/fs/ext/disklayout/inode.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go
diff --git a/pkg/sentry/fs/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
index 8f9f574ce..8f9f574ce 100644
--- a/pkg/sentry/fs/ext/disklayout/inode_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
diff --git a/pkg/sentry/fs/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
index db25b11b6..db25b11b6 100644
--- a/pkg/sentry/fs/ext/disklayout/inode_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
diff --git a/pkg/sentry/fs/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
index dd03ee50e..dd03ee50e 100644
--- a/pkg/sentry/fs/ext/disklayout/inode_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
diff --git a/pkg/sentry/fs/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
index 7a337a5e0..8bb327006 100644
--- a/pkg/sentry/fs/ext/disklayout/superblock.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
@@ -221,7 +221,7 @@ func CompatFeaturesFromInt(f uint32) CompatFeatures {
// This is not exhaustive, unused features are not listed.
const (
// SbDirentFileType indicates that directory entries record the file type.
- // We should use struct ext4_dir_entry_2 for dirents then.
+ // We should use struct DirentNew for dirents then.
SbDirentFileType = 0x2
// SbRecovery indicates that the filesystem needs recovery.
diff --git a/pkg/sentry/fs/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
index 53e515fd3..53e515fd3 100644
--- a/pkg/sentry/fs/ext/disklayout/superblock_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
diff --git a/pkg/sentry/fs/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
index 7c1053fb4..7c1053fb4 100644
--- a/pkg/sentry/fs/ext/disklayout/superblock_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
diff --git a/pkg/sentry/fs/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
index 9221e0251..9221e0251 100644
--- a/pkg/sentry/fs/ext/disklayout/superblock_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
diff --git a/pkg/sentry/fs/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
index 463b5ba21..463b5ba21 100644
--- a/pkg/sentry/fs/ext/disklayout/superblock_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
diff --git a/pkg/sentry/fs/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
index 9c63f04c0..9c63f04c0 100644
--- a/pkg/sentry/fs/ext/disklayout/test_utils.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
new file mode 100644
index 000000000..f10accafc
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -0,0 +1,135 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ext implements readonly ext(2/3/4) filesystems.
+package ext
+
+import (
+ "errors"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// Compiles only if FilesystemType implements vfs.FilesystemType.
+var _ vfs.FilesystemType = (*FilesystemType)(nil)
+
+// getDeviceFd returns an io.ReaderAt to the underlying device.
+// Currently there are two ways of mounting an ext(2/3/4) fs:
+// 1. Specify a mount with our internal special MountType in the OCI spec.
+// 2. Expose the device to the container and mount it from application layer.
+func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReaderAt, error) {
+ if opts.InternalData == nil {
+ // User mount call.
+ // TODO(b/134676337): Open the device specified by `source` and return that.
+ panic("unimplemented")
+ }
+
+ // NewFilesystem call originated from within the sentry.
+ devFd, ok := opts.InternalData.(int)
+ if !ok {
+ return nil, errors.New("internal data for ext fs must be an int containing the file descriptor to device")
+ }
+
+ if devFd < 0 {
+ return nil, fmt.Errorf("ext device file descriptor is not valid: %d", devFd)
+ }
+
+ // The fd.ReadWriter returned from fd.NewReadWriter() does not take ownership
+ // of the file descriptor and hence will not close it when it is garbage
+ // collected.
+ return fd.NewReadWriter(devFd), nil
+}
+
+// isCompatible checks if the superblock has feature sets which are compatible.
+// We only need to check the superblock incompatible feature set since we are
+// mounting readonly. We will also need to check readonly compatible feature
+// set when mounting for read/write.
+func isCompatible(sb disklayout.SuperBlock) bool {
+ // Please note that what is being checked is limited based on the fact that we
+ // are mounting readonly and that we are not journaling. When mounting
+ // read/write or with a journal, this must be reevaluated.
+ incompatFeatures := sb.IncompatibleFeatures()
+ if incompatFeatures.MetaBG {
+ log.Warningf("ext fs: meta block groups are not supported")
+ return false
+ }
+ if incompatFeatures.MMP {
+ log.Warningf("ext fs: multiple mount protection is not supported")
+ return false
+ }
+ if incompatFeatures.Encrypted {
+ log.Warningf("ext fs: encrypted inodes not supported")
+ return false
+ }
+ if incompatFeatures.InlineData {
+ log.Warningf("ext fs: inline files not supported")
+ return false
+ }
+ return true
+}
+
+// NewFilesystem implements vfs.FilesystemType.NewFilesystem.
+func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ // TODO(b/134676337): Ensure that the user is mounting readonly. If not,
+ // EACCESS should be returned according to mount(2). Filesystem independent
+ // flags (like readonly) are currently not available in pkg/sentry/vfs.
+
+ dev, err := getDeviceFd(source, opts)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)}
+ fs.vfsfs.Init(&fs)
+ fs.sb, err = readSuperBlock(dev)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if fs.sb.Magic() != linux.EXT_SUPER_MAGIC {
+ // mount(2) specifies that EINVAL should be returned if the superblock is
+ // invalid.
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Refuse to mount if the filesystem is incompatible.
+ if !isCompatible(fs.sb) {
+ return nil, nil, syserror.EINVAL
+ }
+
+ fs.bgs, err = readBlockGroups(dev, fs.sb)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ rootInode, err := fs.getOrCreateInodeLocked(disklayout.RootDirInode)
+ if err != nil {
+ return nil, nil, err
+ }
+ rootInode.incRef()
+
+ return &fs.vfsfs, &newDentry(rootInode).vfsd, nil
+}
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
new file mode 100644
index 000000000..49b57a2d6
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -0,0 +1,917 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "path"
+ "sort"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+
+ "gvisor.dev/gvisor/runsc/test/testutil"
+)
+
+const (
+ assetsDir = "pkg/sentry/fsimpl/ext/assets"
+)
+
+var (
+ ext2ImagePath = path.Join(assetsDir, "tiny.ext2")
+ ext3ImagePath = path.Join(assetsDir, "tiny.ext3")
+ ext4ImagePath = path.Join(assetsDir, "tiny.ext4")
+)
+
+// setUp opens imagePath as an ext Filesystem and returns all necessary
+// elements required to run tests. If error is non-nil, it also returns a tear
+// down function which must be called after the test is run for clean up.
+func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesystem, *vfs.VirtualDentry, func(), error) {
+ localImagePath, err := testutil.FindFile(imagePath)
+ if err != nil {
+ return nil, nil, nil, nil, fmt.Errorf("failed to open local image at path %s: %v", imagePath, err)
+ }
+
+ f, err := os.Open(localImagePath)
+ if err != nil {
+ return nil, nil, nil, nil, err
+ }
+
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Create VFS.
+ vfsObj := vfs.New()
+ vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())})
+ if err != nil {
+ f.Close()
+ return nil, nil, nil, nil, err
+ }
+
+ root := mntns.Root()
+
+ tearDown := func() {
+ root.DecRef()
+
+ if err := f.Close(); err != nil {
+ t.Fatalf("tearDown failed: %v", err)
+ }
+ }
+ return ctx, vfsObj, &root, tearDown, nil
+}
+
+// TODO(b/134676337): Test vfs.FilesystemImpl.ReadlinkAt and
+// vfs.FilesystemImpl.StatFSAt which are not implemented in
+// vfs.VirtualFilesystem yet.
+
+// TestSeek tests vfs.FileDescriptionImpl.Seek functionality.
+func TestSeek(t *testing.T) {
+ type seekTest struct {
+ name string
+ image string
+ path string
+ }
+
+ tests := []seekTest{
+ {
+ name: "ext4 root dir seek",
+ image: ext4ImagePath,
+ path: "/",
+ },
+ {
+ name: "ext3 root dir seek",
+ image: ext3ImagePath,
+ path: "/",
+ },
+ {
+ name: "ext2 root dir seek",
+ image: ext2ImagePath,
+ path: "/",
+ },
+ {
+ name: "ext4 reg file seek",
+ image: ext4ImagePath,
+ path: "/file.txt",
+ },
+ {
+ name: "ext3 reg file seek",
+ image: ext3ImagePath,
+ path: "/file.txt",
+ },
+ {
+ name: "ext2 reg file seek",
+ image: ext2ImagePath,
+ path: "/file.txt",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ctx, vfsfs, root, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ fd, err := vfsfs.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt failed: %v", err)
+ }
+
+ if n, err := fd.Impl().Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil {
+ t.Errorf("expected seek position 0, got %d and error %v", n, err)
+ }
+
+ stat, err := fd.Impl().Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err)
+ }
+
+ // We should be able to seek beyond the end of file.
+ size := int64(stat.Size)
+ if n, err := fd.Impl().Seek(ctx, size, linux.SEEK_SET); n != size || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", size, n, err)
+ }
+
+ // EINVAL should be returned if the resulting offset is negative.
+ if _, err := fd.Impl().Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL {
+ t.Errorf("expected error EINVAL but got %v", err)
+ }
+
+ if n, err := fd.Impl().Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err)
+ }
+
+ // Make sure negative offsets work with SEEK_CUR.
+ if n, err := fd.Impl().Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
+ }
+
+ // EINVAL should be returned if the resulting offset is negative.
+ if _, err := fd.Impl().Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL {
+ t.Errorf("expected error EINVAL but got %v", err)
+ }
+
+ // Make sure SEEK_END works with regular files.
+ switch fd.Impl().(type) {
+ case *regularFileFD:
+ // Seek back to 0.
+ if n, err := fd.Impl().Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", 0, n, err)
+ }
+
+ // Seek forward beyond EOF.
+ if n, err := fd.Impl().Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil {
+ t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
+ }
+
+ // EINVAL should be returned if the resulting offset is negative.
+ if _, err := fd.Impl().Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL {
+ t.Errorf("expected error EINVAL but got %v", err)
+ }
+ }
+ })
+ }
+}
+
+// TestStatAt tests filesystem.StatAt functionality.
+func TestStatAt(t *testing.T) {
+ type statAtTest struct {
+ name string
+ image string
+ path string
+ want linux.Statx
+ }
+
+ tests := []statAtTest{
+ {
+ name: "ext4 statx small file",
+ image: ext4ImagePath,
+ path: "/file.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13,
+ },
+ },
+ {
+ name: "ext3 statx small file",
+ image: ext3ImagePath,
+ path: "/file.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13,
+ },
+ },
+ {
+ name: "ext2 statx small file",
+ image: ext2ImagePath,
+ path: "/file.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13,
+ },
+ },
+ {
+ name: "ext4 statx big file",
+ image: ext4ImagePath,
+ path: "/bigfile.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13042,
+ },
+ },
+ {
+ name: "ext3 statx big file",
+ image: ext3ImagePath,
+ path: "/bigfile.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13042,
+ },
+ },
+ {
+ name: "ext2 statx big file",
+ image: ext2ImagePath,
+ path: "/bigfile.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0644 | linux.ModeRegular,
+ Size: 13042,
+ },
+ },
+ {
+ name: "ext4 statx symlink file",
+ image: ext4ImagePath,
+ path: "/symlink.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0777 | linux.ModeSymlink,
+ Size: 8,
+ },
+ },
+ {
+ name: "ext3 statx symlink file",
+ image: ext3ImagePath,
+ path: "/symlink.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0777 | linux.ModeSymlink,
+ Size: 8,
+ },
+ },
+ {
+ name: "ext2 statx symlink file",
+ image: ext2ImagePath,
+ path: "/symlink.txt",
+ want: linux.Statx{
+ Blksize: 0x400,
+ Nlink: 1,
+ UID: 0,
+ GID: 0,
+ Mode: 0777 | linux.ModeSymlink,
+ Size: 8,
+ },
+ },
+ }
+
+ // Ignore the fields that are not supported by filesystem.StatAt yet and
+ // those which are likely to change as the image does.
+ ignoredFields := map[string]bool{
+ "Attributes": true,
+ "AttributesMask": true,
+ "Atime": true,
+ "Blocks": true,
+ "Btime": true,
+ "Ctime": true,
+ "DevMajor": true,
+ "DevMinor": true,
+ "Ino": true,
+ "Mask": true,
+ "Mtime": true,
+ "RdevMajor": true,
+ "RdevMinor": true,
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ctx, vfsfs, root, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ got, err := vfsfs.StatAt(ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path},
+ &vfs.StatOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.StatAt failed for file %s in image %s: %v", test.path, test.image, err)
+ }
+
+ cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
+ _, ok := ignoredFields[p.String()]
+ return ok
+ }, cmp.Ignore())
+ if diff := cmp.Diff(got, test.want, cmpIgnoreFields, cmpopts.IgnoreUnexported(linux.Statx{})); diff != "" {
+ t.Errorf("stat mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestRead tests the read functionality for vfs file descriptions.
+func TestRead(t *testing.T) {
+ type readTest struct {
+ name string
+ image string
+ absPath string
+ }
+
+ tests := []readTest{
+ {
+ name: "ext4 read small file",
+ image: ext4ImagePath,
+ absPath: "/file.txt",
+ },
+ {
+ name: "ext3 read small file",
+ image: ext3ImagePath,
+ absPath: "/file.txt",
+ },
+ {
+ name: "ext2 read small file",
+ image: ext2ImagePath,
+ absPath: "/file.txt",
+ },
+ {
+ name: "ext4 read big file",
+ image: ext4ImagePath,
+ absPath: "/bigfile.txt",
+ },
+ {
+ name: "ext3 read big file",
+ image: ext3ImagePath,
+ absPath: "/bigfile.txt",
+ },
+ {
+ name: "ext2 read big file",
+ image: ext2ImagePath,
+ absPath: "/bigfile.txt",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ctx, vfsfs, root, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ fd, err := vfsfs.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.absPath},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt failed: %v", err)
+ }
+
+ // Get a local file descriptor and compare its functionality with a vfs file
+ // description for the same file.
+ localFile, err := testutil.FindFile(path.Join(assetsDir, test.absPath))
+ if err != nil {
+ t.Fatalf("testutil.FindFile failed for %s: %v", test.absPath, err)
+ }
+
+ f, err := os.Open(localFile)
+ if err != nil {
+ t.Fatalf("os.Open failed for %s: %v", localFile, err)
+ }
+ defer f.Close()
+
+ // Read the entire file by reading one byte repeatedly. Doing this stress
+ // tests the underlying file reader implementation.
+ got := make([]byte, 1)
+ want := make([]byte, 1)
+ for {
+ n, err := f.Read(want)
+ fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{})
+
+ if diff := cmp.Diff(got, want); diff != "" {
+ t.Errorf("file data mismatch (-want +got):\n%s", diff)
+ }
+
+ // Make sure there is no more file data left after getting EOF.
+ if n == 0 || err == io.EOF {
+ if n, _ := fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 {
+ t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image)
+ }
+
+ break
+ }
+
+ if err != nil {
+ t.Fatalf("read failed: %v", err)
+ }
+ }
+ })
+ }
+}
+
+// iterDirentsCb is a simple callback which just keeps adding the dirents to an
+// internal list. Implements vfs.IterDirentsCallback.
+type iterDirentsCb struct {
+ dirents []vfs.Dirent
+}
+
+// Compiles only if iterDirentCb implements vfs.IterDirentsCallback.
+var _ vfs.IterDirentsCallback = (*iterDirentsCb)(nil)
+
+// newIterDirentsCb is the iterDirent
+func newIterDirentCb() *iterDirentsCb {
+ return &iterDirentsCb{dirents: make([]vfs.Dirent, 0)}
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) bool {
+ cb.dirents = append(cb.dirents, dirent)
+ return true
+}
+
+// TestIterDirents tests the FileDescriptionImpl.IterDirents functionality.
+func TestIterDirents(t *testing.T) {
+ type iterDirentTest struct {
+ name string
+ image string
+ path string
+ want []vfs.Dirent
+ }
+
+ wantDirents := []vfs.Dirent{
+ vfs.Dirent{
+ Name: ".",
+ Type: linux.DT_DIR,
+ },
+ vfs.Dirent{
+ Name: "..",
+ Type: linux.DT_DIR,
+ },
+ vfs.Dirent{
+ Name: "lost+found",
+ Type: linux.DT_DIR,
+ },
+ vfs.Dirent{
+ Name: "file.txt",
+ Type: linux.DT_REG,
+ },
+ vfs.Dirent{
+ Name: "bigfile.txt",
+ Type: linux.DT_REG,
+ },
+ vfs.Dirent{
+ Name: "symlink.txt",
+ Type: linux.DT_LNK,
+ },
+ }
+ tests := []iterDirentTest{
+ {
+ name: "ext4 root dir iteration",
+ image: ext4ImagePath,
+ path: "/",
+ want: wantDirents,
+ },
+ {
+ name: "ext3 root dir iteration",
+ image: ext3ImagePath,
+ path: "/",
+ want: wantDirents,
+ },
+ {
+ name: "ext2 root dir iteration",
+ image: ext2ImagePath,
+ path: "/",
+ want: wantDirents,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ctx, vfsfs, root, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ fd, err := vfsfs.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt failed: %v", err)
+ }
+
+ cb := &iterDirentsCb{}
+ if err = fd.Impl().IterDirents(ctx, cb); err != nil {
+ t.Fatalf("dir fd.IterDirents() failed: %v", err)
+ }
+
+ sort.Slice(cb.dirents, func(i int, j int) bool { return cb.dirents[i].Name < cb.dirents[j].Name })
+ sort.Slice(test.want, func(i int, j int) bool { return test.want[i].Name < test.want[j].Name })
+
+ // Ignore the inode number and offset of dirents because those are likely to
+ // change as the underlying image changes.
+ cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
+ return p.String() == "Ino" || p.String() == "Off"
+ }, cmp.Ignore())
+ if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" {
+ t.Errorf("dirents mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestRootDir tests that the root directory inode is correctly initialized and
+// returned from setUp.
+func TestRootDir(t *testing.T) {
+ type inodeProps struct {
+ Mode linux.FileMode
+ UID auth.KUID
+ GID auth.KGID
+ Size uint64
+ InodeSize uint16
+ Links uint16
+ Flags disklayout.InodeFlags
+ }
+
+ type rootDirTest struct {
+ name string
+ image string
+ wantInode inodeProps
+ }
+
+ tests := []rootDirTest{
+ {
+ name: "ext4 root dir",
+ image: ext4ImagePath,
+ wantInode: inodeProps{
+ Mode: linux.ModeDirectory | 0755,
+ Size: 0x400,
+ InodeSize: 0x80,
+ Links: 3,
+ Flags: disklayout.InodeFlags{Extents: true},
+ },
+ },
+ {
+ name: "ext3 root dir",
+ image: ext3ImagePath,
+ wantInode: inodeProps{
+ Mode: linux.ModeDirectory | 0755,
+ Size: 0x400,
+ InodeSize: 0x80,
+ Links: 3,
+ },
+ },
+ {
+ name: "ext2 root dir",
+ image: ext2ImagePath,
+ wantInode: inodeProps{
+ Mode: linux.ModeDirectory | 0755,
+ Size: 0x400,
+ InodeSize: 0x80,
+ Links: 3,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ _, _, vd, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ d, ok := vd.Dentry().Impl().(*dentry)
+ if !ok {
+ t.Fatalf("ext dentry of incorrect type: %T", vd.Dentry().Impl())
+ }
+
+ // Offload inode contents into local structs for comparison.
+ gotInode := inodeProps{
+ Mode: d.inode.diskInode.Mode(),
+ UID: d.inode.diskInode.UID(),
+ GID: d.inode.diskInode.GID(),
+ Size: d.inode.diskInode.Size(),
+ InodeSize: d.inode.diskInode.InodeSize(),
+ Links: d.inode.diskInode.LinksCount(),
+ Flags: d.inode.diskInode.Flags(),
+ }
+
+ if diff := cmp.Diff(gotInode, test.wantInode); diff != "" {
+ t.Errorf("inode mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+// TestFilesystemInit tests that the filesystem superblock and block group
+// descriptors are correctly read in and initialized.
+func TestFilesystemInit(t *testing.T) {
+ // sb only contains the immutable properties of the superblock.
+ type sb struct {
+ InodesCount uint32
+ BlocksCount uint64
+ MaxMountCount uint16
+ FirstDataBlock uint32
+ BlockSize uint64
+ BlocksPerGroup uint32
+ ClusterSize uint64
+ ClustersPerGroup uint32
+ InodeSize uint16
+ InodesPerGroup uint32
+ BgDescSize uint16
+ Magic uint16
+ Revision disklayout.SbRevision
+ CompatFeatures disklayout.CompatFeatures
+ IncompatFeatures disklayout.IncompatFeatures
+ RoCompatFeatures disklayout.RoCompatFeatures
+ }
+
+ // bg only contains the immutable properties of the block group descriptor.
+ type bg struct {
+ InodeTable uint64
+ BlockBitmap uint64
+ InodeBitmap uint64
+ ExclusionBitmap uint64
+ Flags disklayout.BGFlags
+ }
+
+ type fsInitTest struct {
+ name string
+ image string
+ wantSb sb
+ wantBgs []bg
+ }
+
+ tests := []fsInitTest{
+ {
+ name: "ext4 filesystem init",
+ image: ext4ImagePath,
+ wantSb: sb{
+ InodesCount: 0x10,
+ BlocksCount: 0x40,
+ MaxMountCount: 0xffff,
+ FirstDataBlock: 0x1,
+ BlockSize: 0x400,
+ BlocksPerGroup: 0x2000,
+ ClusterSize: 0x400,
+ ClustersPerGroup: 0x2000,
+ InodeSize: 0x80,
+ InodesPerGroup: 0x10,
+ BgDescSize: 0x40,
+ Magic: linux.EXT_SUPER_MAGIC,
+ Revision: disklayout.DynamicRev,
+ CompatFeatures: disklayout.CompatFeatures{
+ ExtAttr: true,
+ ResizeInode: true,
+ DirIndex: true,
+ },
+ IncompatFeatures: disklayout.IncompatFeatures{
+ DirentFileType: true,
+ Extents: true,
+ Is64Bit: true,
+ FlexBg: true,
+ },
+ RoCompatFeatures: disklayout.RoCompatFeatures{
+ Sparse: true,
+ LargeFile: true,
+ HugeFile: true,
+ DirNlink: true,
+ ExtraIsize: true,
+ MetadataCsum: true,
+ },
+ },
+ wantBgs: []bg{
+ {
+ InodeTable: 0x23,
+ BlockBitmap: 0x3,
+ InodeBitmap: 0x13,
+ Flags: disklayout.BGFlags{
+ InodeZeroed: true,
+ },
+ },
+ },
+ },
+ {
+ name: "ext3 filesystem init",
+ image: ext3ImagePath,
+ wantSb: sb{
+ InodesCount: 0x10,
+ BlocksCount: 0x40,
+ MaxMountCount: 0xffff,
+ FirstDataBlock: 0x1,
+ BlockSize: 0x400,
+ BlocksPerGroup: 0x2000,
+ ClusterSize: 0x400,
+ ClustersPerGroup: 0x2000,
+ InodeSize: 0x80,
+ InodesPerGroup: 0x10,
+ BgDescSize: 0x20,
+ Magic: linux.EXT_SUPER_MAGIC,
+ Revision: disklayout.DynamicRev,
+ CompatFeatures: disklayout.CompatFeatures{
+ ExtAttr: true,
+ ResizeInode: true,
+ DirIndex: true,
+ },
+ IncompatFeatures: disklayout.IncompatFeatures{
+ DirentFileType: true,
+ },
+ RoCompatFeatures: disklayout.RoCompatFeatures{
+ Sparse: true,
+ LargeFile: true,
+ },
+ },
+ wantBgs: []bg{
+ {
+ InodeTable: 0x5,
+ BlockBitmap: 0x3,
+ InodeBitmap: 0x4,
+ Flags: disklayout.BGFlags{
+ InodeZeroed: true,
+ },
+ },
+ },
+ },
+ {
+ name: "ext2 filesystem init",
+ image: ext2ImagePath,
+ wantSb: sb{
+ InodesCount: 0x10,
+ BlocksCount: 0x40,
+ MaxMountCount: 0xffff,
+ FirstDataBlock: 0x1,
+ BlockSize: 0x400,
+ BlocksPerGroup: 0x2000,
+ ClusterSize: 0x400,
+ ClustersPerGroup: 0x2000,
+ InodeSize: 0x80,
+ InodesPerGroup: 0x10,
+ BgDescSize: 0x20,
+ Magic: linux.EXT_SUPER_MAGIC,
+ Revision: disklayout.DynamicRev,
+ CompatFeatures: disklayout.CompatFeatures{
+ ExtAttr: true,
+ ResizeInode: true,
+ DirIndex: true,
+ },
+ IncompatFeatures: disklayout.IncompatFeatures{
+ DirentFileType: true,
+ },
+ RoCompatFeatures: disklayout.RoCompatFeatures{
+ Sparse: true,
+ LargeFile: true,
+ },
+ },
+ wantBgs: []bg{
+ {
+ InodeTable: 0x5,
+ BlockBitmap: 0x3,
+ InodeBitmap: 0x4,
+ Flags: disklayout.BGFlags{
+ InodeZeroed: true,
+ },
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ _, _, vd, tearDown, err := setUp(t, test.image)
+ if err != nil {
+ t.Fatalf("setUp failed: %v", err)
+ }
+ defer tearDown()
+
+ fs, ok := vd.Mount().Filesystem().Impl().(*filesystem)
+ if !ok {
+ t.Fatalf("ext filesystem of incorrect type: %T", vd.Mount().Filesystem().Impl())
+ }
+
+ // Offload superblock and block group descriptors contents into
+ // local structs for comparison.
+ totalFreeInodes := uint32(0)
+ totalFreeBlocks := uint64(0)
+ gotSb := sb{
+ InodesCount: fs.sb.InodesCount(),
+ BlocksCount: fs.sb.BlocksCount(),
+ MaxMountCount: fs.sb.MaxMountCount(),
+ FirstDataBlock: fs.sb.FirstDataBlock(),
+ BlockSize: fs.sb.BlockSize(),
+ BlocksPerGroup: fs.sb.BlocksPerGroup(),
+ ClusterSize: fs.sb.ClusterSize(),
+ ClustersPerGroup: fs.sb.ClustersPerGroup(),
+ InodeSize: fs.sb.InodeSize(),
+ InodesPerGroup: fs.sb.InodesPerGroup(),
+ BgDescSize: fs.sb.BgDescSize(),
+ Magic: fs.sb.Magic(),
+ Revision: fs.sb.Revision(),
+ CompatFeatures: fs.sb.CompatibleFeatures(),
+ IncompatFeatures: fs.sb.IncompatibleFeatures(),
+ RoCompatFeatures: fs.sb.ReadOnlyCompatibleFeatures(),
+ }
+ gotNumBgs := len(fs.bgs)
+ gotBgs := make([]bg, gotNumBgs)
+ for i := 0; i < gotNumBgs; i++ {
+ gotBgs[i].InodeTable = fs.bgs[i].InodeTable()
+ gotBgs[i].BlockBitmap = fs.bgs[i].BlockBitmap()
+ gotBgs[i].InodeBitmap = fs.bgs[i].InodeBitmap()
+ gotBgs[i].ExclusionBitmap = fs.bgs[i].ExclusionBitmap()
+ gotBgs[i].Flags = fs.bgs[i].Flags()
+
+ totalFreeInodes += fs.bgs[i].FreeInodesCount()
+ totalFreeBlocks += uint64(fs.bgs[i].FreeBlocksCount())
+ }
+
+ if diff := cmp.Diff(gotSb, test.wantSb); diff != "" {
+ t.Errorf("superblock mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(gotBgs, test.wantBgs); diff != "" {
+ t.Errorf("block group descriptors mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(totalFreeInodes, fs.sb.FreeInodesCount()); diff != "" {
+ t.Errorf("total free inodes mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(totalFreeBlocks, fs.sb.FreeBlocksCount()); diff != "" {
+ t.Errorf("total free blocks mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
new file mode 100644
index 000000000..38b68a2d3
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -0,0 +1,237 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "io"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// extentFile is a type of regular file which uses extents to store file data.
+type extentFile struct {
+ regFile regularFile
+
+ // root is the root extent node. This lives in the 60 byte diskInode.Data().
+ // Immutable.
+ root disklayout.ExtentNode
+}
+
+// Compiles only if extentFile implements io.ReaderAt.
+var _ io.ReaderAt = (*extentFile)(nil)
+
+// newExtentFile is the extent file constructor. It reads the entire extent
+// tree into memory.
+// TODO(b/134676337): Build extent tree on demand to reduce memory usage.
+func newExtentFile(regFile regularFile) (*extentFile, error) {
+ file := &extentFile{regFile: regFile}
+ file.regFile.impl = file
+ err := file.buildExtTree()
+ if err != nil {
+ return nil, err
+ }
+ return file, nil
+}
+
+// buildExtTree builds the extent tree by reading it from disk by doing
+// running a simple DFS. It first reads the root node from the inode struct in
+// memory. Then it recursively builds the rest of the tree by reading it off
+// disk.
+//
+// Precondition: inode flag InExtents must be set.
+func (f *extentFile) buildExtTree() error {
+ rootNodeData := f.regFile.inode.diskInode.Data()
+
+ binary.Unmarshal(rootNodeData[:disklayout.ExtentStructsSize], binary.LittleEndian, &f.root.Header)
+
+ // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
+ if f.root.Header.NumEntries > 4 {
+ // read(2) specifies that EINVAL should be returned if the file is unsuitable
+ // for reading.
+ return syserror.EINVAL
+ }
+
+ f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries)
+ for i, off := uint16(0), disklayout.ExtentStructsSize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ var curEntry disklayout.ExtentEntry
+ if f.root.Header.Height == 0 {
+ // Leaf node.
+ curEntry = &disklayout.Extent{}
+ } else {
+ // Internal node.
+ curEntry = &disklayout.ExtentIdx{}
+ }
+ binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentStructsSize], binary.LittleEndian, curEntry)
+ f.root.Entries[i].Entry = curEntry
+ }
+
+ // If this node is internal, perform DFS.
+ if f.root.Header.Height > 0 {
+ for i := uint16(0); i < f.root.Header.NumEntries; i++ {
+ var err error
+ if f.root.Entries[i].Node, err = f.buildExtTreeFromDisk(f.root.Entries[i].Entry); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// buildExtTreeFromDisk reads the extent tree nodes from disk and recursively
+// builds the tree. Performs a simple DFS. It returns the ExtentNode pointed to
+// by the ExtentEntry.
+func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*disklayout.ExtentNode, error) {
+ var header disklayout.ExtentHeader
+ off := entry.PhysicalBlock() * f.regFile.inode.blkSize
+ err := readFromDisk(f.regFile.inode.dev, int64(off), &header)
+ if err != nil {
+ return nil, err
+ }
+
+ entries := make([]disklayout.ExtentEntryPair, header.NumEntries)
+ for i, off := uint16(0), off+disklayout.ExtentStructsSize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ var curEntry disklayout.ExtentEntry
+ if header.Height == 0 {
+ // Leaf node.
+ curEntry = &disklayout.Extent{}
+ } else {
+ // Internal node.
+ curEntry = &disklayout.ExtentIdx{}
+ }
+
+ err := readFromDisk(f.regFile.inode.dev, int64(off), curEntry)
+ if err != nil {
+ return nil, err
+ }
+ entries[i].Entry = curEntry
+ }
+
+ // If this node is internal, perform DFS.
+ if header.Height > 0 {
+ for i := uint16(0); i < header.NumEntries; i++ {
+ var err error
+ entries[i].Node, err = f.buildExtTreeFromDisk(entries[i].Entry)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ return &disklayout.ExtentNode{header, entries}, nil
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *extentFile) ReadAt(dst []byte, off int64) (int, error) {
+ if len(dst) == 0 {
+ return 0, nil
+ }
+
+ if off < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ if uint64(off) >= f.regFile.inode.diskInode.Size() {
+ return 0, io.EOF
+ }
+
+ n, err := f.read(&f.root, uint64(off), dst)
+ if n < len(dst) && err == nil {
+ err = io.EOF
+ }
+ return n, err
+}
+
+// read is the recursive step of extentFile.ReadAt which traverses the extent
+// tree from the node passed and reads file data.
+func (f *extentFile) read(node *disklayout.ExtentNode, off uint64, dst []byte) (int, error) {
+ // Perform a binary search for the node covering bytes starting at r.fileOff.
+ // A highly fragmented filesystem can have upto 340 entries and so linear
+ // search should be avoided. Finds the first entry which does not cover the
+ // file block we want and subtracts 1 to get the desired index.
+ fileBlk := uint32(off / f.regFile.inode.blkSize)
+ n := len(node.Entries)
+ found := sort.Search(n, func(i int) bool {
+ return node.Entries[i].Entry.FileBlock() > fileBlk
+ }) - 1
+
+ // We should be in this recursive step only if the data we want exists under
+ // the current node.
+ if found < 0 {
+ panic("searching for a file block in an extent entry which does not cover it")
+ }
+
+ read := 0
+ toRead := len(dst)
+ var curR int
+ var err error
+ for i := found; i < n && read < toRead; i++ {
+ if node.Header.Height == 0 {
+ curR, err = f.readFromExtent(node.Entries[i].Entry.(*disklayout.Extent), off, dst[read:])
+ } else {
+ curR, err = f.read(node.Entries[i].Node, off, dst[read:])
+ }
+
+ read += curR
+ off += uint64(curR)
+ if err != nil {
+ return read, err
+ }
+ }
+
+ return read, nil
+}
+
+// readFromExtent reads file data from the extent. It takes advantage of the
+// sequential nature of extents and reads file data from multiple blocks in one
+// call.
+//
+// A non-nil error indicates that this is a partial read and there is probably
+// more to read from this extent. The caller should propagate the error upward
+// and not move to the next extent in the tree.
+//
+// A subsequent call to extentReader.Read should continue reading from where we
+// left off as expected.
+func (f *extentFile) readFromExtent(ex *disklayout.Extent, off uint64, dst []byte) (int, error) {
+ curFileBlk := uint32(off / f.regFile.inode.blkSize)
+ exFirstFileBlk := ex.FileBlock()
+ exLastFileBlk := exFirstFileBlk + uint32(ex.Length) // This is exclusive.
+
+ // We should be in this recursive step only if the data we want exists under
+ // the current extent.
+ if curFileBlk < exFirstFileBlk || exLastFileBlk <= curFileBlk {
+ panic("searching for a file block in an extent which does not cover it")
+ }
+
+ curPhyBlk := uint64(curFileBlk-exFirstFileBlk) + ex.PhysicalBlock()
+ readStart := curPhyBlk*f.regFile.inode.blkSize + (off % f.regFile.inode.blkSize)
+
+ endPhyBlk := ex.PhysicalBlock() + uint64(ex.Length)
+ extentEnd := endPhyBlk * f.regFile.inode.blkSize // This is exclusive.
+
+ toRead := int(extentEnd - readStart)
+ if len(dst) < toRead {
+ toRead = len(dst)
+ }
+
+ n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], int64(readStart))
+ if n < toRead {
+ return n, syserror.EIO
+ }
+ return n, nil
+}
diff --git a/pkg/sentry/fs/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
index b3f342c8e..42d0a484b 100644
--- a/pkg/sentry/fs/ext/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -16,17 +16,23 @@ package ext
import (
"bytes"
+ "math/rand"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
)
-// TestExtentTree tests the extent tree building logic.
+const (
+ // mockExtentBlkSize is the mock block size used for testing.
+ // No block has more than 1 header + 4 entries.
+ mockExtentBlkSize = uint64(64)
+)
+
+// The tree described below looks like:
//
-// Test tree:
// 0.{Head}[Idx][Idx]
// / \
// / \
@@ -44,12 +50,8 @@ import (
//
// Please note that ext4 might not construct extent trees looking like this.
// This is purely for testing the tree traversal logic.
-func TestExtentTree(t *testing.T) {
- blkSize := uint64(64) // No block has more than 1 header + 4 entries.
- mockDisk := make([]byte, blkSize*10)
- mockInode := &inode{diskInode: &disklayout.InodeNew{}}
-
- node3 := &disklayout.ExtentNode{
+var (
+ node3 = &disklayout.ExtentNode{
Header: disklayout.ExtentHeader{
Magic: disklayout.ExtentMagic,
NumEntries: 1,
@@ -68,7 +70,7 @@ func TestExtentTree(t *testing.T) {
},
}
- node2 := &disklayout.ExtentNode{
+ node2 = &disklayout.ExtentNode{
Header: disklayout.ExtentHeader{
Magic: disklayout.ExtentMagic,
NumEntries: 1,
@@ -86,7 +88,7 @@ func TestExtentTree(t *testing.T) {
},
}
- node1 := &disklayout.ExtentNode{
+ node1 = &disklayout.ExtentNode{
Header: disklayout.ExtentHeader{
Magic: disklayout.ExtentMagic,
NumEntries: 2,
@@ -113,7 +115,7 @@ func TestExtentTree(t *testing.T) {
},
}
- node0 := &disklayout.ExtentNode{
+ node0 = &disklayout.ExtentNode{
Header: disklayout.ExtentHeader{
Magic: disklayout.ExtentMagic,
NumEntries: 2,
@@ -137,22 +139,69 @@ func TestExtentTree(t *testing.T) {
},
},
}
+)
- writeTree(mockInode, mockDisk, node0, blkSize)
+// TestExtentReader stress tests extentReader functionality. It performs random
+// length reads from all possible positions in the extent tree.
+func TestExtentReader(t *testing.T) {
+ mockExtentFile, want := extentTreeSetUp(t, node0)
+ n := len(want)
- r := bytes.NewReader(mockDisk)
- if err := mockInode.buildExtTree(r, blkSize); err != nil {
- t.Fatalf("inode.buildExtTree failed: %v", err)
+ for from := 0; from < n; from++ {
+ got := make([]byte, n-from)
+
+ if read, err := mockExtentFile.ReadAt(got, int64(from)); err != nil {
+ t.Fatalf("file read operation from offset %d to %d only read %d bytes: %v", from, n, read, err)
+ }
+
+ if diff := cmp.Diff(got, want[from:]); diff != "" {
+ t.Fatalf("file data from offset %d to %d mismatched (-want +got):\n%s", from, n, diff)
+ }
}
+}
+
+// TestBuildExtentTree tests the extent tree building logic.
+func TestBuildExtentTree(t *testing.T) {
+ mockExtentFile, _ := extentTreeSetUp(t, node0)
opt := cmpopts.IgnoreUnexported(disklayout.ExtentIdx{}, disklayout.ExtentHeader{})
- if diff := cmp.Diff(mockInode.root, node0, opt); diff != "" {
+ if diff := cmp.Diff(&mockExtentFile.root, node0, opt); diff != "" {
t.Errorf("extent tree mismatch (-want +got):\n%s", diff)
}
}
-// writeTree writes the tree represented by `root` to the inode and disk passed.
-func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, blkSize uint64) {
+// extentTreeSetUp writes the passed extent tree to a mock disk as an extent
+// tree. It also constucts a mock extent file with the same tree built in it.
+// It also writes random data file data and returns it.
+func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []byte) {
+ t.Helper()
+
+ mockDisk := make([]byte, mockExtentBlkSize*10)
+ mockExtentFile := &extentFile{
+ regFile: regularFile{
+ inode: inode{
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root),
+ },
+ },
+ blkSize: mockExtentBlkSize,
+ dev: bytes.NewReader(mockDisk),
+ },
+ },
+ }
+
+ fileData := writeTree(&mockExtentFile.regFile.inode, mockDisk, node0, mockExtentBlkSize)
+
+ if err := mockExtentFile.buildExtTree(); err != nil {
+ t.Fatalf("inode.buildExtTree failed: %v", err)
+ }
+ return mockExtentFile, fileData
+}
+
+// writeTree writes the tree represented by `root` to the inode and disk. It
+// also writes random file data on disk.
+func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte {
rootData := binary.Marshal(nil, binary.LittleEndian, root.Header)
for _, ep := range root.Entries {
rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry)
@@ -160,26 +209,57 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, blkSize uint
copy(in.diskInode.Data(), rootData)
- if root.Header.Height > 0 {
- for _, ep := range root.Entries {
- writeTreeToDisk(disk, ep, blkSize)
+ var fileData []byte
+ for _, ep := range root.Entries {
+ if root.Header.Height == 0 {
+ fileData = append(fileData, writeFileDataToExtent(disk, ep.Entry.(*disklayout.Extent))...)
+ } else {
+ fileData = append(fileData, writeTreeToDisk(disk, ep)...)
}
}
+ return fileData
}
// writeTreeToDisk is the recursive step for writeTree which writes the tree
-// on the disk only.
-func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair, blkSize uint64) {
+// on the disk only. Also writes random file data on disk.
+func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte {
nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header)
for _, ep := range curNode.Node.Entries {
nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry)
}
- copy(disk[curNode.Entry.PhysicalBlock()*blkSize:], nodeData)
+ copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData)
+
+ var fileData []byte
+ for _, ep := range curNode.Node.Entries {
+ if curNode.Node.Header.Height == 0 {
+ fileData = append(fileData, writeFileDataToExtent(disk, ep.Entry.(*disklayout.Extent))...)
+ } else {
+ fileData = append(fileData, writeTreeToDisk(disk, ep)...)
+ }
+ }
+ return fileData
+}
+
+// writeFileDataToExtent writes random bytes to the blocks on disk that the
+// passed extent points to.
+func writeFileDataToExtent(disk []byte, ex *disklayout.Extent) []byte {
+ phyExStartBlk := ex.PhysicalBlock()
+ phyExStartOff := phyExStartBlk * mockExtentBlkSize
+ phyExEndOff := phyExStartOff + uint64(ex.Length)*mockExtentBlkSize
+ rand.Read(disk[phyExStartOff:phyExEndOff])
+ return disk[phyExStartOff:phyExEndOff]
+}
- if curNode.Node.Header.Height > 0 {
- for _, ep := range curNode.Node.Entries {
- writeTreeToDisk(disk, ep, blkSize)
+// getNumPhyBlks returns the number of physical blocks covered under the node.
+func getNumPhyBlks(node *disklayout.ExtentNode) uint32 {
+ var res uint32
+ for _, ep := range node.Entries {
+ if node.Header.Height == 0 {
+ res += uint32(ep.Entry.(*disklayout.Extent).Length)
+ } else {
+ res += getNumPhyBlks(ep.Node)
}
}
+ return res
}
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
new file mode 100644
index 000000000..a0065343b
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -0,0 +1,86 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// fileDescription is embedded by ext implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+
+ // flags is the same as vfs.OpenOptions.Flags which are passed to
+ // vfs.FilesystemImpl.OpenAt.
+ // TODO(b/134676337): syscalls like read(2), write(2), fchmod(2), fchown(2),
+ // fgetxattr(2), ioctl(2), mmap(2) should fail with EBADF if O_PATH is set.
+ // Only close(2), fstat(2), fstatfs(2) should work.
+ flags uint32
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) inode() *inode {
+ return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *fileDescription) OnClose() error { return nil }
+
+// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
+func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
+ return fd.flags, nil
+}
+
+// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags.
+func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
+ // None of the flags settable by fcntl(F_SETFL) are supported, so this is a
+ // no-op.
+ return nil
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ var stat linux.Statx
+ fd.inode().statTo(&stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// SetStat implements vfs.FileDescriptionImpl.StatFS.
+func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
+ var stat linux.Statfs
+ fd.filesystem().statTo(&stat)
+ return stat, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *fileDescription) Sync(ctx context.Context) error {
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
new file mode 100644
index 000000000..2d15e8aaf
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -0,0 +1,443 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "errors"
+ "io"
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var (
+ // errResolveDirent indicates that the vfs.ResolvingPath.Component() does
+ // not exist on the dentry tree but does exist on disk. So it has to be read in
+ // using the in-memory dirent and added to the dentry tree. Usually indicates
+ // the need to lock filesystem.mu for writing.
+ errResolveDirent = errors.New("resolve path component using dirent")
+)
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // mu serializes changes to the Dentry tree.
+ mu sync.RWMutex
+
+ // dev represents the underlying fs device. It does not require protection
+ // because io.ReaderAt permits concurrent read calls to it. It translates to
+ // the pread syscall which passes on the read request directly to the device
+ // driver. Device drivers are intelligent in serving multiple concurrent read
+ // requests in the optimal order (taking locality into consideration).
+ dev io.ReaderAt
+
+ // inodeCache maps absolute inode numbers to the corresponding Inode struct.
+ // Inodes should be removed from this once their reference count hits 0.
+ //
+ // Protected by mu because most additions (see IterDirents) and all removals
+ // from this corresponds to a change in the dentry tree.
+ inodeCache map[uint32]*inode
+
+ // sb represents the filesystem superblock. Immutable after initialization.
+ sb disklayout.SuperBlock
+
+ // bgs represents all the block group descriptors for the filesystem.
+ // Immutable after initialization.
+ bgs []disklayout.BlockGroup
+}
+
+// Compiles only if filesystem implements vfs.FilesystemImpl.
+var _ vfs.FilesystemImpl = (*filesystem)(nil)
+
+// stepLocked resolves rp.Component() in parent directory vfsd. The write
+// parameter passed tells if the caller has acquired filesystem.mu for writing
+// or not. If set to true, an existing inode on disk can be added to the dentry
+// tree if not present already.
+//
+// stepLocked is loosely analogous to fs/namei.c:walk_component().
+//
+// Preconditions:
+// - filesystem.mu must be locked (for writing if write param is true).
+// - !rp.Done().
+// - inode == vfsd.Impl().(*Dentry).inode.
+func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) {
+ if !inode.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, nil, err
+ }
+
+ for {
+ nextVFSD, err := rp.ResolveComponent(vfsd)
+ if err != nil {
+ return nil, nil, err
+ }
+ if nextVFSD == nil {
+ // Since the Dentry tree is not the sole source of truth for extfs, if it's
+ // not in the Dentry tree, it might need to be pulled from disk.
+ childDirent, ok := inode.impl.(*directory).childMap[rp.Component()]
+ if !ok {
+ // The underlying inode does not exist on disk.
+ return nil, nil, syserror.ENOENT
+ }
+
+ if !write {
+ // filesystem.mu must be held for writing to add to the dentry tree.
+ return nil, nil, errResolveDirent
+ }
+
+ // Create and add the component's dirent to the dentry tree.
+ fs := rp.Mount().Filesystem().Impl().(*filesystem)
+ childInode, err := fs.getOrCreateInodeLocked(childDirent.diskDirent.Inode())
+ if err != nil {
+ return nil, nil, err
+ }
+ // incRef because this is being added to the dentry tree.
+ childInode.incRef()
+ child := newDentry(childInode)
+ vfsd.InsertChild(&child.vfsd, rp.Component())
+
+ // Continue as usual now that nextVFSD is not nil.
+ nextVFSD = &child.vfsd
+ }
+ nextInode := nextVFSD.Impl().(*dentry).inode
+ if nextInode.isSymlink() && rp.ShouldFollowSymlink() {
+ if err := rp.HandleSymlink(inode.impl.(*symlink).target); err != nil {
+ return nil, nil, err
+ }
+ continue
+ }
+ rp.Advance()
+ return nextVFSD, nextInode, nil
+ }
+}
+
+// walkLocked resolves rp to an existing file. The write parameter
+// passed tells if the caller has acquired filesystem.mu for writing or not.
+// If set to true, additions can be made to the dentry tree while walking.
+// If errResolveDirent is returned, the walk needs to be continued with an
+// upgraded filesystem.mu.
+//
+// walkLocked is loosely analogous to Linux's fs/namei.c:path_lookupat().
+//
+// Preconditions:
+// - filesystem.mu must be locked (for writing if write param is true).
+func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
+ vfsd := rp.Start()
+ inode := vfsd.Impl().(*dentry).inode
+ for !rp.Done() {
+ var err error
+ vfsd, inode, err = stepLocked(rp, vfsd, inode, write)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if rp.MustBeDir() && !inode.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, inode, nil
+}
+
+// walkParentLocked resolves all but the last path component of rp to an
+// existing directory. It does not check that the returned directory is
+// searchable by the provider of rp. The write parameter passed tells if the
+// caller has acquired filesystem.mu for writing or not. If set to true,
+// additions can be made to the dentry tree while walking.
+// If errResolveDirent is returned, the walk needs to be continued with an
+// upgraded filesystem.mu.
+//
+// walkParentLocked is loosely analogous to Linux's fs/namei.c:path_parentat().
+//
+// Preconditions:
+// - filesystem.mu must be locked (for writing if write param is true).
+// - !rp.Done().
+func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
+ vfsd := rp.Start()
+ inode := vfsd.Impl().(*dentry).inode
+ for !rp.Final() {
+ var err error
+ vfsd, inode, err = stepLocked(rp, vfsd, inode, write)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if !inode.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, inode, nil
+}
+
+// walk resolves rp to an existing file. If parent is set to true, it resolves
+// the rp till the parent of the last component which should be an existing
+// directory. If parent is false then resolves rp entirely. Attemps to resolve
+// the path as far as it can with a read lock and upgrades the lock if needed.
+func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) {
+ var (
+ vfsd *vfs.Dentry
+ inode *inode
+ err error
+ )
+
+ // Try walking with the hopes that all dentries have already been pulled out
+ // of disk. This reduces congestion (allows concurrent walks).
+ fs.mu.RLock()
+ if parent {
+ vfsd, inode, err = walkParentLocked(rp, false)
+ } else {
+ vfsd, inode, err = walkLocked(rp, false)
+ }
+ fs.mu.RUnlock()
+
+ if err == errResolveDirent {
+ // Upgrade lock and continue walking. Lock upgrading in the middle of the
+ // walk is fine as this is a read only filesystem.
+ fs.mu.Lock()
+ if parent {
+ vfsd, inode, err = walkParentLocked(rp, true)
+ } else {
+ vfsd, inode, err = walkLocked(rp, true)
+ }
+ fs.mu.Unlock()
+ }
+
+ return vfsd, inode, err
+}
+
+// getOrCreateInodeLocked gets the inode corresponding to the inode number passed in.
+// It creates a new one with the given inode number if one does not exist.
+// The caller must increment the ref count if adding this to the dentry tree.
+//
+// Precondition: must be holding fs.mu for writing.
+func (fs *filesystem) getOrCreateInodeLocked(inodeNum uint32) (*inode, error) {
+ if in, ok := fs.inodeCache[inodeNum]; ok {
+ return in, nil
+ }
+
+ in, err := newInode(fs, inodeNum)
+ if err != nil {
+ return nil, err
+ }
+
+ fs.inodeCache[inodeNum] = in
+ return in, nil
+}
+
+// statTo writes the statfs fields to the output parameter.
+func (fs *filesystem) statTo(stat *linux.Statfs) {
+ stat.Type = uint64(fs.sb.Magic())
+ stat.BlockSize = int64(fs.sb.BlockSize())
+ stat.Blocks = fs.sb.BlocksCount()
+ stat.BlocksFree = fs.sb.FreeBlocksCount()
+ stat.BlocksAvailable = fs.sb.FreeBlocksCount()
+ stat.Files = uint64(fs.sb.InodesCount())
+ stat.FilesFree = uint64(fs.sb.FreeInodesCount())
+ stat.NameLength = disklayout.MaxFileName
+ stat.FragmentSize = int64(fs.sb.BlockSize())
+ // TODO(b/134676337): Set Statfs.Flags and Statfs.FSID.
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ vfsd, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+
+ if opts.CheckSearchable {
+ if !inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+
+ inode.incRef()
+ return vfsd, nil
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ vfsd, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return nil, err
+ }
+
+ // EROFS is returned if write access is needed.
+ if vfs.MayWriteFileWithOpenFlags(opts.Flags) || opts.Flags&(linux.O_CREAT|linux.O_EXCL|linux.O_TMPFILE) != 0 {
+ return nil, syserror.EROFS
+ }
+ return inode.open(rp, vfsd, opts.Flags)
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return "", err
+ }
+ symlink, ok := inode.impl.(*symlink)
+ if !ok {
+ return "", syserror.EINVAL
+ }
+ return symlink.target, nil
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ var stat linux.Statx
+ inode.statTo(&stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ if _, _, err := fs.walk(rp, false); err != nil {
+ return linux.Statfs{}, err
+ }
+
+ var stat linux.Statfs
+ fs.statTo(&stat)
+ return stat, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {}
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // This is a readonly filesystem for now.
+ return nil
+}
+
+// The vfs.FilesystemImpl functions below return EROFS because their respective
+// man pages say that EROFS must be returned if the path resolves to a file on
+// this read-only filesystem.
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+
+ if _, _, err := fs.walk(rp, true); err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+
+ if _, _, err := fs.walk(rp, true); err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+
+ _, _, err := fs.walk(rp, true)
+ if err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
+ if rp.Done() {
+ return syserror.ENOENT
+ }
+
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+
+ if !inode.isDir() {
+ return syserror.ENOTDIR
+ }
+
+ return syserror.EROFS
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ _, _, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+
+ _, _, err := fs.walk(rp, true)
+ if err != nil {
+ return err
+ }
+
+ return syserror.EROFS
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ _, inode, err := fs.walk(rp, false)
+ if err != nil {
+ return err
+ }
+
+ if inode.isDir() {
+ return syserror.EISDIR
+ }
+
+ return syserror.EROFS
+}
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
new file mode 100644
index 000000000..e6c847a71
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -0,0 +1,219 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// inode represents an ext inode.
+//
+// inode uses the same inheritance pattern that pkg/sentry/vfs structures use.
+// This has been done to increase memory locality.
+//
+// Implementations:
+// inode --
+// |-- dir
+// |-- symlink
+// |-- regular--
+// |-- extent file
+// |-- block map file
+type inode struct {
+ // refs is a reference count. refs is accessed using atomic memory operations.
+ refs int64
+
+ // inodeNum is the inode number of this inode on disk. This is used to
+ // identify inodes within the ext filesystem.
+ inodeNum uint32
+
+ // dev represents the underlying device. Same as filesystem.dev.
+ dev io.ReaderAt
+
+ // blkSize is the fs data block size. Same as filesystem.sb.BlockSize().
+ blkSize uint64
+
+ // diskInode gives us access to the inode struct on disk. Immutable.
+ diskInode disklayout.Inode
+
+ // This is immutable. The first field of the implementations must have inode
+ // as the first field to ensure temporality.
+ impl interface{}
+}
+
+// incRef increments the inode ref count.
+func (in *inode) incRef() {
+ atomic.AddInt64(&in.refs, 1)
+}
+
+// tryIncRef tries to increment the ref count. Returns true if successful.
+func (in *inode) tryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&in.refs)
+ if refs == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&in.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// decRef decrements the inode ref count and releases the inode resources if
+// the ref count hits 0.
+//
+// Precondition: Must have locked fs.mu.
+func (in *inode) decRef(fs *filesystem) {
+ if refs := atomic.AddInt64(&in.refs, -1); refs == 0 {
+ delete(fs.inodeCache, in.inodeNum)
+ } else if refs < 0 {
+ panic("ext.inode.decRef() called without holding a reference")
+ }
+}
+
+// newInode is the inode constructor. Reads the inode off disk. Identifies
+// inodes based on the absolute inode number on disk.
+func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
+ if inodeNum == 0 {
+ panic("inode number 0 on ext filesystems is not possible")
+ }
+
+ inodeRecordSize := fs.sb.InodeSize()
+ var diskInode disklayout.Inode
+ if inodeRecordSize == disklayout.OldInodeSize {
+ diskInode = &disklayout.InodeOld{}
+ } else {
+ diskInode = &disklayout.InodeNew{}
+ }
+
+ // Calculate where the inode is actually placed.
+ inodesPerGrp := fs.sb.InodesPerGroup()
+ blkSize := fs.sb.BlockSize()
+ inodeTableOff := fs.bgs[getBGNum(inodeNum, inodesPerGrp)].InodeTable() * blkSize
+ inodeOff := inodeTableOff + uint64(uint32(inodeRecordSize)*getBGOff(inodeNum, inodesPerGrp))
+
+ if err := readFromDisk(fs.dev, int64(inodeOff), diskInode); err != nil {
+ return nil, err
+ }
+
+ // Build the inode based on its type.
+ inode := inode{
+ inodeNum: inodeNum,
+ dev: fs.dev,
+ blkSize: blkSize,
+ diskInode: diskInode,
+ }
+
+ switch diskInode.Mode().FileType() {
+ case linux.ModeSymlink:
+ f, err := newSymlink(inode)
+ if err != nil {
+ return nil, err
+ }
+ return &f.inode, nil
+ case linux.ModeRegular:
+ f, err := newRegularFile(inode)
+ if err != nil {
+ return nil, err
+ }
+ return &f.inode, nil
+ case linux.ModeDirectory:
+ f, err := newDirectroy(inode, fs.sb.IncompatibleFeatures().DirentFileType)
+ if err != nil {
+ return nil, err
+ }
+ return &f.inode, nil
+ default:
+ // TODO(b/134676337): Return appropriate errors for sockets, pipes and devices.
+ return nil, syserror.EINVAL
+ }
+}
+
+// open creates and returns a file description for the dentry passed in.
+func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(flags)
+ if err := in.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ switch in.impl.(type) {
+ case *regularFile:
+ var fd regularFileFD
+ fd.flags = flags
+ fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ return &fd.vfsfd, nil
+ case *directory:
+ // Can't open directories writably. This check is not necessary for a read
+ // only filesystem but will be required when write is implemented.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ var fd directoryFD
+ fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ fd.flags = flags
+ return &fd.vfsfd, nil
+ case *symlink:
+ if flags&linux.O_PATH == 0 {
+ // Can't open symlinks without O_PATH.
+ return nil, syserror.ELOOP
+ }
+ var fd symlinkFD
+ fd.flags = flags
+ fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ return &fd.vfsfd, nil
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", in.impl))
+ }
+}
+
+func (in *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, in.isDir(), uint16(in.diskInode.Mode()), in.diskInode.UID(), in.diskInode.GID())
+}
+
+// statTo writes the statx fields to the output parameter.
+func (in *inode) statTo(stat *linux.Statx) {
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK |
+ linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE |
+ linux.STATX_ATIME | linux.STATX_CTIME | linux.STATX_MTIME
+ stat.Blksize = uint32(in.blkSize)
+ stat.Mode = uint16(in.diskInode.Mode())
+ stat.Nlink = uint32(in.diskInode.LinksCount())
+ stat.UID = uint32(in.diskInode.UID())
+ stat.GID = uint32(in.diskInode.GID())
+ stat.Ino = uint64(in.inodeNum)
+ stat.Size = in.diskInode.Size()
+ stat.Atime = in.diskInode.AccessTime().StatxTimestamp()
+ stat.Ctime = in.diskInode.ChangeTime().StatxTimestamp()
+ stat.Mtime = in.diskInode.ModificationTime().StatxTimestamp()
+ // TODO(b/134676337): Set stat.Blocks which is the number of 512 byte blocks
+ // (including metadata blocks) required to represent this file.
+}
+
+// getBGNum returns the block group number that a given inode belongs to.
+func getBGNum(inodeNum uint32, inodesPerGrp uint32) uint32 {
+ return (inodeNum - 1) / inodesPerGrp
+}
+
+// getBGOff returns the offset at which the given inode lives in the block
+// group's inode table, i.e. the index of the inode in the inode table.
+func getBGOff(inodeNum uint32, inodesPerGrp uint32) uint32 {
+ return (inodeNum - 1) % inodesPerGrp
+}
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
new file mode 100644
index 000000000..ffc76ba5b
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -0,0 +1,159 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "io"
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// regularFile represents a regular file's inode. This too follows the
+// inheritance pattern prevelant in the vfs layer described in
+// pkg/sentry/vfs/README.md.
+type regularFile struct {
+ inode inode
+
+ // This is immutable. The first field of fileReader implementations must be
+ // regularFile to ensure temporality.
+ // io.ReaderAt is more strict than io.Reader in the sense that a partial read
+ // is always accompanied by an error. If a read spans past the end of file, a
+ // partial read (within file range) is done and io.EOF is returned.
+ impl io.ReaderAt
+}
+
+// newRegularFile is the regularFile constructor. It figures out what kind of
+// file this is and initializes the fileReader.
+func newRegularFile(inode inode) (*regularFile, error) {
+ regFile := regularFile{
+ inode: inode,
+ }
+
+ inodeFlags := inode.diskInode.Flags()
+
+ if inodeFlags.Extents {
+ file, err := newExtentFile(regFile)
+ if err != nil {
+ return nil, err
+ }
+
+ file.regFile.inode.impl = &file.regFile
+ return &file.regFile, nil
+ }
+
+ file, err := newBlockMapFile(regFile)
+ if err != nil {
+ return nil, err
+ }
+ file.regFile.inode.impl = &file.regFile
+ return &file.regFile, nil
+}
+
+func (in *inode) isRegular() bool {
+ _, ok := in.impl.(*regularFile)
+ return ok
+}
+
+// directoryFD represents a directory file description. It implements
+// vfs.FileDescriptionImpl.
+type regularFileFD struct {
+ fileDescription
+
+ // off is the file offset. off is accessed using atomic memory operations.
+ off int64
+
+ // offMu serializes operations that may mutate off.
+ offMu sync.Mutex
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release() {}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ safeReader := safemem.FromIOReaderAt{
+ ReaderAt: fd.inode().impl.(*regularFile).impl,
+ Offset: offset,
+ }
+
+ // Copies data from disk directly into usermem without any intermediate
+ // allocations (if dst is converted into BlockSeq such that it does not need
+ // safe copying).
+ return dst.CopyOutFrom(ctx, safeReader)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.offMu.Lock()
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // write(2) specifies that EBADF must be returned if the fd is not open for
+ // writing.
+ return 0, syserror.EBADF
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.offMu.Lock()
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *regularFileFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ return syserror.ENOTDIR
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.offMu.Lock()
+ defer fd.offMu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as specified.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END:
+ offset += int64(fd.inode().diskInode.Size())
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+ // TODO(b/134676337): Implement mmap(2).
+ return syserror.ENODEV
+}
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
new file mode 100644
index 000000000..e06548a98
--- /dev/null
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -0,0 +1,111 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ext
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// symlink represents a symlink inode.
+type symlink struct {
+ inode inode
+ target string // immutable
+}
+
+// newSymlink is the symlink constructor. It reads out the symlink target from
+// the inode (however it might have been stored).
+func newSymlink(inode inode) (*symlink, error) {
+ var file *symlink
+ var link []byte
+
+ // If the symlink target is lesser than 60 bytes, its stores in inode.Data().
+ // Otherwise either extents or block maps will be used to store the link.
+ size := inode.diskInode.Size()
+ if size < 60 {
+ link = inode.diskInode.Data()[:size]
+ } else {
+ // Create a regular file out of this inode and read out the target.
+ regFile, err := newRegularFile(inode)
+ if err != nil {
+ return nil, err
+ }
+
+ link = make([]byte, size)
+ if n, err := regFile.impl.ReadAt(link, 0); uint64(n) < size {
+ return nil, err
+ }
+ }
+
+ file = &symlink{inode: inode, target: string(link)}
+ file.inode.impl = file
+ return file, nil
+}
+
+func (in *inode) isSymlink() bool {
+ _, ok := in.impl.(*symlink)
+ return ok
+}
+
+// symlinkFD represents a symlink file description and implements implements
+// vfs.FileDescriptionImpl. which may only be used if open options contains
+// O_PATH. For this reason most of the functions return EBADF.
+type symlinkFD struct {
+ fileDescription
+}
+
+// Compiles only if symlinkFD implements vfs.FileDescriptionImpl.
+var _ vfs.FileDescriptionImpl = (*symlinkFD)(nil)
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *symlinkFD) Release() {}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *symlinkFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *symlinkFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *symlinkFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *symlinkFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *symlinkFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ return syserror.ENOTDIR
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *symlinkFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *symlinkFD) ConfigureMMap(ctx context.Context, opts memmap.MMapOpts) error {
+ return syserror.EBADF
+}
diff --git a/pkg/sentry/fs/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go
index 3472c5fa8..d8b728f8c 100644
--- a/pkg/sentry/fs/ext/utils.go
+++ b/pkg/sentry/fsimpl/ext/utils.go
@@ -15,38 +15,30 @@
package ext
import (
- "encoding/binary"
"io"
- "gvisor.dev/gvisor/pkg/sentry/fs/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
// readFromDisk performs a binary read from disk into the given struct from
// the absolute offset provided.
-//
-// All disk reads should use this helper so we avoid reading from stale
-// previously used offsets. This function forces the offset parameter.
-//
-// Precondition: Must hold the mutex of the filesystem containing dev.
-func readFromDisk(dev io.ReadSeeker, abOff int64, v interface{}) error {
- if _, err := dev.Seek(abOff, io.SeekStart); err != nil {
- return syserror.EIO
- }
-
- if err := binary.Read(dev, binary.LittleEndian, v); err != nil {
+func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error {
+ n := binary.Size(v)
+ buf := make([]byte, n)
+ if read, _ := dev.ReadAt(buf, abOff); read < int(n) {
return syserror.EIO
}
+ binary.Unmarshal(buf, binary.LittleEndian, v)
return nil
}
// readSuperBlock reads the SuperBlock from block group 0 in the underlying
// device. There are three versions of the superblock. This function identifies
// and returns the correct version.
-//
-// Precondition: Must hold the mutex of the filesystem containing dev.
-func readSuperBlock(dev io.ReadSeeker) (disklayout.SuperBlock, error) {
+func readSuperBlock(dev io.ReaderAt) (disklayout.SuperBlock, error) {
var sb disklayout.SuperBlock = &disklayout.SuperBlockOld{}
if err := readFromDisk(dev, disklayout.SbOffset, sb); err != nil {
return nil, err
@@ -76,19 +68,12 @@ func blockGroupsCount(sb disklayout.SuperBlock) uint64 {
blocksPerGroup := uint64(sb.BlocksPerGroup())
// Round up the result. float64 can compromise precision so do it manually.
- bgCount := blocksCount / blocksPerGroup
- if blocksCount%blocksPerGroup != 0 {
- bgCount++
- }
-
- return bgCount
+ return (blocksCount + blocksPerGroup - 1) / blocksPerGroup
}
// readBlockGroups reads the block group descriptor table from block group 0 in
// the underlying device.
-//
-// Precondition: Must hold the mutex of the filesystem containing dev.
-func readBlockGroups(dev io.ReadSeeker, sb disklayout.SuperBlock) ([]disklayout.BlockGroup, error) {
+func readBlockGroups(dev io.ReaderAt, sb disklayout.SuperBlock) ([]disklayout.BlockGroup, error) {
bgCount := blockGroupsCount(sb)
bgdSize := uint64(sb.BgDescSize())
is64Bit := sb.IncompatibleFeatures().Is64Bit
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD
index d5d4f68df..d2450e810 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/memfs/BUILD
@@ -11,8 +11,8 @@ go_template_instance(
prefix = "dentry",
template = "//pkg/ilist:generic_list",
types = {
- "Element": "*Dentry",
- "Linker": "*Dentry",
+ "Element": "*dentry",
+ "Linker": "*dentry",
},
)
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go
index b0c3ea39a..c52dc781c 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/memfs/directory.go
@@ -23,23 +23,23 @@ import (
)
type directory struct {
- inode Inode
+ inode inode
// childList is a list containing (1) child Dentries and (2) fake Dentries
// (with inode == nil) that represent the iteration position of
// directoryFDs. childList is used to support directoryFD.IterDirents()
- // efficiently. childList is protected by Filesystem.mu.
+ // efficiently. childList is protected by filesystem.mu.
childList dentryList
}
-func (fs *Filesystem) newDirectory(creds *auth.Credentials, mode uint16) *Inode {
+func (fs *filesystem) newDirectory(creds *auth.Credentials, mode uint16) *inode {
dir := &directory{}
dir.inode.init(dir, fs, creds, mode)
dir.inode.nlink = 2 // from "." and parent directory or ".." for root
return &dir.inode
}
-func (i *Inode) isDir() bool {
+func (i *inode) isDir() bool {
_, ok := i.impl.(*directory)
return ok
}
@@ -48,8 +48,8 @@ type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
- // Protected by Filesystem.mu.
- iter *Dentry
+ // Protected by filesystem.mu.
+ iter *dentry
off int64
}
@@ -68,7 +68,7 @@ func (fd *directoryFD) Release() {
// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
fs := fd.filesystem()
- d := fd.vfsfd.VirtualDentry().Dentry()
+ vfsd := fd.vfsfd.VirtualDentry().Dentry()
fs.mu.Lock()
defer fs.mu.Unlock()
@@ -77,7 +77,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if !cb.Handle(vfs.Dirent{
Name: ".",
Type: linux.DT_DIR,
- Ino: d.Impl().(*Dentry).inode.ino,
+ Ino: vfsd.Impl().(*dentry).inode.ino,
Off: 0,
}) {
return nil
@@ -85,7 +85,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fd.off++
}
if fd.off == 1 {
- parentInode := d.ParentOrSelf().Impl().(*Dentry).inode
+ parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
if !cb.Handle(vfs.Dirent{
Name: "..",
Type: parentInode.direntType(),
@@ -97,12 +97,12 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fd.off++
}
- dir := d.Impl().(*Dentry).inode.impl.(*directory)
- var child *Dentry
+ dir := vfsd.Impl().(*dentry).inode.impl.(*directory)
+ var child *dentry
if fd.iter == nil {
// Start iteration at the beginning of dir.
child = dir.childList.Front()
- fd.iter = &Dentry{}
+ fd.iter = &dentry{}
} else {
// Continue iteration from where we left off.
child = fd.iter.Next()
@@ -130,32 +130,41 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
- if whence != linux.SEEK_SET {
- // TODO: Linux also allows SEEK_CUR.
+ fs := fd.filesystem()
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as given.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ default:
return 0, syserror.EINVAL
}
if offset < 0 {
return 0, syserror.EINVAL
}
+ // If the offset isn't changing (e.g. due to lseek(0, SEEK_CUR)), don't
+ // seek even if doing so might reposition the iterator due to concurrent
+ // mutation of the directory. Compare fs/libfs.c:dcache_dir_lseek().
+ if fd.off == offset {
+ return offset, nil
+ }
+
fd.off = offset
// Compensate for "." and "..".
- var remChildren int64
- if offset < 2 {
- remChildren = 0
- } else {
+ remChildren := int64(0)
+ if offset >= 2 {
remChildren = offset - 2
}
- fs := fd.filesystem()
dir := fd.inode().impl.(*directory)
- fs.mu.Lock()
- defer fs.mu.Unlock()
-
// Ensure that fd.iter exists and is not linked into dir.childList.
if fd.iter == nil {
- fd.iter = &Dentry{}
+ fd.iter = &dentry{}
} else {
dir.childList.Remove(fd.iter)
}
diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go
index 4d989eeaf..f79e2d9c8 100644
--- a/pkg/sentry/fsimpl/memfs/filesystem.go
+++ b/pkg/sentry/fsimpl/memfs/filesystem.go
@@ -28,9 +28,9 @@ import (
//
// stepLocked is loosely analogous to fs/namei.c:walk_component().
//
-// Preconditions: Filesystem.mu must be locked. !rp.Done(). inode ==
-// vfsd.Impl().(*Dentry).inode.
-func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *Inode) (*vfs.Dentry, *Inode, error) {
+// Preconditions: filesystem.mu must be locked. !rp.Done(). inode ==
+// vfsd.Impl().(*dentry).inode.
+func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode) (*vfs.Dentry, *inode, error) {
if !inode.isDir() {
return nil, nil, syserror.ENOTDIR
}
@@ -47,7 +47,7 @@ afterSymlink:
// not in the Dentry tree, it doesn't exist.
return nil, nil, syserror.ENOENT
}
- nextInode := nextVFSD.Impl().(*Dentry).inode
+ nextInode := nextVFSD.Impl().(*dentry).inode
if symlink, ok := nextInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
// TODO: symlink traversals update access time
if err := rp.HandleSymlink(symlink.target); err != nil {
@@ -64,10 +64,10 @@ afterSymlink:
// walkExistingLocked is loosely analogous to Linux's
// fs/namei.c:path_lookupat().
//
-// Preconditions: Filesystem.mu must be locked.
-func walkExistingLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *Inode, error) {
+// Preconditions: filesystem.mu must be locked.
+func walkExistingLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) {
vfsd := rp.Start()
- inode := vfsd.Impl().(*Dentry).inode
+ inode := vfsd.Impl().(*dentry).inode
for !rp.Done() {
var err error
vfsd, inode, err = stepLocked(rp, vfsd, inode)
@@ -88,10 +88,10 @@ func walkExistingLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *Inode, error) {
// walkParentDirLocked is loosely analogous to Linux's
// fs/namei.c:path_parentat().
//
-// Preconditions: Filesystem.mu must be locked. !rp.Done().
-func walkParentDirLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *Inode, error) {
+// Preconditions: filesystem.mu must be locked. !rp.Done().
+func walkParentDirLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) {
vfsd := rp.Start()
- inode := vfsd.Impl().(*Dentry).inode
+ inode := vfsd.Impl().(*dentry).inode
for !rp.Final() {
var err error
vfsd, inode, err = stepLocked(rp, vfsd, inode)
@@ -108,9 +108,9 @@ func walkParentDirLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *Inode, error) {
// checkCreateLocked checks that a file named rp.Component() may be created in
// directory parentVFSD, then returns rp.Component().
//
-// Preconditions: Filesystem.mu must be locked. parentInode ==
-// parentVFSD.Impl().(*Dentry).inode. parentInode.isDir() == true.
-func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode *Inode) (string, error) {
+// Preconditions: filesystem.mu must be locked. parentInode ==
+// parentVFSD.Impl().(*dentry).inode. parentInode.isDir() == true.
+func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode *inode) (string, error) {
if err := parentInode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
return "", err
}
@@ -144,7 +144,7 @@ func checkDeleteLocked(vfsd *vfs.Dentry) error {
}
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
-func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
vfsd, inode, err := walkExistingLocked(rp)
@@ -164,7 +164,7 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
}
// LinkAt implements vfs.FilesystemImpl.LinkAt.
-func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
if rp.Done() {
return syserror.EEXIST
}
@@ -185,7 +185,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return err
}
defer rp.Mount().EndWrite()
- d := vd.Dentry().Impl().(*Dentry)
+ d := vd.Dentry().Impl().(*dentry)
if d.inode.isDir() {
return syserror.EPERM
}
@@ -197,7 +197,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
}
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
-func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
if rp.Done() {
return syserror.EEXIST
}
@@ -223,7 +223,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
// MknodAt implements vfs.FilesystemImpl.MknodAt.
-func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
if rp.Done() {
return syserror.EEXIST
}
@@ -246,7 +246,7 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
// OpenAt implements vfs.FilesystemImpl.OpenAt.
-func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
// Filter out flags that are not supported by memfs. O_DIRECTORY and
// O_NOFOLLOW have no effect here (they're handled by VFS by setting
// appropriate bits in rp), but are returned by
@@ -265,11 +265,10 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
mustCreate := opts.Flags&linux.O_EXCL != 0
vfsd := rp.Start()
- inode := vfsd.Impl().(*Dentry).inode
+ inode := vfsd.Impl().(*dentry).inode
fs.mu.Lock()
defer fs.mu.Unlock()
if rp.Done() {
- // FIXME: ???
if rp.MustBeDir() {
return nil, syserror.EISDIR
}
@@ -327,7 +326,7 @@ afterTrailingSymlink:
if mustCreate {
return nil, syserror.EEXIST
}
- childInode := childVFSD.Impl().(*Dentry).inode
+ childInode := childVFSD.Impl().(*dentry).inode
if symlink, ok := childInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
// TODO: symlink traversals update access time
if err := rp.HandleSymlink(symlink.target); err != nil {
@@ -340,7 +339,7 @@ afterTrailingSymlink:
return childInode.open(rp, childVFSD, opts.Flags, false)
}
-func (i *Inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
+func (i *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
ats := vfs.AccessTypesForOpenFlags(flags)
if !afterCreate {
if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil {
@@ -385,7 +384,7 @@ func (i *Inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afte
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
-func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
fs.mu.RLock()
_, inode, err := walkExistingLocked(rp)
fs.mu.RUnlock()
@@ -400,9 +399,8 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
}
// RenameAt implements vfs.FilesystemImpl.RenameAt.
-func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
if rp.Done() {
- // FIXME
return syserror.ENOENT
}
fs.mu.Lock()
@@ -424,7 +422,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vf
}
// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
-func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
defer fs.mu.Unlock()
vfsd, inode, err := walkExistingLocked(rp)
@@ -447,12 +445,14 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
return err
}
+ // Remove from parent directory's childList.
+ vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry))
inode.decRef()
return nil
}
// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
-func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
fs.mu.RLock()
_, _, err := walkExistingLocked(rp)
fs.mu.RUnlock()
@@ -462,12 +462,12 @@ func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
if opts.Stat.Mask == 0 {
return nil
}
- // TODO: implement Inode.setStat
+ // TODO: implement inode.setStat
return syserror.EPERM
}
// StatAt implements vfs.FilesystemImpl.StatAt.
-func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
fs.mu.RLock()
_, inode, err := walkExistingLocked(rp)
fs.mu.RUnlock()
@@ -480,7 +480,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
}
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
-func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
fs.mu.RLock()
_, _, err := walkExistingLocked(rp)
fs.mu.RUnlock()
@@ -492,7 +492,7 @@ func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
}
// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
-func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
if rp.Done() {
return syserror.EEXIST
}
@@ -517,7 +517,7 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
}
// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
-func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
fs.mu.Lock()
defer fs.mu.Unlock()
vfsd, inode, err := walkExistingLocked(rp)
@@ -537,6 +537,8 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
return err
}
+ // Remove from parent directory's childList.
+ vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry))
inode.decLinksLocked()
return nil
}
diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go
index f381e1a88..45cd42b3e 100644
--- a/pkg/sentry/fsimpl/memfs/memfs.go
+++ b/pkg/sentry/fsimpl/memfs/memfs.go
@@ -21,10 +21,10 @@
//
// Lock order:
//
-// Filesystem.mu
+// filesystem.mu
// regularFileFD.offMu
// regularFile.mu
-// Inode.mu
+// inode.mu
package memfs
import (
@@ -42,8 +42,8 @@ import (
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct{}
-// Filesystem implements vfs.FilesystemImpl.
-type Filesystem struct {
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
vfsfs vfs.Filesystem
// mu serializes changes to the Dentry tree.
@@ -54,44 +54,44 @@ type Filesystem struct {
// NewFilesystem implements vfs.FilesystemType.NewFilesystem.
func (fstype FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
- var fs Filesystem
+ var fs filesystem
fs.vfsfs.Init(&fs)
root := fs.newDentry(fs.newDirectory(creds, 01777))
return &fs.vfsfs, &root.vfsd, nil
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *Filesystem) Release() {
+func (fs *filesystem) Release() {
}
// Sync implements vfs.FilesystemImpl.Sync.
-func (fs *Filesystem) Sync(ctx context.Context) error {
+func (fs *filesystem) Sync(ctx context.Context) error {
// All filesystem state is in-memory.
return nil
}
-// Dentry implements vfs.DentryImpl.
-type Dentry struct {
+// dentry implements vfs.DentryImpl.
+type dentry struct {
vfsd vfs.Dentry
- // inode is the inode represented by this Dentry. Multiple Dentries may
- // share a single non-directory Inode (with hard links). inode is
+ // inode is the inode represented by this dentry. Multiple Dentries may
+ // share a single non-directory inode (with hard links). inode is
// immutable.
- inode *Inode
+ inode *inode
- // memfs doesn't count references on Dentries; because the Dentry tree is
+ // memfs doesn't count references on dentries; because the dentry tree is
// the sole source of truth, it is by definition always consistent with the
- // state of the filesystem. However, it does count references on Inodes,
- // because Inode resources are released when all references are dropped.
+ // state of the filesystem. However, it does count references on inodes,
+ // because inode resources are released when all references are dropped.
// (memfs doesn't really have resources to release, but we implement
// reference counting because tmpfs regular files will.)
- // dentryEntry (ugh) links Dentries into their parent directory.childList.
+ // dentryEntry (ugh) links dentries into their parent directory.childList.
dentryEntry
}
-func (fs *Filesystem) newDentry(inode *Inode) *Dentry {
- d := &Dentry{
+func (fs *filesystem) newDentry(inode *inode) *dentry {
+ d := &dentry{
inode: inode,
}
d.vfsd.Init(d)
@@ -99,37 +99,37 @@ func (fs *Filesystem) newDentry(inode *Inode) *Dentry {
}
// IncRef implements vfs.DentryImpl.IncRef.
-func (d *Dentry) IncRef(vfsfs *vfs.Filesystem) {
+func (d *dentry) IncRef(vfsfs *vfs.Filesystem) {
d.inode.incRef()
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
-func (d *Dentry) TryIncRef(vfsfs *vfs.Filesystem) bool {
+func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool {
return d.inode.tryIncRef()
}
// DecRef implements vfs.DentryImpl.DecRef.
-func (d *Dentry) DecRef(vfsfs *vfs.Filesystem) {
+func (d *dentry) DecRef(vfsfs *vfs.Filesystem) {
d.inode.decRef()
}
-// Inode represents a filesystem object.
-type Inode struct {
+// inode represents a filesystem object.
+type inode struct {
// refs is a reference count. refs is accessed using atomic memory
// operations.
//
- // A reference is held on all Inodes that are reachable in the filesystem
+ // A reference is held on all inodes that are reachable in the filesystem
// tree. For non-directories (which may have multiple hard links), this
// means that a reference is dropped when nlink reaches 0. For directories,
// nlink never reaches 0 due to the "." entry; instead,
- // Filesystem.RmdirAt() drops the reference.
+ // filesystem.RmdirAt() drops the reference.
refs int64
// Inode metadata; protected by mu and accessed using atomic memory
// operations unless otherwise specified.
mu sync.RWMutex
mode uint32 // excluding file type bits, which are based on impl
- nlink uint32 // protected by Filesystem.mu instead of Inode.mu
+ nlink uint32 // protected by filesystem.mu instead of inode.mu
uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
gid uint32 // auth.KGID, but ...
ino uint64 // immutable
@@ -137,7 +137,7 @@ type Inode struct {
impl interface{} // immutable
}
-func (i *Inode) init(impl interface{}, fs *Filesystem, creds *auth.Credentials, mode uint16) {
+func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode uint16) {
i.refs = 1
i.mode = uint32(mode)
i.uid = uint32(creds.EffectiveKUID)
@@ -147,29 +147,29 @@ func (i *Inode) init(impl interface{}, fs *Filesystem, creds *auth.Credentials,
i.impl = impl
}
-// Preconditions: Filesystem.mu must be locked for writing.
-func (i *Inode) incLinksLocked() {
+// Preconditions: filesystem.mu must be locked for writing.
+func (i *inode) incLinksLocked() {
if atomic.AddUint32(&i.nlink, 1) <= 1 {
- panic("memfs.Inode.incLinksLocked() called with no existing links")
+ panic("memfs.inode.incLinksLocked() called with no existing links")
}
}
-// Preconditions: Filesystem.mu must be locked for writing.
-func (i *Inode) decLinksLocked() {
+// Preconditions: filesystem.mu must be locked for writing.
+func (i *inode) decLinksLocked() {
if nlink := atomic.AddUint32(&i.nlink, ^uint32(0)); nlink == 0 {
i.decRef()
} else if nlink == ^uint32(0) { // negative overflow
- panic("memfs.Inode.decLinksLocked() called with no existing links")
+ panic("memfs.inode.decLinksLocked() called with no existing links")
}
}
-func (i *Inode) incRef() {
+func (i *inode) incRef() {
if atomic.AddInt64(&i.refs, 1) <= 1 {
- panic("memfs.Inode.incRef() called without holding a reference")
+ panic("memfs.inode.incRef() called without holding a reference")
}
}
-func (i *Inode) tryIncRef() bool {
+func (i *inode) tryIncRef() bool {
for {
refs := atomic.LoadInt64(&i.refs)
if refs == 0 {
@@ -181,7 +181,7 @@ func (i *Inode) tryIncRef() bool {
}
}
-func (i *Inode) decRef() {
+func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
// This is unnecessary; it's mostly to simulate what tmpfs would do.
if regfile, ok := i.impl.(*regularFile); ok {
@@ -191,18 +191,18 @@ func (i *Inode) decRef() {
regfile.mu.Unlock()
}
} else if refs < 0 {
- panic("memfs.Inode.decRef() called without holding a reference")
+ panic("memfs.inode.decRef() called without holding a reference")
}
}
-func (i *Inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
+func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
}
// Go won't inline this function, and returning linux.Statx (which is quite
// big) means spending a lot of time in runtime.duffcopy(), so instead it's an
// output parameter.
-func (i *Inode) statTo(stat *linux.Statx) {
+func (i *inode) statTo(stat *linux.Statx) {
stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
stat.Blksize = 1 // usermem.PageSize in tmpfs
stat.Nlink = atomic.LoadUint32(&i.nlink)
@@ -241,7 +241,7 @@ func allocatedBlocksForSize(size uint64) uint64 {
return (size + 511) / 512
}
-func (i *Inode) direntType() uint8 {
+func (i *inode) direntType() uint8 {
switch i.impl.(type) {
case *regularFile:
return linux.DT_REG
@@ -258,16 +258,17 @@ func (i *Inode) direntType() uint8 {
// vfs.FileDescriptionImpl.
type fileDescription struct {
vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
flags uint32 // status flags; immutable
}
-func (fd *fileDescription) filesystem() *Filesystem {
- return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*Filesystem)
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem)
}
-func (fd *fileDescription) inode() *Inode {
- return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+func (fd *fileDescription) inode() *inode {
+ return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode
}
// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
@@ -294,6 +295,6 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
if opts.Stat.Mask == 0 {
return nil
}
- // TODO: implement Inode.setStat
+ // TODO: implement inode.setStat
return syserror.EPERM
}
diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go
index 4a3603cc8..55f869798 100644
--- a/pkg/sentry/fsimpl/memfs/regular_file.go
+++ b/pkg/sentry/fsimpl/memfs/regular_file.go
@@ -28,16 +28,16 @@ import (
)
type regularFile struct {
- inode Inode
+ inode inode
mu sync.RWMutex
data []byte
// dataLen is len(data), but accessed using atomic memory operations to
- // avoid locking in Inode.stat().
+ // avoid locking in inode.stat().
dataLen int64
}
-func (fs *Filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *Inode {
+func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *inode {
file := &regularFile{}
file.inode.init(file, fs, creds, mode)
file.inode.nlink = 1 // from parent directory
@@ -46,7 +46,6 @@ func (fs *Filesystem) newRegularFile(creds *auth.Credentials, mode uint16) *Inod
type regularFileFD struct {
fileDescription
- vfs.FileDescriptionDefaultImpl
// These are immutable.
readable bool
diff --git a/pkg/sentry/fsimpl/memfs/symlink.go b/pkg/sentry/fsimpl/memfs/symlink.go
index e002d1727..b2ac2cbeb 100644
--- a/pkg/sentry/fsimpl/memfs/symlink.go
+++ b/pkg/sentry/fsimpl/memfs/symlink.go
@@ -19,11 +19,11 @@ import (
)
type symlink struct {
- inode Inode
+ inode inode
target string // immutable
}
-func (fs *Filesystem) newSymlink(creds *auth.Credentials, target string) *Inode {
+func (fs *filesystem) newSymlink(creds *auth.Credentials, target string) *inode {
link := &symlink{
target: target,
}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
new file mode 100644
index 000000000..3d8a4deaf
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -0,0 +1,49 @@
+load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "proc",
+ srcs = [
+ "filesystems.go",
+ "loadavg.go",
+ "meminfo.go",
+ "mounts.go",
+ "net.go",
+ "proc.go",
+ "stat.go",
+ "sys.go",
+ "task.go",
+ "version.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/log",
+ "//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ ],
+)
+
+go_test(
+ name = "proc_test",
+ size = "small",
+ srcs = ["net_test.go"],
+ embed = [":proc"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/inet",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/sentry/fsimpl/proc/filesystems.go
new file mode 100644
index 000000000..c36c4aff5
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/filesystems.go
@@ -0,0 +1,25 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+// filesystemsData implements vfs.DynamicBytesSource for /proc/filesystems.
+//
+// +stateify savable
+type filesystemsData struct{}
+
+// TODO(b/138862512): Implement vfs.DynamicBytesSource.Generate for
+// filesystemsData. We would need to retrive filesystem names from
+// vfs.VirtualFilesystem. Also needs vfs replacement for
+// fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev.
diff --git a/pkg/sentry/fsimpl/proc/loadavg.go b/pkg/sentry/fsimpl/proc/loadavg.go
new file mode 100644
index 000000000..9135afef1
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/loadavg.go
@@ -0,0 +1,40 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// loadavgData backs /proc/loadavg.
+//
+// +stateify savable
+type loadavgData struct{}
+
+var _ vfs.DynamicBytesSource = (*loadavgData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/62345059): Include real data in fields.
+ // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
+ // Column 4-5: currently running processes and the total number of processes.
+ // Column 6: the last process ID used.
+ fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/meminfo.go b/pkg/sentry/fsimpl/proc/meminfo.go
new file mode 100644
index 000000000..9a827cd66
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/meminfo.go
@@ -0,0 +1,77 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo.
+//
+// +stateify savable
+type meminfoData struct {
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+var _ vfs.DynamicBytesSource = (*meminfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ mf := d.k.MemoryFile()
+ mf.UpdateUsage()
+ snapshot, totalUsage := usage.MemoryAccounting.Copy()
+ totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+ anon := snapshot.Anonymous + snapshot.Tmpfs
+ file := snapshot.PageCache + snapshot.Mapped
+ // We don't actually have active/inactive LRUs, so just make up numbers.
+ activeFile := (file / 2) &^ (usermem.PageSize - 1)
+ inactiveFile := file - activeFile
+
+ fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024)
+ memFree := (totalSize - totalUsage) / 1024
+ // We use MemFree as MemAvailable because we don't swap.
+ // TODO(rahat): When reclaim is implemented the value of MemAvailable
+ // should change.
+ fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree)
+ fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree)
+ fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices
+ fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
+ // Emulate a system with no swap, which disables inactivation of anon pages.
+ fmt.Fprintf(buf, "SwapCache: 0 kB\n")
+ fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024)
+ fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Inactive(anon): 0 kB\n")
+ fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024)
+ fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "SwapTotal: 0 kB\n")
+ fmt.Fprintf(buf, "SwapFree: 0 kB\n")
+ fmt.Fprintf(buf, "Dirty: 0 kB\n")
+ fmt.Fprintf(buf, "Writeback: 0 kB\n")
+ fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
+ fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/fsimpl/proc/mounts.go
new file mode 100644
index 000000000..e81b1e910
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/mounts.go
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import "gvisor.dev/gvisor/pkg/sentry/kernel"
+
+// TODO(b/138862512): Implement mountInfoFile and mountsFile.
+
+// mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo.
+//
+// +stateify savable
+type mountInfoFile struct {
+ t *kernel.Task
+}
+
+// mountsFile implements vfs.DynamicBytesSource for /proc/[pid]/mounts.
+//
+// +stateify savable
+type mountsFile struct {
+ t *kernel.Task
+}
diff --git a/pkg/sentry/fsimpl/proc/net.go b/pkg/sentry/fsimpl/proc/net.go
new file mode 100644
index 000000000..fd46eebf8
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/net.go
@@ -0,0 +1,338 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6.
+//
+// +stateify savable
+type ifinet6 struct {
+ s inet.Stack
+}
+
+var _ vfs.DynamicBytesSource = (*ifinet6)(nil)
+
+func (n *ifinet6) contents() []string {
+ var lines []string
+ nics := n.s.Interfaces()
+ for id, naddrs := range n.s.InterfaceAddrs() {
+ nic, ok := nics[id]
+ if !ok {
+ // NIC was added after NICNames was called. We'll just
+ // ignore it.
+ continue
+ }
+
+ for _, a := range naddrs {
+ // IPv6 only.
+ if a.Family != linux.AF_INET6 {
+ continue
+ }
+
+ // Fields:
+ // IPv6 address displayed in 32 hexadecimal chars without colons
+ // Netlink device number (interface index) in hexadecimal (use nic id)
+ // Prefix length in hexadecimal
+ // Scope value (use 0)
+ // Interface flags
+ // Device name
+ lines = append(lines, fmt.Sprintf("%032x %02x %02x %02x %02x %8s\n", a.Addr, id, a.PrefixLen, 0, a.Flags, nic.Name))
+ }
+ }
+ return lines
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (n *ifinet6) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ for _, l := range n.contents() {
+ buf.WriteString(l)
+ }
+ return nil
+}
+
+// netDev implements vfs.DynamicBytesSource for /proc/net/dev.
+//
+// +stateify savable
+type netDev struct {
+ s inet.Stack
+}
+
+var _ vfs.DynamicBytesSource = (*netDev)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (n *netDev) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ interfaces := n.s.Interfaces()
+ buf.WriteString("Inter-| Receive | Transmit\n")
+ buf.WriteString(" face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n")
+
+ for _, i := range interfaces {
+ // Implements the same format as
+ // net/core/net-procfs.c:dev_seq_printf_stats.
+ var stats inet.StatDev
+ if err := n.s.Statistics(&stats, i.Name); err != nil {
+ log.Warningf("Failed to retrieve interface statistics for %v: %v", i.Name, err)
+ continue
+ }
+ fmt.Fprintf(
+ buf,
+ "%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n",
+ i.Name,
+ // Received
+ stats[0], // bytes
+ stats[1], // packets
+ stats[2], // errors
+ stats[3], // dropped
+ stats[4], // fifo
+ stats[5], // frame
+ stats[6], // compressed
+ stats[7], // multicast
+ // Transmitted
+ stats[8], // bytes
+ stats[9], // packets
+ stats[10], // errors
+ stats[11], // dropped
+ stats[12], // fifo
+ stats[13], // frame
+ stats[14], // compressed
+ stats[15], // multicast
+ )
+ }
+
+ return nil
+}
+
+// netUnix implements vfs.DynamicBytesSource for /proc/net/unix.
+//
+// +stateify savable
+type netUnix struct {
+ k *kernel.Kernel
+}
+
+var _ vfs.DynamicBytesSource = (*netUnix)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (n *netUnix) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n")
+ for _, se := range n.k.ListSockets() {
+ s := se.Sock.Get()
+ if s == nil {
+ log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
+ continue
+ }
+ sfile := s.(*fs.File)
+ if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
+ s.DecRef()
+ // Not a unix socket.
+ continue
+ }
+ sops := sfile.FileOperations.(*unix.SocketOperations)
+
+ addr, err := sops.Endpoint().GetLocalAddress()
+ if err != nil {
+ log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err)
+ addr.Addr = "<unknown>"
+ }
+
+ sockFlags := 0
+ if ce, ok := sops.Endpoint().(transport.ConnectingEndpoint); ok {
+ if ce.Listening() {
+ // For unix domain sockets, linux reports a single flag
+ // value if the socket is listening, of __SO_ACCEPTCON.
+ sockFlags = linux.SO_ACCEPTCON
+ }
+ }
+
+ // In the socket entry below, the value for the 'Num' field requires
+ // some consideration. Linux prints the address to the struct
+ // unix_sock representing a socket in the kernel, but may redact the
+ // value for unprivileged users depending on the kptr_restrict
+ // sysctl.
+ //
+ // One use for this field is to allow a privileged user to
+ // introspect into the kernel memory to determine information about
+ // a socket not available through procfs, such as the socket's peer.
+ //
+ // In gvisor, returning a pointer to our internal structures would
+ // be pointless, as it wouldn't match the memory layout for struct
+ // unix_sock, making introspection difficult. We could populate a
+ // struct unix_sock with the appropriate data, but even that
+ // requires consideration for which kernel version to emulate, as
+ // the definition of this struct changes over time.
+ //
+ // For now, we always redact this pointer.
+ fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %5d",
+ (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct.
+ sfile.ReadRefs()-1, // RefCount, don't count our own ref.
+ 0, // Protocol, always 0 for UDS.
+ sockFlags, // Flags.
+ sops.Endpoint().Type(), // Type.
+ sops.State(), // State.
+ sfile.InodeID(), // Inode.
+ )
+
+ // Path
+ if len(addr.Addr) != 0 {
+ if addr.Addr[0] == 0 {
+ // Abstract path.
+ fmt.Fprintf(buf, " @%s", string(addr.Addr[1:]))
+ } else {
+ fmt.Fprintf(buf, " %s", string(addr.Addr))
+ }
+ }
+ fmt.Fprintf(buf, "\n")
+
+ s.DecRef()
+ }
+ return nil
+}
+
+// netTCP implements vfs.DynamicBytesSource for /proc/net/tcp.
+//
+// +stateify savable
+type netTCP struct {
+ k *kernel.Kernel
+}
+
+var _ vfs.DynamicBytesSource = (*netTCP)(nil)
+
+func (n *netTCP) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ t := kernel.TaskFromContext(ctx)
+ buf.WriteString(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n")
+ for _, se := range n.k.ListSockets() {
+ s := se.Sock.Get()
+ if s == nil {
+ log.Debugf("Couldn't resolve weakref %+v in socket table, racing with destruction?", se.Sock)
+ continue
+ }
+ sfile := s.(*fs.File)
+ sops, ok := sfile.FileOperations.(socket.Socket)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ }
+ if family, stype, _ := sops.Type(); !(family == linux.AF_INET && stype == linux.SOCK_STREAM) {
+ s.DecRef()
+ // Not tcp4 sockets.
+ continue
+ }
+
+ // Linux's documentation for the fields below can be found at
+ // https://www.kernel.org/doc/Documentation/networking/proc_net_tcp.txt.
+ // For Linux's implementation, see net/ipv4/tcp_ipv4.c:get_tcp4_sock().
+ // Note that the header doesn't contain labels for all the fields.
+
+ // Field: sl; entry number.
+ fmt.Fprintf(buf, "%4d: ", se.ID)
+
+ portBuf := make([]byte, 2)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddrInet
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = *local.(*linux.SockAddrInet)
+ }
+ binary.LittleEndian.PutUint16(portBuf, localAddr.Port)
+ fmt.Fprintf(buf, "%08X:%04X ",
+ binary.LittleEndian.Uint32(localAddr.Addr[:]),
+ portBuf)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddrInet
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = *remote.(*linux.SockAddrInet)
+ }
+ binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port)
+ fmt.Fprintf(buf, "%08X:%04X ",
+ binary.LittleEndian.Uint32(remoteAddr.Addr[:]),
+ portBuf)
+
+ // Field: state; socket state.
+ fmt.Fprintf(buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when; timer active state and number of jiffies
+ // until timer expires. Unimplemented.
+ fmt.Fprintf(buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt; number of unrecovered RTO timeouts.
+ // Unimplemented.
+ fmt.Fprintf(buf, "%08X ", 0)
+
+ // Field: uid.
+ uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ fmt.Fprintf(buf, "%5d ", 0)
+ } else {
+ fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
+ }
+
+ // Field: timeout; number of unanswered 0-window probes.
+ // Unimplemented.
+ fmt.Fprintf(buf, "%8d ", 0)
+
+ // Field: inode.
+ fmt.Fprintf(buf, "%8d ", sfile.InodeID())
+
+ // Field: refcount. Don't count the ref we obtain while deferencing
+ // the weakref to this socket.
+ fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: retransmit timeout. Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: predicted tick of soft clock (delayed ACK control data).
+ // Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: (ack.quick<<1)|ack.pingpong, Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: sending congestion window, Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: Slow start size threshold, -1 if threshold >= 0xFFFF.
+ // Unimplemented, report as large threshold.
+ fmt.Fprintf(buf, "%d", -1)
+
+ fmt.Fprintf(buf, "\n")
+
+ s.DecRef()
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/net_test.go b/pkg/sentry/fsimpl/proc/net_test.go
new file mode 100644
index 000000000..20a77a8ca
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/net_test.go
@@ -0,0 +1,78 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+)
+
+func newIPv6TestStack() *inet.TestStack {
+ s := inet.NewTestStack()
+ s.SupportsIPv6Flag = true
+ return s
+}
+
+func TestIfinet6NoAddresses(t *testing.T) {
+ n := &ifinet6{s: newIPv6TestStack()}
+ var buf bytes.Buffer
+ n.Generate(contexttest.Context(t), &buf)
+ if buf.Len() > 0 {
+ t.Errorf("n.Generate() generated = %v, want = %v", buf.Bytes(), []byte{})
+ }
+}
+
+func TestIfinet6(t *testing.T) {
+ s := newIPv6TestStack()
+ s.InterfacesMap[1] = inet.Interface{Name: "eth0"}
+ s.InterfaceAddrsMap[1] = []inet.InterfaceAddr{
+ {
+ Family: linux.AF_INET6,
+ PrefixLen: 128,
+ Addr: []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"),
+ },
+ }
+ s.InterfacesMap[2] = inet.Interface{Name: "eth1"}
+ s.InterfaceAddrsMap[2] = []inet.InterfaceAddr{
+ {
+ Family: linux.AF_INET6,
+ PrefixLen: 128,
+ Addr: []byte("\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"),
+ },
+ }
+ want := map[string]struct{}{
+ "000102030405060708090a0b0c0d0e0f 01 80 00 00 eth0\n": {},
+ "101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {},
+ }
+
+ n := &ifinet6{s: s}
+ contents := n.contents()
+ if len(contents) != len(want) {
+ t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want))
+ }
+ got := map[string]struct{}{}
+ for _, l := range contents {
+ got[l] = struct{}{}
+ }
+
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Got n.contents() = %v, want = %v", got, want)
+ }
+}
diff --git a/pkg/sentry/fsimpl/proc/proc.go b/pkg/sentry/fsimpl/proc/proc.go
new file mode 100644
index 000000000..31dec36de
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/proc.go
@@ -0,0 +1,16 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package proc implements a partial in-memory file system for procfs.
+package proc
diff --git a/pkg/sentry/fsimpl/proc/stat.go b/pkg/sentry/fsimpl/proc/stat.go
new file mode 100644
index 000000000..720db3828
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/stat.go
@@ -0,0 +1,127 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// cpuStats contains the breakdown of CPU time for /proc/stat.
+type cpuStats struct {
+ // user is time spent in userspace tasks with non-positive niceness.
+ user uint64
+
+ // nice is time spent in userspace tasks with positive niceness.
+ nice uint64
+
+ // system is time spent in non-interrupt kernel context.
+ system uint64
+
+ // idle is time spent idle.
+ idle uint64
+
+ // ioWait is time spent waiting for IO.
+ ioWait uint64
+
+ // irq is time spent in interrupt context.
+ irq uint64
+
+ // softirq is time spent in software interrupt context.
+ softirq uint64
+
+ // steal is involuntary wait time.
+ steal uint64
+
+ // guest is time spent in guests with non-positive niceness.
+ guest uint64
+
+ // guestNice is time spent in guests with positive niceness.
+ guestNice uint64
+}
+
+// String implements fmt.Stringer.
+func (c cpuStats) String() string {
+ return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
+}
+
+// statData implements vfs.DynamicBytesSource for /proc/stat.
+//
+// +stateify savable
+type statData struct {
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+var _ vfs.DynamicBytesSource = (*statData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/37226836): We currently export only zero CPU stats. We could
+ // at least provide some aggregate stats.
+ var cpu cpuStats
+ fmt.Fprintf(buf, "cpu %s\n", cpu)
+
+ for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
+ fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
+ }
+
+ // The total number of interrupts is dependent on the CPUs and PCI
+ // devices on the system. See arch_probe_nr_irqs.
+ //
+ // Since we don't report real interrupt stats, just choose an arbitrary
+ // value from a representative VM.
+ const numInterrupts = 256
+
+ // The Kernel doesn't handle real interrupts, so report all zeroes.
+ // TODO(b/37226836): We could count page faults as #PF.
+ fmt.Fprintf(buf, "intr 0") // total
+ for i := 0; i < numInterrupts; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+
+ // Total number of context switches.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "ctxt 0\n")
+
+ // CLOCK_REALTIME timestamp from boot, in seconds.
+ fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
+
+ // Total number of clones.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "processes 0\n")
+
+ // Number of runnable tasks.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_running 0\n")
+
+ // Number of tasks waiting on IO.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_blocked 0\n")
+
+ // Number of each softirq handled.
+ fmt.Fprintf(buf, "softirq 0") // total
+ for i := 0; i < linux.NumSoftIRQ; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/sys.go b/pkg/sentry/fsimpl/proc/sys.go
new file mode 100644
index 000000000..b88256e12
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/sys.go
@@ -0,0 +1,51 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// mmapMinAddrData implements vfs.DynamicBytesSource for
+// /proc/sys/vm/mmap_min_addr.
+//
+// +stateify savable
+type mmapMinAddrData struct {
+ k *kernel.Kernel
+}
+
+var _ vfs.DynamicBytesSource = (*mmapMinAddrData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress())
+ return nil
+}
+
+// +stateify savable
+type overcommitMemory struct{}
+
+var _ vfs.DynamicBytesSource = (*overcommitMemory)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *overcommitMemory) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "0\n")
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
new file mode 100644
index 000000000..c46e05c3a
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -0,0 +1,261 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// mapsCommon is embedded by mapsData and smapsData.
+type mapsCommon struct {
+ t *kernel.Task
+}
+
+// mm gets the kernel task's MemoryManager. No additional reference is taken on
+// mm here. This is safe because MemoryManager.destroy is required to leave the
+// MemoryManager in a state where it's still usable as a DynamicBytesSource.
+func (md *mapsCommon) mm() *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ md.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
+//
+// +stateify savable
+type mapsData struct {
+ mapsCommon
+}
+
+var _ vfs.DynamicBytesSource = (*mapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (md *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := md.mm(); mm != nil {
+ mm.ReadMapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps.
+//
+// +stateify savable
+type smapsData struct {
+ mapsCommon
+}
+
+var _ vfs.DynamicBytesSource = (*smapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (sd *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := sd.mm(); mm != nil {
+ mm.ReadSmapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// +stateify savable
+type taskStatData struct {
+ t *kernel.Task
+
+ // If tgstats is true, accumulate fault stats (not implemented) and CPU
+ // time across all tasks in t's thread group.
+ tgstats bool
+
+ // pidns is the PID namespace associated with the proc filesystem that
+ // includes the file using this statData.
+ pidns *kernel.PIDNamespace
+}
+
+var _ vfs.DynamicBytesSource = (*taskStatData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.t))
+ fmt.Fprintf(buf, "(%s) ", s.t.Name())
+ fmt.Fprintf(buf, "%c ", s.t.StateStatus()[0])
+ ppid := kernel.ThreadID(0)
+ if parent := s.t.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "%d ", ppid)
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup()))
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session()))
+ fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */)
+ fmt.Fprintf(buf, "0 " /* flags */)
+ fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
+ var cputime usage.CPUStats
+ if s.tgstats {
+ cputime = s.t.ThreadGroup().CPUStats()
+ } else {
+ cputime = s.t.CPUStats()
+ }
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ cputime = s.t.ThreadGroup().JoinedChildCPUStats()
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ fmt.Fprintf(buf, "%d %d ", s.t.Priority(), s.t.Niceness())
+ fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Count())
+
+ // itrealvalue. Since kernel 2.6.17, this field is no longer
+ // maintained, and is hard coded as 0.
+ fmt.Fprintf(buf, "0 ")
+
+ // Start time is relative to boot time, expressed in clock ticks.
+ fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
+
+ var vss, rss uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+ fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
+
+ // rsslim.
+ fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur)
+
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
+ fmt.Fprintf(buf, "0 0 " /* nswap cnswap */)
+ terminationSignal := linux.Signal(0)
+ if s.t == s.t.ThreadGroup().Leader() {
+ terminationSignal = s.t.ThreadGroup().TerminationSignal()
+ }
+ fmt.Fprintf(buf, "%d ", terminationSignal)
+ fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */)
+ fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
+ fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
+ fmt.Fprintf(buf, "0\n" /* exit_code */)
+
+ return nil
+}
+
+// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm.
+//
+// +stateify savable
+type statmData struct {
+ t *kernel.Task
+}
+
+var _ vfs.DynamicBytesSource = (*statmData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var vss, rss uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+
+ fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
+ return nil
+}
+
+// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
+//
+// +stateify savable
+type statusData struct {
+ t *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+var _ vfs.DynamicBytesSource = (*statusData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "Name:\t%s\n", s.t.Name())
+ fmt.Fprintf(buf, "State:\t%s\n", s.t.StateStatus())
+ fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup()))
+ fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t))
+ ppid := kernel.ThreadID(0)
+ if parent := s.t.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
+ tpid := kernel.ThreadID(0)
+ if tracer := s.t.Tracer(); tracer != nil {
+ tpid = s.pidns.IDOfTask(tracer)
+ }
+ fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
+ var fds int
+ var vss, rss, data uint64
+ s.t.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.Size()
+ }
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
+ })
+ fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
+ fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
+ fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
+ fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
+ fmt.Fprintf(buf, "Threads:\t%d\n", s.t.ThreadGroup().Count())
+ creds := s.t.Credentials()
+ fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
+ fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
+ fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
+ fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
+ fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode())
+ return nil
+}
+
+// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
+type ioUsage interface {
+ // IOUsage returns the io usage data.
+ IOUsage() *usage.IO
+}
+
+// +stateify savable
+type ioData struct {
+ ioUsage
+}
+
+var _ vfs.DynamicBytesSource = (*ioData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ io := usage.IO{}
+ io.Accumulate(i.IOUsage())
+
+ fmt.Fprintf(buf, "char: %d\n", io.CharsRead)
+ fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten)
+ fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls)
+ fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls)
+ fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead)
+ fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten)
+ fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/version.go b/pkg/sentry/fsimpl/proc/version.go
new file mode 100644
index 000000000..e1643d4e0
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/version.go
@@ -0,0 +1,68 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// versionData implements vfs.DynamicBytesSource for /proc/version.
+//
+// +stateify savable
+type versionData struct {
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+var _ vfs.DynamicBytesSource = (*versionData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ init := v.k.GlobalInit()
+ if init == nil {
+ // Attempted to read before the init Task is created. This can
+ // only occur during startup, which should never need to read
+ // this file.
+ panic("Attempted to read version before initial Task is available")
+ }
+
+ // /proc/version takes the form:
+ //
+ // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
+ // (COMPILER_VERSION) VERSION"
+ //
+ // where:
+ // - SYSNAME, RELEASE, and VERSION are the same as returned by
+ // sys_utsname
+ // - COMPILE_USER is the user that build the kernel
+ // - COMPILE_HOST is the hostname of the machine on which the kernel
+ // was built
+ // - COMPILER_VERSION is the version reported by the building compiler
+ //
+ // Since we don't really want to expose build information to
+ // applications, those fields are omitted.
+ //
+ // FIXME(mpratt): Using Version from the init task SyscallTable
+ // disregards the different version a task may have (e.g., in a uts
+ // namespace).
+ ver := init.Leader().SyscallTable().Version
+ fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
+ return nil
+}
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index 5b75a4a06..80f227dbe 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -52,12 +52,16 @@ type Stack interface {
// Statistics reports stack statistics.
Statistics(stat interface{}, arg string) error
+
+ // RouteTable returns the network stack's route table.
+ RouteTable() []Route
+
+ // Resume restarts the network stack after restore.
+ Resume()
}
// Interface contains information about a network interface.
type Interface struct {
- // Keep these fields sorted in the order they appear in rtnetlink(7).
-
// DeviceType is the device type, a Linux ARPHRD_* constant.
DeviceType uint16
@@ -77,8 +81,6 @@ type Interface struct {
// InterfaceAddr contains information about a network interface address.
type InterfaceAddr struct {
- // Keep these fields sorted in the order they appear in rtnetlink(7).
-
// Family is the address family, a Linux AF_* constant.
Family uint8
@@ -109,3 +111,45 @@ type TCPBufferSize struct {
// StatDev describes one line of /proc/net/dev, i.e., stats for one network
// interface.
type StatDev [16]uint64
+
+// Route contains information about a network route.
+type Route struct {
+ // Family is the address family, a Linux AF_* constant.
+ Family uint8
+
+ // DstLen is the length of the destination address.
+ DstLen uint8
+
+ // SrcLen is the length of the source address.
+ SrcLen uint8
+
+ // TOS is the Type of Service filter.
+ TOS uint8
+
+ // Table is the routing table ID.
+ Table uint8
+
+ // Protocol is the route origin, a Linux RTPROT_* constant.
+ Protocol uint8
+
+ // Scope is the distance to destination, a Linux RT_SCOPE_* constant.
+ Scope uint8
+
+ // Type is the route origin, a Linux RTN_* constant.
+ Type uint8
+
+ // Flags are route flags. See rtnetlink(7) under "rtm_flags".
+ Flags uint32
+
+ // DstAddr is the route destination address (RTA_DST).
+ DstAddr []byte
+
+ // SrcAddr is the route source address (RTA_SRC).
+ SrcAddr []byte
+
+ // OutputInterface is the output interface index (RTA_OIF).
+ OutputInterface int32
+
+ // GatewayAddr is the route gateway address (RTA_GATEWAY).
+ GatewayAddr []byte
+}
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 75f9e7a77..b9eed7c3a 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -18,6 +18,7 @@ package inet
type TestStack struct {
InterfacesMap map[int32]Interface
InterfaceAddrsMap map[int32][]InterfaceAddr
+ RouteList []Route
SupportsIPv6Flag bool
TCPRecvBufSize TCPBufferSize
TCPSendBufSize TCPBufferSize
@@ -86,3 +87,12 @@ func (s *TestStack) SetTCPSACKEnabled(enabled bool) error {
func (s *TestStack) Statistics(stat interface{}, arg string) error {
return nil
}
+
+// RouteTable implements Stack.RouteTable.
+func (s *TestStack) RouteTable() []Route {
+ return s.RouteList
+}
+
+// Resume implements Stack.Resume.
+func (s *TestStack) Resume() {
+}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index e61d39c82..41bee9a22 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -144,6 +144,7 @@ go_library(
"threads.go",
"timekeeper.go",
"timekeeper_state.go",
+ "tty.go",
"uts_namespace.go",
"vdso.go",
"version.go",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 4c2d48e65..8c1f79ab5 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -112,11 +112,6 @@ type Kernel struct {
rootIPCNamespace *IPCNamespace
rootAbstractSocketNamespace *AbstractSocketNamespace
- // mounts holds the state of the virtual filesystem. mounts is initially
- // nil, and must be set by calling Kernel.SetRootMountNamespace before
- // Kernel.CreateProcess can succeed.
- mounts *fs.MountNamespace
-
// futexes is the "root" futex.Manager, from which all others are forked.
// This is necessary to ensure that shared futexes are coherent across all
// tasks, including those created by CreateProcess.
@@ -197,6 +192,15 @@ type Kernel struct {
// caches. Not all caches use it, only the caches that use host resources use
// the limiter. It may be nil if disabled.
DirentCacheLimiter *fs.DirentCacheLimiter
+
+ // unimplementedSyscallEmitterOnce is used in the initialization of
+ // unimplementedSyscallEmitter.
+ unimplementedSyscallEmitterOnce sync.Once `state:"nosave"`
+
+ // unimplementedSyscallEmitter is used to emit unimplemented syscall
+ // events. This is initialized lazily on the first unimplemented
+ // syscall.
+ unimplementedSyscallEmitter eventchannel.Emitter `state:"nosave"`
}
// InitKernelArgs holds arguments to Init.
@@ -290,7 +294,6 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
k.futexes = futex.NewManager()
k.netlinkPorts = port.New()
-
return nil
}
@@ -384,11 +387,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
// flushMountSourceRefs flushes the MountSources for all mounted filesystems
// and open FDs.
func (k *Kernel) flushMountSourceRefs() error {
- // Flush all mount sources for currently mounted filesystems in the
- // root mount namespace.
- k.mounts.FlushMountSourceRefs()
-
- // Some tasks may have other mount namespaces; flush those as well.
+ // Flush all mount sources for currently mounted filesystems in each task.
flushed := make(map[*fs.MountNamespace]struct{})
k.tasks.mu.RLock()
k.tasks.forEachThreadGroupLocked(func(tg *ThreadGroup) {
@@ -497,7 +496,7 @@ func (ts *TaskSet) unregisterEpollWaiters() {
}
// LoadFrom returns a new Kernel loaded from args.
-func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error {
+func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
k.networkStack = net
@@ -541,6 +540,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error {
log.Infof("Overall load took [%s]", time.Since(loadStart))
+ k.Timekeeper().SetClocks(clocks)
+ if net != nil {
+ net.Resume()
+ }
+
// Ensure that all pending asynchronous work is complete:
// - namedpipe opening
// - inode file opening
@@ -550,7 +554,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error {
tcpip.AsyncLoading.Wait()
- log.Infof("Overall load took [%s]", time.Since(loadStart))
+ log.Infof("Overall load took [%s] after async work", time.Since(loadStart))
// Applications may size per-cpu structures based on k.applicationCores, so
// it can't change across save/restore. When we are virtualizing CPU
@@ -565,16 +569,6 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack) error {
return nil
}
-// Destroy releases resources owned by k.
-//
-// Preconditions: There must be no task goroutines running in k.
-func (k *Kernel) Destroy() {
- if k.mounts != nil {
- k.mounts.DecRef()
- k.mounts = nil
- }
-}
-
// UniqueID returns a unique identifier.
func (k *Kernel) UniqueID() uint64 {
id := atomic.AddUint64(&k.uniqueID, 1)
@@ -586,11 +580,17 @@ func (k *Kernel) UniqueID() uint64 {
// CreateProcessArgs holds arguments to kernel.CreateProcess.
type CreateProcessArgs struct {
- // Filename is the filename to load.
+ // Filename is the filename to load as the init binary.
//
- // If this is provided as "", then the file will be guessed via Argv[0].
+ // If this is provided as "", File will be checked, then the file will be
+ // guessed via Argv[0].
Filename string
+ // File is a passed host FD pointing to a file to load as the init binary.
+ //
+ // This is checked if and only if Filename is "".
+ File *fs.File
+
// Argvv is a list of arguments.
Argv []string
@@ -632,19 +632,12 @@ type CreateProcessArgs struct {
AbstractSocketNamespace *AbstractSocketNamespace
// MountNamespace optionally contains the mount namespace for this
- // process. If nil, the kernel's mount namespace is used.
+ // process. If nil, the init process's mount namespace is used.
//
// Anyone setting MountNamespace must donate a reference (i.e.
// increment it).
MountNamespace *fs.MountNamespace
- // Root optionally contains the dirent that serves as the root for the
- // process. If nil, the mount namespace's root is used as the process'
- // root.
- //
- // Anyone setting Root must donate a reference (i.e. increment it).
- Root *fs.Dirent
-
// ContainerID is the container that the process belongs to.
ContainerID string
}
@@ -682,16 +675,10 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} {
case auth.CtxCredentials:
return ctx.args.Credentials
case fs.CtxRoot:
- if ctx.args.Root != nil {
- // Take a reference on the root dirent that will be
- // given to the caller.
- ctx.args.Root.IncRef()
- return ctx.args.Root
- }
- if ctx.k.mounts != nil {
- // MountNamespace.Root() will take a reference on the
- // root dirent for us.
- return ctx.k.mounts.Root()
+ if ctx.args.MountNamespace != nil {
+ // MountNamespace.Root() will take a reference on the root
+ // dirent for us.
+ return ctx.args.MountNamespace.Root()
}
return nil
case fs.CtxDirentCacheLimiter:
@@ -735,30 +722,18 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
defer k.extMu.Unlock()
log.Infof("EXEC: %v", args.Argv)
- if k.mounts == nil {
- return nil, 0, fmt.Errorf("no kernel MountNamespace")
- }
-
// Grab the mount namespace.
mounts := args.MountNamespace
if mounts == nil {
- // If no MountNamespace was configured, then use the kernel's
- // root mount namespace, with an extra reference that will be
- // donated to the task.
- mounts = k.mounts
+ mounts = k.GlobalInit().Leader().MountNamespace()
mounts.IncRef()
}
tg := k.newThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits, k.monotonicClock)
ctx := args.NewContext(k)
- // Grab the root directory.
- root := args.Root
- if root == nil {
- // If no Root was configured, then get it from the
- // MountNamespace.
- root = mounts.Root()
- }
+ // Get the root directory from the MountNamespace.
+ root := mounts.Root()
// The call to newFSContext below will take a reference on root, so we
// don't need to hold this one.
defer root.DecRef()
@@ -768,15 +743,23 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
wd := root // Default.
if args.WorkingDirectory != "" {
var err error
- wd, err = k.mounts.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
+ wd, err = mounts.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
if err != nil {
return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
}
defer wd.DecRef()
}
- if args.Filename == "" {
- // Was anything provided?
+ // Check which file to start from.
+ switch {
+ case args.Filename != "":
+ // If a filename is given, take that.
+ // Set File to nil so we resolve the path in LoadTaskImage.
+ args.File = nil
+ case args.File != nil:
+ // If File is set, take the File provided directly.
+ default:
+ // Otherwise look at Argv and see if the first argument is a valid path.
if len(args.Argv) == 0 {
return nil, 0, fmt.Errorf("no filename or command provided")
}
@@ -788,7 +771,8 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
// Create a fresh task context.
remainingTraversals = uint(args.MaxSymlinkTraversals)
- tc, se := k.LoadTaskImage(ctx, k.mounts, root, wd, &remainingTraversals, args.Filename, args.Argv, args.Envv, k.featureSet)
+
+ tc, se := k.LoadTaskImage(ctx, mounts, root, wd, &remainingTraversals, args.Filename, args.File, args.Argv, args.Envv, k.featureSet)
if se != nil {
return nil, 0, errors.New(se.String())
}
@@ -1032,20 +1016,6 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace {
return k.rootAbstractSocketNamespace
}
-// RootMountNamespace returns the MountNamespace.
-func (k *Kernel) RootMountNamespace() *fs.MountNamespace {
- k.extMu.Lock()
- defer k.extMu.Unlock()
- return k.mounts
-}
-
-// SetRootMountNamespace sets the MountNamespace.
-func (k *Kernel) SetRootMountNamespace(mounts *fs.MountNamespace) {
- k.extMu.Lock()
- defer k.extMu.Unlock()
- k.mounts = mounts
-}
-
// NetworkStack returns the network stack. NetworkStack may return nil if no
// network stack is available.
func (k *Kernel) NetworkStack() inet.Stack {
@@ -1168,16 +1138,6 @@ func (k *Kernel) SupervisorContext() context.Context {
}
}
-// EmitUnimplementedEvent emits an UnimplementedSyscall event via the event
-// channel.
-func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
- t := TaskFromContext(ctx)
- eventchannel.Emit(&uspb.UnimplementedSyscall{
- Tid: int32(t.ThreadID()),
- Registers: t.Arch().StateData().Proto(),
- })
-}
-
// SocketEntry represents a socket recorded in Kernel.sockets. It implements
// refs.WeakRefUser for sockets stored in the socket table.
//
@@ -1246,7 +1206,10 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
// The supervisor context is global root.
return auth.NewRootCredentials(ctx.k.rootUserNamespace)
case fs.CtxRoot:
- return ctx.k.mounts.Root()
+ if ctx.k.globalInit != nil {
+ return ctx.k.globalInit.mounts.Root()
+ }
+ return nil
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
case ktime.CtxRealtimeClock:
@@ -1272,3 +1235,23 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
return nil
}
}
+
+// Rate limits for the number of unimplemented syscall events.
+const (
+ unimplementedSyscallsMaxRate = 100 // events per second
+ unimplementedSyscallBurst = 1000 // events
+)
+
+// EmitUnimplementedEvent emits an UnimplementedSyscall event via the event
+// channel.
+func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
+ k.unimplementedSyscallEmitterOnce.Do(func() {
+ k.unimplementedSyscallEmitter = eventchannel.RateLimitedEmitterFrom(eventchannel.DefaultEmitter, unimplementedSyscallsMaxRate, unimplementedSyscallBurst)
+ })
+
+ t := TaskFromContext(ctx)
+ k.unimplementedSyscallEmitter.Emit(&uspb.UnimplementedSyscall{
+ Tid: int32(t.ThreadID()),
+ Registers: t.Arch().StateData().Proto(),
+ })
+}
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index 81fcd8258..e5f297478 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -47,6 +47,11 @@ type Session struct {
// The id is immutable.
id SessionID
+ // foreground is the foreground process group.
+ //
+ // This is protected by TaskSet.mu.
+ foreground *ProcessGroup
+
// ProcessGroups is a list of process groups in this Session. This is
// protected by TaskSet.mu.
processGroups processGroupList
@@ -260,12 +265,14 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
func (tg *ThreadGroup) CreateSession() error {
tg.pidns.owner.mu.Lock()
defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
return tg.createSession()
}
// createSession creates a new session for a threadgroup.
//
-// Precondition: callers must hold TaskSet.mu for writing.
+// Precondition: callers must hold TaskSet.mu and the signal mutex for writing.
func (tg *ThreadGroup) createSession() error {
// Get the ID for this thread in the current namespace.
id := tg.pidns.tgids[tg]
@@ -346,6 +353,9 @@ func (tg *ThreadGroup) createSession() error {
ns.processGroups[ProcessGroupID(local)] = pg
}
+ // Disconnect from the controlling terminal.
+ tg.tty = nil
+
return nil
}
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
index 2a2e6f662..dd69939f9 100644
--- a/pkg/sentry/kernel/task_block.go
+++ b/pkg/sentry/kernel/task_block.go
@@ -15,6 +15,7 @@
package kernel
import (
+ "runtime"
"time"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -121,6 +122,17 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error {
// Deactive our address space, we don't need it.
interrupt := t.SleepStart()
+ // If the request is not completed, but the timer has already expired,
+ // then ensure that we run through a scheduler cycle. This is because
+ // we may see applications relying on timer slack to yield the thread.
+ // For example, they may attempt to sleep for some number of nanoseconds,
+ // and expect that this will actually yield the CPU and sleep for at
+ // least microseconds, e.g.:
+ // https://github.com/LMAX-Exchange/disruptor/commit/6ca210f2bcd23f703c479804d583718e16f43c07
+ if len(timerChan) > 0 {
+ runtime.Gosched()
+ }
+
select {
case <-C:
t.SleepFinish(true)
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 54b1676b0..8639d379f 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -140,15 +140,22 @@ func (t *Task) Stack() *arch.Stack {
// * wd: Working directory to lookup filename under
// * maxTraversals: maximum number of symlinks to follow
// * filename: path to binary to load
+// * file: an open fs.File object of the binary to load. If set,
+// file will be loaded and not filename.
// * argv: Binary argv
// * envv: Binary envv
// * fs: Binary FeatureSet
-func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) {
+func (k *Kernel) LoadTaskImage(ctx context.Context, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, filename string, file *fs.File, argv, envv []string, fs *cpuid.FeatureSet) (*TaskContext, *syserr.Error) {
+ // If File is not nil, we should load that instead of resolving filename.
+ if file != nil {
+ filename = file.MappedName(ctx)
+ }
+
// Prepare a new user address space to load into.
m := mm.NewMemoryManager(k, k)
defer m.DecUsers(ctx)
- os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, argv, envv, k.extraAuxv, k.vdso)
+ os, ac, name, err := loader.Load(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv, envv, k.extraAuxv, k.vdso)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index d60cd62c7..ae6fc4025 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -172,9 +172,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
if parentPG := tg.parentPG(); parentPG == nil {
tg.createSession()
} else {
- // Inherit the process group.
+ // Inherit the process group and terminal.
parentPG.incRefWithParent(parentPG)
tg.processGroup = parentPG
+ tg.tty = t.parent.tg.tty
}
}
tg.tasks.PushBack(t)
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 2a97e3e8e..0eef24bfb 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -19,10 +19,13 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// A ThreadGroup is a logical grouping of tasks that has widespread
@@ -245,6 +248,12 @@ type ThreadGroup struct {
//
// mounts is immutable.
mounts *fs.MountNamespace
+
+ // tty is the thread group's controlling terminal. If nil, there is no
+ // controlling terminal.
+ //
+ // tty is protected by the signal mutex.
+ tty *TTY
}
// newThreadGroup returns a new, empty thread group in PID namespace ns. The
@@ -324,6 +333,176 @@ func (tg *ThreadGroup) forEachChildThreadGroupLocked(fn func(*ThreadGroup)) {
}
}
+// SetControllingTTY sets tty as the controlling terminal of tg.
+func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ // We might be asked to set the controlling terminal of multiple
+ // processes, so we lock both the TaskSet and SignalHandlers.
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // "The calling process must be a session leader and not have a
+ // controlling terminal already." - tty_ioctl(4)
+ if tg.processGroup.session.leader != tg || tg.tty != nil {
+ return syserror.EINVAL
+ }
+
+ // "If this terminal is already the controlling terminal of a different
+ // session group, then the ioctl fails with EPERM, unless the caller
+ // has the CAP_SYS_ADMIN capability and arg equals 1, in which case the
+ // terminal is stolen, and all processes that had it as controlling
+ // terminal lose it." - tty_ioctl(4)
+ if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session {
+ if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 {
+ return syserror.EPERM
+ }
+ // Steal the TTY away. Unlike TIOCNOTTY, don't send signals.
+ for othertg := range tg.pidns.owner.Root.tgids {
+ // This won't deadlock by locking tg.signalHandlers
+ // because at this point:
+ // - We only lock signalHandlers if it's in the same
+ // session as the tty's controlling thread group.
+ // - We know that the calling thread group is not in
+ // the same session as the tty's controlling thread
+ // group.
+ if othertg.processGroup.session == tty.tg.processGroup.session {
+ othertg.signalHandlers.mu.Lock()
+ othertg.tty = nil
+ othertg.signalHandlers.mu.Unlock()
+ }
+ }
+ }
+
+ // Set the controlling terminal and foreground process group.
+ tg.tty = tty
+ tg.processGroup.session.foreground = tg.processGroup
+ // Set this as the controlling process of the terminal.
+ tty.tg = tg
+
+ return nil
+}
+
+// ReleaseControllingTTY gives up tty as the controlling tty of tg.
+func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ // We might be asked to set the controlling terminal of multiple
+ // processes, so we lock both the TaskSet and SignalHandlers.
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+
+ // Just below, we may re-lock signalHandlers in order to send signals.
+ // Thus we can't defer Unlock here.
+ tg.signalHandlers.mu.Lock()
+
+ if tg.tty == nil || tg.tty != tty {
+ tg.signalHandlers.mu.Unlock()
+ return syserror.ENOTTY
+ }
+
+ // "If the process was session leader, then send SIGHUP and SIGCONT to
+ // the foreground process group and all processes in the current
+ // session lose their controlling terminal." - tty_ioctl(4)
+ // Remove tty as the controlling tty for each process in the session,
+ // then send them SIGHUP and SIGCONT.
+
+ // If we're not the session leader, we don't have to do much.
+ if tty.tg != tg {
+ tg.tty = nil
+ tg.signalHandlers.mu.Unlock()
+ return nil
+ }
+
+ tg.signalHandlers.mu.Unlock()
+
+ // We're the session leader. SIGHUP and SIGCONT the foreground process
+ // group and remove all controlling terminals in the session.
+ var lastErr error
+ for othertg := range tg.pidns.owner.Root.tgids {
+ if othertg.processGroup.session == tg.processGroup.session {
+ othertg.signalHandlers.mu.Lock()
+ othertg.tty = nil
+ if othertg.processGroup == tg.processGroup.session.foreground {
+ if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil {
+ lastErr = err
+ }
+ if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil {
+ lastErr = err
+ }
+ }
+ othertg.signalHandlers.mu.Unlock()
+ }
+ }
+
+ return lastErr
+}
+
+// ForegroundProcessGroup returns the process group ID of the foreground
+// process group.
+func (tg *ThreadGroup) ForegroundProcessGroup(tty *TTY) (int32, error) {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // "When fd does not refer to the controlling terminal of the calling
+ // process, -1 is returned" - tcgetpgrp(3)
+ if tg.tty != tty {
+ return -1, syserror.ENOTTY
+ }
+
+ return int32(tg.processGroup.session.foreground.id), nil
+}
+
+// SetForegroundProcessGroup sets the foreground process group of tty to pgid.
+func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) (int32, error) {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // TODO(b/129283598): "If tcsetpgrp() is called by a member of a
+ // background process group in its session, and the calling process is
+ // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all
+ // members of this background process group."
+
+ // tty must be the controlling terminal.
+ if tg.tty != tty {
+ return -1, syserror.ENOTTY
+ }
+
+ // pgid must be positive.
+ if pgid < 0 {
+ return -1, syserror.EINVAL
+ }
+
+ // pg must not be empty. Empty process groups are removed from their
+ // pid namespaces.
+ pg, ok := tg.pidns.processGroups[pgid]
+ if !ok {
+ return -1, syserror.ESRCH
+ }
+
+ // pg must be part of this process's session.
+ if tg.processGroup.session != pg.session {
+ return -1, syserror.EPERM
+ }
+
+ tg.processGroup.session.foreground.id = pgid
+ return 0, nil
+}
+
// itimerRealListener implements ktime.Listener for ITIMER_REAL expirations.
//
// +stateify savable
diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go
new file mode 100644
index 000000000..34f84487a
--- /dev/null
+++ b/pkg/sentry/kernel/tty.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "sync"
+
+// TTY defines the relationship between a thread group and its controlling
+// terminal.
+//
+// +stateify savable
+type TTY struct {
+ mu sync.Mutex `state:"nosave"`
+
+ // tg is protected by mu.
+ tg *ThreadGroup
+}
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index bc5b841fb..ba9c9ce12 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -464,7 +464,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
// base address big enough to fit all segments, so we first create a
// mapping for the total size just to find a region that is big enough.
//
- // It is safe to unmap it immediately with racing with another mapping
+ // It is safe to unmap it immediately without racing with another mapping
// because we are the only one in control of the MemoryManager.
//
// Note that the vaddr of the first PT_LOAD segment is ignored when
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index baa12d9a0..f6f1ae762 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -67,8 +67,64 @@ func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, m
if err != nil {
return nil, nil, err
}
+
+ // Open file will take a reference to Dirent, so destroy this one.
defer d.DecRef()
+ return openFile(ctx, nil, d, name)
+}
+
+// openFile performs checks on a file to be executed. If provided a *fs.File,
+// openFile takes that file's Dirent and performs checks on it. If provided a
+// *fs.Dirent and not a *fs.File, it creates a *fs.File object from the Dirent's
+// Inode and performs checks on that.
+//
+// openFile returns an *fs.File and *fs.Dirent, and the caller takes ownership
+// of both.
+//
+// "dirent" and "file" must not both be nil and point to a readable, executable, regular file.
+func openFile(ctx context.Context, file *fs.File, dirent *fs.Dirent, name string) (*fs.Dirent, *fs.File, error) {
+ // file and dirent must not be nil.
+ if dirent == nil && file == nil {
+ ctx.Infof("dirent and file cannot both be nil.")
+ return nil, nil, syserror.ENOENT
+ }
+
+ if file != nil {
+ dirent = file.Dirent
+ }
+
+ // Perform permissions checks on the file.
+ if err := checkFile(ctx, dirent, name); err != nil {
+ return nil, nil, err
+ }
+
+ if file == nil {
+ var ferr error
+ if file, ferr = dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}); ferr != nil {
+ return nil, nil, ferr
+ }
+ } else {
+ // GetFile takes a reference to the created file, so make one in the case
+ // that the file reference already existed.
+ file.IncRef()
+ }
+
+ // We must be able to read at arbitrary offsets.
+ if !file.Flags().Pread {
+ file.DecRef()
+ ctx.Infof("%s cannot be read at an offset: %+v", file.MappedName(ctx), file.Flags())
+ return nil, nil, syserror.EACCES
+ }
+
+ // Grab reference for caller.
+ dirent.IncRef()
+ return dirent, file, nil
+}
+
+// checkFile performs file permissions checks for binaries called in openPath
+// and openFile
+func checkFile(ctx context.Context, d *fs.Dirent, name string) error {
perms := fs.PermMask{
// TODO(gvisor.dev/issue/160): Linux requires only execute
// permission, not read. However, our backing filesystems may
@@ -80,7 +136,7 @@ func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, m
Execute: true,
}
if err := d.Inode.CheckPermission(ctx, perms); err != nil {
- return nil, nil, err
+ return err
}
// If they claim it's a directory, then make sure.
@@ -88,31 +144,17 @@ func openPath(ctx context.Context, mm *fs.MountNamespace, root, wd *fs.Dirent, m
// N.B. we reject directories below, but we must first reject
// non-directories passed as directories.
if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) {
- return nil, nil, syserror.ENOTDIR
+ return syserror.ENOTDIR
}
// No exec-ing directories, pipes, etc!
if !fs.IsRegular(d.Inode.StableAttr) {
ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr)
- return nil, nil, syserror.EACCES
+ return syserror.EACCES
}
- // Create a new file.
- file, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
- if err != nil {
- return nil, nil, err
- }
+ return nil
- // We must be able to read at arbitrary offsets.
- if !file.Flags().Pread {
- file.DecRef()
- ctx.Infof("%s cannot be read at an offset: %+v", name, file.Flags())
- return nil, nil, syserror.EACCES
- }
-
- // Grab a reference for the caller.
- d.IncRef()
- return d, file, nil
}
// allocStack allocates and maps a stack in to any available part of the address space.
@@ -131,16 +173,30 @@ const (
maxLoaderAttempts = 6
)
-// loadPath resolves filename to a binary and loads it.
+// loadBinary loads a binary that is pointed to by "file". If nil, the path
+// "filename" is resolved and loaded.
//
// It returns:
// * loadedELF, description of the loaded binary
// * arch.Context matching the binary arch
// * fs.Dirent of the binary file
// * Possibly updated argv
-func loadPath(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, fs *cpuid.FeatureSet, filename string, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) {
+func loadBinary(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, remainingTraversals *uint, features *cpuid.FeatureSet, filename string, passedFile *fs.File, argv []string) (loadedELF, arch.Context, *fs.Dirent, []string, error) {
for i := 0; i < maxLoaderAttempts; i++ {
- d, f, err := openPath(ctx, mounts, root, wd, remainingTraversals, filename)
+ var (
+ d *fs.Dirent
+ f *fs.File
+ err error
+ )
+ if passedFile == nil {
+ d, f, err = openPath(ctx, mounts, root, wd, remainingTraversals, filename)
+
+ } else {
+ d, f, err = openFile(ctx, passedFile, nil, "")
+ // Set to nil in case we loop on a Interpreter Script.
+ passedFile = nil
+ }
+
if err != nil {
ctx.Infof("Error opening %s: %v", filename, err)
return loadedELF{}, nil, nil, nil, err
@@ -165,7 +221,7 @@ func loadPath(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespac
switch {
case bytes.Equal(hdr[:], []byte(elfMagic)):
- loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, fs, f)
+ loaded, ac, err := loadELF(ctx, m, mounts, root, wd, remainingTraversals, features, f)
if err != nil {
ctx.Infof("Error loading ELF: %v", err)
return loadedELF{}, nil, nil, nil, err
@@ -190,7 +246,8 @@ func loadPath(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespac
return loadedELF{}, nil, nil, nil, syserror.ELOOP
}
-// Load loads filename into a MemoryManager.
+// Load loads "file" into a MemoryManager. If file is nil, the path "filename"
+// is resolved and loaded instead.
//
// If Load returns ErrSwitchFile it should be called again with the returned
// path and argv.
@@ -198,9 +255,9 @@ func loadPath(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespac
// Preconditions:
// * The Task MemoryManager is empty.
// * Load is called on the Task goroutine.
-func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) {
+func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, root, wd *fs.Dirent, maxTraversals *uint, fs *cpuid.FeatureSet, filename string, file *fs.File, argv, envv []string, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) {
// Load the binary itself.
- loaded, ac, d, argv, err := loadPath(ctx, m, mounts, root, wd, maxTraversals, fs, filename, argv)
+ loaded, ac, d, argv, err := loadBinary(ctx, m, mounts, root, wd, maxTraversals, fs, filename, file, argv)
if err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", filename, err), syserr.FromError(err).ToLinux())
}
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
index a8819aa84..8c2246bb4 100644
--- a/pkg/sentry/mm/procfs.go
+++ b/pkg/sentry/mm/procfs.go
@@ -58,6 +58,34 @@ func (mm *MemoryManager) NeedsUpdate(generation int64) bool {
return true
}
+// ReadMapsDataInto is called by fsimpl/proc.mapsData.Generate to
+// implement /proc/[pid]/maps.
+func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ var start usermem.Addr
+
+ for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
+ // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
+ // "panic: autosave error: type usermem.Addr is not registered".
+ mm.appendVMAMapsEntryLocked(ctx, vseg, buf)
+ }
+
+ // We always emulate vsyscall, so advertise it here. Everything about a
+ // vsyscall region is static, so just hard code the maps entry since we
+ // don't have a real vma backing it. The vsyscall region is at the end of
+ // the virtual address space so nothing should be mapped after it (if
+ // something is really mapped in the tiny ~10 MiB segment afterwards, we'll
+ // get the sorting on the maps file wrong at worst; but that's not possible
+ // on any current platform).
+ //
+ // Artifically adjust the seqfile handle so we only output vsyscall entry once.
+ if start != vsyscallEnd {
+ // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
+ buf.WriteString(vsyscallMapsEntry)
+ }
+}
+
// ReadMapsSeqFileData is called by fs/proc.mapsData.ReadSeqFileData to
// implement /proc/[pid]/maps.
func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
@@ -151,6 +179,27 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI
b.WriteString("\n")
}
+// ReadSmapsDataInto is called by fsimpl/proc.smapsData.Generate to
+// implement /proc/[pid]/maps.
+func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffer) {
+ mm.mappingMu.RLock()
+ defer mm.mappingMu.RUnlock()
+ var start usermem.Addr
+
+ for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
+ // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
+ // "panic: autosave error: type usermem.Addr is not registered".
+ mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf)
+ }
+
+ // We always emulate vsyscall, so advertise it here. See
+ // ReadMapsSeqFileData for additional commentary.
+ if start != vsyscallEnd {
+ // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
+ buf.WriteString(vsyscallSmapsEntry)
+ }
+}
+
// ReadSmapsSeqFileData is called by fs/proc.smapsData.ReadSeqFileData to
// implement /proc/[pid]/smaps.
func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
@@ -190,7 +239,12 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
// Preconditions: mm.mappingMu must be locked.
func (mm *MemoryManager) vmaSmapsEntryLocked(ctx context.Context, vseg vmaIterator) []byte {
var b bytes.Buffer
- mm.appendVMAMapsEntryLocked(ctx, vseg, &b)
+ mm.vmaSmapsEntryIntoLocked(ctx, vseg, &b)
+ return b.Bytes()
+}
+
+func (mm *MemoryManager) vmaSmapsEntryIntoLocked(ctx context.Context, vseg vmaIterator, b *bytes.Buffer) {
+ mm.appendVMAMapsEntryLocked(ctx, vseg, b)
vma := vseg.ValuePtr()
// We take mm.activeMu here in each call to vmaSmapsEntryLocked, instead of
@@ -211,40 +265,40 @@ func (mm *MemoryManager) vmaSmapsEntryLocked(ctx context.Context, vseg vmaIterat
}
mm.activeMu.RUnlock()
- fmt.Fprintf(&b, "Size: %8d kB\n", vseg.Range().Length()/1024)
- fmt.Fprintf(&b, "Rss: %8d kB\n", rss/1024)
+ fmt.Fprintf(b, "Size: %8d kB\n", vseg.Range().Length()/1024)
+ fmt.Fprintf(b, "Rss: %8d kB\n", rss/1024)
// Currently we report PSS = RSS, i.e. we pretend each page mapped by a pma
// is only mapped by that pma. This avoids having to query memmap.Mappables
// for reference count information on each page. As a corollary, all pages
// are accounted as "private" whether or not the vma is private; compare
// Linux's fs/proc/task_mmu.c:smaps_account().
- fmt.Fprintf(&b, "Pss: %8d kB\n", rss/1024)
- fmt.Fprintf(&b, "Shared_Clean: %8d kB\n", 0)
- fmt.Fprintf(&b, "Shared_Dirty: %8d kB\n", 0)
+ fmt.Fprintf(b, "Pss: %8d kB\n", rss/1024)
+ fmt.Fprintf(b, "Shared_Clean: %8d kB\n", 0)
+ fmt.Fprintf(b, "Shared_Dirty: %8d kB\n", 0)
// Pretend that all pages are dirty if the vma is writable, and clean otherwise.
clean := rss
if vma.effectivePerms.Write {
clean = 0
}
- fmt.Fprintf(&b, "Private_Clean: %8d kB\n", clean/1024)
- fmt.Fprintf(&b, "Private_Dirty: %8d kB\n", (rss-clean)/1024)
+ fmt.Fprintf(b, "Private_Clean: %8d kB\n", clean/1024)
+ fmt.Fprintf(b, "Private_Dirty: %8d kB\n", (rss-clean)/1024)
// Pretend that all pages are "referenced" (recently touched).
- fmt.Fprintf(&b, "Referenced: %8d kB\n", rss/1024)
- fmt.Fprintf(&b, "Anonymous: %8d kB\n", anon/1024)
+ fmt.Fprintf(b, "Referenced: %8d kB\n", rss/1024)
+ fmt.Fprintf(b, "Anonymous: %8d kB\n", anon/1024)
// Hugepages (hugetlb and THP) are not implemented.
- fmt.Fprintf(&b, "AnonHugePages: %8d kB\n", 0)
- fmt.Fprintf(&b, "Shared_Hugetlb: %8d kB\n", 0)
- fmt.Fprintf(&b, "Private_Hugetlb: %7d kB\n", 0)
+ fmt.Fprintf(b, "AnonHugePages: %8d kB\n", 0)
+ fmt.Fprintf(b, "Shared_Hugetlb: %8d kB\n", 0)
+ fmt.Fprintf(b, "Private_Hugetlb: %7d kB\n", 0)
// Swap is not implemented.
- fmt.Fprintf(&b, "Swap: %8d kB\n", 0)
- fmt.Fprintf(&b, "SwapPss: %8d kB\n", 0)
- fmt.Fprintf(&b, "KernelPageSize: %8d kB\n", usermem.PageSize/1024)
- fmt.Fprintf(&b, "MMUPageSize: %8d kB\n", usermem.PageSize/1024)
+ fmt.Fprintf(b, "Swap: %8d kB\n", 0)
+ fmt.Fprintf(b, "SwapPss: %8d kB\n", 0)
+ fmt.Fprintf(b, "KernelPageSize: %8d kB\n", usermem.PageSize/1024)
+ fmt.Fprintf(b, "MMUPageSize: %8d kB\n", usermem.PageSize/1024)
locked := rss
if vma.mlockMode == memmap.MLockNone {
locked = 0
}
- fmt.Fprintf(&b, "Locked: %8d kB\n", locked/1024)
+ fmt.Fprintf(b, "Locked: %8d kB\n", locked/1024)
b.WriteString("VmFlags: ")
if vma.realPerms.Read {
@@ -284,6 +338,4 @@ func (mm *MemoryManager) vmaSmapsEntryLocked(ctx context.Context, vseg vmaIterat
b.WriteString("ac ")
}
b.WriteString("\n")
-
- return b.Bytes()
}
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 8bd3e885d..f7f7298c4 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -285,7 +285,10 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
switch opts.DelayedEviction {
case DelayedEvictionDefault:
opts.DelayedEviction = DelayedEvictionEnabled
- case DelayedEvictionDisabled, DelayedEvictionEnabled, DelayedEvictionManual:
+ case DelayedEvictionDisabled, DelayedEvictionManual:
+ opts.UseHostMemcgPressure = false
+ case DelayedEvictionEnabled:
+ // ok
default:
return nil, fmt.Errorf("invalid MemoryFileOpts.DelayedEviction: %v", opts.DelayedEviction)
}
@@ -777,6 +780,14 @@ func (f *MemoryFile) MarkAllUnevictable(user EvictableMemoryUser) {
}
}
+// ShouldCacheEvictable returns true if f is meaningfully delaying evictions of
+// evictable memory, such that it may be advantageous to cache data in
+// evictable memory. The value returned by ShouldCacheEvictable may change
+// between calls.
+func (f *MemoryFile) ShouldCacheEvictable() bool {
+ return f.opts.DelayedEviction == DelayedEvictionManual || f.opts.UseHostMemcgPressure
+}
+
// UpdateUsage ensures that the memory usage statistics in
// usage.MemoryAccounting are up to date.
func (f *MemoryFile) UpdateUsage() error {
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index 1b6c54e96..ebcc8c098 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -7,13 +7,17 @@ go_library(
srcs = [
"filters.go",
"ptrace.go",
+ "ptrace_amd64.go",
+ "ptrace_arm64.go",
"ptrace_unsafe.go",
"stub_amd64.s",
+ "stub_arm64.s",
"stub_unsafe.go",
"subprocess.go",
"subprocess_amd64.go",
+ "subprocess_arm64.go",
"subprocess_linux.go",
- "subprocess_linux_amd64_unsafe.go",
+ "subprocess_linux_unsafe.go",
"subprocess_unsafe.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ptrace",
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index 6fd30ed25..7b120a15d 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -60,7 +60,7 @@ var (
// maximum user address. This is valid only after a call to stubInit.
//
// We attempt to link the stub here, and adjust downward as needed.
- stubStart uintptr = 0x7fffffff0000
+ stubStart uintptr = stubInitAddress
// stubEnd is the first byte past the end of the stub, as with
// stubStart this is valid only after a call to stubInit.
diff --git a/pkg/flipcall/endpoint_futex.go b/pkg/sentry/platform/ptrace/ptrace_amd64.go
index 5cab02b1d..db0212538 100644
--- a/pkg/flipcall/endpoint_futex.go
+++ b/pkg/sentry/platform/ptrace/ptrace_amd64.go
@@ -12,34 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package flipcall
+package ptrace
import (
- "fmt"
-)
+ "syscall"
-type endpointControlState struct{}
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
-func (ep *Endpoint) initControlState(ctrlMode ControlMode) error {
- if ctrlMode != ControlModeFutex {
- return fmt.Errorf("unsupported control mode: %v", ctrlMode)
+// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
+func fpRegSet(useXsave bool) uintptr {
+ if useXsave {
+ return linux.NT_X86_XSTATE
}
- return nil
-}
-
-func (ep *Endpoint) doRoundTrip() error {
- return ep.doFutexRoundTrip()
-}
-
-func (ep *Endpoint) doWaitFirst() error {
- return ep.doFutexWaitFirst()
-}
-
-func (ep *Endpoint) doNotifyLast() error {
- return ep.doFutexNotifyLast()
+ return linux.NT_PRFPREG
}
-// Preconditions: ep.isShutdown() == true.
-func (ep *Endpoint) interruptForShutdown() {
- ep.doFutexInterruptForShutdown()
+func stackPointer(r *syscall.PtraceRegs) uintptr {
+ return uintptr(r.Rsp)
}
diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64.go b/pkg/sentry/platform/ptrace/ptrace_arm64.go
new file mode 100644
index 000000000..4db28c534
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_arm64.go
@@ -0,0 +1,30 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ptrace
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
+func fpRegSet(_ bool) uintptr {
+ return linux.NT_PRFPREG
+}
+
+func stackPointer(r *syscall.PtraceRegs) uintptr {
+ return uintptr(r.Sp)
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 2706039a5..47957bb3b 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -18,37 +18,23 @@ import (
"syscall"
"unsafe"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
-// GETREGSET/SETREGSET register set types.
-//
-// See include/uapi/linux/elf.h.
-const (
- // _NT_PRFPREG is for x86 floating-point state without using xsave.
- _NT_PRFPREG = 0x2
-
- // _NT_X86_XSTATE is for x86 extended state using xsave.
- _NT_X86_XSTATE = 0x202
-)
-
-// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
-func fpRegSet(useXsave bool) uintptr {
- if useXsave {
- return _NT_X86_XSTATE
- }
- return _NT_PRFPREG
-}
-
-// getRegs sets the regular register set.
+// getRegs gets the general purpose register set.
func (t *thread) getRegs(regs *syscall.PtraceRegs) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(regs)),
+ Len: uint64(unsafe.Sizeof(*regs)),
+ }
_, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
- syscall.PTRACE_GETREGS,
+ syscall.PTRACE_GETREGSET,
uintptr(t.tid),
- 0,
- uintptr(unsafe.Pointer(regs)),
+ linux.NT_PRSTATUS,
+ uintptr(unsafe.Pointer(&iovec)),
0, 0)
if errno != 0 {
return errno
@@ -56,14 +42,18 @@ func (t *thread) getRegs(regs *syscall.PtraceRegs) error {
return nil
}
-// setRegs sets the regular register set.
+// setRegs sets the general purpose register set.
func (t *thread) setRegs(regs *syscall.PtraceRegs) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(regs)),
+ Len: uint64(unsafe.Sizeof(*regs)),
+ }
_, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
- syscall.PTRACE_SETREGS,
+ syscall.PTRACE_SETREGSET,
uintptr(t.tid),
- 0,
- uintptr(unsafe.Pointer(regs)),
+ linux.NT_PRSTATUS,
+ uintptr(unsafe.Pointer(&iovec)),
0, 0)
if errno != 0 {
return errno
@@ -131,7 +121,7 @@ func (t *thread) getSignalInfo(si *arch.SignalInfo) error {
//
// Precondition: the OS thread must be locked and own t.
func (t *thread) clone() (*thread, error) {
- r, ok := usermem.Addr(t.initRegs.Rsp).RoundUp()
+ r, ok := usermem.Addr(stackPointer(&t.initRegs)).RoundUp()
if !ok {
return nil, syscall.EINVAL
}
diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s
new file mode 100644
index 000000000..2c5e4d5cb
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/stub_arm64.s
@@ -0,0 +1,106 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+#define SYS_GETPID 172
+#define SYS_EXIT 93
+#define SYS_KILL 129
+#define SYS_GETPPID 173
+#define SYS_PRCTL 167
+
+#define SIGKILL 9
+#define SIGSTOP 19
+
+#define PR_SET_PDEATHSIG 1
+
+// stub bootstraps the child and sends itself SIGSTOP to wait for attach.
+//
+// R7 contains the expected PPID.
+//
+// This should not be used outside the context of a new ptrace child (as the
+// function is otherwise a bunch of nonsense).
+TEXT ·stub(SB),NOSPLIT,$0
+begin:
+ // N.B. This loop only executes in the context of a single-threaded
+ // fork child.
+
+ MOVD $SYS_PRCTL, R8
+ MOVD $PR_SET_PDEATHSIG, R0
+ MOVD $SIGKILL, R1
+ SVC
+
+ CMN $4095, R0
+ BCS error
+
+ // If the parent already died before we called PR_SET_DEATHSIG then
+ // we'll have an unexpected PPID.
+ MOVD $SYS_GETPPID, R8
+ SVC
+
+ CMP R0, R7
+ BNE parent_dead
+
+ MOVD $SYS_GETPID, R8
+ SVC
+
+ CMP $0x0, R0
+ BLT error
+
+ // SIGSTOP to wait for attach.
+ //
+ // The SYSCALL instruction will be used for future syscall injection by
+ // thread.syscall.
+ MOVD $SYS_KILL, R8
+ MOVD $SIGSTOP, R1
+ SVC
+ // The tracer may "detach" and/or allow code execution here in three cases:
+ //
+ // 1. New (traced) stub threads are explicitly detached by the
+ // goroutine in newSubprocess. However, they are detached while in
+ // group-stop, so they do not execute code here.
+ //
+ // 2. If a tracer thread exits, it implicitly detaches from the stub,
+ // potentially allowing code execution here. However, the Go runtime
+ // never exits individual threads, so this case never occurs.
+ //
+ // 3. subprocess.createStub clones a new stub process that is untraced,
+ // thus executing this code. We setup the PDEATHSIG before SIGSTOPing
+ // ourselves for attach by the tracer.
+ //
+ // R7 has been updated with the expected PPID.
+ B begin
+
+error:
+ // Exit with -errno.
+ NEG R0, R0
+ MOVD $SYS_EXIT, R8
+ SVC
+ HLT
+
+parent_dead:
+ MOVD $SYS_EXIT, R8
+ MOVD $1, R0
+ SVC
+ HLT
+
+// stubCall calls the stub function at the given address with the given PPID.
+//
+// This is a distinct function because stub, above, may be mapped at any
+// arbitrary location, and stub has a specific binary API (see above).
+TEXT ·stubCall(SB),NOSPLIT,$0-16
+ MOVD addr+0(FP), R0
+ MOVD pid+8(FP), R7
+ B (R0)
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 15e84735e..6bf7cd097 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -28,6 +28,16 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
+// Linux kernel errnos which "should never be seen by user programs", but will
+// be revealed to ptrace syscall exit tracing.
+//
+// These constants are only used in subprocess.go.
+const (
+ ERESTARTSYS = syscall.Errno(512)
+ ERESTARTNOINTR = syscall.Errno(513)
+ ERESTARTNOHAND = syscall.Errno(514)
+)
+
// globalPool exists to solve two distinct problems:
//
// 1) Subprocesses can't always be killed properly (see Release).
@@ -282,7 +292,7 @@ func (t *thread) grabInitRegs() {
if err := t.getRegs(&t.initRegs); err != nil {
panic(fmt.Sprintf("ptrace get regs failed: %v", err))
}
- t.initRegs.Rip -= initRegsRipAdjustment
+ t.adjustInitRegsRip()
}
// detach detaches from the thread.
@@ -344,6 +354,9 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
continue // Spurious stop.
}
if stopSig == syscall.SIGTRAP {
+ if status.TrapCause() == syscall.PTRACE_EVENT_EXIT {
+ t.dumpAndPanic("wait failed: the process exited")
+ }
// Re-encode the trap cause the way it's expected.
return stopSig | syscall.Signal(status.TrapCause()<<8)
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index a70512913..4649a94a7 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -28,20 +28,13 @@ const (
// maximumUserAddress is the largest possible user address.
maximumUserAddress = 0x7ffffffff000
+ // stubInitAddress is the initial attempt link address for the stub.
+ stubInitAddress = 0x7fffffff0000
+
// initRegsRipAdjustment is the size of the syscall instruction.
initRegsRipAdjustment = 2
)
-// Linux kernel errnos which "should never be seen by user programs", but will
-// be revealed to ptrace syscall exit tracing.
-//
-// These constants are used in subprocess.go.
-const (
- ERESTARTSYS = syscall.Errno(512)
- ERESTARTNOINTR = syscall.Errno(513)
- ERESTARTNOHAND = syscall.Errno(514)
-)
-
// resetSysemuRegs sets up emulation registers.
//
// This should be called prior to calling sysemu.
@@ -139,3 +132,14 @@ func dumpRegs(regs *syscall.PtraceRegs) string {
return m.String()
}
+
+// adjustInitregsRip adjust the current register RIP value to
+// be just before the system call instruction excution
+func (t *thread) adjustInitRegsRip() {
+ t.initRegs.Rip -= initRegsRipAdjustment
+}
+
+// Pass the expected PPID to the child via R15 when creating stub process
+func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
+ initregs.R15 = uint64(ppid)
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
new file mode 100644
index 000000000..bec884ba5
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -0,0 +1,126 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ptrace
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+const (
+ // maximumUserAddress is the largest possible user address.
+ maximumUserAddress = 0xfffffffff000
+
+ // stubInitAddress is the initial attempt link address for the stub.
+ // Only support 48bits VA currently.
+ stubInitAddress = 0xffffffff0000
+
+ // initRegsRipAdjustment is the size of the svc instruction.
+ initRegsRipAdjustment = 4
+)
+
+// resetSysemuRegs sets up emulation registers.
+//
+// This should be called prior to calling sysemu.
+func (s *subprocess) resetSysemuRegs(regs *syscall.PtraceRegs) {
+}
+
+// createSyscallRegs sets up syscall registers.
+//
+// This should be called to generate registers for a system call.
+func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs {
+ // Copy initial registers (Pc, Sp, etc.).
+ regs := *initRegs
+
+ // Set our syscall number.
+ // r8 for the syscall number.
+ // r0-r6 is used to store the parameters.
+ regs.Regs[8] = uint64(sysno)
+ if len(args) >= 1 {
+ regs.Regs[0] = args[0].Uint64()
+ }
+ if len(args) >= 2 {
+ regs.Regs[1] = args[1].Uint64()
+ }
+ if len(args) >= 3 {
+ regs.Regs[2] = args[2].Uint64()
+ }
+ if len(args) >= 4 {
+ regs.Regs[3] = args[3].Uint64()
+ }
+ if len(args) >= 5 {
+ regs.Regs[4] = args[4].Uint64()
+ }
+ if len(args) >= 6 {
+ regs.Regs[5] = args[5].Uint64()
+ }
+
+ return regs
+}
+
+// isSingleStepping determines if the registers indicate single-stepping.
+func isSingleStepping(regs *syscall.PtraceRegs) bool {
+ // Refer to the ARM SDM D2.12.3: software step state machine
+ // return (regs.Pstate.SS == 1) && (MDSCR_EL1.SS == 1).
+ //
+ // Since the host Linux kernel will set MDSCR_EL1.SS on our behalf
+ // when we call a single-step ptrace command, we only need to check
+ // the Pstate.SS bit here.
+ return (regs.Pstate & arch.ARMTrapFlag) != 0
+}
+
+// updateSyscallRegs updates registers after finishing sysemu.
+func updateSyscallRegs(regs *syscall.PtraceRegs) {
+ // No special work is necessary.
+ return
+}
+
+// syscallReturnValue extracts a sensible return from registers.
+func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
+ rval := int64(regs.Regs[0])
+ if rval < 0 {
+ return 0, syscall.Errno(-rval)
+ }
+ return uintptr(rval), nil
+}
+
+func dumpRegs(regs *syscall.PtraceRegs) string {
+ var m strings.Builder
+
+ fmt.Fprintf(&m, "Registers:\n")
+
+ for i := 0; i < 31; i++ {
+ fmt.Fprintf(&m, "\tRegs[%d]\t = %016x\n", i, regs.Regs[i])
+ }
+ fmt.Fprintf(&m, "\tSp\t = %016x\n", regs.Sp)
+ fmt.Fprintf(&m, "\tPc\t = %016x\n", regs.Pc)
+ fmt.Fprintf(&m, "\tPstate\t = %016x\n", regs.Pstate)
+
+ return m.String()
+}
+
+// adjustInitregsRip adjust the current register RIP value to
+// be just before the system call instruction excution
+func (t *thread) adjustInitRegsRip() {
+ t.initRegs.Pc -= initRegsRipAdjustment
+}
+
+// Pass the expected PPID to the child via X7 when creating stub process
+func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
+ initregs.Regs[7] = uint64(ppid)
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index 87ded0bbd..f09b0b3d0 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -284,7 +284,7 @@ func (s *subprocess) createStub() (*thread, error) {
// Pass the expected PPID to the child via R15.
regs := t.initRegs
- regs.R15 = uint64(t.tgid)
+ initChildProcessPPID(&regs, t.tgid)
// Call fork in a subprocess.
//
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
index e977992f9..de6783fb0 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux_amd64_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 linux
+// +build linux
+// +build amd64 arm64
package ptrace
diff --git a/pkg/sentry/safemem/io.go b/pkg/sentry/safemem/io.go
index 5c3d73eb7..f039a5c34 100644
--- a/pkg/sentry/safemem/io.go
+++ b/pkg/sentry/safemem/io.go
@@ -157,7 +157,8 @@ func (w ToIOWriter) Write(src []byte) (int, error) {
}
// FromIOReader implements Reader for an io.Reader by repeatedly invoking
-// io.Reader.Read until it returns an error or partial read.
+// io.Reader.Read until it returns an error or partial read. This is not
+// thread-safe.
//
// FromIOReader will return a successful partial read iff Reader.Read does so.
type FromIOReader struct {
@@ -206,6 +207,58 @@ func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) {
return wbn, buf, rerr
}
+// FromIOReaderAt implements Reader for an io.ReaderAt. Does not repeatedly
+// invoke io.ReaderAt.ReadAt because ReadAt is more strict than Read. A partial
+// read indicates an error. This is not thread-safe.
+type FromIOReaderAt struct {
+ ReaderAt io.ReaderAt
+ Offset int64
+}
+
+// ReadToBlocks implements Reader.ReadToBlocks.
+func (r FromIOReaderAt) ReadToBlocks(dsts BlockSeq) (uint64, error) {
+ var buf []byte
+ var done uint64
+ for !dsts.IsEmpty() {
+ dst := dsts.Head()
+ var n int
+ var err error
+ n, buf, err = r.readToBlock(dst, buf)
+ done += uint64(n)
+ if n != dst.Len() {
+ return done, err
+ }
+ dsts = dsts.Tail()
+ if err != nil {
+ if dsts.IsEmpty() && err == io.EOF {
+ return done, nil
+ }
+ return done, err
+ }
+ }
+ return done, nil
+}
+
+func (r FromIOReaderAt) readToBlock(dst Block, buf []byte) (int, []byte, error) {
+ // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require
+ // safecopy.
+ if !dst.NeedSafecopy() {
+ n, err := r.ReaderAt.ReadAt(dst.ToSlice(), r.Offset)
+ r.Offset += int64(n)
+ return n, buf, err
+ }
+ if len(buf) < dst.Len() {
+ buf = make([]byte, dst.Len())
+ }
+ rn, rerr := r.ReaderAt.ReadAt(buf[:dst.Len()], r.Offset)
+ r.Offset += int64(rn)
+ wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn]))
+ if wberr != nil {
+ return wbn, buf, wberr
+ }
+ return wbn, buf, rerr
+}
+
// FromIOWriter implements Writer for an io.Writer by repeatedly invoking
// io.Writer.Write until it returns an error or partial write.
//
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 2b03ea87c..3300f9a6b 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -9,6 +9,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/binary",
"//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD
index 1f014f399..e927821e1 100644
--- a/pkg/sentry/socket/epsocket/BUILD
+++ b/pkg/sentry/socket/epsocket/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
+ "//pkg/sentry/socket/netfilter",
"//pkg/sentry/unimpl",
"//pkg/sentry/usermem",
"//pkg/syserr",
@@ -38,6 +39,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index e57aed927..635042263 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -43,6 +43,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
@@ -290,18 +291,22 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// GetAddress reads an sockaddr struct from the given address and converts it
-// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
-// addresses.
-func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and
+// AF_INET6 addresses.
+//
+// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
+//
+// AddressAndFamily returns an address, its family.
+func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
}
family := usermem.ByteOrder.Uint16(addr)
if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
}
// Get the rest of the fields based on the address family.
@@ -309,7 +314,7 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
case linux.AF_UNIX:
path := addr[2:]
if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
// Drop the terminating NUL (if one exists) and everything after
// it for filesystem (non-abstract) addresses.
@@ -320,12 +325,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
}
return tcpip.FullAddress{
Addr: tcpip.Address(path),
- }, nil
+ }, family, nil
case linux.AF_INET:
var a linux.SockAddrInet
if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
@@ -333,12 +338,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
Addr: bytesToIPAddress(a.Addr[:]),
Port: ntohs(a.Port),
}
- return out, nil
+ return out, family, nil
case linux.AF_INET6:
var a linux.SockAddrInet6
if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
@@ -349,13 +354,13 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
if isLinkLocal(out.Addr) {
out.NIC = tcpip.NICID(a.Scope_id)
}
- return out, nil
+ return out, family, nil
case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, nil
+ return tcpip.FullAddress{}, family, nil
default:
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
}
}
@@ -428,6 +433,11 @@ func (i *ioSequencePayload) Size() int {
return int(i.src.NumBytes())
}
+// DropFirst drops the first n bytes from underlying src.
+func (i *ioSequencePayload) DropFirst(n int) {
+ i.src = i.src.DropFirst(int(n))
+}
+
// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
f := &ioSequencePayload{ctx: ctx, src: src}
@@ -476,11 +486,18 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, false /* strict */)
+ addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */)
if err != nil {
return err
}
+ if family == linux.AF_UNSPEC {
+ err := s.Endpoint.Disconnect()
+ if err == tcpip.ErrNotSupported {
+ return syserr.ErrAddressFamilyNotSupported
+ }
+ return syserr.TranslateNetstackError(err)
+ }
// Always return right away in the non-blocking case.
if !blocking {
return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
@@ -509,7 +526,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, true /* strict */)
+ addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */)
if err != nil {
return err
}
@@ -547,7 +564,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait
// Accept implements the linux syscall accept(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
ep, wq, terr := s.Endpoint.Accept()
if terr != nil {
@@ -574,7 +591,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
ns.SetFlags(flags.Settable())
}
- var addr interface{}
+ var addr linux.SockAddr
var addrLen uint32
if peerRequested {
// Get address of the peer and write it to peer slice.
@@ -624,7 +641,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for epsocket.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -655,6 +672,33 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (
return val, nil
}
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ info, err := netfilter.GetInfo(t, s.Endpoint, outPtr)
+ if err != nil {
+ return nil, err
+ }
+ return info, nil
+
+ case linux.IPT_SO_GET_ENTRIES:
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ entries, err := netfilter.GetEntries(t, s.Endpoint, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return entries, nil
+
+ }
+ }
+
return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
}
@@ -1028,7 +1072,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
- return a.(linux.SockAddrInet).Addr, nil
+ return a.(*linux.SockAddrInet).Addr, nil
case linux.IP_MULTICAST_LOOP:
if outLen < sizeOfInt32 {
@@ -1658,7 +1702,7 @@ func isLinkLocal(addr tcpip.Address) bool {
}
// ConvertAddress converts the given address to a native format.
-func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
+func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
switch family {
case linux.AF_UNIX:
var out linux.SockAddrUnix
@@ -1674,15 +1718,15 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
// address length is the max. Abstract and empty paths always return
// the full exact length.
if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
- return out, uint32(2 + l)
+ return &out, uint32(2 + l)
}
- return out, uint32(3 + l)
+ return &out, uint32(3 + l)
case linux.AF_INET:
var out linux.SockAddrInet
copy(out.Addr[:], addr.Addr)
out.Family = linux.AF_INET
out.Port = htons(addr.Port)
- return out, uint32(binary.Size(out))
+ return &out, uint32(binary.Size(out))
case linux.AF_INET6:
var out linux.SockAddrInet6
if len(addr.Addr) == 4 {
@@ -1698,7 +1742,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
if isLinkLocal(addr.Addr) {
out.Scope_id = uint32(addr.NIC)
}
- return out, uint32(binary.Size(out))
+ return &out, uint32(binary.Size(out))
default:
return nil, 0
}
@@ -1706,7 +1750,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -1718,7 +1762,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -1791,7 +1835,7 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) {
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
-func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
isPacket := s.isPacketBased()
// Fast path for regular reads from stream (e.g., TCP) endpoints. Note
@@ -1839,7 +1883,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
if err == nil {
s.updateTimestamp()
}
- var addr interface{}
+ var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
addr, addrLen = ConvertAddress(s.family, s.sender)
@@ -1914,7 +1958,7 @@ func (s *SocketOperations) updateTimestamp() {
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -1990,7 +2034,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, err := GetAddress(s.family, to, true /* strict */)
+ addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */)
if err != nil {
return 0, err
}
@@ -1998,28 +2042,22 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
addr = &addrBuf
}
- v := buffer.NewView(int(src.NumBytes()))
-
- // Copy all the data into the buffer.
- if _, err := src.CopyIn(t, v); err != nil {
- return 0, syserr.FromError(err)
- }
-
opts := tcpip.WriteOptions{
To: addr,
More: flags&linux.MSG_MORE != 0,
EndOfRecord: flags&linux.MSG_EOR != 0,
}
- n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ v := &ioSequencePayload{t, src}
+ n, resCh, err := s.Endpoint.Write(v, opts)
if resCh != nil {
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err)
}
- n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ n, _, err = s.Endpoint.Write(v, opts)
}
dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= uintptr(len(v)) || dontWait) {
+ if err == nil && (n >= int64(v.Size()) || dontWait) {
// Complete write.
return int(n), nil
}
@@ -2033,18 +2071,18 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
s.EventRegister(&e, waiter.EventOut)
defer s.EventUnregister(&e)
- v.TrimFront(int(n))
+ v.DropFirst(int(n))
total := n
for {
- n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
- v.TrimFront(int(n))
+ n, _, err = s.Endpoint.Write(v, opts)
+ v.DropFirst(int(n))
total += n
if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
return 0, syserr.TranslateNetstackError(err)
}
- if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock {
return int(total), nil
}
@@ -2252,19 +2290,19 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
case syscall.SIOCGIFMAP:
// Gets the hardware parameters of the device.
- // TODO(b/71872867): Implement.
+ // TODO(gvisor.dev/issue/505): Implement.
case syscall.SIOCGIFTXQLEN:
// Gets the transmit queue length of the device.
- // TODO(b/71872867): Implement.
+ // TODO(gvisor.dev/issue/505): Implement.
case syscall.SIOCGIFDSTADDR:
// Gets the destination address of a point-to-point device.
- // TODO(b/71872867): Implement.
+ // TODO(gvisor.dev/issue/505): Implement.
case syscall.SIOCGIFBRDADDR:
// Gets the broadcast address of a device.
- // TODO(b/71872867): Implement.
+ // TODO(gvisor.dev/issue/505): Implement.
case syscall.SIOCGIFNETMASK:
// Gets the network mask of a device.
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go
index 8fe489c0e..7cf7ff735 100644
--- a/pkg/sentry/socket/epsocket/stack.go
+++ b/pkg/sentry/socket/epsocket/stack.go
@@ -18,7 +18,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -143,3 +146,57 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
func (s *Stack) Statistics(stat interface{}, arg string) error {
return syserr.ErrEndpointOperation.ToError()
}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ var routeTable []inet.Route
+
+ for _, rt := range s.Stack.GetRouteTable() {
+ var family uint8
+ switch len(rt.Destination.ID()) {
+ case header.IPv4AddressSize:
+ family = linux.AF_INET
+ case header.IPv6AddressSize:
+ family = linux.AF_INET6
+ default:
+ log.Warningf("Unknown network protocol in route %+v", rt)
+ continue
+ }
+
+ routeTable = append(routeTable, inet.Route{
+ Family: family,
+ DstLen: uint8(rt.Destination.Prefix()), // The CIDR prefix for the destination.
+
+ // Always return unspecified protocol since we have no notion of
+ // protocol for routes.
+ Protocol: linux.RTPROT_UNSPEC,
+ // Set statically to LINK scope for now.
+ //
+ // TODO(gvisor.dev/issue/595): Set scope for routes.
+ Scope: linux.RT_SCOPE_LINK,
+ Type: linux.RTN_UNICAST,
+
+ DstAddr: []byte(rt.Destination.ID()),
+ OutputInterface: int32(rt.NIC),
+ GatewayAddr: []byte(rt.Gateway),
+ })
+ }
+
+ return routeTable
+}
+
+// IPTables returns the stack's iptables.
+func (s *Stack) IPTables() (iptables.IPTables, error) {
+ return s.Stack.IPTables(), nil
+}
+
+// FillDefaultIPTables sets the stack's iptables to the default tables, which
+// allow and do not modify all traffic.
+func (s *Stack) FillDefaultIPTables() {
+ netfilter.FillDefaultIPTables(s.Stack)
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {
+ s.Stack.Resume()
+}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7f69406b7..92beb1bcf 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -189,15 +189,16 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
- var peerAddr []byte
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ var peerAddr linux.SockAddr
+ var peerAddrBuf []byte
var peerAddrlen uint32
var peerAddrPtr *byte
var peerAddrlenPtr *uint32
if peerRequested {
- peerAddr = make([]byte, sizeofSockaddr)
- peerAddrlen = uint32(len(peerAddr))
- peerAddrPtr = &peerAddr[0]
+ peerAddrBuf = make([]byte, sizeofSockaddr)
+ peerAddrlen = uint32(len(peerAddrBuf))
+ peerAddrPtr = &peerAddrBuf[0]
peerAddrlenPtr = &peerAddrlen
}
@@ -222,7 +223,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
if peerRequested {
- peerAddr = peerAddr[:peerAddrlen]
+ peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen])
}
if syscallErr != nil {
return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
@@ -272,7 +273,7 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -353,7 +354,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
// Whitelist flags.
//
// FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
@@ -363,9 +364,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
}
- var senderAddr []byte
+ var senderAddr linux.SockAddr
+ var senderAddrBuf []byte
if senderRequested {
- senderAddr = make([]byte, sizeofSockaddr)
+ senderAddrBuf = make([]byte, sizeofSockaddr)
}
var msgFlags int
@@ -384,7 +386,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if dsts.NumBlocks() == 1 {
// Skip allocating []syscall.Iovec.
- return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddr)
+ return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf)
}
iovs := iovecsFromBlockSeq(dsts)
@@ -392,15 +394,15 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
Iov: &iovs[0],
Iovlen: uint64(len(iovs)),
}
- if len(senderAddr) != 0 {
- msg.Name = &senderAddr[0]
- msg.Namelen = uint32(len(senderAddr))
+ if len(senderAddrBuf) != 0 {
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(len(senderAddrBuf))
}
n, err := recvmsg(s.fd, &msg, sysflags)
if err != nil {
return 0, err
}
- senderAddr = senderAddr[:msg.Namelen]
+ senderAddrBuf = senderAddrBuf[:msg.Namelen]
msgFlags = int(msg.Flags)
return n, nil
})
@@ -431,7 +433,10 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// We don't allow control messages.
msgFlags &^= linux.MSG_CTRUNC
- return int(n), msgFlags, senderAddr, uint32(len(senderAddr)), socket.ControlMessages{}, syserr.FromError(err)
+ if senderRequested {
+ senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
+ }
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err)
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
index 6c69ba9c7..e69ec38c2 100644
--- a/pkg/sentry/socket/hostinet/socket_unsafe.go
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -18,10 +18,12 @@ import (
"syscall"
"unsafe"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
@@ -91,25 +93,25 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) {
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr := make([]byte, sizeofSockaddr)
addrlen := uint32(len(addr))
_, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
if errno != 0 {
return nil, 0, syserr.FromError(errno)
}
- return addr[:addrlen], addrlen, nil
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr := make([]byte, sizeofSockaddr)
addrlen := uint32(len(addr))
_, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
if errno != 0 {
return nil, 0, syserr.FromError(errno)
}
- return addr[:addrlen], addrlen, nil
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
}
func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) {
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index cc1f66fa1..3a4fdec47 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -46,6 +46,7 @@ type Stack struct {
// Stack is immutable.
interfaces map[int32]inet.Interface
interfaceAddrs map[int32][]inet.InterfaceAddr
+ routes []inet.Route
supportsIPv6 bool
tcpRecvBufSize inet.TCPBufferSize
tcpSendBufSize inet.TCPBufferSize
@@ -66,6 +67,10 @@ func (s *Stack) Configure() error {
return err
}
+ if err := addHostRoutes(s); err != nil {
+ return err
+ }
+
if _, err := os.Stat("/proc/net/if_inet6"); err == nil {
s.supportsIPv6 = true
}
@@ -161,6 +166,60 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
return nil
}
+// ExtractHostRoutes populates the given routes slice with the data from the
+// host route table.
+func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) {
+ var routes []inet.Route
+ for _, routeMsg := range routeMsgs {
+ if routeMsg.Header.Type != syscall.RTM_NEWROUTE {
+ continue
+ }
+
+ var ifRoute syscall.RtMsg
+ binary.Unmarshal(routeMsg.Data[:syscall.SizeofRtMsg], usermem.ByteOrder, &ifRoute)
+ inetRoute := inet.Route{
+ Family: ifRoute.Family,
+ DstLen: ifRoute.Dst_len,
+ SrcLen: ifRoute.Src_len,
+ TOS: ifRoute.Tos,
+ Table: ifRoute.Table,
+ Protocol: ifRoute.Protocol,
+ Scope: ifRoute.Scope,
+ Type: ifRoute.Type,
+ Flags: ifRoute.Flags,
+ }
+
+ // Not clearly documented: syscall.ParseNetlinkRouteAttr will check the
+ // syscall.NetlinkMessage.Header.Type and skip the struct rtmsg
+ // accordingly.
+ attrs, err := syscall.ParseNetlinkRouteAttr(&routeMsg)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid rtattrs: %v", err)
+ }
+
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.RTA_DST:
+ inetRoute.DstAddr = attr.Value
+ case syscall.RTA_SRC:
+ inetRoute.SrcAddr = attr.Value
+ case syscall.RTA_GATEWAY:
+ inetRoute.GatewayAddr = attr.Value
+ case syscall.RTA_OIF:
+ expected := int(binary.Size(inetRoute.OutputInterface))
+ if len(attr.Value) != expected {
+ return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected)
+ }
+ binary.Unmarshal(attr.Value, usermem.ByteOrder, &inetRoute.OutputInterface)
+ }
+ }
+
+ routes = append(routes, inetRoute)
+ }
+
+ return routes, nil
+}
+
func addHostInterfaces(s *Stack) error {
links, err := doNetlinkRouteRequest(syscall.RTM_GETLINK)
if err != nil {
@@ -175,6 +234,20 @@ func addHostInterfaces(s *Stack) error {
return ExtractHostInterfaces(links, addrs, s.interfaces, s.interfaceAddrs)
}
+func addHostRoutes(s *Stack) error {
+ routes, err := doNetlinkRouteRequest(syscall.RTM_GETROUTE)
+ if err != nil {
+ return fmt.Errorf("RTM_GETROUTE failed: %v", err)
+ }
+
+ s.routes, err = ExtractHostRoutes(routes)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func doNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
data, err := syscall.NetlinkRIB(req, syscall.AF_UNSPEC)
if err != nil {
@@ -202,12 +275,20 @@ func readTCPBufferSizeFile(filename string) (inet.TCPBufferSize, error) {
// Interfaces implements inet.Stack.Interfaces.
func (s *Stack) Interfaces() map[int32]inet.Interface {
- return s.interfaces
+ interfaces := make(map[int32]inet.Interface)
+ for k, v := range s.interfaces {
+ interfaces[k] = v
+ }
+ return interfaces
}
// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
- return s.interfaceAddrs
+ addrs := make(map[int32][]inet.InterfaceAddr)
+ for k, v := range s.interfaceAddrs {
+ addrs[k] = append([]inet.InterfaceAddr(nil), v...)
+ }
+ return addrs
}
// SupportsIPv6 implements inet.Stack.SupportsIPv6.
@@ -249,3 +330,11 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
func (s *Stack) Statistics(stat interface{}, arg string) error {
return syserror.EOPNOTSUPP
}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ return append([]inet.Route(nil), s.routes...)
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
new file mode 100644
index 000000000..354a0d6ee
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -0,0 +1,24 @@
+package(licenses = ["notice"])
+
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_library(
+ name = "netfilter",
+ srcs = [
+ "netfilter.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netfilter",
+ # This target depends on netstack and should only be used by epsocket,
+ # which is allowed to depend on netstack.
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/usermem",
+ "//pkg/syserr",
+ "//pkg/tcpip",
+ "//pkg/tcpip/iptables",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
new file mode 100644
index 000000000..9f87c32f1
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -0,0 +1,286 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netfilter helps the sentry interact with netstack's netfilter
+// capabilities.
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// errorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const errorTargetName = "ERROR"
+
+// metadata is opaque to netstack. It holds data that we need to translate
+// between Linux's and netstack's iptables representations.
+type metadata struct {
+ HookEntry [linux.NF_INET_NUMHOOKS]uint32
+ Underflow [linux.NF_INET_NUMHOOKS]uint32
+ NumEntries uint32
+ Size uint32
+}
+
+// GetInfo returns information about iptables.
+func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
+ // Read in the struct and table name.
+ var info linux.IPTGetinfo
+ if _, err := t.CopyIn(outPtr, &info); err != nil {
+ return linux.IPTGetinfo{}, syserr.FromError(err)
+ }
+
+ // Find the appropriate table.
+ table, err := findTable(ep, info.TableName())
+ if err != nil {
+ return linux.IPTGetinfo{}, err
+ }
+
+ // Get the hooks that apply to this table.
+ info.ValidHooks = table.ValidHooks()
+
+ // Grab the metadata struct, which is used to store information (e.g.
+ // the number of entries) that applies to the user's encoding of
+ // iptables, but not netstack's.
+ metadata := table.Metadata().(metadata)
+
+ // Set values from metadata.
+ info.HookEntry = metadata.HookEntry
+ info.Underflow = metadata.Underflow
+ info.NumEntries = metadata.NumEntries
+ info.Size = metadata.Size
+
+ return info, nil
+}
+
+// GetEntries returns netstack's iptables rules encoded for the iptables tool.
+func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
+ // Read in the struct and table name.
+ var userEntries linux.IPTGetEntries
+ if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ return linux.KernelIPTGetEntries{}, syserr.FromError(err)
+ }
+
+ // Find the appropriate table.
+ table, err := findTable(ep, userEntries.TableName())
+ if err != nil {
+ return linux.KernelIPTGetEntries{}, err
+ }
+
+ // Convert netstack's iptables rules to something that the iptables
+ // tool can understand.
+ entries, _, err := convertNetstackToBinary(userEntries.TableName(), table)
+ if err != nil {
+ return linux.KernelIPTGetEntries{}, err
+ }
+ if binary.Size(entries) > uintptr(outLen) {
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
+ }
+
+ return entries, nil
+}
+
+func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) {
+ ipt, err := ep.IPTables()
+ if err != nil {
+ return iptables.Table{}, syserr.FromError(err)
+ }
+ table, ok := ipt.Tables[tableName]
+ if !ok {
+ return iptables.Table{}, syserr.ErrInvalidArgument
+ }
+ return table, nil
+}
+
+// FillDefaultIPTables sets stack's IPTables to the default tables and
+// populates them with metadata.
+func FillDefaultIPTables(stack *stack.Stack) {
+ ipt := iptables.DefaultTables()
+
+ // In order to fill in the metadata, we have to translate ipt from its
+ // netstack format to Linux's giant-binary-blob format.
+ for name, table := range ipt.Tables {
+ _, metadata, err := convertNetstackToBinary(name, table)
+ if err != nil {
+ panic(fmt.Errorf("Unable to set default IP tables: %v", err))
+ }
+ table.SetMetadata(metadata)
+ ipt.Tables[name] = table
+ }
+
+ stack.SetIPTables(ipt)
+}
+
+// convertNetstackToBinary converts the iptables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a bit, reading some offsets,
+// jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
+ // Return values.
+ var entries linux.KernelIPTGetEntries
+ var meta metadata
+
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(name) {
+ return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ }
+ copy(entries.Name[:], name)
+
+ // Deal with the built in chains first (INPUT, OUTPUT, etc.). Each of
+ // these chains ends with an unconditional policy entry.
+ for hook := iptables.Prerouting; hook < iptables.NumHooks; hook++ {
+ chain, ok := table.BuiltinChains[hook]
+ if !ok {
+ // This table doesn't support this hook.
+ continue
+ }
+
+ // Sanity check.
+ if len(chain.Rules) < 1 {
+ return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ }
+
+ for ruleIdx, rule := range chain.Rules {
+ // If this is the first rule of a builtin chain, set
+ // the metadata hook entry point.
+ if ruleIdx == 0 {
+ meta.HookEntry[hook] = entries.Size
+ }
+
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+ entry.TargetOffset += uint16(len(serialized))
+ }
+
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+
+ // The underflow rule is the last rule in the chain,
+ // and is an unconditional rule (i.e. it matches any
+ // packet). This is enforced when saving iptables.
+ if ruleIdx == len(chain.Rules)-1 {
+ meta.Underflow[hook] = entries.Size
+ }
+
+ entries.Size += uint32(entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ meta.NumEntries++
+ }
+
+ }
+
+ // TODO(gvisor.dev/issue/170): Deal with the user chains here. Each of
+ // these starts with an error node holding the chain's name and ends
+ // with an unconditional return.
+
+ // Lastly, each table ends with an unconditional error target rule as
+ // its final entry.
+ errorEntry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+ var errorTarget linux.XTErrorTarget
+ errorTarget.Target.TargetSize = linux.SizeOfXTErrorTarget
+ copy(errorTarget.ErrorName[:], errorTargetName)
+ copy(errorTarget.Target.Name[:], errorTargetName)
+
+ // Serialize and add it to the list of entries.
+ errorTargetBuf := make([]byte, 0, linux.SizeOfXTErrorTarget)
+ serializedErrorTarget := binary.Marshal(errorTargetBuf, usermem.ByteOrder, errorTarget)
+ errorEntry.Elems = append(errorEntry.Elems, serializedErrorTarget...)
+ errorEntry.NextOffset += uint16(len(serializedErrorTarget))
+
+ entries.Size += uint32(errorEntry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, errorEntry)
+ meta.NumEntries++
+ meta.Size = entries.Size
+
+ return entries, meta, nil
+}
+
+func marshalMatcher(matcher iptables.Matcher) []byte {
+ switch matcher.(type) {
+ default:
+ // TODO(gvisor.dev/issue/170): We don't support any matchers yet, so
+ // any call to marshalMatcher will panic.
+ panic(fmt.Errorf("unknown matcher of type %T", matcher))
+ }
+}
+
+func marshalTarget(target iptables.Target) []byte {
+ switch target.(type) {
+ case iptables.UnconditionalAcceptTarget:
+ return marshalUnconditionalAcceptTarget()
+ default:
+ panic(fmt.Errorf("unknown target of type %T", target))
+ }
+}
+
+func marshalUnconditionalAcceptTarget() []byte {
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ Verdict: translateStandardVerdict(iptables.Accept),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+// translateStandardVerdict translates verdicts the same way as the iptables
+// tool.
+func translateStandardVerdict(verdict iptables.Verdict) int32 {
+ switch verdict {
+ case iptables.Accept:
+ return -linux.NF_ACCEPT - 1
+ case iptables.Drop:
+ return -linux.NF_DROP - 1
+ case iptables.Queue:
+ return -linux.NF_QUEUE - 1
+ case iptables.Return:
+ return linux.NF_RETURN
+ case iptables.Jump:
+ // TODO(gvisor.dev/issue/170): Support Jump.
+ panic("Jump isn't supported yet")
+ default:
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+ }
+}
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index fb1ff329c..cc70ac237 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -110,7 +110,7 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader
m.PutAttr(linux.IFLA_ADDRESS, mac)
m.PutAttr(linux.IFLA_BROADCAST, brd)
- // TODO(b/68878065): There are many more attributes.
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
}
return nil
@@ -151,13 +151,69 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader
m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
- // TODO(b/68878065): There are many more attributes.
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
}
}
return nil
}
+// dumpRoutes handles RTM_GETROUTE + NLM_F_DUMP requests.
+func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // RTM_GETROUTE dump requests need not contain anything more than the
+ // netlink header and 1 byte protocol family common to all
+ // NETLINK_ROUTE requests.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network routes.
+ return nil
+ }
+
+ for _, rt := range stack.RouteTable() {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWROUTE,
+ })
+
+ m.Put(linux.RouteMessage{
+ Family: rt.Family,
+ DstLen: rt.DstLen,
+ SrcLen: rt.SrcLen,
+ TOS: rt.TOS,
+
+ // Always return the main table since we don't have multiple
+ // routing tables.
+ Table: linux.RT_TABLE_MAIN,
+ Protocol: rt.Protocol,
+ Scope: rt.Scope,
+ Type: rt.Type,
+
+ Flags: rt.Flags,
+ })
+
+ m.PutAttr(254, []byte{123})
+ if rt.DstLen > 0 {
+ m.PutAttr(linux.RTA_DST, rt.DstAddr)
+ }
+ if rt.SrcLen > 0 {
+ m.PutAttr(linux.RTA_SRC, rt.SrcAddr)
+ }
+ if rt.OutputInterface != 0 {
+ m.PutAttr(linux.RTA_OIF, rt.OutputInterface)
+ }
+ if len(rt.GatewayAddr) > 0 {
+ m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr)
+ }
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+ }
+
+ return nil
+}
+
// ProcessMessage implements netlink.Protocol.ProcessMessage.
func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
// All messages start with a 1 byte protocol family.
@@ -186,6 +242,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageH
return p.dumpLinks(ctx, hdr, data, ms)
case linux.RTM_GETADDR:
return p.dumpAddrs(ctx, hdr, data, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, hdr, data, ms)
default:
return syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index f3d6c1e9b..d0aab293d 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -271,7 +271,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr
}
// Accept implements socket.Socket.Accept.
-func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Netlink sockets never support accept.
return 0, nil, 0, syserr.ErrNotSupported
}
@@ -289,7 +289,7 @@ func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -379,11 +379,11 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
s.mu.Lock()
defer s.mu.Unlock()
- sa := linux.SockAddrNetlink{
+ sa := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
PortID: uint32(s.portID),
}
@@ -391,8 +391,8 @@ func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
- sa := linux.SockAddrNetlink{
+func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ sa := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
// TODO(b/68878065): Support non-kernel peers. For now the peer
// must be the kernel.
@@ -402,8 +402,8 @@ func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
- from := linux.SockAddrNetlink{
+func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ from := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
PortID: 0,
}
@@ -511,6 +511,19 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
return nil
}
+func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_ERROR,
+ })
+
+ m.Put(linux.NetlinkErrorMessage{
+ Error: int32(-err.ToLinux().Number()),
+ Header: hdr,
+ })
+ return nil
+
+}
+
// processMessages handles each message in buf, passing it to the protocol
// handler for final handling.
func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
@@ -545,14 +558,20 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error
continue
}
+ ms := NewMessageSet(s.portID, hdr.Seq)
+ var err *syserr.Error
// TODO(b/68877377): ACKs not supported yet.
if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
- return syserr.ErrNotSupported
- }
+ err = syserr.ErrNotSupported
+ } else {
- ms := NewMessageSet(s.portID, hdr.Seq)
- if err := s.protocol.ProcessMessage(ctx, hdr, data, ms); err != nil {
- return err
+ err = s.protocol.ProcessMessage(ctx, hdr, data, ms)
+ }
+ if err != nil {
+ ms = NewMessageSet(s.portID, hdr.Seq)
+ if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil {
+ return err
+ }
}
if err := s.sendResponse(ctx, ms); err != nil {
diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD
index a536f2e44..a3585e10d 100644
--- a/pkg/sentry/socket/rpcinet/notifier/BUILD
+++ b/pkg/sentry/socket/rpcinet/notifier/BUILD
@@ -6,10 +6,11 @@ go_library(
name = "notifier",
srcs = ["notifier.go"],
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier",
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
"//pkg/sentry/socket/rpcinet/conn",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go
index aa157dd51..7efe4301f 100644
--- a/pkg/sentry/socket/rpcinet/notifier/notifier.go
+++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go
@@ -20,6 +20,7 @@ import (
"sync"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
"gvisor.dev/gvisor/pkg/waiter"
@@ -76,7 +77,7 @@ func (n *Notifier) waitFD(fd uint32, fi *fdInfo, mask waiter.EventMask) error {
}
e := pb.EpollEvent{
- Events: mask.ToLinux() | -syscall.EPOLLET,
+ Events: mask.ToLinux() | unix.EPOLLET,
Fd: fd,
}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index ccaaddbfc..ddb76d9d4 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -285,7 +285,7 @@ func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultP
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
payload, se := rpcAccept(t, s.fd, peerRequested)
// Check if we need to block.
@@ -328,6 +328,9 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
NonBlocking: flags&linux.SOCK_NONBLOCK != 0,
}
file := fs.NewFile(t, dirent, fileFlags, &socketOperations{
+ family: s.family,
+ stype: s.stype,
+ protocol: s.protocol,
wq: &wq,
fd: payload.Fd,
rpcConn: s.rpcConn,
@@ -344,7 +347,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
t.Kernel().RecordSocket(file)
if peerRequested {
- return fd, payload.Address.Address, payload.Address.Length, nil
+ return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil
}
return fd, nil, 0, nil
@@ -395,7 +398,7 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
// SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed
// within the sentry.
if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
@@ -469,7 +472,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
<-c
@@ -480,11 +483,11 @@ func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy
}
addr := res.(*pb.GetPeerNameResponse_Address).Address
- return addr.Address, addr.Length, nil
+ return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
stack := t.NetworkContext().(*Stack)
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
<-c
@@ -495,7 +498,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy
}
addr := res.(*pb.GetSockNameResponse_Address).Address
- return addr.Address, addr.Length, nil
+ return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
}
func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) {
@@ -682,7 +685,7 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
Fd: s.fd,
Length: uint32(dst.NumBytes()),
@@ -703,7 +706,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
return 0, 0, nil, 0, socket.ControlMessages{}, err
@@ -727,7 +730,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
c := s.extractControlMessages(res)
- return int(res.Length), 0, res.Address.GetAddress(), res.Address.GetLength(), c, syserr.FromError(e)
+ return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
}
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
return 0, 0, nil, 0, socket.ControlMessages{}, err
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
index 49bd3a220..5dcb6b455 100644
--- a/pkg/sentry/socket/rpcinet/stack.go
+++ b/pkg/sentry/socket/rpcinet/stack.go
@@ -30,6 +30,7 @@ import (
type Stack struct {
interfaces map[int32]inet.Interface
interfaceAddrs map[int32][]inet.InterfaceAddr
+ routes []inet.Route
rpcConn *conn.RPCConnection
notifier *notifier.Notifier
}
@@ -69,6 +70,16 @@ func NewStack(fd int32) (*Stack, error) {
return nil, e
}
+ routes, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETROUTE)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETROUTE failed: %v", err)
+ }
+
+ stack.routes, e = hostinet.ExtractHostRoutes(routes)
+ if e != nil {
+ return nil, e
+ }
+
return stack, nil
}
@@ -89,12 +100,20 @@ func (s *Stack) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
// Interfaces implements inet.Stack.Interfaces.
func (s *Stack) Interfaces() map[int32]inet.Interface {
- return s.interfaces
+ interfaces := make(map[int32]inet.Interface)
+ for k, v := range s.interfaces {
+ interfaces[k] = v
+ }
+ return interfaces
}
// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
- return s.interfaceAddrs
+ addrs := make(map[int32][]inet.InterfaceAddr)
+ for k, v := range s.interfaceAddrs {
+ addrs[k] = append([]inet.InterfaceAddr(nil), v...)
+ }
+ return addrs
}
// SupportsIPv6 implements inet.Stack.SupportsIPv6.
@@ -138,3 +157,11 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
func (s *Stack) Statistics(stat interface{}, arg string) error {
return syserr.ErrEndpointOperation.ToError()
}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ return append([]inet.Route(nil), s.routes...)
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 0efa58a58..8c250c325 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -20,8 +20,10 @@ package socket
import (
"fmt"
"sync/atomic"
+ "syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -52,7 +54,7 @@ type Socket interface {
// Accept implements the accept4(2) linux syscall.
// Returns fd, real peer address length and error. Real peer address
// length is only set if len(peer) > 0.
- Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error)
+ Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error)
// Bind implements the bind(2) linux syscall.
Bind(t *kernel.Task, sockaddr []byte) *syserr.Error
@@ -64,7 +66,7 @@ type Socket interface {
Shutdown(t *kernel.Task, how int) *syserr.Error
// GetSockOpt implements the getsockopt(2) linux syscall.
- GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error)
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
// SetSockOpt implements the setsockopt(2) linux syscall.
SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
@@ -73,13 +75,13 @@ type Socket interface {
//
// addrLen is the address length to be returned to the application, not
// necessarily the actual length of the address.
- GetSockName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error)
+ GetSockName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error)
// GetPeerName implements the getpeername(2) linux syscall.
//
// addrLen is the address length to be returned to the application, not
// necessarily the actual length of the address.
- GetPeerName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error)
+ GetPeerName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error)
// RecvMsg implements the recvmsg(2) linux syscall.
//
@@ -92,7 +94,7 @@ type Socket interface {
// msgFlags. In that case, the caller should set MSG_CTRUNC appropriately.
//
// If err != nil, the recv was not successful.
- RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
+ RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
// SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take
// ownership of the ControlMessage on error.
@@ -340,3 +342,31 @@ func emitUnimplementedEvent(t *kernel.Task, name int) {
t.Kernel().EmitUnimplementedEvent(t)
}
}
+
+// UnmarshalSockAddr unmarshals memory representing a struct sockaddr to one of
+// the ABI socket address types.
+//
+// Precondition: data must be long enough to represent a socket address of the
+// given family.
+func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
+ switch family {
+ case syscall.AF_INET:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(data[:syscall.SizeofSockaddrInet4], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_INET6:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(data[:syscall.SizeofSockaddrInet6], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_UNIX:
+ var addr linux.SockAddrUnix
+ binary.Unmarshal(data[:syscall.SizeofSockaddrUnix], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_NETLINK:
+ var addr linux.SockAddrNetlink
+ binary.Unmarshal(data[:syscall.SizeofSockaddrNetlink], usermem.ByteOrder, &addr)
+ return &addr
+ default:
+ panic(fmt.Sprintf("Unsupported socket family %v", family))
+ }
+}
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 760c7beab..2ec1a662d 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -62,7 +62,7 @@ type EndpointReader struct {
Creds bool
// NumRights is the number of SCM_RIGHTS FDs requested.
- NumRights uintptr
+ NumRights int
// Peek indicates that the data should not be consumed from the
// endpoint.
@@ -70,7 +70,7 @@ type EndpointReader struct {
// MsgSize is the size of the message that was read from. For stream
// sockets, it is the amount read.
- MsgSize uintptr
+ MsgSize int64
// From, if not nil, will be set with the address read from.
From *tcpip.FullAddress
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 73d2df15d..4bd15808a 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -436,7 +436,7 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
-func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
// Stream sockets do not support specifying the endpoint. Seqpacket
// sockets ignore the passed endpoint.
if e.stype == linux.SOCK_STREAM && to != nil {
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index c7f7c5b16..0322dec0b 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -99,7 +99,7 @@ func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (Con
// SendMsg writes data and a control message to the specified endpoint.
// This method does not block if the data cannot be written.
-func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
if to == nil {
return e.baseEndpoint.SendMsg(ctx, data, c, nil)
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 7fb9cb1e0..2b0ad6395 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -121,13 +121,13 @@ type Endpoint interface {
// CMTruncated indicates that the numRights hint was used to receive fewer
// than the total available SCM_RIGHTS FDs. Additional truncation may be
// required by the caller.
- RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error)
+ RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, err *syserr.Error)
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
//
// SendMsg does not take ownership of any of its arguments on error.
- SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (uintptr, *syserr.Error)
+ SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (int64, *syserr.Error)
// Connect connects this endpoint directly to another.
//
@@ -291,7 +291,7 @@ type Receiver interface {
// See Endpoint.RecvMsg for documentation on shared arguments.
//
// notify indicates if RecvNotify should be called.
- Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
+ Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
// RecvNotify notifies the Receiver of a successful Recv. This must not be
// called while holding any endpoint locks.
@@ -331,7 +331,7 @@ type queueReceiver struct {
}
// Recv implements Receiver.Recv.
-func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
var m *message
var notify bool
var err *syserr.Error
@@ -344,13 +344,13 @@ func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek
return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
}
src := []byte(m.Data)
- var copied uintptr
+ var copied int64
for i := 0; i < len(data) && len(src) > 0; i++ {
n := copy(data[i], src)
- copied += uintptr(n)
+ copied += int64(n)
src = src[n:]
}
- return copied, uintptr(len(m.Data)), m.Control, false, m.Address, notify, nil
+ return copied, int64(len(m.Data)), m.Control, false, m.Address, notify, nil
}
// RecvNotify implements Receiver.RecvNotify.
@@ -401,11 +401,11 @@ type streamQueueReceiver struct {
addr tcpip.FullAddress
}
-func vecCopy(data [][]byte, buf []byte) (uintptr, [][]byte, []byte) {
- var copied uintptr
+func vecCopy(data [][]byte, buf []byte) (int64, [][]byte, []byte) {
+ var copied int64
for len(data) > 0 && len(buf) > 0 {
n := copy(data[0], buf)
- copied += uintptr(n)
+ copied += int64(n)
buf = buf[n:]
data[0] = data[0][n:]
if len(data[0]) == 0 {
@@ -443,7 +443,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
}
// Recv implements Receiver.Recv.
-func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -464,7 +464,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
q.addr = m.Address
}
- var copied uintptr
+ var copied int64
if peek {
// Don't consume control message if we are peeking.
c := q.control.Clone()
@@ -531,7 +531,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uint
break
}
- var cpd uintptr
+ var cpd int64
cpd, data, q.buffer = vecCopy(data, q.buffer)
copied += cpd
@@ -569,7 +569,7 @@ type ConnectedEndpoint interface {
//
// syserr.ErrWouldBlock can be returned along with a partial write if
// the caller should block to send the rest of the data.
- Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err *syserr.Error)
+ Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error)
// SendNotify notifies the ConnectedEndpoint of a successful Send. This
// must not be called while holding any endpoint locks.
@@ -637,7 +637,7 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
}
// Send implements ConnectedEndpoint.Send.
-func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) {
+func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
var l int64
for _, d := range data {
l += int64(len(d))
@@ -665,7 +665,7 @@ func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages,
}
l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate)
- return uintptr(l), notify, err
+ return int64(l), notify, err
}
// SendNotify implements ConnectedEndpoint.SendNotify.
@@ -781,7 +781,7 @@ func (e *baseEndpoint) Connected() bool {
}
// RecvMsg reads data and a control message from the endpoint.
-func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) {
+func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (int64, int64, ControlMessages, bool, *syserr.Error) {
e.Lock()
if e.receiver == nil {
@@ -807,7 +807,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
-func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
e.Lock()
if !e.Connected() {
e.Unlock()
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index eb262ecaf..0d0cb68df 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
return "", err
}
@@ -137,7 +137,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) {
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -149,7 +149,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *sy
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -166,7 +166,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
@@ -199,7 +199,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
ep, err := s.ep.Accept()
if err != nil {
@@ -223,7 +223,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
ns.SetFlags(flags.Settable())
}
- var addr interface{}
+ var addr linux.SockAddr
var addrLen uint32
if peerRequested {
// Get address of the peer.
@@ -505,7 +505,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -535,7 +535,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
Ctx: t,
Endpoint: s.ep,
Creds: wantCreds,
- NumRights: uintptr(numRights),
+ NumRights: numRights,
Peek: peek,
}
if senderRequested {
@@ -543,7 +543,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
var total int64
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
- var from interface{}
+ var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
@@ -578,7 +578,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
for {
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
- var from interface{}
+ var from linux.SockAddr
var fromLen uint32
if r.From != nil {
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD
index f297ef3b7..88765f4d6 100644
--- a/pkg/sentry/state/BUILD
+++ b/pkg/sentry/state/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/log",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
+ "//pkg/sentry/time",
"//pkg/sentry/watchdog",
"//pkg/state/statefile",
"//pkg/syserror",
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
index 026549756..9eb626b76 100644
--- a/pkg/sentry/state/state.go
+++ b/pkg/sentry/state/state.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/state/statefile"
"gvisor.dev/gvisor/pkg/syserror"
@@ -104,7 +105,7 @@ type LoadOpts struct {
}
// Load loads the given kernel, setting the provided platform and stack.
-func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error {
+func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error {
// Open the file.
r, m, err := statefile.NewReader(opts.Source, opts.Key)
if err != nil {
@@ -114,5 +115,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack) error {
previousMetadata = m
// Restore the Kernel object graph.
- return k.LoadFrom(r, n)
+ return k.LoadFrom(r, n, clocks)
}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 386b40af7..f779186ad 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, err := epsocket.GetAddress(int(family), b, true /* strict */)
+ fa, _, err := epsocket.AddressAndFamily(int(family), b, true /* strict */)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index 264301bfa..1d9018c96 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -91,6 +91,10 @@ func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op strin
// TODO(gvisor.dev/issue/161): In some cases SIGPIPE should
// also be sent to the application.
return nil
+ case syserror.ENOSPC:
+ // Similar to EPIPE. Return what we wrote this time, and let
+ // ENOSPC be returned on the next call.
+ return nil
case syserror.ECONNRESET:
// For TCP sendfile connections, we may have a reset. But we
// should just return n as the result.
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 51db2d8f7..ed996ba51 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -30,8 +30,7 @@ import (
const _AUDIT_ARCH_X86_64 = 0xc000003e
// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall
-// numbers from Linux 4.4. The entries commented out are those syscalls we
-// don't currently support.
+// numbers from Linux 4.4.
var AMD64 = &kernel.SyscallTable{
OS: abi.Linux,
Arch: arch.AMD64,
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index 4a2b9f061..65b4a227b 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -107,19 +107,20 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// copyOutEvents copies epoll events from the kernel to user memory.
func copyOutEvents(t *kernel.Task, addr usermem.Addr, e []epoll.Event) error {
const itemLen = 12
- if _, ok := addr.AddLength(uint64(len(e)) * itemLen); !ok {
+ buffLen := len(e) * itemLen
+ if _, ok := addr.AddLength(uint64(buffLen)); !ok {
return syserror.EFAULT
}
- b := t.CopyScratchBuffer(itemLen)
+ b := t.CopyScratchBuffer(buffLen)
for i := range e {
- usermem.ByteOrder.PutUint32(b[0:], e[i].Events)
- usermem.ByteOrder.PutUint32(b[4:], uint32(e[i].Data[0]))
- usermem.ByteOrder.PutUint32(b[8:], uint32(e[i].Data[1]))
- if _, err := t.CopyOutBytes(addr, b); err != nil {
- return err
- }
- addr += itemLen
+ usermem.ByteOrder.PutUint32(b[i*itemLen:], e[i].Events)
+ usermem.ByteOrder.PutUint32(b[i*itemLen+4:], uint32(e[i].Data[0]))
+ usermem.ByteOrder.PutUint32(b[i*itemLen+8:], uint32(e[i].Data[1]))
+ }
+
+ if _, err := t.CopyOutBytes(addr, b); err != nil {
+ return err
}
return nil
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index 63e2c5a5d..912cbe4ff 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -120,7 +120,7 @@ func newDirent(width uint, name string, attr fs.DentAttr, offset uint64) *dirent
Ino: attr.InodeID,
Off: offset,
},
- Typ: toType(attr.Type),
+ Typ: fs.ToDirentType(attr.Type),
},
Name: []byte(name),
}
@@ -142,28 +142,6 @@ func smallestDirent64(a arch.Context) uint {
return uint(binary.Size(d.Hdr)) + a.Width()
}
-// toType converts an fs.InodeOperationsInfo to a linux dirent typ field.
-func toType(nodeType fs.InodeType) uint8 {
- switch nodeType {
- case fs.RegularFile, fs.SpecialFile:
- return linux.DT_REG
- case fs.Symlink:
- return linux.DT_LNK
- case fs.Directory, fs.SpecialDirectory:
- return linux.DT_DIR
- case fs.Pipe:
- return linux.DT_FIFO
- case fs.CharacterDevice:
- return linux.DT_CHR
- case fs.BlockDevice:
- return linux.DT_BLK
- case fs.Socket:
- return linux.DT_SOCK
- default:
- return linux.DT_UNKNOWN
- }
-}
-
// padRec pads the name field until the rec length is a multiple of the width,
// which must be a power of 2. It returns the padded rec length.
func (d *dirent) padRec(width int) uint16 {
diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go
index 9080a10c3..8c13e2d82 100644
--- a/pkg/sentry/syscalls/linux/sys_mount.go
+++ b/pkg/sentry/syscalls/linux/sys_mount.go
@@ -109,9 +109,17 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, syserror.EINVAL
}
- return 0, nil, fileOpOn(t, linux.AT_FDCWD, targetPath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if err := fileOpOn(t, linux.AT_FDCWD, targetPath, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ // Mount will take a reference on rootInode if successful.
return t.MountNamespace().Mount(t, d, rootInode)
- })
+ }); err != nil {
+ // Something went wrong. Drop our ref on rootInode before
+ // returning the error.
+ rootInode.DecRef()
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
}
// Umount2 implements Linux syscall umount2(2).
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index b2474e60d..3ab54271c 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -191,7 +191,6 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
// Preadv2 implements linux syscall preadv2(2).
-// TODO(b/120162627): Implement RWF_HIPRI functionality.
func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
// While the syscall is
// preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
@@ -228,6 +227,8 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
// Check flags field.
+ // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is
+ // accepted as a valid flag argument for preadv2.
if flags&^linux.RWF_VALID != 0 {
return 0, nil, syserror.EOPNOTSUPP
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index fa568a660..3bac4d90d 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -460,7 +460,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
// Call syscall implementation then copy both value and value len out.
- v, e := getSockOpt(t, s, int(level), int(name), int(optLen))
+ v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen))
if e != nil {
return 0, nil, e.ToError()
}
@@ -483,7 +483,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.Socket, level, name, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -505,7 +505,7 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name, len int) (interfac
}
}
- return s.GetSockOpt(t, level, name, len)
+ return s.GetSockOpt(t, level, name, optValAddr, len)
}
// SetSockOpt implements the linux syscall setsockopt(2).
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index a7c98efcb..8a98fedcb 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -91,22 +91,29 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
// Get files.
+ inFile := t.GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ if !inFile.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
outFile := t.GetFile(outFD)
if outFile == nil {
return 0, nil, syserror.EBADF
}
defer outFile.DecRef()
- inFile := t.GetFile(inFD)
- if inFile == nil {
+ if !outFile.Flags().Write {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
- // Verify that the outfile Append flag is not set. Note that fs.Splice
- // itself validates that the output file is writable.
+ // Verify that the outfile Append flag is not set.
if outFile.Flags().Append {
- return 0, nil, syserror.EBADF
+ return 0, nil, syserror.EINVAL
}
// Verify that we have a regular infile. This is a requirement; the
@@ -207,6 +214,10 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ESPIPE
}
if outOffset != 0 {
+ if !outFile.Flags().Pwrite {
+ return 0, nil, syserror.EINVAL
+ }
+
var offset int64
if _, err := t.CopyIn(outOffset, &offset); err != nil {
return 0, nil, err
@@ -220,6 +231,10 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ESPIPE
}
if inOffset != 0 {
+ if !inFile.Flags().Pread {
+ return 0, nil, syserror.EINVAL
+ }
+
var offset int64
if _, err := t.CopyIn(inOffset, &offset); err != nil {
return 0, nil, err
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 595eb9155..8ab7ffa25 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -96,7 +96,7 @@ func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Load the new TaskContext.
maxTraversals := uint(linux.MaxSymlinkTraversals)
- tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, argv, envv, t.Arch().FeatureSet())
+ tc, se := t.Kernel().LoadTaskImage(t, t.MountNamespace(), root, wd, &maxTraversals, filename, nil, argv, envv, t.Arch().FeatureSet())
if se != nil {
return 0, nil, se.ToError()
}
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index 5278c96a6..27cd2c336 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -191,7 +191,6 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
// Pwritev2 implements linux syscall pwritev2(2).
-// TODO(b/120162627): Implement RWF_HIPRI functionality.
// TODO(b/120161091): Implement O_SYNC and D_SYNC functionality.
func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
// While the syscall is
@@ -227,6 +226,8 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, syserror.ESPIPE
}
+ // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is
+ // accepted as a valid flag argument for pwritev2.
if flags&^linux.RWF_VALID != 0 {
return uintptr(flags), nil, syserror.EOPNOTSUPP
}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 4de6c41cf..0f247bf77 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -18,6 +18,7 @@ go_library(
"permissions.go",
"resolving_path.go",
"syscalls.go",
+ "testutil.go",
"vfs.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/vfs",
@@ -40,7 +41,16 @@ go_test(
name = "vfs_test",
size = "small",
srcs = [
+ "file_description_impl_util_test.go",
"mount_test.go",
],
embed = [":vfs"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/context",
+ "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/usermem",
+ "//pkg/syserror",
+ ],
)
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 486893e70..ba230da72 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -15,6 +15,10 @@
package vfs
import (
+ "bytes"
+ "io"
+ "sync"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -24,6 +28,16 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// The following design pattern is strongly recommended for filesystem
+// implementations to adapt:
+// - Have a local fileDescription struct (containing FileDescription) which
+// embeds FileDescriptionDefaultImpl and overrides the default methods
+// which are common to all fd implementations for that for that filesystem
+// like StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc.
+// - This should be embedded in all file description implementations as the
+// first field by value.
+// - Directory FDs would also embed DirectoryFileDescriptionDefaultImpl.
+
// FileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl to obtain implementations of many FileDescriptionImpl
// methods with default behavior analogous to Linux's.
@@ -115,11 +129,8 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg
// DirectoryFileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl that always represent directories to obtain
-// implementations of non-directory I/O methods that return EISDIR, and
-// implementations of other methods consistent with FileDescriptionDefaultImpl.
-type DirectoryFileDescriptionDefaultImpl struct {
- FileDescriptionDefaultImpl
-}
+// implementations of non-directory I/O methods that return EISDIR.
+type DirectoryFileDescriptionDefaultImpl struct{}
// PRead implements FileDescriptionImpl.PRead.
func (DirectoryFileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
@@ -140,3 +151,104 @@ func (DirectoryFileDescriptionDefaultImpl) PWrite(ctx context.Context, src userm
func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
return 0, syserror.EISDIR
}
+
+// DynamicBytesFileDescriptionImpl may be embedded by implementations of
+// FileDescriptionImpl that represent read-only regular files whose contents
+// are backed by a bytes.Buffer that is regenerated when necessary, consistent
+// with Linux's fs/seq_file.c:single_open().
+//
+// DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first
+// use.
+type DynamicBytesFileDescriptionImpl struct {
+ data DynamicBytesSource // immutable
+ mu sync.Mutex // protects the following fields
+ buf bytes.Buffer
+ off int64
+ lastRead int64 // offset at which the last Read, PRead, or Seek ended
+}
+
+// DynamicBytesSource represents a data source for a
+// DynamicBytesFileDescriptionImpl.
+type DynamicBytesSource interface {
+ // Generate writes the file's contents to buf.
+ Generate(ctx context.Context, buf *bytes.Buffer) error
+}
+
+// SetDataSource must be called exactly once on fd before first use.
+func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) {
+ fd.data = data
+}
+
+// Preconditions: fd.mu must be locked.
+func (fd *DynamicBytesFileDescriptionImpl) preadLocked(ctx context.Context, dst usermem.IOSequence, offset int64, opts *ReadOptions) (int64, error) {
+ // Regenerate the buffer if it's empty, or before pread() at a new offset.
+ // Compare fs/seq_file.c:seq_read() => traverse().
+ switch {
+ case offset != fd.lastRead:
+ fd.buf.Reset()
+ fallthrough
+ case fd.buf.Len() == 0:
+ if err := fd.data.Generate(ctx, &fd.buf); err != nil {
+ fd.buf.Reset()
+ // fd.off is not updated in this case.
+ fd.lastRead = 0
+ return 0, err
+ }
+ }
+ bs := fd.buf.Bytes()
+ if offset >= int64(len(bs)) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, bs[offset:])
+ fd.lastRead = offset + int64(n)
+ return int64(n), err
+}
+
+// PRead implements FileDescriptionImpl.PRead.
+func (fd *DynamicBytesFileDescriptionImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.preadLocked(ctx, dst, offset, &opts)
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (fd *DynamicBytesFileDescriptionImpl) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.preadLocked(ctx, dst, fd.off, &opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Seek implements FileDescriptionImpl.Seek.
+func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as given.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ default:
+ // fs/seq_file:seq_lseek() rejects SEEK_END etc.
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset != fd.lastRead {
+ // Regenerate the file's contents immediately. Compare
+ // fs/seq_file.c:seq_lseek() => traverse().
+ fd.buf.Reset()
+ if err := fd.data.Generate(ctx, &fd.buf); err != nil {
+ fd.buf.Reset()
+ fd.off = 0
+ fd.lastRead = 0
+ return 0, err
+ }
+ fd.lastRead = offset
+ }
+ fd.off = offset
+ return offset, nil
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
new file mode 100644
index 000000000..511b829fc
--- /dev/null
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -0,0 +1,141 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "sync/atomic"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// fileDescription is the common fd struct which a filesystem implementation
+// embeds in all of its file description implementations as required.
+type fileDescription struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+}
+
+// genCountFD is a read-only FileDescriptionImpl representing a regular file
+// that contains the number of times its DynamicBytesSource.Generate()
+// implementation has been called.
+type genCountFD struct {
+ fileDescription
+ DynamicBytesFileDescriptionImpl
+
+ count uint64 // accessed using atomic memory ops
+}
+
+func newGenCountFD(mnt *Mount, vfsd *Dentry) *FileDescription {
+ var fd genCountFD
+ fd.vfsfd.Init(&fd, mnt, vfsd)
+ fd.DynamicBytesFileDescriptionImpl.SetDataSource(&fd)
+ return &fd.vfsfd
+}
+
+// Release implements FileDescriptionImpl.Release.
+func (fd *genCountFD) Release() {
+}
+
+// StatusFlags implements FileDescriptionImpl.StatusFlags.
+func (fd *genCountFD) StatusFlags(ctx context.Context) (uint32, error) {
+ return 0, nil
+}
+
+// SetStatusFlags implements FileDescriptionImpl.SetStatusFlags.
+func (fd *genCountFD) SetStatusFlags(ctx context.Context, flags uint32) error {
+ return syserror.EPERM
+}
+
+// Stat implements FileDescriptionImpl.Stat.
+func (fd *genCountFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ // Note that Statx.Mask == 0 in the return value.
+ return linux.Statx{}, nil
+}
+
+// SetStat implements FileDescriptionImpl.SetStat.
+func (fd *genCountFD) SetStat(ctx context.Context, opts SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// Generate implements DynamicBytesSource.Generate.
+func (fd *genCountFD) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d", atomic.AddUint64(&fd.count, 1))
+ return nil
+}
+
+func TestGenCountFD(t *testing.T) {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := New() // vfs.New()
+ vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{})
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &NewFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("failed to create testfs root mount: %v", err)
+ }
+ vd := mntns.Root()
+ defer vd.DecRef()
+
+ fd := newGenCountFD(vd.Mount(), vd.Dentry())
+ defer fd.DecRef()
+
+ // The first read causes Generate to be called to fill the FD's buffer.
+ buf := make([]byte, 2)
+ ioseq := usermem.BytesIOSequence(buf)
+ n, err := fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ if n != 1 || (err != nil && err != io.EOF) {
+ t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err)
+ }
+ if want := byte('1'); buf[0] != want {
+ t.Errorf("first Read: got byte %c, wanted %c", buf[0], want)
+ }
+
+ // A second read without seeking is still at EOF.
+ n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ if n != 0 || err != io.EOF {
+ t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err)
+ }
+
+ // Seeking to the beginning of the file causes it to be regenerated.
+ n, err = fd.Impl().Seek(ctx, 0, linux.SEEK_SET)
+ if n != 0 || err != nil {
+ t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err)
+ }
+ n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ if n != 1 || (err != nil && err != io.EOF) {
+ t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err)
+ }
+ if want := byte('2'); buf[0] != want {
+ t.Errorf("Read after Seek: got byte %c, wanted %c", buf[0], want)
+ }
+
+ // PRead at the beginning of the file also causes it to be regenerated.
+ n, err = fd.Impl().PRead(ctx, ioseq, 0, ReadOptions{})
+ if n != 1 || (err != nil && err != io.EOF) {
+ t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err)
+ }
+ if want := byte('3'); buf[0] != want {
+ t.Errorf("PRead: got byte %c, wanted %c", buf[0], want)
+ }
+}
diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go
new file mode 100644
index 000000000..70b192ece
--- /dev/null
+++ b/pkg/sentry/vfs/testutil.go
@@ -0,0 +1,139 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// FDTestFilesystemType is a test-only FilesystemType that produces Filesystems
+// for which all FilesystemImpl methods taking a path return EPERM. It is used
+// to produce Mounts and Dentries for testing of FileDescriptionImpls that do
+// not depend on their originating Filesystem.
+type FDTestFilesystemType struct{}
+
+// FDTestFilesystem is a test-only FilesystemImpl produced by
+// FDTestFilesystemType.
+type FDTestFilesystem struct {
+ vfsfs Filesystem
+}
+
+// NewFilesystem implements FilesystemType.NewFilesystem.
+func (fstype FDTestFilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) {
+ var fs FDTestFilesystem
+ fs.vfsfs.Init(&fs)
+ return &fs.vfsfs, fs.NewDentry(), nil
+}
+
+// Release implements FilesystemImpl.Release.
+func (fs *FDTestFilesystem) Release() {
+}
+
+// Sync implements FilesystemImpl.Sync.
+func (fs *FDTestFilesystem) Sync(ctx context.Context) error {
+ return nil
+}
+
+// GetDentryAt implements FilesystemImpl.GetDentryAt.
+func (fs *FDTestFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// LinkAt implements FilesystemImpl.LinkAt.
+func (fs *FDTestFilesystem) LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error {
+ return syserror.EPERM
+}
+
+// MkdirAt implements FilesystemImpl.MkdirAt.
+func (fs *FDTestFilesystem) MkdirAt(ctx context.Context, rp *ResolvingPath, opts MkdirOptions) error {
+ return syserror.EPERM
+}
+
+// MknodAt implements FilesystemImpl.MknodAt.
+func (fs *FDTestFilesystem) MknodAt(ctx context.Context, rp *ResolvingPath, opts MknodOptions) error {
+ return syserror.EPERM
+}
+
+// OpenAt implements FilesystemImpl.OpenAt.
+func (fs *FDTestFilesystem) OpenAt(ctx context.Context, rp *ResolvingPath, opts OpenOptions) (*FileDescription, error) {
+ return nil, syserror.EPERM
+}
+
+// ReadlinkAt implements FilesystemImpl.ReadlinkAt.
+func (fs *FDTestFilesystem) ReadlinkAt(ctx context.Context, rp *ResolvingPath) (string, error) {
+ return "", syserror.EPERM
+}
+
+// RenameAt implements FilesystemImpl.RenameAt.
+func (fs *FDTestFilesystem) RenameAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry, opts RenameOptions) error {
+ return syserror.EPERM
+}
+
+// RmdirAt implements FilesystemImpl.RmdirAt.
+func (fs *FDTestFilesystem) RmdirAt(ctx context.Context, rp *ResolvingPath) error {
+ return syserror.EPERM
+}
+
+// SetStatAt implements FilesystemImpl.SetStatAt.
+func (fs *FDTestFilesystem) SetStatAt(ctx context.Context, rp *ResolvingPath, opts SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// StatAt implements FilesystemImpl.StatAt.
+func (fs *FDTestFilesystem) StatAt(ctx context.Context, rp *ResolvingPath, opts StatOptions) (linux.Statx, error) {
+ return linux.Statx{}, syserror.EPERM
+}
+
+// StatFSAt implements FilesystemImpl.StatFSAt.
+func (fs *FDTestFilesystem) StatFSAt(ctx context.Context, rp *ResolvingPath) (linux.Statfs, error) {
+ return linux.Statfs{}, syserror.EPERM
+}
+
+// SymlinkAt implements FilesystemImpl.SymlinkAt.
+func (fs *FDTestFilesystem) SymlinkAt(ctx context.Context, rp *ResolvingPath, target string) error {
+ return syserror.EPERM
+}
+
+// UnlinkAt implements FilesystemImpl.UnlinkAt.
+func (fs *FDTestFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) error {
+ return syserror.EPERM
+}
+
+type fdTestDentry struct {
+ vfsd Dentry
+}
+
+// NewDentry returns a new Dentry.
+func (fs *FDTestFilesystem) NewDentry() *Dentry {
+ var d fdTestDentry
+ d.vfsd.Init(&d)
+ return &d.vfsd
+}
+
+// IncRef implements DentryImpl.IncRef.
+func (d *fdTestDentry) IncRef(vfsfs *Filesystem) {
+}
+
+// TryIncRef implements DentryImpl.TryIncRef.
+func (d *fdTestDentry) TryIncRef(vfsfs *Filesystem) bool {
+ return true
+}
+
+// DecRef implements DentryImpl.DecRef.
+func (d *fdTestDentry) DecRef(vfsfs *Filesystem) {
+}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 047f8329a..df37c7d5a 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -12,6 +12,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/iptables",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index c40924852..0d2637ee4 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -24,6 +24,7 @@ go_test(
embed = [":gonet"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 308f620e5..cd6ce930a 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -404,7 +404,7 @@ func (c *Conn) Write(b []byte) (int, error) {
}
}
- var n uintptr
+ var n int64
var resCh <-chan struct{}
n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
nbytes += int(n)
@@ -556,32 +556,50 @@ type PacketConn struct {
wq *waiter.Queue
}
-// NewPacketConn creates a new PacketConn.
-func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
- // Create UDP endpoint and bind it.
+// DialUDP creates a new PacketConn.
+//
+// If laddr is nil, a local address is automatically chosen.
+//
+// If raddr is nil, the PacketConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
- if err := ep.Bind(addr); err != nil {
- ep.Close()
- return nil, &net.OpError{
- Op: "bind",
- Net: "udp",
- Addr: fullToUDPAddr(addr),
- Err: errors.New(err.String()),
+ if laddr != nil {
+ if err := ep.Bind(*laddr); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "udp",
+ Addr: fullToUDPAddr(*laddr),
+ Err: errors.New(err.String()),
+ }
}
}
- c := &PacketConn{
+ c := PacketConn{
stack: s,
ep: ep,
wq: &wq,
}
c.deadlineTimer.init()
- return c, nil
+
+ if raddr != nil {
+ if err := c.ep.Connect(*raddr); err != nil {
+ c.ep.Close()
+ return nil, &net.OpError{
+ Op: "connect",
+ Net: "udp",
+ Addr: fullToUDPAddr(*raddr),
+ Err: errors.New(err.String()),
+ }
+ }
+ }
+
+ return &c, nil
}
func (c *PacketConn) newOpError(op string, err error) *net.OpError {
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 39efe44c7..672f026b2 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -26,6 +26,7 @@ import (
"golang.org/x/net/nettest"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -69,17 +70,13 @@ func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
s.SetRouteTable([]tcpip.Route{
// IPv4
{
- Destination: tcpip.Address(strings.Repeat("\x00", 4)),
- Mask: tcpip.AddressMask(strings.Repeat("\x00", 4)),
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: NICID,
},
// IPv6
{
- Destination: tcpip.Address(strings.Repeat("\x00", 16)),
- Mask: tcpip.AddressMask(strings.Repeat("\x00", 16)),
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: NICID,
},
})
@@ -371,9 +368,9 @@ func TestUDPForwarder(t *testing.T) {
})
s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
- c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 5):", err)
+ t.Fatal("DialUDP(bind port 5):", err)
}
sent := "abc123"
@@ -452,13 +449,13 @@ func TestPacketConnTransfer(t *testing.T) {
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
- c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber)
+ c1, err := DialUDP(s, &addr1, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 4):", err)
+ t.Fatal("DialUDP(bind port 4):", err)
}
- c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ c2, err := DialUDP(s, &addr2, nil, ipv4.ProtocolNumber)
if err != nil {
- t.Fatal("NewPacketConn(port 5):", err)
+ t.Fatal("DialUDP(bind port 5):", err)
}
c1.SetDeadline(time.Now().Add(time.Second))
@@ -491,6 +488,50 @@ func TestPacketConnTransfer(t *testing.T) {
}
}
+func TestConnectedPacketConnTransfer(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+
+ c1, err := DialUDP(s, &addr, nil, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 4):", err)
+ }
+ c2, err := DialUDP(s, nil, &addr, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("DialUDP(bind port 5):", err)
+ }
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ sent := "abc123"
+ if n, err := c2.Write([]byte(sent)); err != nil || n != len(sent) {
+ t.Errorf("got c2.Write(%q) = %d, %v, want = %d, %v", sent, n, err, len(sent), nil)
+ }
+ recv := make([]byte, len(sent))
+ n, err := c1.Read(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("got c1.Read() = %d, %v, want = %d, %v", n, err, len(recv), nil)
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("got recv = %q, want = %q", recv, sent)
+ }
+
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+}
+
func makePipe() (c1, c2 net.Conn, stop func(), err error) {
s, e := newLoopbackStack()
if e != nil {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 94a3af289..17fc9c68e 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -111,6 +111,15 @@ const (
IPv4FlagDontFragment
)
+// IPv4EmptySubnet is the empty IPv4 subnet.
+var IPv4EmptySubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// IPVersion returns the version of IP used in the given packet. It returns -1
// if the packet is not large enough to contain the version field.
func IPVersion(b []byte) int {
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 95fe8bfc3..bc4e56535 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -27,7 +27,7 @@ const (
nextHdr = 6
hopLimit = 7
v6SrcAddr = 8
- v6DstAddr = 24
+ v6DstAddr = v6SrcAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -82,6 +82,15 @@ const (
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
)
+// IPv6EmptySubnet is the empty IPv6 subnet.
+var IPv6EmptySubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
@@ -110,13 +119,13 @@ func (b IPv6) Payload() []byte {
// SourceAddress returns the "source address" field of the ipv6 header.
func (b IPv6) SourceAddress() tcpip.Address {
- return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize])
}
// DestinationAddress returns the "destination address" field of the ipv6
// header.
func (b IPv6) DestinationAddress() tcpip.Address {
- return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize])
}
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
@@ -144,13 +153,13 @@ func (b IPv6) SetPayloadLength(payloadLength uint16) {
// SetSourceAddress sets the "source address" field of the ipv6 header.
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
+ copy(b[v6SrcAddr:][:IPv6AddressSize], addr)
}
// SetDestinationAddress sets the "destination address" field of the ipv6
// header.
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
+ copy(b[v6DstAddr:][:IPv6AddressSize], addr)
}
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
@@ -169,8 +178,8 @@ func (b IPv6) Encode(i *IPv6Fields) {
b.SetPayloadLength(i.PayloadLength)
b[nextHdr] = i.NextHeader
b[hopLimit] = i.HopLimit
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
+ b.SetSourceAddress(i.SrcAddr)
+ b.SetDestinationAddress(i.DstAddr)
}
// IsValid performs basic validation on the packet.
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
index fc9abbb55..3fc14bacd 100644
--- a/pkg/tcpip/iptables/BUILD
+++ b/pkg/tcpip/iptables/BUILD
@@ -11,8 +11,5 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables",
visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- ],
+ deps = ["//pkg/tcpip/buffer"],
)
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index f1e1d1fad..68c68d4aa 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -32,8 +32,8 @@ const (
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
-func DefaultTables() *IPTables {
- return &IPTables{
+func DefaultTables() IPTables {
+ return IPTables{
Tables: map[string]Table{
tablenameNat: Table{
BuiltinChains: map[Hook]Chain{
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 600bd9a10..42a79ef9f 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -15,7 +15,6 @@
package iptables
import (
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -128,15 +127,29 @@ type Table struct {
// UserChains, and its purpose is to make looking up tables by name
// fast.
Chains map[string]*Chain
+
+ // Metadata holds information about the Table that is useful to users
+ // of IPTables, but not to the netstack IPTables code itself.
+ metadata interface{}
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
-func (table *Table) ValidHooks() (uint32, *tcpip.Error) {
+func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
for hook, _ := range table.BuiltinChains {
hooks |= 1 << hook
}
- return hooks, nil
+ return hooks
+}
+
+// Metadata returns the metadata object stored in table.
+func (table *Table) Metadata() interface{} {
+ return table.metadata
+}
+
+// SetMetadata sets the metadata object stored in table.
+func (table *Table) SetMetadata(metadata interface{}) {
+ table.metadata = metadata
}
// A Chain defines a list of rules for packet processing. When a packet
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index d786d8fdf..74fbbb896 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -8,8 +8,8 @@ go_library(
"endpoint.go",
"endpoint_unsafe.go",
"mmap.go",
- "mmap_amd64.go",
- "mmap_amd64_unsafe.go",
+ "mmap_stub.go",
+ "mmap_unsafe.go",
"packet_dispatchers.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index fe19c2bc2..8bfeb97e4 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -12,14 +12,183 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build !linux !amd64
+// +build linux,amd64 linux,arm64
package fdbased
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "encoding/binary"
+ "syscall"
-// Stubbed out version for non-linux/non-amd64 platforms.
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+)
-func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, *tcpip.Error) {
- return nil, nil
+const (
+ tPacketAlignment = uintptr(16)
+ tpStatusKernel = 0
+ tpStatusUser = 1
+ tpStatusCopy = 2
+ tpStatusLosing = 4
+)
+
+// We overallocate the frame size to accommodate space for the
+// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
+//
+// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
+//
+// NOTE:
+// Frames need to be aligned at 16 byte boundaries.
+// BlockSize needs to be page aligned.
+//
+// For details see PACKET_MMAP setting constraints in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+const (
+ tpFrameSize = 65536 + 128
+ tpBlockSize = tpFrameSize * 32
+ tpBlockNR = 1
+ tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
+)
+
+// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
+// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
+func tPacketAlign(v uintptr) uintptr {
+ return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
+}
+
+// tPacketReq is the tpacket_req structure as described in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+type tPacketReq struct {
+ tpBlockSize uint32
+ tpBlockNR uint32
+ tpFrameSize uint32
+ tpFrameNR uint32
+}
+
+// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
+type tPacketHdr []byte
+
+const (
+ tpStatusOffset = 0
+ tpLenOffset = 8
+ tpSnapLenOffset = 12
+ tpMacOffset = 16
+ tpNetOffset = 18
+ tpSecOffset = 20
+ tpUSecOffset = 24
+)
+
+func (t tPacketHdr) tpLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpLenOffset:])
+}
+
+func (t tPacketHdr) tpSnapLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
+}
+
+func (t tPacketHdr) tpMac() uint16 {
+ return binary.LittleEndian.Uint16(t[tpMacOffset:])
+}
+
+func (t tPacketHdr) tpNet() uint16 {
+ return binary.LittleEndian.Uint16(t[tpNetOffset:])
+}
+
+func (t tPacketHdr) tpSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSecOffset:])
+}
+
+func (t tPacketHdr) tpUSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpUSecOffset:])
+}
+
+func (t tPacketHdr) Payload() []byte {
+ return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
+}
+
+// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
+// See: mmap_amd64_unsafe.go for implementation details.
+type packetMMapDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // ringBuffer is only used when PacketMMap dispatcher is used and points
+ // to the start of the mmapped PACKET_RX_RING buffer.
+ ringBuffer []byte
+
+ // ringOffset is the current offset into the ring buffer where the next
+ // inbound packet will be placed by the kernel.
+ ringOffset int
+}
+
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+ hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ for hdr.tpStatus()&tpStatusUser == 0 {
+ event := rawfile.PollEvent{
+ FD: int32(d.fd),
+ Events: unix.POLLIN | unix.POLLERR,
+ }
+ if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ if errno == syscall.EINTR {
+ continue
+ }
+ return nil, rawfile.TranslateErrno(errno)
+ }
+ if hdr.tpStatus()&tpStatusCopy != 0 {
+ // This frame is truncated so skip it after flipping the
+ // buffer to the kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ continue
+ }
+ }
+
+ // Copy out the packet from the mmapped frame to a locally owned buffer.
+ pkt := make([]byte, hdr.tpSnapLen())
+ copy(pkt, hdr.Payload())
+ // Release packet to kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ return pkt, nil
+}
+
+// dispatch reads packets from an mmaped ring buffer and dispatches them to the
+// network stack.
+func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+ pkt, err := d.readMMappedPacket()
+ if err != nil {
+ return false, err
+ }
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(pkt)
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(pkt) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ pkt = pkt[d.e.hdrSize:]
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go
deleted file mode 100644
index 8bbb4f9ab..000000000
--- a/pkg/tcpip/link/fdbased/mmap_amd64.go
+++ /dev/null
@@ -1,194 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux,amd64
-
-package fdbased
-
-import (
- "encoding/binary"
- "syscall"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
-)
-
-const (
- tPacketAlignment = uintptr(16)
- tpStatusKernel = 0
- tpStatusUser = 1
- tpStatusCopy = 2
- tpStatusLosing = 4
-)
-
-// We overallocate the frame size to accommodate space for the
-// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
-//
-// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
-//
-// NOTE:
-// Frames need to be aligned at 16 byte boundaries.
-// BlockSize needs to be page aligned.
-//
-// For details see PACKET_MMAP setting constraints in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-const (
- tpFrameSize = 65536 + 128
- tpBlockSize = tpFrameSize * 32
- tpBlockNR = 1
- tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
-)
-
-// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
-// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
-func tPacketAlign(v uintptr) uintptr {
- return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
-}
-
-// tPacketReq is the tpacket_req structure as described in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-type tPacketReq struct {
- tpBlockSize uint32
- tpBlockNR uint32
- tpFrameSize uint32
- tpFrameNR uint32
-}
-
-// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
-type tPacketHdr []byte
-
-const (
- tpStatusOffset = 0
- tpLenOffset = 8
- tpSnapLenOffset = 12
- tpMacOffset = 16
- tpNetOffset = 18
- tpSecOffset = 20
- tpUSecOffset = 24
-)
-
-func (t tPacketHdr) tpLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpLenOffset:])
-}
-
-func (t tPacketHdr) tpSnapLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
-}
-
-func (t tPacketHdr) tpMac() uint16 {
- return binary.LittleEndian.Uint16(t[tpMacOffset:])
-}
-
-func (t tPacketHdr) tpNet() uint16 {
- return binary.LittleEndian.Uint16(t[tpNetOffset:])
-}
-
-func (t tPacketHdr) tpSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpSecOffset:])
-}
-
-func (t tPacketHdr) tpUSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpUSecOffset:])
-}
-
-func (t tPacketHdr) Payload() []byte {
- return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
-}
-
-// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
-// See: mmap_amd64_unsafe.go for implementation details.
-type packetMMapDispatcher struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
-
- // e is the endpoint this dispatcher is attached to.
- e *endpoint
-
- // ringBuffer is only used when PacketMMap dispatcher is used and points
- // to the start of the mmapped PACKET_RX_RING buffer.
- ringBuffer []byte
-
- // ringOffset is the current offset into the ring buffer where the next
- // inbound packet will be placed by the kernel.
- ringOffset int
-}
-
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
- hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
- for hdr.tpStatus()&tpStatusUser == 0 {
- event := rawfile.PollEvent{
- FD: int32(d.fd),
- Events: unix.POLLIN | unix.POLLERR,
- }
- if _, errno := rawfile.BlockingPoll(&event, 1, -1); errno != 0 {
- if errno == syscall.EINTR {
- continue
- }
- return nil, rawfile.TranslateErrno(errno)
- }
- if hdr.tpStatus()&tpStatusCopy != 0 {
- // This frame is truncated so skip it after flipping the
- // buffer to the kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
- continue
- }
- }
-
- // Copy out the packet from the mmapped frame to a locally owned buffer.
- pkt := make([]byte, hdr.tpSnapLen())
- copy(pkt, hdr.Payload())
- // Release packet to kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- return pkt, nil
-}
-
-// dispatch reads packets from an mmaped ring buffer and dispatches them to the
-// network stack.
-func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
- pkt, err := d.readMMappedPacket()
- if err != nil {
- return false, err
- }
- var (
- p tcpip.NetworkProtocolNumber
- remote, local tcpip.LinkAddress
- )
- if d.e.hdrSize > 0 {
- eth := header.Ethernet(pkt)
- p = eth.Type()
- remote = eth.SourceAddress()
- local = eth.DestinationAddress()
- } else {
- // We don't get any indication of what the packet is, so try to guess
- // if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(pkt) {
- case header.IPv4Version:
- p = header.IPv4ProtocolNumber
- case header.IPv6Version:
- p = header.IPv6ProtocolNumber
- default:
- return true, nil
- }
- }
-
- pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
- return true, nil
-}
diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go
new file mode 100644
index 000000000..67be52d67
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/mmap_stub.go
@@ -0,0 +1,23 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux !amd64,!arm64
+
+package fdbased
+
+// Stubbed out version for non-linux/non-amd64/non-arm64 platforms.
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ return nil, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
index 47cb1d1cc..3894185ae 100644
--- a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
package fdbased
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 6e3a7a9d7..088eb8a21 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -6,8 +6,10 @@ go_library(
name = "rawfile",
srcs = [
"blockingpoll_amd64.s",
- "blockingpoll_amd64_unsafe.go",
+ "blockingpoll_arm64.s",
+ "blockingpoll_noyield_unsafe.go",
"blockingpoll_unsafe.go",
+ "blockingpoll_yield_unsafe.go",
"errors.go",
"rawfile_unsafe.go",
],
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
index b54131573..298bad55d 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -14,17 +14,18 @@
#include "textflag.h"
-// BlockingPoll makes the poll() syscall while calling the version of
+// BlockingPoll makes the ppoll() syscall while calling the version of
// entersyscall that relinquishes the P so that other Gs can run. This is meant
// to be called in cases when the syscall is expected to block.
//
-// func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (n int, err syscall.Errno)
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
CALL ·callEntersyscallblock(SB)
MOVQ fds+0(FP), DI
MOVQ nfds+8(FP), SI
MOVQ timeout+16(FP), DX
- MOVQ $0x7, AX // SYS_POLL
+ MOVQ $0x0, R10 // sigmask parameter which isn't used here
+ MOVQ $0x10f, AX // SYS_PPOLL
SYSCALL
CMPQ AX, $0xfffffffffffff001
JLS ok
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
new file mode 100644
index 000000000..b62888b93
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the ppoll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ BL ·callEntersyscallblock(SB)
+ MOVD fds+0(FP), R0
+ MOVD nfds+8(FP), R1
+ MOVD timeout+16(FP), R2
+ MOVD $0x0, R3 // sigmask parameter which isn't used here
+ MOVD $0x49, R8 // SYS_PPOLL
+ SVC
+ CMP $0xfffffffffffff001, R0
+ BLS ok
+ MOVD $-1, R1
+ MOVD R1, n+24(FP)
+ NEG R0, R0
+ MOVD R0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVD R0, n+24(FP)
+ MOVD $0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
new file mode 100644
index 000000000..621ab8d29
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,!amd64,!arm64
+
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// BlockingPoll is just a stub function that forwards to the ppoll() system call
+// on non-amd64 and non-arm64 platforms.
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
+ uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
+
+ return int(n), e
+}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
index 4eab77c74..84dc0e918 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
@@ -21,9 +21,11 @@ import (
"unsafe"
)
-// BlockingPoll is just a stub function that forwards to the poll() system call
+// BlockingPoll is just a stub function that forwards to the ppoll() system call
// on non-amd64 platforms.
-func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno) {
- n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(timeout))
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
+ uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
+
return int(n), e
}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index c87268610..dda3b10a6 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
// +build go1.12
// +build !go1.14
@@ -25,8 +25,14 @@ import (
_ "unsafe" // for go:linkname
)
+// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the
+// version of entersyscall that relinquishes the P so that other Gs can
+// run. This is meant to be called in cases when the syscall is expected to
+// block. On non amd64/arm64 platforms it just forwards to the ppoll() system
+// call.
+//
//go:noescape
-func BlockingPoll(fds *PollEvent, nfds int, timeout int64) (int, syscall.Errno)
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno)
// Use go:linkname to call into the runtime. As of Go 1.12 this has to
// be done from Go code so that we make an ABIInternal call to an
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index e3fbb15c2..7e286a3a6 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -123,7 +123,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
Events: 1, // POLLIN
}
- _, e = BlockingPoll(&event, 1, -1)
+ _, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
@@ -145,7 +145,7 @@ func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
Events: 1, // POLLIN
}
- _, e = BlockingPoll(&event, 1, -1)
+ _, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
@@ -175,7 +175,7 @@ func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
Events: 1, // POLLIN
}
- if _, e := BlockingPoll(&event, 1, -1); e != 0 && e != syscall.EINTR {
+ if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index fc584c6a4..36c8c46fc 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -360,10 +360,9 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
+ details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
+ size -= header.UDPMinimumSize
}
- size -= header.UDPMinimumSize
-
- details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
case header.TCPProtocolNumber:
transName = "tcp"
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 60070874d..fd6395fc1 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -109,13 +109,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
pkt.SetOp(header.ARPReply)
copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
- fallthrough // also fill the cache from requests
case header.ARPReply:
- addr := tcpip.Address(h.ProtocolAddressSender())
- linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 66c55821b..4c4b54469 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -15,6 +15,7 @@
package arp_test
import (
+ "strconv"
"testing"
"time"
@@ -65,9 +66,7 @@ func newTestContext(t *testing.T) *testContext {
}
s.SetRouteTable([]tcpip.Route{{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
}})
@@ -101,40 +100,30 @@ func TestDirectRequest(t *testing.T) {
c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView())
}
- inject(stackAddr1)
- {
- pkt := <-c.linkEP.C
- if pkt.Proto != arp.ProtocolNumber {
- t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto)
- }
- rep := header.ARP(pkt.Header)
- if !rep.IsValid() {
- t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
- }
- if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 {
- t.Errorf("stackAddr1: expected sender to be set")
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got)
- }
- }
-
- inject(stackAddr2)
- {
- pkt := <-c.linkEP.C
- if pkt.Proto != arp.ProtocolNumber {
- t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto)
- }
- rep := header.ARP(pkt.Header)
- if !rep.IsValid() {
- t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
- }
- if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 {
- t.Errorf("stackAddr2: expected sender to be set")
- }
- if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
- t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got)
- }
+ for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ inject(address)
+ pkt := <-c.linkEP.C
+ if pkt.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto)
+ }
+ rep := header.ARP(pkt.Header)
+ if !rep.IsValid() {
+ t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want {
+ t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
+ }
+ if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want)
+ }
+ })
}
inject(stackAddrBad)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 55e9eec99..6bbfcd97f 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -173,8 +173,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
- Destination: ipv4SubnetAddr,
- Mask: ipv4SubnetMask,
+ Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
NIC: 1,
}})
@@ -187,8 +186,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
- Destination: ipv6SubnetAddr,
- Mask: ipv6SubnetMask,
+ Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
NIC: 1,
}})
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index fbef6947d..497164cbb 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -94,6 +94,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
copy(pkt, h)
pkt.SetType(header.ICMPv4EchoReply)
+ pkt.SetChecksum(0)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
sent := stats.ICMP.V4PacketsSent
if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil {
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 3207a3d46..1b5a55bea 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -52,9 +52,7 @@ func TestExcludeBroadcast(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
}})
@@ -247,14 +245,22 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
_, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
linkEPId := stack.RegisterLinkEndpoint(linkEP)
s.CreateNIC(1, linkEPId)
- s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01")
- s.SetRouteTable([]tcpip.Route{{
- Destination: "\x10\x00\x00\x02",
- Mask: "\xff\xff\xff\xff",
- Gateway: "",
- NIC: 1,
- }})
- r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */)
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ s.AddAddress(1, ipv4.ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("s.FindRoute got %v, want %v", err, nil)
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 5e6a59e91..1689af16f 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -100,13 +100,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
-
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
@@ -146,7 +144,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 726362c87..d0dc72506 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -91,13 +91,18 @@ func TestICMPCounts(t *testing.T) {
t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
}
}
- s.SetRouteTable(
- []tcpip.Route{{
- Destination: lladdr1,
- Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
- NIC: 1,
- }},
- )
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
netProto := s.NetworkProtocolInstance(ProtocolNumber)
if netProto == nil {
@@ -237,17 +242,23 @@ func newTestContext(t *testing.T) *testContext {
t.Fatalf("AddAddress sn lladdr1: %v", err)
}
+ subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
c.s0.SetRouteTable(
[]tcpip.Route{{
- Destination: lladdr1,
- Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
+ Destination: subnet0,
NIC: 1,
}},
)
+ subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0))))
+ if err != nil {
+ t.Fatal(err)
+ }
c.s1.SetRouteTable(
[]tcpip.Route{{
- Destination: lladdr0,
- Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
+ Destination: subnet1,
NIC: 1,
}},
)
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index 996939581..a57752a7c 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -8,6 +8,7 @@ go_binary(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/rawfile",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 3ac381631..e2021cd15 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -52,6 +52,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -152,9 +153,7 @@ func main() {
// Add default route.
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index da425394a..1716be285 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -149,12 +149,15 @@ func main() {
log.Fatal(err)
}
+ subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr))))
+ if err != nil {
+ log.Fatal(err)
+ }
+
// Add default route.
s.SetRouteTable([]tcpip.Route{
{
- Destination: tcpip.Address(strings.Repeat("\x00", len(addr))),
- Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))),
- Gateway: "",
+ Destination: subnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 28d11c797..b692c60ce 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,11 +1,25 @@
package(licenses = ["notice"])
+load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+go_template_instance(
+ name = "linkaddrentry_list",
+ out = "linkaddrentry_list.go",
+ package = "stack",
+ prefix = "linkAddrEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*linkAddrEntry",
+ "Linker": "*linkAddrEntry",
+ },
+)
+
go_library(
name = "stack",
srcs = [
"linkaddrcache.go",
+ "linkaddrentry_list.go",
"nic.go",
"registration.go",
"route.go",
@@ -24,6 +38,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/waiter",
@@ -42,6 +57,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/waiter",
@@ -58,3 +74,11 @@ go_test(
"//pkg/tcpip",
],
)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "linkaddrentry_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 77bb0ccb9..267df60d1 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -42,10 +42,11 @@ type linkAddrCache struct {
// resolved before failing.
resolutionAttempts int
- mu sync.Mutex
- cache map[tcpip.FullAddress]*linkAddrEntry
- next int // array index of next available entry
- entries [linkAddrCacheSize]linkAddrEntry
+ cache struct {
+ sync.Mutex
+ table map[tcpip.FullAddress]*linkAddrEntry
+ lru linkAddrEntryList
+ }
}
// entryState controls the state of a single entry in the cache.
@@ -60,9 +61,6 @@ const (
// failed means that address resolution timed out and the address
// could not be resolved.
failed
- // expired means that the cache entry has expired and the address must be
- // resolved again.
- expired
)
// String implements Stringer.
@@ -74,8 +72,6 @@ func (s entryState) String() string {
return "ready"
case failed:
return "failed"
- case expired:
- return "expired"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -84,64 +80,46 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ linkAddrEntryEntry
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
// wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of 'incomplete' these waiters are notified.
+ // state transitions out of incomplete these waiters are notified.
wakers map[*sleep.Waker]struct{}
+ // done is used to allow callers to wait on address resolution. It is nil iff
+ // s is incomplete and resolution is not yet in progress.
done chan struct{}
}
-func (e *linkAddrEntry) state() entryState {
- if e.s != expired && time.Now().After(e.expiration) {
- // Force the transition to ensure waiters are notified.
- e.changeState(expired)
- }
- return e.s
-}
-
-func (e *linkAddrEntry) changeState(ns entryState) {
- if e.s == ns {
- return
- }
-
- // Validate state transition.
- switch e.s {
- case incomplete:
- // All transitions are valid.
- case ready, failed:
- if ns != expired {
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- }
- case expired:
- // Terminal state.
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- default:
- panic(fmt.Sprintf("invalid state: %s", e.s))
- }
-
+// changeState sets the entry's state to ns, notifying any waiters.
+//
+// The entry's expiration is bumped up to the greater of itself and the passed
+// expiration; the zero value indicates immediate expiration, and is set
+// unconditionally - this is an implementation detail that allows for entries
+// to be reused.
+func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
// Notify whoever is waiting on address resolution when transitioning
- // out of 'incomplete'.
- if e.s == incomplete {
+ // out of incomplete.
+ if e.s == incomplete && ns != incomplete {
for w := range e.wakers {
w.Assert()
}
e.wakers = nil
- if e.done != nil {
- close(e.done)
+ if ch := e.done; ch != nil {
+ close(ch)
}
+ e.done = nil
}
- e.s = ns
-}
-func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) {
- if w != nil {
- e.wakers[w] = struct{}{}
+ if expiration.IsZero() || expiration.After(e.expiration) {
+ e.expiration = expiration
}
+ e.s = ns
}
func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
@@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
- if ok {
- s := entry.state()
- if s != expired && entry.linkAddr == v {
- // Disregard repeated calls.
- return
- }
- // Check if entry is waiting for address resolution.
- if s == incomplete {
- entry.linkAddr = v
- } else {
- // Otherwise create a new entry to replace it.
- entry = c.makeAndAddEntry(k, v)
- }
- } else {
- entry = c.makeAndAddEntry(k, v)
- }
+ // Calculate expiration time before acquiring the lock, since expiration is
+ // relative to the time when information was learned, rather than when it
+ // happened to be inserted into the cache.
+ expiration := time.Now().Add(c.ageLimit)
- entry.changeState(ready)
+ c.cache.Lock()
+ entry := c.getOrCreateEntryLocked(k)
+ entry.linkAddr = v
+
+ entry.changeState(ready, expiration)
+ c.cache.Unlock()
}
-// makeAndAddEntry is a helper function to create and add a new
-// entry to the cache map and evict older entry as needed.
-func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
- // Take over the next entry.
- entry := &c.entries[c.next]
- if c.cache[entry.addr] == entry {
- delete(c.cache, entry.addr)
+// getOrCreateEntryLocked retrieves a cache entry associated with k. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
+ if entry, ok := c.cache.table[k]; ok {
+ c.cache.lru.Remove(entry)
+ c.cache.lru.PushFront(entry)
+ return entry
}
+ var entry *linkAddrEntry
+ if len(c.cache.table) == linkAddrCacheSize {
+ entry = c.cache.lru.Back()
- // Mark the soon-to-be-replaced entry as expired, just in case there is
- // someone waiting for address resolution on it.
- entry.changeState(expired)
+ delete(c.cache.table, entry.addr)
+ c.cache.lru.Remove(entry)
- *entry = linkAddrEntry{
- addr: k,
- linkAddr: v,
- expiration: time.Now().Add(c.ageLimit),
- wakers: make(map[*sleep.Waker]struct{}),
- done: make(chan struct{}),
+ // Wake waiters and mark the soon-to-be-reused entry as expired. Note
+ // that the state passed doesn't matter when the zero time is passed.
+ entry.changeState(failed, time.Time{})
+ } else {
+ entry = new(linkAddrEntry)
}
- c.cache[k] = entry
- c.next = (c.next + 1) % len(c.entries)
+ *entry = linkAddrEntry{
+ addr: k,
+ s: incomplete,
+ }
+ c.cache.table[k] = entry
+ c.cache.lru.PushFront(entry)
return entry
}
@@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
}
- c.mu.Lock()
- defer c.mu.Unlock()
- if entry, ok := c.cache[k]; ok {
- switch s := entry.state(); s {
- case expired:
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return "", nil, tcpip.ErrNoLinkAddress
- case incomplete:
- // Address resolution is still in progress.
- entry.maybeAddWaker(waker)
- return "", entry.done, tcpip.ErrWouldBlock
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry := c.getOrCreateEntryLocked(k)
+ switch s := entry.s; s {
+ case ready, failed:
+ if !time.Now().After(entry.expiration) {
+ // Not expired.
+ switch s {
+ case ready:
+ return entry.linkAddr, nil, nil
+ case failed:
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
- }
- if linkRes == nil {
- return "", nil, tcpip.ErrNoLinkAddress
- }
+ entry.changeState(incomplete, time.Time{})
+ fallthrough
+ case incomplete:
+ if waker != nil {
+ if entry.wakers == nil {
+ entry.wakers = make(map[*sleep.Waker]struct{})
+ }
+ entry.wakers[waker] = struct{}{}
+ }
- // Add 'incomplete' entry in the cache to mark that resolution is in progress.
- e := c.makeAndAddEntry(k, "")
- e.maybeAddWaker(waker)
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ }
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ entry.done = make(chan struct{})
+ go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ }
- return "", e.done, tcpip.ErrWouldBlock
+ return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
// removeWaker removes a waker previously added through get().
func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.mu.Lock()
- defer c.mu.Unlock()
+ c.cache.Lock()
+ defer c.cache.Unlock()
- if entry, ok := c.cache[k]; ok {
+ if entry, ok := c.cache.table[k]; ok {
entry.removeWaker(waker)
}
}
@@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
select {
- case <-time.After(c.resolutionTimeout):
- if stop := c.checkLinkRequest(k, i); stop {
+ case now := <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(now, k, i); stop {
return
}
case <-done:
@@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
// checkLinkRequest checks whether previous attempt to resolve address has succeeded
// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
// can stop, false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry, ok := c.cache.table[k]
if !ok {
// Entry was evicted from the cache.
return true
}
-
- switch s := entry.state(); s {
- case ready, failed, expired:
+ switch s := entry.s; s {
+ case ready, failed:
// Entry was made ready by resolver or failed. Either way we're done.
- return true
case incomplete:
- if attempt+1 >= c.resolutionAttempts {
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed)
- return true
+ if attempt+1 < c.resolutionAttempts {
+ // No response yet, need to send another ARP request.
+ return false
}
- // No response yet, need to send another ARP request.
- return false
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed, now.Add(c.ageLimit))
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
+ return true
}
func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
- return &linkAddrCache{
+ c := &linkAddrCache{
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
- cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
}
+ c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ return c
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 924f4d240..9946b8fe8 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"sync"
+ "sync/atomic"
"testing"
"time"
@@ -29,25 +30,34 @@ type testaddr struct {
linkAddr tcpip.LinkAddress
}
-var testaddrs []testaddr
+var testAddrs = func() []testaddr {
+ var addrs []testaddr
+ for i := 0; i < 4*linkAddrCacheSize; i++ {
+ addr := fmt.Sprintf("Addr%06d", i)
+ addrs = append(addrs, testaddr{
+ addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ linkAddr: tcpip.LinkAddress("Link" + addr),
+ })
+ }
+ return addrs
+}()
type testLinkAddressResolver struct {
- cache *linkAddrCache
- delay time.Duration
+ cache *linkAddrCache
+ delay time.Duration
+ onLinkAddressRequest func()
}
func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
- go func() {
- if r.delay > 0 {
- time.Sleep(r.delay)
- }
- r.fakeRequest(addr)
- }()
+ time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
return nil
}
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
- for _, ta := range testaddrs {
+ for _, ta := range testAddrs {
if ta.addr.Addr == addr {
r.cache.add(ta.addr, ta.linkAddr)
break
@@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
}
}
-func init() {
- for i := 0; i < 4*linkAddrCacheSize; i++ {
- addr := fmt.Sprintf("Addr%06d", i)
- testaddrs = append(testaddrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
- linkAddr: tcpip.LinkAddress("Link" + addr),
- })
- }
-}
-
func TestCacheOverflow(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- for i := len(testaddrs) - 1; i >= 0; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= 0; i-- {
+ e := testAddrs[i]
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
@@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) {
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
- e := testaddrs[i]
+ e := testAddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
@@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) {
}
}
// The earliest entries should no longer be in the cache.
- for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
+ e := testAddrs[i]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
}
@@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) {
for r := 0; r < 16; r++ {
wg.Add(1)
go func() {
- for _, e := range testaddrs {
+ for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
@@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) {
// All goroutines add in the same order and add more values than
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
- e = testaddrs[0]
+ e = testAddrs[0]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
@@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) {
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
@@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) {
func TestCacheReplace(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
@@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) {
func TestCacheResolution(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
linkRes := &testLinkAddressResolver{cache: c}
- for i, ta := range testaddrs {
+ for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
@@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) {
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
+ var requestCount uint32
+ linkRes.onLinkAddressRequest = func() {
+ atomic.AddUint32(&requestCount, 1)
+ }
+
// First, sanity check that resolution is working...
- e := testaddrs[0]
+ e := testAddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
+ before := atomic.LoadUint32(&requestCount)
+
e.addr.Addr += "2"
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
+
+ if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
}
func TestCacheResolutionTimeout(t *testing.T) {
@@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) {
c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
- e := testaddrs[0]
+ e := testAddrs[0]
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 3e6ff4afb..89b4c5960 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -139,7 +139,7 @@ func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Add
if list, ok := n.primary[protocol]; ok {
for e := list.Front(); e != nil; e = e.Next() {
ref := e.(*referencedNetworkEndpoint)
- if ref.holdsInsertRef && ref.tryIncRef() {
+ if ref.kind == permanent && ref.tryIncRef() {
r = ref
break
}
@@ -178,7 +178,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
case header.IPv4Broadcast, header.IPv4Any:
continue
}
- if r.tryIncRef() {
+ if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
}
@@ -186,82 +186,155 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
return nil
}
+func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
+}
+
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing)
+}
+
+// getRefEpOrCreateTemp returns the referenced network endpoint for the given
+// protocol and address. If none exists a temporary one may be created if
+// we are in promiscuous mode or spoofing.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
- ref := n.endpoints[id]
- if ref != nil && !ref.tryIncRef() {
- ref = nil
+
+ if ref, ok := n.endpoints[id]; ok {
+ // An endpoint with this id exists, check if it can be used and return it.
+ switch ref.kind {
+ case permanentExpired:
+ if !spoofingOrPromiscuous {
+ n.mu.RUnlock()
+ return nil
+ }
+ fallthrough
+ case temporary, permanent:
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+ }
+ }
+
+ // A usable reference was not found, create a temporary one if requested by
+ // the caller or if the address is found in the NIC's subnets.
+ createTempEP := spoofingOrPromiscuous
+ if !createTempEP {
+ for _, sn := range n.subnets {
+ if sn.Contains(address) {
+ createTempEP = true
+ break
+ }
+ }
}
- spoofing := n.spoofing
+
n.mu.RUnlock()
- if ref != nil || !spoofing {
- return ref
+ if !createTempEP {
+ return nil
}
// Try again with the lock in exclusive mode. If we still can't get the
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- ref = n.endpoints[id]
- if ref == nil || !ref.tryIncRef() {
- if netProto, ok := n.stack.networkProtocols[protocol]; ok {
- addrWithPrefix := tcpip.AddressWithPrefix{address, netProto.DefaultPrefixLen()}
- ref, _ = n.addAddressLocked(protocol, addrWithPrefix, peb, true)
- if ref != nil {
- ref.holdsInsertRef = false
- }
+ if ref, ok := n.endpoints[id]; ok {
+ // No need to check the type as we are ok with expired endpoints at this
+ // point.
+ if ref.tryIncRef() {
+ n.mu.Unlock()
+ return ref
}
+ // tryIncRef failing means the endpoint is scheduled to be removed once the
+ // lock is released. Remove it here so we can create a new (temporary) one.
+ // The removal logic waiting for the lock handles this case.
+ n.removeEndpointLocked(ref)
}
- n.mu.Unlock()
- return ref
-}
-func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
- return nil, tcpip.ErrUnknownProtocol
+ n.mu.Unlock()
+ return nil
}
+ ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, peb, temporary)
- // Create the new network endpoint.
- ep, err := netProto.NewEndpoint(n.id, addrWithPrefix, n.stack, n, n.linkEP)
- if err != nil {
- return nil, err
- }
+ n.mu.Unlock()
+ return ref
+}
- id := *ep.ID()
+func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
if ref, ok := n.endpoints[id]; ok {
- if !replace {
+ switch ref.kind {
+ case permanent:
+ // The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
+ case permanentExpired, temporary:
+ // Promote the endpoint to become permanent.
+ if ref.tryIncRef() {
+ ref.kind = permanent
+ return ref, nil
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once
+ // the lock is released. Remove it here so we can create a new
+ // (permanent) one. The removal logic waiting for the lock handles this
+ // case.
+ n.removeEndpointLocked(ref)
}
+ }
+ return n.addAddressLocked(protocolAddress, peb, permanent)
+}
- n.removeEndpointLocked(ref)
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // Sanity check.
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if _, ok := n.endpoints[id]; ok {
+ // Endpoint already exists.
+ return nil, tcpip.ErrDuplicateAddress
}
+ netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ // Create the new network endpoint.
+ ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP)
+ if err != nil {
+ return nil, err
+ }
ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocol,
- holdsInsertRef: true,
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
}
// Set up cache if link address resolution exists for this protocol.
if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if _, ok := n.stack.linkAddrResolvers[protocol]; ok {
+ if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
ref.linkCache = n.stack
}
}
n.endpoints[id] = ref
- l, ok := n.primary[protocol]
+ l, ok := n.primary[protocolAddress.Protocol]
if !ok {
l = &ilist.List{}
- n.primary[protocol] = l
+ n.primary[protocolAddress.Protocol] = l
}
switch peb {
@@ -276,10 +349,10 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addrWithPre
// AddAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error {
+func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addAddressLocked(protocol, addrWithPrefix, peb, false)
+ _, err := n.addPermanentAddressLocked(protocolAddress, peb)
n.mu.Unlock()
return err
@@ -291,6 +364,12 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress {
defer n.mu.RUnlock()
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.kind {
+ case permanentExpired, temporary:
+ continue
+ }
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -356,13 +435,16 @@ func (n *NIC) Subnets() []tcpip.Subnet {
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
- // Nothing to do if the reference has already been replaced with a
- // different one.
+ // Nothing to do if the reference has already been replaced with a different
+ // one. This happens in the case where 1) this endpoint's ref count hit zero
+ // and was waiting (on the lock) to be removed and 2) the same address was
+ // re-added in the meantime by removing this endpoint from the list and
+ // adding a new one.
if n.endpoints[id] != r {
return
}
- if r.holdsInsertRef {
+ if r.kind == permanent {
panic("Reference count dropped to zero before being removed")
}
@@ -381,14 +463,13 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
+func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
r := n.endpoints[NetworkEndpointID{addr}]
- if r == nil || !r.holdsInsertRef {
+ if r == nil || r.kind != permanent {
return tcpip.ErrBadLocalAddress
}
- r.holdsInsertRef = false
-
+ r.kind = permanentExpired
r.decRefLocked()
return nil
@@ -398,7 +479,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.removeAddressLocked(addr)
+ return n.removePermanentAddressLocked(addr)
}
// joinGroup adds a new endpoint for the given multicast address, if none
@@ -414,8 +495,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
if !ok {
return tcpip.ErrUnknownProtocol
}
- addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()}
- if _, err := n.addAddressLocked(protocol, addrWithPrefix, NeverPrimaryEndpoint, false); err != nil {
+ if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, NeverPrimaryEndpoint); err != nil {
return err
}
}
@@ -437,7 +523,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
case 1:
// This is the last one, clean up.
- if err := n.removeAddressLocked(addr); err != nil {
+ if err := n.removePermanentAddressLocked(addr); err != nil {
return err
}
}
@@ -445,6 +531,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return nil
}
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) {
+ r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remotelinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+}
+
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the physical interface.
@@ -472,6 +565,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
src, dst := netProto.ParseAddresses(vv.First())
+ n.stack.AddLinkAddress(n.id, src, remote)
+
// If the packet is destined to the IPv4 Broadcast address, then make a
// route to each IPv4 network endpoint and let each endpoint handle the
// packet.
@@ -479,11 +574,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
// n.endpoints is mutex protected so acquire lock.
n.mu.RLock()
for _, ref := range n.endpoints {
- if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
}
}
n.mu.RUnlock()
@@ -491,10 +583,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
}
if ref := n.getRef(protocol, dst); ref != nil {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
return
}
@@ -517,8 +606,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n := r.ref.nic
n.mu.RLock()
ref, ok := n.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoing() && ref.tryIncRef()
n.mu.RUnlock()
- if ok && ref.tryIncRef() {
+ if ok {
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
ref.ep.HandlePacket(&r, vv)
@@ -543,52 +633,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n.stack.stats.IP.InvalidAddressesReceived.Increment()
}
-func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- id := NetworkEndpointID{dst}
-
- n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
-
- promiscuous := n.promiscuous
- // Check if the packet is for a subnet this NIC cares about.
- if !promiscuous {
- for _, sn := range n.subnets {
- if sn.Contains(dst) {
- promiscuous = true
- break
- }
- }
- }
- n.mu.RUnlock()
- if promiscuous {
- // Try again with the lock in exclusive mode. If we still can't
- // get the endpoint, create a new "temporary" one. It will only
- // exist while there's a route through it.
- n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.Unlock()
- return ref
- }
- netProto, ok := n.stack.networkProtocols[protocol]
- if !ok {
- n.mu.Unlock()
- return nil
- }
- addrWithPrefix := tcpip.AddressWithPrefix{dst, netProto.DefaultPrefixLen()}
- ref, err := n.addAddressLocked(protocol, addrWithPrefix, CanBePrimaryEndpoint, true)
- n.mu.Unlock()
- if err == nil {
- ref.holdsInsertRef = false
- return ref
- }
- }
-
- return nil
-}
-
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
@@ -676,9 +720,33 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+type networkEndpointKind int
+
+const (
+ // A permanent endpoint is created by adding a permanent address (vs. a
+ // temporary one) to the NIC. Its reference count is biased by 1 to avoid
+ // removal when no route holds a reference to it. It is removed by explicitly
+ // removing the permanent address from the NIC.
+ permanent networkEndpointKind = iota
+
+ // An expired permanent endoint is a permanent endoint that had its address
+ // removed from the NIC, and it is waiting to be removed once no more routes
+ // hold a reference to it. This is achieved by decreasing its reference count
+ // by 1. If its address is re-added before the endpoint is removed, its type
+ // changes back to permanent and its reference count increases by 1 again.
+ permanentExpired
+
+ // A temporary endpoint is created for spoofing outgoing packets, or when in
+ // promiscuous mode and accepting incoming packets that don't match any
+ // permanent endpoint. Its reference count is not biased by 1 and the
+ // endpoint is removed immediately when no more route holds a reference to
+ // it. A temporary endpoint can be promoted to permanent if its address
+ // is added permanently.
+ temporary
+)
+
type referencedNetworkEndpoint struct {
ilist.Entry
- refs int32
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
@@ -687,11 +755,25 @@ type referencedNetworkEndpoint struct {
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
- // holdsInsertRef is protected by the NIC's mutex. It indicates whether
- // the reference count is biased by 1 due to the insertion of the
- // endpoint. It is reset to false when RemoveAddress is called on the
- // NIC.
- holdsInsertRef bool
+ // refs is counting references held for this endpoint. When refs hits zero it
+ // triggers the automatic removal of the endpoint from the NIC.
+ refs int32
+
+ kind networkEndpointKind
+}
+
+// isValidForOutgoing returns true if the endpoint can be used to send out a
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in spoofing mode.
+func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
+ return r.kind != permanentExpired || r.nic.spoofing
+}
+
+// isValidForIncoming returns true if the endpoint can accept an incoming
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in promiscuous mode.
+func (r *referencedNetworkEndpoint) isValidForIncoming() bool {
+ return r.kind != permanentExpired || r.nic.promiscuous
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 391ab4344..e52cdd674 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
func (r *Route) IsResolutionRequired() bool {
- return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+ return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
@@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 57b8a9994..d69162ba1 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/waiter"
@@ -333,6 +334,15 @@ type TCPEndpointState struct {
Sender TCPSenderState
}
+// ResumableEndpoint is an endpoint that needs to be resumed after restore.
+type ResumableEndpoint interface {
+ // Resume resumes an endpoint after restore. This can be used to restart
+ // background workers such as protocol goroutines. This must be called after
+ // all indirect dependencies of the endpoint has been restored, which
+ // generally implies at the end of the restore process.
+ Resume(*Stack)
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -372,6 +382,13 @@ type Stack struct {
// handleLocal allows non-loopback interfaces to loop packets.
handleLocal bool
+
+ // tables are the iptables packet filtering and manipulation rules.
+ tables iptables.IPTables
+
+ // resumableEndpoints is a list of endpoints that need to be resumed if the
+ // stack is being restored.
+ resumableEndpoints []ResumableEndpoint
}
// Options contains optional Stack configuration.
@@ -751,10 +768,10 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
-// AddAddressWithPrefix adds a new network-layer address/prefixLen to the
+// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix) *tcpip.Error {
- return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, CanBePrimaryEndpoint)
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
+ return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
}
// AddAddressWithOptions is the same as AddAddress, but allows you to specify
@@ -764,13 +781,18 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt
if !ok {
return tcpip.ErrUnknownProtocol
}
- addrWithPrefix := tcpip.AddressWithPrefix{addr, netProto.DefaultPrefixLen()}
- return s.AddAddressWithPrefixAndOptions(id, protocol, addrWithPrefix, peb)
-}
-
-// AddAddressWithPrefixAndOptions is the same as AddAddressWithPrefixLen,
-// but allows you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addrWithPrefix tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) *tcpip.Error {
+ return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: netProto.DefaultPrefixLen(),
+ },
+ }, peb)
+}
+
+// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
+// you to specify whether the new endpoint can be primary or not.
+func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -779,7 +801,7 @@ func (s *Stack) AddAddressWithPrefixAndOptions(id tcpip.NICID, protocol tcpip.Ne
return tcpip.ErrUnknownNICID
}
- return nic.AddAddress(protocol, addrWithPrefix, peb)
+ return nic.AddAddress(protocolAddress, peb)
}
// AddSubnet adds a subnet range to the specified NIC.
@@ -873,7 +895,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
} else {
for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Match(remoteAddr)) {
+ if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) {
continue
}
if nic, ok := s.nics[route.NIC]; ok {
@@ -1082,6 +1104,28 @@ func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip
}
}
+// RegisterRestoredEndpoint records e as an endpoint that has been restored on
+// this stack.
+func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
+ s.mu.Lock()
+ s.resumableEndpoints = append(s.resumableEndpoints, e)
+ s.mu.Unlock()
+}
+
+// Resume restarts the stack after a restore. This must be called after the
+// entire system has been restored.
+func (s *Stack) Resume() {
+ // ResumableEndpoint.Resume() may call other methods on s, so we can't hold
+ // s.mu while resuming the endpoints.
+ s.mu.Lock()
+ eps := s.resumableEndpoints
+ s.resumableEndpoints = nil
+ s.mu.Unlock()
+ for _, e := range eps {
+ e.Resume(s)
+ }
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1161,3 +1205,13 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
}
return tcpip.ErrUnknownNICID
}
+
+// IPTables returns the stack's iptables.
+func (s *Stack) IPTables() iptables.IPTables {
+ return s.tables
+}
+
+// SetIPTables sets the stack's iptables.
+func (s *Stack) SetIPTables(ipt iptables.IPTables) {
+ s.tables = ipt
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 9d082bba4..4debd1eec 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -181,6 +181,10 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
return fakeDefaultPrefixLen
}
+func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
+ return f.packetCount[int(intfAddr)%len(f.packetCount)]
+}
+
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
@@ -188,7 +192,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
- id: stack.NetworkEndpointID{addrWithPrefix.Address},
+ id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
@@ -289,16 +293,75 @@ func TestNetworkReceive(t *testing.T) {
}
}
-func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) {
+func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatal("FindRoute failed:", err)
+ return err
}
defer r.Release()
+ return send(r, payload)
+}
+func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil {
- t.Error("WritePacket failed:", err)
+ return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123)
+}
+
+func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ linkEP.Drain()
+ if err := sendTo(s, addr, payload); err != nil {
+ t.Error("sendTo failed:", err)
+ }
+ if got, want := linkEP.Drain(), 1; got != want {
+ t.Errorf("sendTo packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ linkEP.Drain()
+ if err := send(r, payload); err != nil {
+ t.Error("send failed:", err)
+ }
+ if got, want := linkEP.Drain(), 1; got != want {
+ t.Errorf("send packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := send(r, payload); gotErr != wantErr {
+ t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
+ t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte) + 1
+ testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want)
+}
+
+func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we do NOT expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte)
+ testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want)
+}
+
+func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) {
+ t.Helper()
+ linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ if got := fakeNet.PacketCount(localAddrByte); got != want {
+ t.Errorf("receive packet count: got = %d, want %d", got, want)
}
}
@@ -312,17 +375,20 @@ func TestNetworkSend(t *testing.T) {
t.Fatal("NewNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress failed:", err)
}
// Make sure that the link-layer endpoint received the outbound packet.
- sendTo(t, s, "\x03", nil)
- if c := linkEP.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x03", linkEP, nil)
}
func TestNetworkSendMultiRoute(t *testing.T) {
@@ -360,24 +426,26 @@ func TestNetworkSendMultiRoute(t *testing.T) {
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\x01", "\x00", 1},
- {"\x00", "\x01", "\x00", 2},
- })
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ })
+ }
// Send a packet to an odd destination.
- sendTo(t, s, "\x05", nil)
-
- if c := linkEP1.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x05", linkEP1, nil)
// Send a packet to an even destination.
- sendTo(t, s, "\x06", nil)
-
- if c := linkEP2.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x06", linkEP2, nil)
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
@@ -439,10 +507,20 @@ func TestRoutes(t *testing.T) {
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\x01", "\x00", 1},
- {"\x00", "\x01", "\x00", 2},
- })
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ })
+ }
// Test routes to odd address.
testRoute(t, s, 0, "", "\x05", "\x01")
@@ -472,6 +550,10 @@ func TestRoutes(t *testing.T) {
}
func TestAddressRemoval(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
@@ -479,99 +561,285 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
- // Remove the address, then check that packet doesn't get delivered
- // anymore.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
}
-func TestDelayedRemovalDueToRoute(t *testing.T) {
+func TestAddressRemovalWithRouteHeld(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatal("CreateNIC failed:", err)
}
-
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress failed:", err)
- }
-
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
buf := buffer.NewView(30)
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- // Get a route, check that packet is still deliverable.
- r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 2 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSend(t, r, linkEP, nil)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
- // Remove the address, then check that packet is still deliverable
- // because the route is keeping the address alive.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
+}
- // Release the route, then check that packet is not deliverable anymore.
- r.Release()
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) {
+ t.Helper()
+ info, ok := s.NICInfo()[nicid]
+ if !ok {
+ t.Fatalf("NICInfo() failed to find nicid=%d", nicid)
+ }
+ if len(addr) == 0 {
+ // No address given, verify that there is no address assigned to the NIC.
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
+ t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
+ }
+ }
+ return
+ }
+ // Address given, verify the address is assigned to the NIC and no other
+ // address is.
+ found := false
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber {
+ if a.AddressWithPrefix.Address == addr {
+ found = true
+ } else {
+ t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr)
+ }
+ }
+ }
+ if !found {
+ t.Errorf("verify address: couldn't find %s on the NIC", addr)
+ }
+}
+
+func TestEndpointExpiration(t *testing.T) {
+ const (
+ localAddrByte byte = 0x01
+ remoteAddr tcpip.Address = "\x03"
+ noAddr tcpip.Address = ""
+ nicid tcpip.NICID = 1
+ )
+ localAddr := tcpip.Address([]byte{localAddrByte})
+
+ for _, promiscuous := range []bool{true, false} {
+ for _, spoofing := range []bool{true, false} {
+ t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
+ buf[0] = localAddrByte
+
+ if promiscuous {
+ if err := s.SetPromiscuousMode(nicid, true); err != nil {
+ t.Fatal("SetPromiscuousMode failed:", err)
+ }
+ }
+
+ if spoofing {
+ if err := s.SetSpoofing(nicid, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ }
+
+ // 1. No Address yet, send should only work for spoofing, receive for
+ // promiscuous mode.
+ //-----------------------
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+
+ // 2. Add Address, everything should work.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+
+ // 3. Remove the address, send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+
+ // 4. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+
+ // 5. Take a reference to the endpoint by getting a route. Verify that
+ // we can still send/receive, including sending using the route.
+ //-----------------------
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+ testSend(t, r, linkEP, nil)
+
+ // 6. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ testSend(t, r, linkEP, nil)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+
+ // 7. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+ testSend(t, r, linkEP, nil)
+
+ // 8. Remove the route, sendTo/recv should still work.
+ //-----------------------
+ r.Release()
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testSendTo(t, s, remoteAddr, linkEP, nil)
+
+ // 9. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ }
+ })
+ }
}
}
@@ -583,9 +851,13 @@ func TestPromiscuousMode(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
@@ -593,22 +865,15 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
// Set promiscuous mode, then check that packet is delivered.
if err := s.SetPromiscuousMode(1, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
@@ -621,54 +886,120 @@ func TestPromiscuousMode(t *testing.T) {
if err := s.SetPromiscuousMode(1, false); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+func TestSpoofingWithAddress(t *testing.T) {
+ localAddr := tcpip.Address("\x01")
+ nonExistentLocalAddr := tcpip.Address("\x02")
+ dstAddr := tcpip.Address("\x03")
+
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
}
+ // Sending a packet works.
+ testSendTo(t, s, dstAddr, linkEP, nil)
+ testSend(t, r, linkEP, nil)
+
+ // FindRoute should also work with a local address that exists on the NIC.
+ r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != localAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet using the route works.
+ testSend(t, r, linkEP, nil)
}
-func TestAddressSpoofing(t *testing.T) {
- srcAddr := tcpip.Address("\x01")
+func TestSpoofingNoAddress(t *testing.T) {
+ nonExistentLocalAddr := tcpip.Address("\x01")
dstAddr := tcpip.Address("\x02")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
+ id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
- t.Fatal("AddAddress failed:", err)
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
-
// With address spoofing disabled, FindRoute does not permit an address
// that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err == nil {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
+ // Sending a packet fails.
+ testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute)
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
if err := s.SetSpoofing(1, true); err != nil {
t.Fatal("SetSpoofing failed:", err)
}
- r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- if r.LocalAddress != srcAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
}
+ // Sending a packet works.
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, linkEP, nil)
}
func TestBroadcastNeedsNoRoute(t *testing.T) {
@@ -806,16 +1137,20 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -824,9 +1159,52 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
t.Fatal("AddSubnet failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+}
+
+// Set the subnet, then check that CheckLocalAddress returns the correct NIC.
+func TestCheckLocalAddressForSubnet(t *testing.T) {
+ const nicID tcpip.NICID = 1
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, id); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}})
+ }
+
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
+
+ if err != nil {
+ t.Fatal("NewSubnet failed:", err)
+ }
+ if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddSubnet failed:", err)
+ }
+
+ // Loop over all subnet addresses and check them.
+ numOfAddresses := 1 << uint(8-subnet.Prefix())
+ if numOfAddresses < 1 || numOfAddresses > 255 {
+ t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
+ }
+ addr := []byte(subnet.ID())
+ for i := 0; i < numOfAddresses; i++ {
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID)
+ }
+ addr[0]++
+ }
+
+ // Trying the next address should fail since it is outside the subnet range.
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0)
}
}
@@ -839,16 +1217,20 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
t.Fatal("CreateNIC failed:", err)
}
- s.SetRouteTable([]tcpip.Route{
- {"\x00", "\x00", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -856,10 +1238,7 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
t.Fatal("AddSubnet failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
}
func TestNetworkOptions(t *testing.T) {
@@ -969,12 +1348,18 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// prefixLen.
address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen))
if behavior == stack.CanBePrimaryEndpoint {
- addressWithPrefix := tcpip.AddressWithPrefix{address, addrLen * 8}
- if err := s.AddAddressWithPrefixAndOptions(1, fakeNetNumber, addressWithPrefix, behavior); err != nil {
- t.Fatal("AddAddressWithPrefixAndOptions failed:", err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: addrLen * 8,
+ },
+ }
+ if err := s.AddProtocolAddressWithOptions(1, protocolAddress, behavior); err != nil {
+ t.Fatal("AddProtocolAddressWithOptions failed:", err)
}
// Remember the address/prefix.
- primaryAddrAdded[addressWithPrefix] = struct{}{}
+ primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{}
} else {
if err := s.AddAddressWithOptions(1, fakeNetNumber, address, behavior); err != nil {
t.Fatal("AddAddressWithOptions failed:", err)
@@ -1024,20 +1409,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
{"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", 116},
} {
t.Run(tc.name, func(t *testing.T) {
- addressWithPrefix := tcpip.AddressWithPrefix{tc.address, tc.prefixLen}
-
- if err := s.AddAddressWithPrefix(1, fakeNetNumber, addressWithPrefix); err != nil {
- t.Fatal("AddAddressWithPrefix failed:", err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tc.address,
+ PrefixLen: tc.prefixLen,
+ },
+ }
+ if err := s.AddProtocolAddress(1, protocolAddress); err != nil {
+ t.Fatal("AddProtocolAddress failed:", err)
}
// Check that we get the right initial address and prefix length.
if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
t.Fatal("GetMainNICAddress failed:", err)
- } else if gotAddressWithPrefix != addressWithPrefix {
- t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, addressWithPrefix)
+ } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix)
}
- if err := s.RemoveAddress(1, addressWithPrefix.Address); err != nil {
+ if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
@@ -1102,7 +1492,7 @@ func TestAddAddress(t *testing.T) {
verifyAddresses(t, expectedAddresses, gotAddresses)
}
-func TestAddAddressWithPrefix(t *testing.T) {
+func TestAddProtocolAddress(t *testing.T) {
const nicid = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
@@ -1116,14 +1506,17 @@ func TestAddAddressWithPrefix(t *testing.T) {
expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange))
for _, addrLen := range addrLenRange {
for _, prefixLen := range prefixLenRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithPrefix(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}); err != nil {
- t.Errorf("AddAddressWithPrefix(address=%s, prefixLen=%d) failed: %s", address, prefixLen, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addrGen.next(addrLen),
+ PrefixLen: prefixLen,
+ },
}
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen},
- })
+ if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil {
+ t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
+ }
+ expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
@@ -1160,7 +1553,7 @@ func TestAddAddressWithOptions(t *testing.T) {
verifyAddresses(t, expectedAddresses, gotAddresses)
}
-func TestAddAddressWithPrefixAndOptions(t *testing.T) {
+func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicid = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
@@ -1176,14 +1569,17 @@ func TestAddAddressWithPrefixAndOptions(t *testing.T) {
for _, addrLen := range addrLenRange {
for _, prefixLen := range prefixLenRange {
for _, behavior := range behaviorRange {
- address := addrGen.next(addrLen)
- if err := s.AddAddressWithPrefixAndOptions(nicid, fakeNetNumber, tcpip.AddressWithPrefix{address, prefixLen}, behavior); err != nil {
- t.Fatalf("AddAddressWithPrefixAndOptions(address=%s, prefixLen=%d, behavior=%d) failed: %s", address, prefixLen, behavior, err)
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addrGen.next(addrLen),
+ PrefixLen: prefixLen,
+ },
}
- expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
- Protocol: fakeNetNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{address, prefixLen},
- })
+ if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
+ }
+ expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
}
@@ -1196,15 +1592,19 @@ func TestNICStats(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id1, linkEP1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ t.Fatal("CreateNIC failed: ", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress failed:", err)
}
// Route all packets for address \x01 to NIC 1.
- s.SetRouteTable([]tcpip.Route{
- {"\x01", "\xff", "\x00", 1},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x01", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
// Send a packet to address 1.
buf := buffer.NewView(30)
@@ -1219,7 +1619,9 @@ func TestNICStats(t *testing.T) {
payload := buffer.NewView(10)
// Write a packet out via the address for NIC 1
- sendTo(t, s, "\x01", payload)
+ if err := sendTo(s, "\x01", payload); err != nil {
+ t.Fatal("sendTo failed: ", err)
+ }
want := uint64(linkEP1.Drain())
if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want)
@@ -1253,9 +1655,13 @@ func TestNICForwarding(t *testing.T) {
}
// Route all packets to address 3 to NIC 2.
- s.SetRouteTable([]tcpip.Route{
- {"\x03", "\xff", "\x00", 2},
- })
+ {
+ subnet, err := tcpip.NewSubnet("\x03", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
+ }
// Send a packet to address 3.
buf := buffer.NewView(30)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index b418db046..5335897f5 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -64,7 +65,7 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr
return buffer.View{}, tcpip.ControlMessages{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, nil, tcpip.ErrNoRoute
}
@@ -78,10 +79,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions)
return 0, nil, err
}
- return uintptr(len(v)), nil, nil
+ return int64(len(v)), nil, nil
}
-func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -104,6 +105,11 @@ func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*fakeTransportEndpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
@@ -200,6 +206,13 @@ func (f *fakeTransportEndpoint) State() uint32 {
func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {
}
+func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
+ return iptables.IPTables{}, nil
+}
+
+func (f *fakeTransportEndpoint) Resume(*stack.Stack) {
+}
+
type fakeTransportGoodOption bool
type fakeTransportBadOption bool
@@ -271,7 +284,13 @@ func TestTransportReceive(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
@@ -327,7 +346,13 @@ func TestTransportControlReceive(t *testing.T) {
t.Fatalf("CreateNIC failed: %v", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
@@ -393,7 +418,13 @@ func TestTransportSend(t *testing.T) {
t.Fatalf("AddAddress failed: %v", err)
}
- s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
// Create endpoint and bind it.
wq := waiter.Queue{}
@@ -484,10 +515,20 @@ func TestTransportForwarding(t *testing.T) {
// Route all packets to address 3 to NIC 2 and all packets to address
// 1 to NIC 1.
- s.SetRouteTable([]tcpip.Route{
- {"\x03", "\xff", "\x00", 2},
- {"\x01", "\xff", "\x00", 1},
- })
+ {
+ subnet0, err := tcpip.NewSubnet("\x03", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet0, Gateway: "\x00", NIC: 2},
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ })
+ }
wq := waiter.Queue{}
ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 4208c0303..8f9b86cce 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -31,6 +31,7 @@ package tcpip
import (
"errors"
"fmt"
+ "math/bits"
"reflect"
"strconv"
"strings"
@@ -39,6 +40,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -144,8 +146,17 @@ type Address string
type AddressMask string
// String implements Stringer.
-func (a AddressMask) String() string {
- return Address(a).String()
+func (m AddressMask) String() string {
+ return Address(m).String()
+}
+
+// Prefix returns the number of bits before the first host bit.
+func (m AddressMask) Prefix() int {
+ p := 0
+ for _, b := range []byte(m) {
+ p += bits.LeadingZeros8(^b)
+ }
+ return p
}
// Subnet is a subnet defined by its address and mask.
@@ -167,6 +178,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) {
return Subnet{a, m}, nil
}
+// String implements Stringer.
+func (s Subnet) String() string {
+ return fmt.Sprintf("%s/%d", s.ID(), s.Prefix())
+}
+
// Contains returns true iff the address is of the same length and matches the
// subnet address and mask.
func (s *Subnet) Contains(a Address) bool {
@@ -189,28 +205,13 @@ func (s *Subnet) ID() Address {
// Bits returns the number of ones (network bits) and zeros (host bits) in the
// subnet mask.
func (s *Subnet) Bits() (ones int, zeros int) {
- for _, b := range []byte(s.mask) {
- for i := uint(0); i < 8; i++ {
- if b&(1<<i) == 0 {
- zeros++
- } else {
- ones++
- }
- }
- }
- return
+ ones = s.mask.Prefix()
+ return ones, len(s.mask)*8 - ones
}
// Prefix returns the number of bits before the first host bit.
func (s *Subnet) Prefix() int {
- for i, b := range []byte(s.mask) {
- for j := 7; j >= 0; j-- {
- if b&(1<<uint(j)) == 0 {
- return i*8 + 7 - j
- }
- }
- }
- return len(s.mask) * 8
+ return s.mask.Prefix()
}
// Mask returns the subnet mask.
@@ -328,12 +329,12 @@ type Endpoint interface {
// ErrNoLinkAddress and a notification channel is returned for the caller to
// block. Channel is closed once address resolution is complete (success or
// not). The channel is only non-nil in this case.
- Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error)
+ Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error)
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
- Peek([][]byte) (uintptr, ControlMessages, *Error)
+ Peek([][]byte) (int64, ControlMessages, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -352,6 +353,9 @@ type Endpoint interface {
// ErrAddressFamilyNotSupported must be returned.
Connect(address FullAddress) *Error
+ // Disconnect disconnects the endpoint from its peer.
+ Disconnect() *Error
+
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
Shutdown(flags ShutdownFlags) *Error
@@ -403,6 +407,9 @@ type Endpoint interface {
//
// NOTE: This method is a no-op for sockets other than TCP.
ModerateRecvBuf(copied int)
+
+ // IPTables returns the iptables for this endpoint's stack.
+ IPTables() (iptables.IPTables, error)
}
// WriteOptions contains options for Endpoint.Write.
@@ -563,13 +570,8 @@ type BroadcastOption int
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
type Route struct {
- // Destination is the address that must be matched against the masked
- // target address to check if this row is viable.
- Destination Address
-
- // Mask specifies which bits of the Destination and the target address
- // must match for this row to be viable.
- Mask AddressMask
+ // Destination must contain the target address for this row to be viable.
+ Destination Subnet
// Gateway is the gateway to be used if this row is viable.
Gateway Address
@@ -578,25 +580,15 @@ type Route struct {
NIC NICID
}
-// Match determines if r is viable for the given destination address.
-func (r *Route) Match(addr Address) bool {
- if len(addr) != len(r.Destination) {
- return false
- }
-
- // Using header.Ipv4Broadcast would introduce an import cycle, so
- // we'll use a literal instead.
- if addr == "\xff\xff\xff\xff" {
- return true
- }
-
- for i := 0; i < len(r.Destination); i++ {
- if (addr[i] & r.Mask[i]) != r.Destination[i] {
- return false
- }
+// String implements the fmt.Stringer interface.
+func (r Route) String() string {
+ var out strings.Builder
+ fmt.Fprintf(&out, "%s", r.Destination)
+ if len(r.Gateway) > 0 {
+ fmt.Fprintf(&out, " via %s", r.Gateway)
}
-
- return true
+ fmt.Fprintf(&out, " nic %d", r.NIC)
+ return out.String()
}
// LinkEndpointID represents a data link layer endpoint.
@@ -1068,6 +1060,11 @@ type AddressWithPrefix struct {
PrefixLen int
}
+// String implements the fmt.Stringer interface.
+func (a AddressWithPrefix) String() string {
+ return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen)
+}
+
// ProtocolAddress is an address and the network protocol it is associated
// with.
type ProtocolAddress struct {
@@ -1078,11 +1075,13 @@ type ProtocolAddress struct {
AddressWithPrefix AddressWithPrefix
}
-// danglingEndpointsMu protects access to danglingEndpoints.
-var danglingEndpointsMu sync.Mutex
+var (
+ // danglingEndpointsMu protects access to danglingEndpoints.
+ danglingEndpointsMu sync.Mutex
-// danglingEndpoints tracks all dangling endpoints no longer owned by the app.
-var danglingEndpoints = make(map[Endpoint]struct{})
+ // danglingEndpoints tracks all dangling endpoints no longer owned by the app.
+ danglingEndpoints = make(map[Endpoint]struct{})
+)
// GetDanglingEndpoints returns all dangling endpoints.
func GetDanglingEndpoints() []Endpoint {
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index ebb1c1b56..fb3a0a5ee 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -60,12 +60,12 @@ func TestSubnetBits(t *testing.T) {
}{
{"\x00", 0, 8},
{"\x00\x00", 0, 16},
- {"\x36", 4, 4},
- {"\x5c", 4, 4},
- {"\x5c\x5c", 8, 8},
- {"\x5c\x36", 8, 8},
- {"\x36\x5c", 8, 8},
- {"\x36\x36", 8, 8},
+ {"\x36", 0, 8},
+ {"\x5c", 0, 8},
+ {"\x5c\x5c", 0, 16},
+ {"\x5c\x36", 0, 16},
+ {"\x36\x5c", 0, 16},
+ {"\x36\x36", 0, 16},
{"\xff", 8, 0},
{"\xff\xff", 16, 0},
}
@@ -122,26 +122,6 @@ func TestSubnetCreation(t *testing.T) {
}
}
-func TestRouteMatch(t *testing.T) {
- tests := []struct {
- d Address
- m AddressMask
- a Address
- want bool
- }{
- {"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
- {"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
- {"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
- {"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
- }
- for _, tt := range tests {
- r := Route{Destination: tt.d, Mask: tt.m}
- if got := r.Match(tt.a); got != tt.want {
- t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want)
- }
- }
-}
-
func TestAddressString(t *testing.T) {
for _, want := range []string{
// Taken from stdlib.
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 62182a3e6..d78a162b8 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index ba6671c26..451d3880e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -130,6 +131,11 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+// IPTables implements tcpip.Endpoint.IPTables.
+func (e *endpoint) IPTables() (iptables.IPTables, error) {
+ return e.stack.IPTables(), nil
+}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -199,7 +205,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -301,11 +307,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
return 0, nil, err
}
- return uintptr(len(v)), nil, nil
+ return int64(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -422,16 +428,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
nicid := addr.NIC
localPort := uint16(0)
switch e.state {
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 99b8c4093..c587b96b6 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -63,7 +63,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- e.stack = stack.StackFromEnv
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
if e.state != stateBound && e.state != stateConnected {
return
@@ -73,7 +78,7 @@ func (e *endpoint) afterLoad() {
if e.state == stateConnected {
e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
if err != nil {
- panic(*err)
+ panic(err)
}
e.id.LocalAddress = e.route.LocalAddress
@@ -85,6 +90,6 @@ func (e *endpoint) afterLoad() {
e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
if err != nil {
- panic(*err)
+ panic(err)
}
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index bc4b255b4..7241f6c19 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b633cd9d8..13e17e2a6 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -168,6 +169,11 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
+// IPTables implements tcpip.Endpoint.IPTables.
+func (ep *endpoint) IPTables() (iptables.IPTables, error) {
+ return ep.stack.IPTables(), nil
+}
+
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
if !ep.associated {
@@ -201,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
}
// Write implements tcpip.Endpoint.Write.
-func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -305,7 +311,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) {
// We may need to resolve the route (match a link layer address to the
// network address). If that requires blocking (e.g. to use ARP),
// return a channel on which the caller can wait.
@@ -335,24 +341,24 @@ func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintpt
return 0, nil, tcpip.ErrUnknownProtocol
}
- return uintptr(len(payloadBytes)), nil, nil
+ return int64(len(payloadBytes)), nil, nil
}
// Peek implements tcpip.Endpoint.Peek.
-func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect implements tcpip.Endpoint.Connect.
func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if addr.Addr == "" {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
if ep.closed {
return tcpip.ErrInvalidEndpointState
}
@@ -484,7 +490,7 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return nil
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index cb5534d90..168953dec 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -63,19 +63,23 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- // StackFromEnv is a stack used specifically for save/restore.
- ep.stack = stack.StackFromEnv
+ stack.StackFromEnv.RegisterRestoredEndpoint(ep)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (ep *endpoint) Resume(s *stack.Stack) {
+ ep.stack = s
- // If the endpoint is connected, re-connect via the save/restore stack.
+ // If the endpoint is connected, re-connect.
if ep.connected {
var err *tcpip.Error
ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
if err != nil {
- panic(*err)
+ panic(err)
}
}
- // If the endpoint is bound, re-bind via the save/restore stack.
+ // If the endpoint is bound, re-bind.
if ep.bound {
if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
panic(tcpip.ErrBadLocalAddress)
@@ -83,6 +87,6 @@ func (ep *endpoint) afterLoad() {
}
if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
- panic(*err)
+ panic(err)
}
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 4cd25e8e2..1ee1a53f8 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -48,6 +48,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 52fd1bfa3..e9c5099ea 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -96,6 +96,17 @@ type listenContext struct {
hasher hash.Hash
v6only bool
netProto tcpip.NetworkProtocolNumber
+ // pendingMu protects pendingEndpoints. This should only be accessed
+ // by the listening endpoint's worker goroutine.
+ //
+ // Lock Ordering: listenEP.workerMu -> pendingMu
+ pendingMu sync.Mutex
+ // pending is used to wait for all pendingEndpoints to finish when
+ // a socket is closed.
+ pending sync.WaitGroup
+ // pendingEndpoints is a map of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -133,14 +144,15 @@ func decSynRcvdCount() {
}
// newListenContext creates a new listen context.
-func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stack,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6only: v6only,
- netProto: netProto,
- listenEP: listenEP,
+ stack: stk,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6only: v6only,
+ netProto: netProto,
+ listenEP: listenEP,
+ pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
rand.Read(l.nonce[0][:])
@@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, err
}
+ // listenEP is nil when listenContext is used by tcp.Forwarder.
+ if l.listenEP != nil {
+ l.listenEP.mu.Lock()
+ if l.listenEP.state != StateListen {
+ l.listenEP.mu.Unlock()
+ return nil, tcpip.ErrConnectionAborted
+ }
+ l.addPendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
// Perform the 3-way handshake.
h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
@@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if err := h.execute(); err != nil {
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
return nil, err
}
ep.mu.Lock()
@@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return ep, nil
}
+func (l *listenContext) addPendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ l.pendingEndpoints[n.id] = n
+ l.pending.Add(1)
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) removePendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ delete(l.pendingEndpoints, n.id)
+ l.pending.Done()
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) closeAllPendingEndpoints() {
+ l.pendingMu.Lock()
+ for _, n := range l.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
+ }
+ l.pendingMu.Unlock()
+ l.pending.Wait()
+}
+
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state, the new endpoint is closed
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
- e.mu.RLock()
+ e.mu.Lock()
state := e.state
- e.mu.RUnlock()
+ e.pendingAccepted.Add(1)
+ defer e.pendingAccepted.Done()
+ acceptedChan := e.acceptedChan
+ e.mu.Unlock()
if state == StateListen {
- e.acceptedChan <- n
+ acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
} else {
n.Close()
@@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
return
}
-
+ ctx.removePendingEndpoint(n)
e.deliverAccepted(n)
}
@@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+ e.mu.Lock()
+ v6only := e.v6only
+ e.mu.Unlock()
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
@@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
e.state = StateClose
+ // close any endpoints in SYN-RCVD state.
+ ctx.closeAllPendingEndpoints()
+
// Do cleanup if needed.
e.completeWorkerLocked()
@@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
}()
- e.mu.Lock()
- v6only := e.v6only
- e.mu.Unlock()
-
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
-
s := sleep.Sleeper{}
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
@@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.handleListenSegment(ctx, s)
s.decRef()
}
- synRcvdCount.pending.Wait()
close(e.drainDone)
<-e.undrain
}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index d9f79e8c5..c54610a87 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -570,3 +570,89 @@ func TestV4AcceptOnV4(t *testing.T) {
// Test acceptance.
testV4Accept(t, c)
}
+
+func testV4ListenClose(t *testing.T, c *context.Context) {
+ // Set the SynRcvd threshold to zero to force a syn cookie based accept
+ // to happen.
+ saved := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = saved
+ }()
+ tcp.SynRcvdCountThreshold = 0
+ const n = uint16(32)
+
+ // Start listening.
+ if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ irs := seqnum.Value(789)
+ for i := uint16(0); i < n; i++ {
+ // Send a SYN request.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + i,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Each of these ACK's will cause a syn-cookie based connection to be
+ // accepted and delivered to the listening endpoint.
+ for i := uint16(0); i < n; i++ {
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+ nep, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ nep, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(10 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+ nep.Close()
+ c.EP.Close()
+}
+
+func TestV4ListenCloseOnV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4ListenClose(t, c)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index cc49c8272..ac927569a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tmutex"
@@ -361,6 +362,12 @@ type endpoint struct {
// without hearing a response, the connection is closed.
keepalive keepalive
+ // pendingAccepted is a synchronization primitive used to track number
+ // of connections that are queued up to be delivered to the accepted
+ // channel. We use this to ensure that all goroutines blocked on writing
+ // to the acceptedChan below terminate before we close acceptedChan.
+ pendingAccepted sync.WaitGroup `state:"nosave"`
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -374,7 +381,11 @@ type endpoint struct {
// The goroutine drain completion notification channel.
drainDone chan struct{} `state:"nosave"`
- // The goroutine undrain notification channel.
+ // The goroutine undrain notification channel. This is currently used as
+ // a way to block the worker goroutines. Today nothing closes/writes
+ // this channel and this causes any goroutines waiting on this to just
+ // block. This is used during save/restore to prevent worker goroutines
+ // from mutating state as it's being saved.
undrain chan struct{} `state:"nosave"`
// probe if not nil is invoked on every received segment. It is passed
@@ -574,6 +585,34 @@ func (e *endpoint) Close() {
e.mu.Unlock()
}
+// closePendingAcceptableConnections closes all connections that have completed
+// handshake but not yet been delivered to the application.
+func (e *endpoint) closePendingAcceptableConnectionsLocked() {
+ done := make(chan struct{})
+ // Spin a goroutine up as ranging on e.acceptedChan will just block when
+ // there are no more connections in the channel. Using a non-blocking
+ // select does not work as it can potentially select the default case
+ // even when there are pending writes but that are not yet written to
+ // the channel.
+ go func() {
+ defer close(done)
+ for n := range e.acceptedChan {
+ n.mu.Lock()
+ n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.mu.Unlock()
+ n.Close()
+ }
+ }()
+ // pendingAccepted(see endpoint.deliverAccepted) tracks the number of
+ // endpoints which have completed handshake but are not yet written to
+ // the e.acceptedChan. We wait here till the goroutine above can drain
+ // all such connections from e.acceptedChan.
+ e.pendingAccepted.Wait()
+ close(e.acceptedChan)
+ <-done
+ e.acceptedChan = nil
+}
+
// cleanupLocked frees all resources associated with the endpoint. It is called
// after Close() is called and the worker goroutine (if any) is done with its
// work.
@@ -581,14 +620,7 @@ func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
- close(e.acceptedChan)
- for n := range e.acceptedChan {
- n.mu.Lock()
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
- n.mu.Unlock()
- n.Close()
- }
- e.acceptedChan = nil
+ e.closePendingAcceptableConnectionsLocked()
}
e.workerCleanup = false
@@ -683,6 +715,11 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
+// IPTables implements tcpip.Endpoint.IPTables.
+func (e *endpoint) IPTables() (iptables.IPTables, error) {
+ return e.stack.IPTables(), nil
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
@@ -740,60 +777,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
return v, nil
}
-// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
- // Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
- // and opts.EndOfRecord are also ignored.
-
- e.mu.RLock()
- defer e.mu.RUnlock()
-
+// isEndpointWritableLocked checks if a given endpoint is writable
+// and also returns the number of bytes that can be written at this
+// moment. If the endpoint is not writable then it returns an error
+// indicating the reason why it's not writable.
+// Caller must hold e.mu and e.sndBufMu
+func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
if !e.state.connected() {
switch e.state {
case StateError:
- return 0, nil, e.hardError
+ return 0, e.hardError
default:
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
}
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
- }
-
- e.sndBufMu.Lock()
-
// Check if the connection has already been closed for sends.
if e.sndClosed {
- e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrClosedForSend
+ return 0, tcpip.ErrClosedForSend
}
- // Check against the limit.
avail := e.sndBufSize - e.sndBufUsed
if avail <= 0 {
+ return 0, tcpip.ErrWouldBlock
+ }
+ return avail, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
e.sndBufMu.Unlock()
- return 0, nil, tcpip.ErrWouldBlock
+ e.mu.RUnlock()
+ return 0, nil, err
}
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+
+ // Nothing to do if the buffer is empty.
+ if p.Size() == 0 {
+ return 0, nil, nil
+ }
+
+ // Copy in memory without holding sndBufMu so that worker goroutine can
+ // make progress independent of this operation.
v, perr := p.Get(avail)
if perr != nil {
- e.sndBufMu.Unlock()
return 0, nil, perr
}
- l := len(v)
- s := newSegmentFromView(&e.route, e.id, v)
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a
+ // write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due to a
+ // simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
// Add data to the send queue.
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
e.sndBufUsed += l
e.sndBufInQueue += seqnum.Size(l)
e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
if e.workMu.TryLock() {
// Do the work inline.
@@ -803,13 +875,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return uintptr(l), nil, nil
+ return int64(l), nil, nil
}
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -835,8 +907,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
// Make a copy of vec so we can modify the slide headers.
vec = append([][]byte(nil), vec...)
- var num uintptr
-
+ var num int64
for s := e.rcvList.Front(); s != nil; s = s.Next() {
views := s.data.Views()
@@ -855,7 +926,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er
n := copy(vec[0], v)
v = v[n:]
vec[0] = vec[0][n:]
- num += uintptr(n)
+ num += int64(n)
}
}
}
@@ -1277,7 +1348,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == "\x00\x00\x00\x00" {
+ if addr.Addr == header.IPv4Any {
addr.Addr = ""
}
}
@@ -1291,13 +1362,13 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
return netProto, nil
}
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (*endpoint) Disconnect() *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- if addr.Addr == "" && addr.Port == 0 {
- // AF_UNSPEC isn't supported.
- return tcpip.ErrAddressFamilyNotSupported
- }
-
return e.connect(addr, true, true)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b3f0f6c5d..831389ec7 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -165,7 +165,12 @@ func (e *endpoint) loadState(state EndpointState) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- e.stack = stack.StackFromEnv
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
@@ -197,14 +202,13 @@ func (e *endpoint) afterLoad() {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
bind()
if len(e.connectingAddress) == 0 {
+ e.connectingAddress = e.id.RemoteAddress
// This endpoint is accepted by netstack but not yet by
// the app. If the endpoint is IPv6 but the remote
// address is IPv4, we need to connect as IPv6 so that
// dual-stack mode can be properly activated.
if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
- } else {
- e.connectingAddress = e.id.RemoteAddress
}
}
// Reset the scoreboard to reinitialize the sack information as
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 0fee7ab72..735edfe55 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -39,6 +39,28 @@ const (
nDupAckThreshold = 3
)
+// ccState indicates the current congestion control state for this sender.
+type ccState int
+
+const (
+ // Open indicates that the sender is receiving acks in order and
+ // no loss or dupACK's etc have been detected.
+ Open ccState = iota
+ // RTORecovery indicates that an RTO has occurred and the sender
+ // has entered an RTO based recovery phase.
+ RTORecovery
+ // FastRecovery indicates that the sender has entered FastRecovery
+ // based on receiving nDupAck's. This state is entered only when
+ // SACK is not in use.
+ FastRecovery
+ // SACKRecovery indicates that the sender has entered SACK based
+ // recovery.
+ SACKRecovery
+ // Disorder indicates the sender either received some SACK blocks
+ // or dupACK's.
+ Disorder
+)
+
// congestionControl is an interface that must be implemented by any supported
// congestion control algorithm.
type congestionControl interface {
@@ -138,6 +160,9 @@ type sender struct {
// maxSentAck is the maxium acknowledgement actually sent.
maxSentAck seqnum.Value
+ // state is the current state of congestion control for this endpoint.
+ state ccState
+
// cc is the congestion control algorithm in use for this sender.
cc congestionControl
}
@@ -435,6 +460,7 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveFastRecovery()
}
+ s.state = RTORecovery
s.cc.HandleRTOExpired()
// Mark the next segment to be sent as the first unacknowledged one and
@@ -638,7 +664,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
segEnd = seg.sequenceNumber.Add(1)
// Transition to FIN-WAIT1 state since we're initiating an active close.
s.ep.mu.Lock()
- s.ep.state = StateFinWait1
+ switch s.ep.state {
+ case StateCloseWait:
+ // We've already received a FIN and are now sending our own. The
+ // sender is now awaiting a final ACK for this FIN.
+ s.ep.state = StateLastAck
+ default:
+ s.ep.state = StateFinWait1
+ }
s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
@@ -820,9 +853,11 @@ func (s *sender) enterFastRecovery() {
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
if s.ep.sackPermitted {
+ s.state = SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
return
}
+ s.state = FastRecovery
s.ep.stack.Stats().TCP.FastRecovery.Increment()
}
@@ -981,6 +1016,7 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
s.fr.highRxt = s.sndUna - 1
// Do run SetPipe() to calculate the outstanding segments.
s.SetPipe()
+ s.state = Disorder
return false
}
@@ -1112,6 +1148,9 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// window based on the number of acknowledged packets.
if !s.fr.active {
s.cc.Update(originalOutstanding - s.outstanding)
+ if s.fr.last.LessThan(s.sndUna) {
+ s.state = Open
+ }
}
// It is possible for s.outstanding to drop below zero if we get
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 915a98047..f79b8ec5f 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2874,15 +2874,11 @@ func makeStack() (*stack.Stack, *tcpip.Error) {
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
- Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index bcc0f3e28..272481aa0 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -168,15 +168,11 @@ func New(t *testing.T, mtu uint32) *Context {
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
- Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 6dac66b50..ac2666f69 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 91f89a781..ac5905772 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -172,6 +173,11 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
+// IPTables implements tcpip.Endpoint.IPTables.
+func (e *endpoint) IPTables() (iptables.IPTables, error) {
+ return e.stack.IPTables(), nil
+}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -241,13 +247,13 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto, err := e.checkV4Mapped(&addr, false)
- if err != nil {
- return stack.Route{}, 0, 0, err
+func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+ localAddr := e.id.LocalAddress
+ if isBroadcastOrMulticast(localAddr) {
+ // A packet can only originate from a unicast address (i.e., an interface).
+ localAddr = ""
}
- localAddr := e.id.LocalAddress
if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
if nicid == 0 {
nicid = e.multicastNICID
@@ -260,14 +266,14 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac
// Find a route to the desired destination.
r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
- return stack.Route{}, 0, 0, err
+ return stack.Route{}, 0, err
}
- return r, nicid, netProto, nil
+ return r, nicid, nil
}
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -336,7 +342,12 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- r, _, _, err := e.connectRoute(nicid, *to)
+ netProto, err := e.checkV4Mapped(to, false)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ r, _, err := e.connectRoute(nicid, *to, netProto)
if err != nil {
return 0, nil, err
}
@@ -368,11 +379,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
return 0, nil, err
}
- return uintptr(len(v)), nil, nil
+ return int64(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -442,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
nicID := v.NIC
- if v.InterfaceAddr == header.IPv4Any {
+
+ // The interface address is considered not-set if it is empty or contains
+ // all-zeros. The former represent the zero-value in golang, the latter the
+ // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall.
+ allZeros := header.IPv4Any
+ if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros {
if nicID == 0 {
r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
if err == nil {
@@ -686,7 +702,7 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == "\x00\x00\x00\x00" {
+ if addr.Addr == header.IPv4Any {
addr.Addr = ""
}
@@ -705,7 +721,8 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
-func (e *endpoint) disconnect() *tcpip.Error {
+// Disconnect implements tcpip.Endpoint.Disconnect.
+func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -740,8 +757,9 @@ func (e *endpoint) disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- if addr.Addr == "" {
- return e.disconnect()
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
}
if addr.Port == 0 {
// We don't support connecting to port zero.
@@ -770,7 +788,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- r, nicid, netProto, err := e.connectRoute(nicid, addr)
+ r, nicid, err := e.connectRoute(nicid, addr, netProto)
if err != nil {
return err
}
@@ -906,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
nicid := addr.NIC
- if len(addr.Addr) != 0 {
- // A local address was specified, verify that it's valid.
+ if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
+ // A local unicast address was specified, verify that it's valid.
nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nicid == 0 {
return tcpip.ErrBadLocalAddress
@@ -1056,3 +1074,7 @@ func (e *endpoint) State() uint32 {
// TODO(b/112063468): Translate internal state to values returned by Linux.
return 0
}
+
+func isBroadcastOrMulticast(a tcpip.Address) bool {
+ return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 18e786397..5cbb56120 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -64,7 +64,12 @@ func (e *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- e.stack = stack.StackFromEnv
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
+}
+
+// Resume implements tcpip.ResumableEndpoint.Resume.
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
for _, m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
@@ -90,9 +95,10 @@ func (e *endpoint) afterLoad() {
if e.state == stateConnected {
e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
if err != nil {
- panic(*err)
+ panic(err)
}
- } else if len(e.id.LocalAddress) != 0 { // stateBound
+ } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound
+ // A local unicast address is specified, verify that it's valid.
if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
@@ -105,6 +111,6 @@ func (e *endpoint) afterLoad() {
e.id.LocalPort = 0
e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
if err != nil {
- panic(*err)
+ panic(err)
}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 56c285f88..9da6edce2 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "fmt"
"math"
"math/rand"
"testing"
@@ -34,13 +35,19 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// Addresses and ports used for testing. It is recommended that tests stick to
+// using these addresses as it allows using the testFlow helper.
+// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*'
+// represents the remote endpoint.
const (
+ v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
- testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
- multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ stackV4MappedAddr = v4MappedAddrPrefix + stackAddr
+ testV4MappedAddr = v4MappedAddrPrefix + testAddr
+ multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr
+ broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr
+ v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00"
stackAddr = "\x0a\x00\x00\x01"
stackPort = 1234
@@ -48,7 +55,7 @@ const (
testPort = 4096
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- multicastPort = 1234
+ broadcastAddr = header.IPv4Broadcast
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -56,6 +63,205 @@ const (
defaultMTU = 65536
)
+// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in
+// a packet header. These values are used to populate a header or verify one.
+// Note that because they are used in packet headers, the addresses are never in
+// a V4-mapped format.
+type header4Tuple struct {
+ srcAddr tcpip.FullAddress
+ dstAddr tcpip.FullAddress
+}
+
+// testFlow implements a helper type used for sending and receiving test
+// packets. A given test flow value defines 1) the socket endpoint used for the
+// test and 2) the type of packet send or received on the endpoint. E.g., a
+// multicastV6Only flow is a V6 multicast packet passing through a V6-only
+// endpoint. The type provides helper methods to characterize the flow (e.g.,
+// isV4) as well as return a proper header4Tuple for it.
+type testFlow int
+
+const (
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+)
+
+func (flow testFlow) String() string {
+ switch flow {
+ case unicastV4:
+ return "unicastV4"
+ case unicastV6:
+ return "unicastV6"
+ case unicastV6Only:
+ return "unicastV6Only"
+ case unicastV4in6:
+ return "unicastV4in6"
+ case multicastV4:
+ return "multicastV4"
+ case multicastV6:
+ return "multicastV6"
+ case multicastV6Only:
+ return "multicastV6Only"
+ case multicastV4in6:
+ return "multicastV4in6"
+ case broadcast:
+ return "broadcast"
+ case broadcastIn6:
+ return "broadcastIn6"
+ default:
+ return "unknown"
+ }
+}
+
+// packetDirection explains if a flow is incoming (read) or outgoing (write).
+type packetDirection int
+
+const (
+ incoming packetDirection = iota
+ outgoing
+)
+
+// header4Tuple returns the header4Tuple for the given flow and direction. Note
+// that the tuple contains no mapped addresses as those only exist at the socket
+// level but not at the packet header level.
+func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
+ var h header4Tuple
+ if flow.isV4() {
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastAddr
+ } else if flow.isBroadcast() {
+ h.dstAddr.Addr = broadcastAddr
+ }
+ } else { // IPv6
+ if d == outgoing {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ }
+ } else {
+ h = header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
+ }
+ }
+ if flow.isMulticast() {
+ h.dstAddr.Addr = multicastV6Addr
+ }
+ }
+ return h
+}
+
+func (flow testFlow) getMcastAddr() tcpip.Address {
+ if flow.isV4() {
+ return multicastAddr
+ }
+ return multicastV6Addr
+}
+
+// mapAddrIfApplicable converts the given V4 address into its V4-mapped version
+// if it is applicable to the flow.
+func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address {
+ if flow.isMapped() {
+ return v4MappedAddrPrefix + v4Addr
+ }
+ return v4Addr
+}
+
+// netProto returns the protocol number used for the network packet.
+func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
+ if flow.isV4() {
+ return ipv4.ProtocolNumber
+ }
+ return ipv6.ProtocolNumber
+}
+
+// sockProto returns the protocol number used when creating the socket
+// endpoint for this flow.
+func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
+ switch flow {
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ return ipv6.ProtocolNumber
+ case unicastV4, multicastV4, broadcast:
+ return ipv4.ProtocolNumber
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) {
+ if flow.isV4() {
+ return checker.IPv4
+ }
+ return checker.IPv6
+}
+
+func (flow testFlow) isV6() bool { return !flow.isV4() }
+func (flow testFlow) isV4() bool {
+ return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped()
+}
+
+func (flow testFlow) isV6Only() bool {
+ switch flow {
+ case unicastV6Only, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMulticast() bool {
+ switch flow {
+ case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isBroadcast() bool {
+ switch flow {
+ case broadcast, broadcastIn6:
+ return true
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
+func (flow testFlow) isMapped() bool {
+ switch flow {
+ case unicastV4in6, multicastV4in6, broadcastIn6:
+ return true
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ return false
+ default:
+ panic(fmt.Sprintf("invalid testFlow given: %d", flow))
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -65,12 +271,9 @@ type testContext struct {
wq waiter.Queue
}
-type headers struct {
- srcPort uint16
- dstPort uint16
-}
-
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
+ t.Helper()
+
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
id, linkEP := channel.New(256, mtu, "")
@@ -91,15 +294,11 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
s.SetRouteTable([]tcpip.Route{
{
- Destination: "\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
- Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
- Gateway: "",
+ Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
@@ -117,51 +316,54 @@ func (c *testContext) cleanup() {
}
}
-func (c *testContext) createV6Endpoint(v6only bool) {
+func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
+ c.t.Helper()
+
var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
+ c.t.Fatal("NewEndpoint failed: ", err)
}
+}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.ep.SetSockOpt(v); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
+func (c *testContext) createEndpointForFlow(flow testFlow) {
+ c.t.Helper()
+
+ c.createEndpoint(flow.sockProto())
+ if flow.isV6Only() {
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ } else if flow.isBroadcast() {
+ if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
+ c.t.Fatal("SetSockOpt failed:", err)
+ }
}
}
-func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte {
+// getPacketAndVerify reads a packet from the link endpoint and verifies the
+// header against expected values from the given test flow. In addition, it
+// calls any extra checker functions provided.
+func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
+ c.t.Helper()
+
select {
case p := <-c.linkEP.C:
- if p.Proto != protocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber)
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
- var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker)
- var srcAddr, dstAddr tcpip.Address
- switch protocolNumber {
- case ipv4.ProtocolNumber:
- checkerFn = checker.IPv4
- srcAddr, dstAddr = stackAddr, testAddr
- if multicast {
- dstAddr = multicastAddr
- }
- case ipv6.ProtocolNumber:
- checkerFn = checker.IPv6
- srcAddr, dstAddr = stackV6Addr, testV6Addr
- if multicast {
- dstAddr = multicastV6Addr
- }
- default:
- c.t.Fatalf("unknown protocol %d", protocolNumber)
- }
- checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr))
+ h := flow.header4Tuple(outgoing)
+ checkers := append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
return b
case <-time.After(2 * time.Second):
@@ -171,7 +373,22 @@ func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, mult
return nil
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers) {
+// injectPacket creates a packet of the given flow and with the given payload,
+// and injects it into the link endpoint.
+func (c *testContext) injectPacket(flow testFlow, payload []byte) {
+ c.t.Helper()
+
+ h := flow.header4Tuple(incoming)
+ if flow.isV4() {
+ c.injectV4Packet(payload, &h)
+ } else {
+ c.injectV6Packet(payload, &h)
+ }
+}
+
+// injectV6Packet creates a V6 test packet with the given payload and header
+// values, and injects it into the link endpoint.
+func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -182,20 +399,20 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) {
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: testV6Addr,
- DstAddr: stackV6Addr,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -205,7 +422,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) {
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
-func (c *testContext) sendPacket(payload []byte, h *headers) {
+// injectV6Packet creates a V4 test packet with the given payload and header
+// values, and injects it into the link endpoint.
+func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -217,21 +436,21 @@ func (c *testContext) sendPacket(payload []byte, h *headers) {
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
- SrcAddr: testAddr,
- DstAddr: stackAddr,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
})
ip.SetChecksum(^ip.CalculateChecksum())
// Initialize the UDP header.
u := header.UDP(buf[header.IPv4MinimumSize:])
u.Encode(&header.UDPFields{
- SrcPort: h.srcPort,
- DstPort: h.dstPort,
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testAddr, stackAddr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -253,7 +472,7 @@ func TestBindPortReuse(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
var eps [5]tcpip.Endpoint
reusePortOpt := tcpip.ReusePortOption(1)
@@ -296,9 +515,9 @@ func TestBindPortReuse(t *testing.T) {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
- c.sendV6Packet(payload, &headers{
- srcPort: testPort + port,
- dstPort: stackPort,
+ c.injectV6Packet(payload, &header4Tuple{
+ srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
+ dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
})
var addr tcpip.FullAddress
@@ -333,13 +552,14 @@ func TestBindPortReuse(t *testing.T) {
}
}
-func testV4Read(c *testContext) {
- // Send a packet.
+// testRead sends a packet of the given test flow into the stack by injecting it
+// into the link endpoint. It then reads it from the UDP endpoint and verifies
+// its correctness.
+func testRead(c *testContext, flow testFlow) {
+ c.t.Helper()
+
payload := newPayload()
- c.sendPacket(payload, &headers{
- srcPort: testPort,
- dstPort: stackPort,
- })
+ c.injectPacket(flow, payload)
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
@@ -363,8 +583,9 @@ func testV4Read(c *testContext) {
}
// Check the peer address.
- if addr.Addr != testAddr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
+ h := flow.header4Tuple(incoming)
+ if addr.Addr != h.srcAddr.Addr {
+ c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, h.srcAddr)
}
// Check the payload.
@@ -377,7 +598,7 @@ func TestBindEphemeralPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
t.Fatalf("ep.Bind(...) failed: %v", err)
@@ -388,7 +609,7 @@ func TestBindReservedPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
@@ -447,7 +668,7 @@ func TestV4ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -455,29 +676,29 @@ func TestV4ReadOnV6(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to v4 mapped wildcard.
- if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}); err != nil {
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV4in6)
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
@@ -485,69 +706,29 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4in6)
}
func TestV6ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpointForFlow(unicastV6)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- // Send a packet.
- payload := newPayload()
- c.sendV6Packet(payload, &headers{
- srcPort: testPort,
- dstPort: stackPort,
- })
-
- // Try to receive the data.
- we, ch := waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&we, waiter.EventIn)
- defer c.wq.EventUnregister(&we)
-
- var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
- if err == tcpip.ErrWouldBlock {
- // Wait for data to become available.
- select {
- case <-ch:
- v, _, err = c.ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- c.t.Fatalf("Timed out waiting for data")
- }
- }
-
- // Check the peer address.
- if addr.Addr != testV6Addr {
- c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
- }
-
- // Check the payload.
- if !bytes.Equal(payload, v) {
- c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
- }
+ // Test acceptance.
+ testRead(c, unicastV6)
}
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- // Create v4 UDP endpoint.
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
+ c.createEndpointForFlow(unicastV4)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -555,62 +736,123 @@ func TestV4ReadOnV4(t *testing.T) {
}
// Test acceptance.
- testV4Read(c)
+ testRead(c, unicastV4)
}
-func testV4Write(c *testContext) uint16 {
- // Write to V4 mapped address.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast
+// address and receive data sent to that address.
+func TestReadOnBoundToMulticast(t *testing.T) {
+ // FIXME(b/128189410): multicastV4in6 currently doesn't work as
+ // AddMembershipOption doesn't handle V4in6 addresses.
+ for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to multicast address.
+ mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr())
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ // Join multicast group.
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatal("SetSockOpt failed:", err)
+ }
+
+ testRead(c, flow)
+ })
}
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+}
+
+// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast
+// address and receive broadcast data on it.
+func TestV4ReadOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to broadcast address.
+ bcastAddr := flow.mapAddrIfApplicable(broadcastAddr)
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Test acceptance.
+ testRead(c, flow)
+ })
}
+}
- // Check that we received the packet.
- b := c.getPacket(ipv4.ProtocolNumber, false)
- udp := header.UDP(header.IPv4(b).Payload())
- checker.IPv4(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
+// testFailingWrite sends a packet of the given test flow into the UDP endpoint
+// and verifies it fails with the provided error code.
+func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
+ c.t.Helper()
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+
+ payload := buffer.View(newPayload())
+ _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ })
+ if gotErr != wantErr {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr)
}
+}
- return udp.SourcePort()
+// testWrite sends a packet of the given test flow from the UDP endpoint to the
+// flow's destination address:port. It then receives it from the link endpoint
+// and verifies its correctness including any additional checker functions
+// provided.
+func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, true, checkers...)
}
-func testV6Write(c *testContext) uint16 {
- // Write to v6 address.
+// testWriteWithoutDestination sends a packet of the given test flow from the
+// UDP endpoint without giving a destination address:port. It then receives it
+// from the link endpoint and verifies its correctness including any additional
+// checker functions provided.
+func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+ return testWriteInternal(c, flow, false, checkers...)
+}
+
+func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 {
+ c.t.Helper()
+
+ writeOpts := tcpip.WriteOptions{}
+ if setDest {
+ h := flow.header4Tuple(outgoing)
+ writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
+ writeOpts = tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
+ }
+ }
payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
+ n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
- if n != uintptr(len(payload)) {
+ if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
- // Check that we received the packet.
- b := c.getPacket(ipv6.ProtocolNumber, false)
- udp := header.UDP(header.IPv6(b).Payload())
- checker.IPv6(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
-
- // Check the payload.
+ // Received the packet and check the payload.
+ b := c.getPacketAndVerify(flow, checkers...)
+ var udp header.UDP
+ if flow.isV4() {
+ udp = header.UDP(header.IPv4(b).Payload())
+ } else {
+ udp = header.UDP(header.IPv6(b).Payload())
+ }
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
@@ -619,8 +861,10 @@ func testV6Write(c *testContext) uint16 {
}
func testDualWrite(c *testContext) uint16 {
- v4Port := testV4Write(c)
- v6Port := testV6Write(c)
+ c.t.Helper()
+
+ v4Port := testWrite(c, unicastV4in6)
+ v6Port := testWrite(c, unicastV6)
if v4Port != v6Port {
c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
}
@@ -632,7 +876,7 @@ func TestDualWriteUnbound(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
testDualWrite(c)
}
@@ -641,7 +885,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
@@ -658,69 +902,51 @@ func TestDualWriteConnectedToV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV6Write(c)
+ testWrite(c, unicastV6)
// Write to V4 mapped address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != tcpip.ErrNetworkUnreachable {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable)
- }
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV4Write(c)
+ testWrite(c, unicastV4in6)
// Write to v6 address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
- if err != tcpip.ErrInvalidEndpointState {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
- }
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
}
func TestV4WriteOnV6Only(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(true)
+ c.createEndpointForFlow(unicastV6Only)
// Write to V4 mapped address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
- })
- if err != tcpip.ErrNoRoute {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
- }
+ testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute)
}
func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
@@ -728,84 +954,154 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
}
// Write to v6 address.
- payload := buffer.View(newPayload())
- _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
- To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
- })
- if err != tcpip.ErrInvalidEndpointState {
- c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
- }
+ testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
}
func TestV6WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
- // Write without destination.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
- }
-
- // Check that we received the packet.
- b := c.getPacket(ipv6.ProtocolNumber, false)
- udp := header.UDP(header.IPv6(b).Payload())
- checker.IPv6(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
-
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
- }
+ testWriteWithoutDestination(c, unicastV6)
}
func TestV4WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
- // Write without destination.
- payload := buffer.View(newPayload())
- n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ testWriteWithoutDestination(c, unicastV4)
+}
+
+// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket
+// that is bound to a V4 multicast address.
+func TestWriteOnBoundToV4Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
}
- if n != uintptr(len(payload)) {
- c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+}
+
+// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a
+// socket that is bound to a V4-mapped multicast address.
+func TestWriteOnBoundToV4MappedMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
}
+}
- // Check that we received the packet.
- b := c.getPacket(ipv4.ProtocolNumber, false)
- udp := header.UDP(header.IPv4(b).Payload())
- checker.IPv4(c.t, b,
- checker.UDP(
- checker.DstPort(testPort),
- ),
- )
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6Multicast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6, multicastV6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- // Check the payload.
- if !bytes.Equal(payload, udp.Payload()) {
- c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV6Multicast checks that we can send packets out of a
+// V6-only socket that is bound to a V6 multicast address.
+func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV6Only, multicastV6Only} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V6 mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToBroadcast checks that we can send packets out of a
+// socket that is bound to the broadcast address.
+func TestWriteOnBoundToBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4 broadcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+
+ testWrite(c, flow)
+ })
+ }
+}
+
+// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a
+// socket that is bound to the V4-mapped broadcast address.
+func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} {
+ t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Bind to V4Mapped mcast address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ testWrite(c, flow)
+ })
}
}
@@ -814,18 +1110,14 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
defer c.cleanup()
// Create IPv4 UDP endpoint
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
+ c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
- testV4Read(c)
+ testRead(c, unicastV4)
var want uint64 = 1
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
@@ -837,7 +1129,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
- c.createV6Endpoint(false)
+ c.createEndpoint(ipv6.ProtocolNumber)
testDualWrite(c)
@@ -847,244 +1139,102 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
-func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) {
- for _, name := range []string{"v4", "v6", "dual"} {
- t.Run(name, func(t *testing.T) {
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch name {
- case "v4":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "v6", "dual":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatal("unknown test variant")
- }
+func TestTTL(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- var variants []string
- switch name {
- case "v4":
- variants = []string{"v4"}
- case "v6":
- variants = []string{"v6"}
- case "dual":
- variants = []string{"v6", "mapped"}
- }
+ c.createEndpointForFlow(flow)
- for _, variant := range variants {
- t.Run(variant, func(t *testing.T) {
- optFunc(t, name, networkProtocolNumber, variant)
- })
+ const multicastTTL = 42
+ if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
}
- })
- }
-}
-func TestTTL(t *testing.T) {
- payload := tcpip.SlicePayload(buffer.View(newPayload()))
-
- setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
- for _, typ := range []string{"unicast", "multicast"} {
- t.Run(typ, func(t *testing.T) {
- var addr tcpip.Address
- var port uint16
- switch typ {
- case "unicast":
- port = testPort
- switch variant {
- case "v4":
- addr = testAddr
- case "mapped":
- addr = testV4MappedAddr
- case "v6":
- addr = testV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- port = multicastPort
- switch variant {
- case "v4":
- addr = multicastAddr
- case "mapped":
- addr = multicastV4MappedAddr
- case "v6":
- addr = multicastV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- default:
- t.Fatal("unknown test variant")
+ var wantTTL uint8
+ if flow.isMulticast() {
+ wantTTL = multicastTTL
+ } else {
+ var p stack.NetworkProtocol
+ if flow.isV4() {
+ p = ipv4.NewProtocol()
+ } else {
+ p = ipv6.NewProtocol()
}
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- switch name {
- case "v4":
- case "v6":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- case "dual":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- default:
- t.Fatal("unknown test variant")
- }
-
- const multicastTTL = 42
- if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ t.Fatal(err)
}
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ }
- n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != uintptr(len(payload)) {
- c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
- }
+ testWrite(c, flow, checker.TTL(wantTTL))
+ })
+ }
+}
- checkerFn := checker.IPv4
- switch variant {
- case "v4", "mapped":
- case "v6":
- checkerFn = checker.IPv6
- default:
- t.Fatal("unknown test variant")
- }
- var wantTTL uint8
- var multicast bool
- switch typ {
- case "unicast":
- multicast = false
- switch variant {
- case "v4", "mapped":
- ep, err := ipv4.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- case "v6":
- ep, err := ipv6.NewProtocol().NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- wantTTL = multicastTTL
- multicast = true
- default:
- t.Fatal("unknown test variant")
- }
+func TestMulticastInterfaceOption(t *testing.T) {
+ for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ for _, bindTyp := range []string{"bound", "unbound"} {
+ t.Run(bindTyp, func(t *testing.T) {
+ for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
+ t.Run(optTyp, func(t *testing.T) {
+ h := flow.header4Tuple(outgoing)
+ mcastAddr := h.dstAddr.Addr
+ localIfAddr := h.srcAddr.Addr
+
+ var ifoptSet tcpip.MulticastInterfaceOption
+ switch optTyp {
+ case "use local-addr":
+ ifoptSet.InterfaceAddr = localIfAddr
+ case "use NICID":
+ ifoptSet.NIC = 1
+ case "use local-addr and NIC":
+ ifoptSet.InterfaceAddr = localIfAddr
+ ifoptSet.NIC = 1
+ default:
+ t.Fatal("unknown test variant")
+ }
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch variant {
- case "v4", "mapped":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "v6":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatal("unknown test variant")
- }
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(flow.sockProto())
+
+ if bindTyp == "bound" {
+ // Bind the socket by connecting to the multicast address.
+ // This may have an influence on how the multicast interface
+ // is set.
+ addr := tcpip.FullAddress{
+ Addr: flow.mapAddrIfApplicable(mcastAddr),
+ Port: stackPort,
+ }
+ if err := c.ep.Connect(addr); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+ }
- b := c.getPacket(networkProtocolNumber, multicast)
- checkerFn(c.t, b,
- checker.TTL(wantTTL),
- checker.UDP(
- checker.DstPort(port),
- ),
- )
- })
- }
- })
-}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
-func TestMulticastInterfaceOption(t *testing.T) {
- setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
- for _, bindTyp := range []string{"bound", "unbound"} {
- t.Run(bindTyp, func(t *testing.T) {
- for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
- t.Run(optTyp, func(t *testing.T) {
- var mcastAddr, localIfAddr tcpip.Address
- switch variant {
- case "v4":
- mcastAddr = multicastAddr
- localIfAddr = stackAddr
- case "mapped":
- mcastAddr = multicastV4MappedAddr
- localIfAddr = stackAddr
- case "v6":
- mcastAddr = multicastV6Addr
- localIfAddr = stackV6Addr
- default:
- t.Fatal("unknown test variant")
- }
-
- var ifoptSet tcpip.MulticastInterfaceOption
- switch optTyp {
- case "use local-addr":
- ifoptSet.InterfaceAddr = localIfAddr
- case "use NICID":
- ifoptSet.NIC = 1
- case "use local-addr and NIC":
- ifoptSet.InterfaceAddr = localIfAddr
- ifoptSet.NIC = 1
- default:
- t.Fatal("unknown test variant")
- }
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if bindTyp == "bound" {
- // Bind the socket by connecting to the multicast address.
- // This may have an influence on how the multicast interface
- // is set.
- addr := tcpip.FullAddress{
- Addr: mcastAddr,
- Port: multicastPort,
+ // Verify multicast interface addr and NIC were set correctly.
+ // Note that NIC must be 1 since this is our outgoing interface.
+ ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
+ var ifoptGot tcpip.MulticastInterfaceOption
+ if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
+ c.t.Fatalf("GetSockOpt failed: %v", err)
}
- if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ if ifoptGot != ifoptWant {
+ c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
}
- }
-
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
-
- // Verify multicast interface addr and NIC were set correctly.
- // Note that NIC must be 1 since this is our outgoing interface.
- ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
- var ifoptGot tcpip.MulticastInterfaceOption
- if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
- }
- if ifoptGot != ifoptWant {
- c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
- }
- })
- }
- })
- }
- })
+ })
+ }
+ })
+ }
+ })
+ }
}
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index 769509e80..cbd92fc05 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -11,8 +11,8 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/unet",
visibility = ["//visibility:public"],
deps = [
- "//pkg/abi/linux",
"//pkg/gate",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/unet/unet_unsafe.go b/pkg/unet/unet_unsafe.go
index f8a42c914..85ef46edf 100644
--- a/pkg/unet/unet_unsafe.go
+++ b/pkg/unet/unet_unsafe.go
@@ -16,12 +16,11 @@ package unet
import (
"io"
- "math"
"sync/atomic"
"syscall"
"unsafe"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "golang.org/x/sys/unix"
)
// wait blocks until the socket FD is ready for reading or writing, depending
@@ -37,23 +36,23 @@ func (s *Socket) wait(write bool) error {
return errClosing
}
- events := []linux.PollFD{
+ events := []unix.PollFd{
{
// The actual socket FD.
- FD: fd,
- Events: linux.POLLIN,
+ Fd: fd,
+ Events: unix.POLLIN,
},
{
// The eventfd, signaled when we are closing.
- FD: int32(s.efd),
- Events: linux.POLLIN,
+ Fd: int32(s.efd),
+ Events: unix.POLLIN,
},
}
if write {
- events[0].Events = linux.POLLOUT
+ events[0].Events = unix.POLLOUT
}
- _, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(&events[0])), 2, uintptr(math.MaxUint64))
+ _, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(&events[0])), 2, 0, 0, 0, 0)
if e == syscall.EINTR {
continue
}
@@ -61,7 +60,7 @@ func (s *Socket) wait(write bool) error {
return e
}
- if events[1].REvents&linux.POLLIN == linux.POLLIN {
+ if events[1].Revents&unix.POLLIN == unix.POLLIN {
// eventfd signaled, we're closing.
return errClosing
}