summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD3
-rw-r--r--pkg/abi/linux/futex.go18
-rw-r--r--pkg/abi/linux/ioctl.go27
-rw-r--r--pkg/abi/linux/netfilter.go146
-rw-r--r--pkg/abi/linux/netlink_route.go2
-rw-r--r--pkg/abi/linux/socket.go17
-rw-r--r--pkg/iovec/BUILD18
-rw-r--r--pkg/iovec/iovec.go75
-rw-r--r--pkg/iovec/iovec_test.go121
-rw-r--r--pkg/p9/messages.go2
-rw-r--r--pkg/sentry/arch/arch_aarch64.go23
-rw-r--r--pkg/sentry/arch/arch_amd64.go4
-rw-r--r--pkg/sentry/arch/arch_arm64.go4
-rw-r--r--pkg/sentry/arch/arch_x86.go16
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go7
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go6
-rw-r--r--pkg/sentry/fsimpl/devpts/slave.go8
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD4
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD5
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go33
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go200
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go9
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go28
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go87
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go70
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go92
-rw-r--r--pkg/sentry/fsimpl/host/BUILD1
-rw-r--r--pkg/sentry/fsimpl/host/host.go13
-rw-r--r--pkg/sentry/fsimpl/host/socket_iovec.go7
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go14
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/non_directory.go4
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go2
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go6
-rw-r--r--pkg/sentry/fsimpl/proc/task.go4
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go2
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go4
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD2
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD2
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go35
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go7
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/futex/futex.go8
-rw-r--r--pkg/sentry/kernel/kernel.go9
-rw-r--r--pkg/sentry/kernel/syslog.go9
-rw-r--r--pkg/sentry/kernel/task.go19
-rw-r--r--pkg/sentry/kernel/task_exec.go3
-rw-r--r--pkg/sentry/kernel/task_exit.go3
-rw-r--r--pkg/sentry/kernel/task_futex.go125
-rw-r--r--pkg/sentry/kernel/task_run.go17
-rw-r--r--pkg/sentry/kernel/task_work.go38
-rw-r--r--pkg/sentry/kernel/time/BUILD1
-rw-r--r--pkg/sentry/kernel/time/tcpip.go131
-rw-r--r--pkg/sentry/kernel/timekeeper.go3
-rw-r--r--pkg/sentry/loader/BUILD4
-rw-r--r--pkg/sentry/loader/elf.go15
-rw-r--r--pkg/sentry/loader/loader.go21
-rw-r--r--pkg/sentry/loader/vdso.go61
-rw-r--r--pkg/sentry/platform/kvm/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go51
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go24
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go13
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.s6
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s13
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go6
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/BUILD2
-rw-r--r--pkg/sentry/socket/hostinet/socket.go7
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go1
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go49
-rw-r--r--pkg/sentry/socket/netlink/BUILD2
-rw-r--r--pkg/sentry/socket/netlink/socket.go14
-rw-r--r--pkg/sentry/socket/netstack/BUILD2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go339
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go16
-rw-r--r--pkg/sentry/socket/netstack/stack.go22
-rw-r--r--pkg/sentry/socket/socket.go3
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/unix.go3
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go3
-rw-r--r--pkg/sentry/syscalls/linux/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go13
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go9
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go321
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go2
-rw-r--r--pkg/sentry/time/parameters.go12
-rw-r--r--pkg/sentry/vfs/inotify.go6
-rw-r--r--pkg/sentry/vfs/options.go17
-rw-r--r--pkg/sentry/vfs/permissions.go8
-rw-r--r--pkg/sentry/vfs/vfs.go3
-rw-r--r--pkg/shim/runsc/BUILD16
-rw-r--r--pkg/shim/runsc/runsc.go514
-rw-r--r--pkg/shim/runsc/utils.go44
-rw-r--r--pkg/shim/v1/proc/BUILD36
-rw-r--r--pkg/shim/v1/proc/deleted_state.go49
-rw-r--r--pkg/shim/v1/proc/exec.go281
-rw-r--r--pkg/shim/v1/proc/exec_state.go154
-rw-r--r--pkg/shim/v1/proc/init.go460
-rw-r--r--pkg/shim/v1/proc/init_state.go182
-rw-r--r--pkg/shim/v1/proc/io.go162
-rw-r--r--pkg/shim/v1/proc/process.go37
-rw-r--r--pkg/shim/v1/proc/types.go69
-rw-r--r--pkg/shim/v1/proc/utils.go90
-rw-r--r--pkg/shim/v1/shim/BUILD40
-rw-r--r--pkg/shim/v1/shim/api.go28
-rw-r--r--pkg/shim/v1/shim/platform.go106
-rw-r--r--pkg/shim/v1/shim/service.go573
-rw-r--r--pkg/shim/v1/utils/BUILD27
-rw-r--r--pkg/shim/v1/utils/annotations.go25
-rw-r--r--pkg/shim/v1/utils/utils.go56
-rw-r--r--pkg/shim/v1/utils/volumes.go155
-rw-r--r--pkg/shim/v1/utils/volumes_test.go308
-rw-r--r--pkg/shim/v2/BUILD43
-rw-r--r--pkg/shim/v2/api.go22
-rw-r--r--pkg/shim/v2/epoll.go129
-rw-r--r--pkg/shim/v2/options/BUILD11
-rw-r--r--pkg/shim/v2/options/options.go33
-rw-r--r--pkg/shim/v2/runtimeoptions/BUILD20
-rw-r--r--pkg/shim/v2/runtimeoptions/runtimeoptions.go27
-rw-r--r--pkg/shim/v2/runtimeoptions/runtimeoptions.proto25
-rw-r--r--pkg/shim/v2/service.go824
-rw-r--r--pkg/shim/v2/service_linux.go108
-rw-r--r--pkg/sleep/BUILD1
-rw-r--r--pkg/sleep/sleep_test.go20
-rw-r--r--pkg/sleep/sleep_unsafe.go7
-rw-r--r--pkg/sync/BUILD1
-rw-r--r--pkg/sync/nocopy.go28
-rw-r--r--pkg/syserr/netstack.go2
-rw-r--r--pkg/tcpip/header/BUILD4
-rw-r--r--pkg/tcpip/header/arp.go77
-rw-r--r--pkg/tcpip/header/eth.go4
-rw-r--r--pkg/tcpip/header/icmpv4.go1
-rw-r--r--pkg/tcpip/header/icmpv6.go11
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go10
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go107
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go81
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go2
-rw-r--r--pkg/tcpip/link/loopback/loopback.go8
-rw-r--r--pkg/tcpip/link/muxed/BUILD1
-rw-r--r--pkg/tcpip/link/muxed/injectable.go10
-rw-r--r--pkg/tcpip/link/nested/BUILD1
-rw-r--r--pkg/tcpip/link/nested/nested.go21
-rw-r--r--pkg/tcpip/link/nested/nested_test.go4
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD1
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go18
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go33
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go26
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go4
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go5
-rw-r--r--pkg/tcpip/link/tun/device.go53
-rw-r--r--pkg/tcpip/link/waitable/BUILD2
-rw-r--r--pkg/tcpip/link/waitable/waitable.go20
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go15
-rw-r--r--pkg/tcpip/network/arp/arp.go6
-rw-r--r--pkg/tcpip/network/ip_test.go14
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go3
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go21
-rw-r--r--pkg/tcpip/network/ipv6/BUILD2
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go4
-rw-r--r--pkg/tcpip/stack/BUILD17
-rw-r--r--pkg/tcpip/stack/conntrack.go391
-rw-r--r--pkg/tcpip/stack/forwarder_test.go11
-rw-r--r--pkg/tcpip/stack/iptables.go245
-rw-r--r--pkg/tcpip/stack/iptables_state.go40
-rw-r--r--pkg/tcpip/stack/iptables_targets.go2
-rw-r--r--pkg/tcpip/stack/iptables_types.go33
-rw-r--r--pkg/tcpip/stack/ndp.go140
-rw-r--r--pkg/tcpip/stack/ndp_test.go28
-rw-r--r--pkg/tcpip/stack/nic.go37
-rw-r--r--pkg/tcpip/stack/nic_test.go10
-rw-r--r--pkg/tcpip/stack/packet_buffer.go18
-rw-r--r--pkg/tcpip/stack/registration.go28
-rw-r--r--pkg/tcpip/stack/stack.go19
-rw-r--r--pkg/tcpip/tcpip.go111
-rw-r--r--pkg/tcpip/time_unsafe.go30
-rw-r--r--pkg/tcpip/timer.go168
-rw-r--r--pkg/tcpip/timer_test.go91
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go18
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go197
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go19
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go46
-rw-r--r--pkg/tcpip/transport/tcp/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go6
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go3
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go41
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go57
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go27
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go125
-rw-r--r--pkg/test/criutil/criutil.go77
-rw-r--r--pkg/test/dockerutil/BUILD15
-rw-r--r--pkg/test/dockerutil/container.go501
-rw-r--r--pkg/test/dockerutil/dockerutil.go601
-rw-r--r--pkg/test/dockerutil/exec.go193
-rw-r--r--pkg/test/dockerutil/network.go113
-rw-r--r--pkg/test/testutil/BUILD2
-rw-r--r--pkg/test/testutil/testutil.go15
217 files changed, 9689 insertions, 1836 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 2b789c4ec..a4bb62013 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -72,6 +72,9 @@ go_library(
"//pkg/abi",
"//pkg/binary",
"//pkg/bits",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go
index 08bfde3b5..8138088a6 100644
--- a/pkg/abi/linux/futex.go
+++ b/pkg/abi/linux/futex.go
@@ -60,3 +60,21 @@ const (
FUTEX_WAITERS = 0x80000000
FUTEX_OWNER_DIED = 0x40000000
)
+
+// FUTEX_BITSET_MATCH_ANY has all bits set.
+const FUTEX_BITSET_MATCH_ANY = 0xffffffff
+
+// ROBUST_LIST_LIMIT protects against a deliberately circular list.
+const ROBUST_LIST_LIMIT = 2048
+
+// RobustListHead corresponds to Linux's struct robust_list_head.
+//
+// +marshal
+type RobustListHead struct {
+ List uint64
+ FutexOffset uint64
+ ListOpPending uint64
+}
+
+// SizeOfRobustListHead is the size of a RobustListHead struct.
+var SizeOfRobustListHead = (*RobustListHead)(nil).SizeBytes()
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 2062e6a4b..2c5e56ae5 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -67,10 +67,29 @@ const (
// ioctl(2) requests provided by uapi/linux/sockios.h
const (
- SIOCGIFMEM = 0x891f
- SIOCGIFPFLAGS = 0x8935
- SIOCGMIIPHY = 0x8947
- SIOCGMIIREG = 0x8948
+ SIOCGIFNAME = 0x8910
+ SIOCGIFCONF = 0x8912
+ SIOCGIFFLAGS = 0x8913
+ SIOCGIFADDR = 0x8915
+ SIOCGIFDSTADDR = 0x8917
+ SIOCGIFBRDADDR = 0x8919
+ SIOCGIFNETMASK = 0x891b
+ SIOCGIFMETRIC = 0x891d
+ SIOCGIFMTU = 0x8921
+ SIOCGIFMEM = 0x891f
+ SIOCGIFHWADDR = 0x8927
+ SIOCGIFINDEX = 0x8933
+ SIOCGIFPFLAGS = 0x8935
+ SIOCGIFTXQLEN = 0x8942
+ SIOCETHTOOL = 0x8946
+ SIOCGMIIPHY = 0x8947
+ SIOCGMIIREG = 0x8948
+ SIOCGIFMAP = 0x8970
+)
+
+// ioctl(2) requests provided by uapi/asm-generic/sockios.h
+const (
+ SIOCGSTAMP = 0x8906
)
// ioctl(2) directions. Used to calculate requests number.
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 46d8b0b42..a91f9f018 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -14,6 +14,14 @@
package linux
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
+)
+
// This file contains structures required to support netfilter, specifically
// the iptables tool.
@@ -76,6 +84,8 @@ const (
// IPTEntry is an iptable rule. It corresponds to struct ipt_entry in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTEntry struct {
// IP is used to filter packets based on the IP header.
IP IPTIP
@@ -112,21 +122,41 @@ type IPTEntry struct {
// SizeOfIPTEntry is the size of an IPTEntry.
const SizeOfIPTEntry = 112
-// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This
-// struct marshaled via the binary package to write an IPTEntry to userspace.
+// KernelIPTEntry is identical to IPTEntry, but includes the Elems field.
+// KernelIPTEntry itself is not Marshallable but it implements some methods of
+// marshal.Marshallable that help in other implementations of Marshallable.
type KernelIPTEntry struct {
- IPTEntry
+ Entry IPTEntry
// Elems holds the data for all this rule's matches followed by the
// target. It is variable length -- users have to iterate over any
// matches and use TargetOffset and NextOffset to make sense of the
// data.
- Elems []byte
+ Elems primitive.ByteSlice
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTEntry) SizeBytes() int {
+ return ke.Entry.SizeBytes() + ke.Elems.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTEntry) MarshalBytes(dst []byte) {
+ ke.Entry.MarshalBytes(dst)
+ ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) {
+ ke.Entry.UnmarshalBytes(src)
+ ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
// IPTIP contains information for matching a packet's IP header.
// It corresponds to struct ipt_ip in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTIP struct {
// Src is the source IP address.
Src InetAddr
@@ -189,6 +219,8 @@ const SizeOfIPTIP = 84
// XTCounters holds packet and byte counts for a rule. It corresponds to struct
// xt_counters in include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTCounters struct {
// Pcnt is the packet count.
Pcnt uint64
@@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetinfo struct {
Name TableName
ValidHooks uint32
@@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84
// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
// corresponds to struct ipt_get_entries in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetEntries struct {
Name TableName
Size uint32
@@ -350,13 +386,103 @@ type IPTGetEntries struct {
const SizeOfIPTGetEntries = 40
// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
-// Entrytable field. This struct marshaled via the binary package to write an
-// KernelIPTGetEntries to userspace.
+// Entrytable field. This has been manually made marshal.Marshallable since it
+// is dynamically sized.
type KernelIPTGetEntries struct {
IPTGetEntries
Entrytable []KernelIPTEntry
}
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTGetEntries) SizeBytes() int {
+ res := ke.IPTGetEntries.SizeBytes()
+ for _, entry := range ke.Entrytable {
+ res += entry.SizeBytes()
+ }
+ return res
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
+ ke.IPTGetEntries.MarshalBytes(dst)
+ marshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
+ marshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) {
+ ke.IPTGetEntries.UnmarshalBytes(src)
+ unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
+ unmarshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (ke *KernelIPTGetEntries) Packed() bool {
+ // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an
+ // indirection to the actual data we want to marshal (the slice data
+ // pointer), and the memory for KernelIPTGetEntries contains the slice
+ // header which we don't want to marshal.
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) {
+ // Fall back to safe Marshal because the type in not packed.
+ ke.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) {
+ // Fall back to safe Unmarshal because the type in not packed.
+ ke.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results in a
+ // partially unmarshalled struct.
+ ke.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task))
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+}
+
+func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte {
+ buf := task.CopyScratchBuffer(ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ return buf
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) {
+ buf := make([]byte, ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ length, err := w.Write(buf)
+ return int64(length), err
+}
+
+var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
+
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
@@ -374,12 +500,6 @@ type IPTReplace struct {
// Entries [0]IPTEntry
}
-// KernelIPTReplace is identical to IPTReplace, but includes the Entries field.
-type KernelIPTReplace struct {
- IPTReplace
- Entries [0]IPTEntry
-}
-
// SizeOfIPTReplace is the size of an IPTReplace.
const SizeOfIPTReplace = 96
@@ -392,6 +512,8 @@ func (en ExtensionName) String() string {
}
// TableName holds the name of a netfilter table.
+//
+// +marshal
type TableName [XT_TABLE_MAXNAMELEN]byte
// String implements fmt.Stringer.
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
index 40bec566c..ceda0a8d3 100644
--- a/pkg/abi/linux/netlink_route.go
+++ b/pkg/abi/linux/netlink_route.go
@@ -187,6 +187,8 @@ const (
// Device types, from uapi/linux/if_arp.h.
const (
+ ARPHRD_NONE = 65534
+ ARPHRD_ETHER = 1
ARPHRD_LOOPBACK = 772
)
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 4a14ef691..c24a8216e 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -134,6 +134,15 @@ const (
SHUT_RDWR = 2
)
+// Packet types from <linux/if_packet.h>
+const (
+ PACKET_HOST = 0 // To us
+ PACKET_BROADCAST = 1 // To all
+ PACKET_MULTICAST = 2 // To group
+ PACKET_OTHERHOST = 3 // To someone else
+ PACKET_OUTGOING = 4 // Outgoing of any type
+)
+
// Socket options from socket.h.
const (
SO_DEBUG = 1
@@ -225,6 +234,8 @@ const (
const SockAddrMax = 128
// InetAddr is struct in_addr, from uapi/linux/in.h.
+//
+// +marshal
type InetAddr [4]byte
// SockAddrInet is struct sockaddr_in, from uapi/linux/in.h.
@@ -294,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {}
func (s *SockAddrNetlink) implementsSockAddr() {}
// Linger is struct linger, from include/linux/socket.h.
+//
+// +marshal
type Linger struct {
OnOff int32
Linger int32
@@ -308,6 +321,8 @@ const SizeOfLinger = 8
// the end of this struct or within existing unusued space, so its size grows
// over time. The current iteration is based on linux v4.17. New versions are
// always backwards compatible.
+//
+// +marshal
type TCPInfo struct {
State uint8
CaState uint8
@@ -405,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
//
// ControlMessageCredentials represents struct ucred from linux/socket.h.
+//
+// +marshal
type ControlMessageCredentials struct {
PID int32
UID uint32
diff --git a/pkg/iovec/BUILD b/pkg/iovec/BUILD
new file mode 100644
index 000000000..eda82cfc1
--- /dev/null
+++ b/pkg/iovec/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "iovec",
+ srcs = ["iovec.go"],
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/abi/linux"],
+)
+
+go_test(
+ name = "iovec_test",
+ size = "small",
+ srcs = ["iovec_test.go"],
+ library = ":iovec",
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/pkg/iovec/iovec.go b/pkg/iovec/iovec.go
new file mode 100644
index 000000000..dd70fe80f
--- /dev/null
+++ b/pkg/iovec/iovec.go
@@ -0,0 +1,75 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package iovec provides helpers to interact with vectorized I/O on host
+// system.
+package iovec
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// MaxIovs is the maximum number of iovecs host platform can accept.
+var MaxIovs = linux.UIO_MAXIOV
+
+// Builder is a builder for slice of syscall.Iovec.
+type Builder struct {
+ iovec []syscall.Iovec
+ storage [8]syscall.Iovec
+
+ // overflow tracks the last buffer when iovec length is at MaxIovs.
+ overflow []byte
+}
+
+// Add adds buf to b preparing to be written. Zero-length buf won't be added.
+func (b *Builder) Add(buf []byte) {
+ if len(buf) == 0 {
+ return
+ }
+ if b.iovec == nil {
+ b.iovec = b.storage[:0]
+ }
+ if len(b.iovec) >= MaxIovs {
+ b.addByAppend(buf)
+ return
+ }
+ b.iovec = append(b.iovec, syscall.Iovec{
+ Base: &buf[0],
+ Len: uint64(len(buf)),
+ })
+ // Keep the last buf if iovec is at max capacity. We will need to append to it
+ // for later bufs.
+ if len(b.iovec) == MaxIovs {
+ n := len(buf)
+ b.overflow = buf[:n:n]
+ }
+}
+
+func (b *Builder) addByAppend(buf []byte) {
+ b.overflow = append(b.overflow, buf...)
+ b.iovec[len(b.iovec)-1] = syscall.Iovec{
+ Base: &b.overflow[0],
+ Len: uint64(len(b.overflow)),
+ }
+}
+
+// Build returns the final Iovec slice. The length of returned iovec will not
+// excceed MaxIovs.
+func (b *Builder) Build() []syscall.Iovec {
+ return b.iovec
+}
diff --git a/pkg/iovec/iovec_test.go b/pkg/iovec/iovec_test.go
new file mode 100644
index 000000000..a3900c299
--- /dev/null
+++ b/pkg/iovec/iovec_test.go
@@ -0,0 +1,121 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package iovec
+
+import (
+ "bytes"
+ "fmt"
+ "syscall"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func TestBuilderEmpty(t *testing.T) {
+ var builder Builder
+ iovecs := builder.Build()
+ if got, want := len(iovecs), 0; got != want {
+ t.Errorf("len(iovecs) = %d, want %d", got, want)
+ }
+}
+
+func TestBuilderBuild(t *testing.T) {
+ a := []byte{1, 2}
+ b := []byte{3, 4, 5}
+
+ var builder Builder
+ builder.Add(a)
+ builder.Add(b)
+ builder.Add(nil) // Nil slice won't be added.
+ builder.Add([]byte{}) // Empty slice won't be added.
+ iovecs := builder.Build()
+
+ if got, want := len(iovecs), 2; got != want {
+ t.Fatalf("len(iovecs) = %d, want %d", got, want)
+ }
+ for i, data := range [][]byte{a, b} {
+ if got, want := *iovecs[i].Base, data[0]; got != want {
+ t.Fatalf("*iovecs[%d].Base = %d, want %d", i, got, want)
+ }
+ if got, want := iovecs[i].Len, uint64(len(data)); got != want {
+ t.Fatalf("iovecs[%d].Len = %d, want %d", i, got, want)
+ }
+ }
+}
+
+func TestBuilderBuildMaxIov(t *testing.T) {
+ for _, test := range []struct {
+ numIov int
+ }{
+ {
+ numIov: MaxIovs - 1,
+ },
+ {
+ numIov: MaxIovs,
+ },
+ {
+ numIov: MaxIovs + 1,
+ },
+ {
+ numIov: MaxIovs + 10,
+ },
+ } {
+ name := fmt.Sprintf("numIov=%v", test.numIov)
+ t.Run(name, func(t *testing.T) {
+ var data []byte
+ var builder Builder
+ for i := 0; i < test.numIov; i++ {
+ buf := []byte{byte(i)}
+ builder.Add(buf)
+ data = append(data, buf...)
+ }
+ iovec := builder.Build()
+
+ // Check the expected length of iovec.
+ wantNum := test.numIov
+ if wantNum > MaxIovs {
+ wantNum = MaxIovs
+ }
+ if got, want := len(iovec), wantNum; got != want {
+ t.Errorf("len(iovec) = %d, want %d", got, want)
+ }
+
+ // Test a real read-write.
+ var fds [2]int
+ if err := unix.Pipe(fds[:]); err != nil {
+ t.Fatalf("Pipe: %v", err)
+ }
+ defer syscall.Close(fds[0])
+ defer syscall.Close(fds[1])
+
+ wrote, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fds[1]), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec)))
+ if int(wrote) != len(data) || e != 0 {
+ t.Fatalf("writev: %v, %v; want %v, 0", wrote, e, len(data))
+ }
+
+ got := make([]byte, len(data))
+ if n, err := syscall.Read(fds[0], got); n != len(got) || err != nil {
+ t.Fatalf("read: %v, %v; want %v, nil", n, err, len(got))
+ }
+
+ if !bytes.Equal(got, data) {
+ t.Errorf("read: got data %v, want %v", got, data)
+ }
+ })
+ }
+}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index 57b89ad7d..2cb59f934 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -2506,7 +2506,7 @@ type msgFactory struct {
var msgRegistry registry
type registry struct {
- factories [math.MaxUint8]msgFactory
+ factories [math.MaxUint8 + 1]msgFactory
// largestFixedSize is computed so that given some message size M, you can
// compute the maximum payload size (e.g. for Twrite, Rread) with
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index daba8b172..fd95eb2d2 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -28,7 +28,14 @@ import (
)
// Registers represents the CPU registers for this architecture.
-type Registers = linux.PtraceRegs
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+
+ // TPIDR_EL0 is the EL0 Read/Write Software Thread ID Register.
+ TPIDR_EL0 uint64
+}
const (
// SyscallWidth is the width of insturctions.
@@ -101,9 +108,6 @@ type State struct {
// Our floating point state.
aarch64FPState `state:"wait"`
- // TLS pointer
- TPValue uint64
-
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
@@ -157,7 +161,6 @@ func (s *State) Fork() State {
return State{
Regs: s.Regs,
aarch64FPState: s.aarch64FPState.fork(),
- TPValue: s.TPValue,
FeatureSet: s.FeatureSet,
OrigR0: s.OrigR0,
}
@@ -241,18 +244,18 @@ func (s *State) ptraceGetRegs() Registers {
return s.Regs
}
-var registersSize = (*Registers)(nil).SizeBytes()
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
var regs Registers
- buf := make([]byte, registersSize)
+ buf := make([]byte, ptraceRegistersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
regs.UnmarshalUnsafe(buf)
s.Regs = regs
- return registersSize, nil
+ return ptraceRegistersSize, nil
}
// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
@@ -278,7 +281,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -291,7 +294,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 3b3a0a272..1c3e3c14c 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -300,7 +300,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
// PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and
// u_debugreg, returning 0 or silently no-oping for other fields
// respectively.
- if addr < uintptr(registersSize) {
+ if addr < uintptr(ptraceRegistersSize) {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
@@ -315,7 +315,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
if addr&7 != 0 || addr >= userStructSize {
return syscall.EIO
}
- if addr < uintptr(registersSize) {
+ if addr < uintptr(ptraceRegistersSize) {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index ada7ac7b8..cabbf60e0 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -142,7 +142,7 @@ func (c *context64) SetStack(value uintptr) {
// TLS returns the current TLS pointer.
func (c *context64) TLS() uintptr {
- return uintptr(c.TPValue)
+ return uintptr(c.Regs.TPIDR_EL0)
}
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
@@ -151,7 +151,7 @@ func (c *context64) SetTLS(value uintptr) bool {
return false
}
- c.TPValue = uint64(value)
+ c.Regs.TPIDR_EL0 = uint64(value)
return true
}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index dc458b37f..b9405b320 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -31,7 +31,11 @@ import (
)
// Registers represents the CPU registers for this architecture.
-type Registers = linux.PtraceRegs
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+}
// System-related constants for x86.
const (
@@ -311,12 +315,12 @@ func (s *State) ptraceGetRegs() Registers {
return regs
}
-var registersSize = (*Registers)(nil).SizeBytes()
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
var regs Registers
- buf := make([]byte, registersSize)
+ buf := make([]byte, ptraceRegistersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
@@ -374,7 +378,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
}
regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable)
s.Regs = regs
- return registersSize, nil
+ return ptraceRegistersSize, nil
}
// isUserSegmentSelector returns true if the given segment selector specifies a
@@ -543,7 +547,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -563,7 +567,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index aabce6cc9..d41d23a43 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
+ "//pkg/iovec",
"//pkg/log",
"//pkg/refs",
"//pkg/safemem",
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
index 5c18dbd5e..905afb50d 100644
--- a/pkg/sentry/fs/host/socket_iovec.go
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -17,15 +17,12 @@ package host
import (
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/syserror"
)
// LINT.IfChange
-// maxIovs is the maximum number of iovecs to pass to the host.
-var maxIovs = linux.UIO_MAXIOV
-
// copyToMulti copies as many bytes from src to dst as possible.
func copyToMulti(dst [][]byte, src []byte) {
for _, d := range dst {
@@ -76,7 +73,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
}
}
- if iovsRequired > maxIovs {
+ if iovsRequired > iovec.MaxIovs {
// The kernel will reject our call if we pass this many iovs.
// Use a single intermediate buffer instead.
b := make([]byte, stopLen)
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index 69879498a..1081fff52 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -67,8 +67,8 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf
}
// Stat implements kernfs.Inode.Stat.
-func (mi *masterInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- statx, err := mi.InodeAttrs.Stat(vfsfs, opts)
+func (mi *masterInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := mi.InodeAttrs.Stat(ctx, vfsfs, opts)
if err != nil {
return linux.Statx{}, err
}
@@ -186,7 +186,7 @@ func (mfd *masterFileDescription) SetStat(ctx context.Context, opts vfs.SetStatO
// Stat implements vfs.FileDescriptionImpl.Stat.
func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem()
- return mfd.inode.Stat(fs, opts)
+ return mfd.inode.Stat(ctx, fs, opts)
}
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
index cf1a0f0ac..a91cae3ef 100644
--- a/pkg/sentry/fsimpl/devpts/slave.go
+++ b/pkg/sentry/fsimpl/devpts/slave.go
@@ -73,8 +73,8 @@ func (si *slaveInode) Valid(context.Context) bool {
}
// Stat implements kernfs.Inode.Stat.
-func (si *slaveInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- statx, err := si.InodeAttrs.Stat(vfsfs, opts)
+func (si *slaveInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := si.InodeAttrs.Stat(ctx, vfsfs, opts)
if err != nil {
return linux.Statx{}, err
}
@@ -132,7 +132,7 @@ func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequen
return sfd.inode.t.ld.outputQueueWrite(ctx, src)
}
-// Ioctl implements vfs.FileDescripionImpl.Ioctl.
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch cmd := args[1].Uint(); cmd {
case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
@@ -183,7 +183,7 @@ func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOp
// Stat implements vfs.FileDescriptionImpl.Stat.
func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
- return sfd.inode.Stat(fs, opts)
+ return sfd.inode.Stat(ctx, fs, opts)
}
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index ef24f8159..abc610ef3 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -96,7 +96,7 @@ go_test(
"//pkg/syserror",
"//pkg/test/testutil",
"//pkg/usermem",
- "@com_github_google_go-cmp//cmp:go_default_library",
- "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 41567967d..737007748 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -6,12 +6,17 @@ go_library(
name = "fuse",
srcs = [
"dev.go",
+ "fusefs.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/log",
"//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index f6a67d005..c9e12a94f 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -30,6 +31,10 @@ type fuseDevice struct{}
// Open implements vfs.Device.Open.
func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ if !kernel.FUSEEnabled {
+ return nil, syserror.ENOENT
+ }
+
var fd DeviceFD
if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
UseDentryMetadata: true,
@@ -46,6 +51,9 @@ type DeviceFD struct {
vfs.DentryMetadataFileDescriptionImpl
vfs.NoLockFD
+ // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD.
+ mounted bool
+
// TODO(gvisor.dev/issue/2987): Add all the data structures needed to enqueue
// and deque requests, control synchronization and establish communication
// between the FUSE kernel module and the /dev/fuse character device.
@@ -56,26 +64,51 @@ func (fd *DeviceFD) Release() {}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// Read implements vfs.FileDescriptionImpl.Read.
func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
new file mode 100644
index 000000000..f7775fb9b
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -0,0 +1,200 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fuse implements fusefs.
+package fuse
+
+import (
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "fuse"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+type filesystemOptions struct {
+ // userID specifies the numeric uid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ userID uint32
+
+ // groupID specifies the numeric gid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ groupID uint32
+
+ // rootMode specifies the the file mode of the filesystem's root.
+ rootMode linux.FileMode
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+ devMinor uint32
+
+ // fuseFD is the FD returned when opening /dev/fuse. It is used for communication
+ // between the FUSE server daemon and the sentry fusefs.
+ fuseFD *DeviceFD
+
+ // opts is the options the fusefs is initialized with.
+ opts filesystemOptions
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ var fsopts filesystemOptions
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ deviceDescriptorStr, ok := mopts["fd"]
+ if !ok {
+ log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "fd")
+
+ deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor))
+
+ // Parse and set all the other supported FUSE mount options.
+ // TODO: Expand the supported mount options.
+ if userIDStr, ok := mopts["user_id"]; ok {
+ delete(mopts, "user_id")
+ userID, err := strconv.ParseUint(userIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.userID = uint32(userID)
+ }
+
+ if groupIDStr, ok := mopts["group_id"]; ok {
+ delete(mopts, "group_id")
+ groupID, err := strconv.ParseUint(groupIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.groupID = uint32(groupID)
+ }
+
+ rootMode := linux.FileMode(0777)
+ modeStr, ok := mopts["rootmode"]
+ if ok {
+ delete(mopts, "rootmode")
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr)
+ return nil, nil, syserror.EINVAL
+ }
+ rootMode = linux.FileMode(mode)
+ }
+ fsopts.rootMode = rootMode
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to
+ // mount a FUSE filesystem.
+ fuseFD := fuseFd.Impl().(*DeviceFD)
+ fuseFD.mounted = true
+
+ fs := &filesystem{
+ devMinor: devMinor,
+ fuseFD: fuseFD,
+ opts: fsopts,
+ }
+
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+
+ // TODO: dispatch a FUSE_INIT request to the FUSE daemon server before
+ // returning. Mount will not block on this dispatched request.
+
+ // root is the fusefs root directory.
+ root := fs.newInode(creds, fsopts.rootMode)
+
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+// Inode implements kernfs.Inode.
+type Inode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
+ i := &Inode{}
+ i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ i.dentry.Init(i)
+
+ return &i.dentry
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *Inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 5d83fe363..8c7c8e1b3 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -85,6 +85,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
d2 := &dentry{
refs: 1, // held by d
fs: d.fs,
+ ino: d.fs.nextSyntheticIno(),
mode: uint32(opts.mode),
uid: uint32(opts.kuid),
gid: uint32(opts.kgid),
@@ -184,13 +185,13 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
{
Name: ".",
Type: linux.DT_DIR,
- Ino: d.ino,
+ Ino: uint64(d.ino),
NextOff: 1,
},
{
Name: "..",
Type: uint8(atomic.LoadUint32(&parent.mode) >> 12),
- Ino: parent.ino,
+ Ino: uint64(parent.ino),
NextOff: 2,
},
}
@@ -226,7 +227,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
}
dirent := vfs.Dirent{
Name: p9d.Name,
- Ino: p9d.QID.Path,
+ Ino: uint64(inoFromPath(p9d.QID.Path)),
NextOff: int64(len(dirents) + 1),
}
// p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
@@ -259,7 +260,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
dirents = append(dirents, vfs.Dirent{
Name: child.name,
Type: uint8(atomic.LoadUint32(&child.mode) >> 12),
- Ino: child.ino,
+ Ino: uint64(child.ino),
NextOff: int64(len(dirents) + 1),
})
}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 7bcc99b29..00e3c99cd 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -150,11 +150,9 @@ afterSymlink:
return nil, err
}
if d != d.parent && !d.cachedMetadataAuthoritative() {
- _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask())
- if err != nil {
+ if err := d.parent.updateFromGetattr(ctx); err != nil {
return nil, err
}
- d.parent.updateFromP9Attrs(attrMask, &attr)
}
rp.Advance()
return d.parent, nil
@@ -209,18 +207,28 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil
// Preconditions: As for getChildLocked. !parent.isSynthetic().
func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) {
+ if child != nil {
+ // Need to lock child.metadataMu because we might be updating child
+ // metadata. We need to hold the lock *before* getting metadata from the
+ // server and release it after updating local metadata.
+ child.metadataMu.Lock()
+ }
qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
if err != nil && err != syserror.ENOENT {
+ if child != nil {
+ child.metadataMu.Unlock()
+ }
return nil, err
}
if child != nil {
- if !file.isNil() && qid.Path == child.ino {
- // The file at this path hasn't changed. Just update cached
- // metadata.
+ if !file.isNil() && inoFromPath(qid.Path) == child.ino {
+ // The file at this path hasn't changed. Just update cached metadata.
file.close(ctx)
- child.updateFromP9Attrs(attrMask, &attr)
+ child.updateFromP9AttrsLocked(attrMask, &attr)
+ child.metadataMu.Unlock()
return child, nil
}
+ child.metadataMu.Unlock()
if file.isNil() && child.isSynthetic() {
// We have a synthetic file, and no remote file has arisen to
// replace it.
@@ -1326,7 +1334,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.renameMuRUnlockAndCheckCaching(&ds)
return err
}
- if err := d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount()); err != nil {
+ if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil {
fs.renameMuRUnlockAndCheckCaching(&ds)
return err
}
@@ -1499,3 +1507,7 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.renameMu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
+
+func (fs *filesystem) nextSyntheticIno() inodeNumber {
+ return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask)
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 8e74e60a5..e20de84b5 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -110,6 +110,26 @@ type filesystem struct {
syncMu sync.Mutex
syncableDentries map[*dentry]struct{}
specialFileFDs map[*specialFileFD]struct{}
+
+ // syntheticSeq stores a counter to used to generate unique inodeNumber for
+ // synthetic dentries.
+ syntheticSeq uint64
+}
+
+// inodeNumber represents inode number reported in Dirent.Ino. For regular
+// dentries, it comes from QID.Path from the 9P server. Synthetic dentries
+// have have their inodeNumber generated sequentially, with the MSB reserved to
+// prevent conflicts with regular dentries.
+type inodeNumber uint64
+
+// Reserve MSB for synthetic mounts.
+const syntheticInoMask = uint64(1) << 63
+
+func inoFromPath(path uint64) inodeNumber {
+ if path&syntheticInoMask != 0 {
+ log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask)
+ }
+ return inodeNumber(path &^ syntheticInoMask)
}
type filesystemOptions struct {
@@ -582,21 +602,27 @@ type dentry struct {
// returned by the server. dirents is protected by dirMu.
dirents []vfs.Dirent
- // Cached metadata; protected by metadataMu and accessed using atomic
- // memory operations unless otherwise specified.
+ // Cached metadata; protected by metadataMu.
+ // To access:
+ // - In situations where consistency is not required (like stat), these
+ // can be accessed using atomic operations only (without locking).
+ // - Lock metadataMu and can access without atomic operations.
+ // To mutate:
+ // - Lock metadataMu and use atomic operations to update because we might
+ // have atomic readers that don't hold the lock.
metadataMu sync.Mutex
- ino uint64 // immutable
- mode uint32 // type is immutable, perms are mutable
- uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
- gid uint32 // auth.KGID, but ...
- blockSize uint32 // 0 if unknown
+ ino inodeNumber // immutable
+ mode uint32 // type is immutable, perms are mutable
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ blockSize uint32 // 0 if unknown
// Timestamps, all nsecs from the Unix epoch.
atime int64
mtime int64
ctime int64
btime int64
// File size, protected by both metadataMu and dataMu (i.e. both must be
- // locked to mutate it).
+ // locked to mutate it; locking either is sufficient to access it).
size uint64
// nlink counts the number of hard links to this dentry. It's updated and
@@ -704,7 +730,7 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
d := &dentry{
fs: fs,
file: file,
- ino: qid.Path,
+ ino: inoFromPath(qid.Path),
mode: uint32(attr.Mode),
uid: uint32(fs.opts.dfltuid),
gid: uint32(fs.opts.dfltgid),
@@ -759,8 +785,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool {
// updateFromP9Attrs is called to update d's metadata after an update from the
// remote filesystem.
-func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
- d.metadataMu.Lock()
+// Precondition: d.metadataMu must be locked.
+func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
if mask.Mode {
if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want {
d.metadataMu.Unlock()
@@ -796,7 +822,6 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
if mask.Size {
d.updateFileSizeLocked(attr.Size)
}
- d.metadataMu.Unlock()
}
// Preconditions: !d.isSynthetic()
@@ -808,6 +833,10 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
file p9file
handleMuRLocked bool
)
+ // d.metadataMu must be locked *before* we getAttr so that we do not end up
+ // updating stale attributes in d.updateFromP9AttrsLocked().
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
d.handleMu.RLock()
if !d.handle.file.isNil() {
file = d.handle.file
@@ -823,7 +852,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
if err != nil {
return err
}
- d.updateFromP9Attrs(attrMask, &attr)
+ d.updateFromP9AttrsLocked(attrMask, &attr)
return nil
}
@@ -846,7 +875,7 @@ func (d *dentry) statTo(stat *linux.Statx) {
stat.UID = atomic.LoadUint32(&d.uid)
stat.GID = atomic.LoadUint32(&d.gid)
stat.Mode = uint16(atomic.LoadUint32(&d.mode))
- stat.Ino = d.ino
+ stat.Ino = uint64(d.ino)
stat.Size = atomic.LoadUint64(&d.size)
// This is consistent with regularFileFD.Seek(), which treats regular files
// as having no holes.
@@ -859,7 +888,8 @@ func (d *dentry) statTo(stat *linux.Statx) {
stat.DevMinor = d.fs.devMinor
}
-func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error {
+func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions, mnt *vfs.Mount) error {
+ stat := &opts.Stat
if stat.Mask == 0 {
return nil
}
@@ -867,7 +897,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
return syserror.EPERM
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
if err := mnt.CheckBeginWrite(); err != nil {
@@ -884,14 +914,14 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
// Prepare for truncate.
if stat.Mask&linux.STATX_SIZE != 0 {
- switch d.mode & linux.S_IFMT {
- case linux.S_IFREG:
+ switch mode.FileType() {
+ case linux.ModeRegular:
if !setLocalMtime {
// Truncate updates mtime.
setLocalMtime = true
stat.Mtime.Nsec = linux.UTIME_NOW
}
- case linux.S_IFDIR:
+ case linux.ModeDirectory:
return syserror.EISDIR
default:
return syserror.EINVAL
@@ -908,6 +938,17 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
}
if !d.isSynthetic() {
if stat.Mask != 0 {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // Check whether to allow a truncate request to be made.
+ switch d.mode & linux.S_IFMT {
+ case linux.S_IFREG:
+ // Allow.
+ case linux.S_IFDIR:
+ return syserror.EISDIR
+ default:
+ return syserror.EINVAL
+ }
+ }
if err := d.file.setAttr(ctx, p9.SetAttrMask{
Permissions: stat.Mask&linux.STATX_MODE != 0,
UID: stat.Mask&linux.STATX_UID != 0,
@@ -974,7 +1015,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
func (d *dentry) updateFileSizeLocked(newSize uint64) {
d.dataMu.Lock()
oldSize := d.size
- d.size = newSize
+ atomic.StoreUint64(&d.size, newSize)
// d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings
// below. This allows concurrent calls to Read/Translate/etc. These
// functions synchronize with truncation by refusing to use cache
@@ -1320,8 +1361,8 @@ func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name
// Extended attributes in the user.* namespace are only supported for regular
// files and directories.
func (d *dentry) userXattrSupported() bool {
- filetype := linux.S_IFMT & atomic.LoadUint32(&d.mode)
- return filetype == linux.S_IFREG || filetype == linux.S_IFDIR
+ filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType()
+ return filetype == linux.ModeRegular || filetype == linux.ModeDirectory
}
// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDir().
@@ -1469,7 +1510,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount()); err != nil {
+ if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts, fd.vfsfd.Mount()); err != nil {
return err
}
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
index 724a3f1f7..8792ca4f2 100644
--- a/pkg/sentry/fsimpl/gofer/handle.go
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -126,11 +126,16 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
}
func (h *handle) sync(ctx context.Context) error {
+ // Handle most common case first.
if h.fd >= 0 {
ctx.UninterruptibleSleepStart(false)
err := syscall.Fsync(int(h.fd))
ctx.UninterruptibleSleepFinish(false)
return err
}
+ if h.file.isNil() {
+ // File hasn't been touched, there is nothing to sync.
+ return nil
+ }
return h.file.fsync(ctx)
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 3d2d3530a..02317a133 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -89,7 +89,9 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint
if err != nil {
return err
}
- d.size = size
+ d.dataMu.Lock()
+ atomic.StoreUint64(&d.size, size)
+ d.dataMu.Unlock()
if !d.cachedMetadataAuthoritative() {
d.touchCMtimeLocked()
}
@@ -153,26 +155,53 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
// Check that flags are supported.
//
// TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
if opts.Flags&^linux.RWF_HIPRI != 0 {
- return 0, syserror.EOPNOTSUPP
+ return 0, offset, syserror.EOPNOTSUPP
}
+ d := fd.dentry()
+ // If the fd was opened with O_APPEND, make sure the file size is updated.
+ // There is a possible race here if size is modified externally after
+ // metadata cache is updated.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
+ }
+
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ // Set offset to file size if the fd was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
+ }
limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(limit)
+ n, err := fd.pwriteLocked(ctx, src, offset, opts)
+ return n, offset + n, err
+}
+// Preconditions: fd.dentry().metatdataMu must be locked.
+func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
d := fd.dentry()
- d.metadataMu.Lock()
- defer d.metadataMu.Unlock()
if d.fs.opts.interop != InteropModeShared {
// Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
// file_update_time(). This is d.touchCMtime(), but without locking
@@ -235,8 +264,8 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
fd.mu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.mu.Unlock()
return n, err
}
@@ -582,20 +611,19 @@ func (fd *regularFileFD) Sync(ctx context.Context) error {
func (d *dentry) syncSharedHandle(ctx context.Context) error {
d.handleMu.RLock()
- if !d.handleWritable {
- d.handleMu.RUnlock()
- return nil
- }
- d.dataMu.Lock()
- // Write dirty cached data to the remote file.
- err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt)
- d.dataMu.Unlock()
- if err == nil {
- // Sync the remote file.
- err = d.handle.sync(ctx)
+ defer d.handleMu.RUnlock()
+
+ if d.handleWritable {
+ d.dataMu.Lock()
+ // Write dirty cached data to the remote file.
+ err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt)
+ d.dataMu.Unlock()
+ if err != nil {
+ return err
+ }
}
- d.handleMu.RUnlock()
- return err
+ // Sync the remote file.
+ return d.handle.sync(ctx)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 3c4e7e2e4..811528982 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -16,6 +16,7 @@ package gofer
import (
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -28,9 +29,9 @@ import (
)
// specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device
-// special files, and (when filesystemOptions.specialRegularFiles is in effect)
-// regular files. specialFileFD differs from regularFileFD by using per-FD
-// handles instead of shared per-dentry handles, and never buffering I/O.
+// special files, and (when filesystemOptions.regularFilesUseSpecialFileFD is
+// in effect) regular files. specialFileFD differs from regularFileFD by using
+// per-FD handles instead of shared per-dentry handles, and never buffering I/O.
type specialFileFD struct {
fileDescription
@@ -41,10 +42,10 @@ type specialFileFD struct {
// file offset is significant, i.e. a regular file. seekable is immutable.
seekable bool
- // mayBlock is true if this file description represents a file for which
- // queue may send I/O readiness events. mayBlock is immutable.
- mayBlock bool
- queue waiter.Queue
+ // haveQueue is true if this file description represents a file for which
+ // queue may send I/O readiness events. haveQueue is immutable.
+ haveQueue bool
+ queue waiter.Queue
// If seekable is true, off is the file offset. off is protected by mu.
mu sync.Mutex
@@ -54,14 +55,14 @@ type specialFileFD struct {
func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) {
ftype := d.fileType()
seekable := ftype == linux.S_IFREG
- mayBlock := ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK
+ haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0
fd := &specialFileFD{
- handle: h,
- seekable: seekable,
- mayBlock: mayBlock,
+ handle: h,
+ seekable: seekable,
+ haveQueue: haveQueue,
}
fd.LockFD.Init(locks)
- if mayBlock && h.fd >= 0 {
+ if haveQueue {
if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil {
return nil, err
}
@@ -70,7 +71,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
DenyPRead: !seekable,
DenyPWrite: !seekable,
}); err != nil {
- if mayBlock && h.fd >= 0 {
+ if haveQueue {
fdnotifier.RemoveFD(h.fd)
}
return nil, err
@@ -80,7 +81,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
// Release implements vfs.FileDescriptionImpl.Release.
func (fd *specialFileFD) Release() {
- if fd.mayBlock && fd.handle.fd >= 0 {
+ if fd.haveQueue {
fdnotifier.RemoveFD(fd.handle.fd)
}
fd.handle.close(context.Background())
@@ -100,7 +101,7 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error {
// Readiness implements waiter.Waitable.Readiness.
func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask {
- if fd.mayBlock {
+ if fd.haveQueue {
return fdnotifier.NonBlockingPoll(fd.handle.fd, mask)
}
return fd.fileDescription.Readiness(mask)
@@ -108,8 +109,9 @@ func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask {
// EventRegister implements waiter.Waitable.EventRegister.
func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- if fd.mayBlock {
+ if fd.haveQueue {
fd.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(fd.handle.fd)
return
}
fd.fileDescription.EventRegister(e, mask)
@@ -117,8 +119,9 @@ func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
// EventUnregister implements waiter.Waitable.EventUnregister.
func (fd *specialFileFD) EventUnregister(e *waiter.Entry) {
- if fd.mayBlock {
+ if fd.haveQueue {
fd.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(fd.handle.fd)
return
}
fd.fileDescription.EventUnregister(e)
@@ -142,7 +145,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
// mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
// hold here since specialFileFD doesn't client-cache data. Just buffer the
// read instead.
- if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ if d := fd.dentry(); d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
buf := make([]byte, dst.NumBytes())
@@ -174,39 +177,76 @@ func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if fd.seekable && offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
// Check that flags are supported.
//
// TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
if opts.Flags&^linux.RWF_HIPRI != 0 {
- return 0, syserror.EOPNOTSUPP
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ // If the regular file fd was opened with O_APPEND, make sure the file size
+ // is updated. There is a possible race here if size is modified externally
+ // after metadata cache is updated.
+ if fd.seekable && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
}
if fd.seekable {
+ // We need to hold the metadataMu *while* writing to a regular file.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ // Set offset to file size if the regular file was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
+ }
limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(limit)
}
// Do a buffered write. See rationale in PRead.
- if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ if d.cachedMetadataAuthoritative() {
d.touchCMtime()
}
buf := make([]byte, src.NumBytes())
// Don't do partial writes if we get a partial read from src.
if _, err := src.CopyIn(ctx, buf); err != nil {
- return 0, err
+ return 0, offset, err
}
n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
if err == syserror.EAGAIN {
err = syserror.ErrWouldBlock
}
- return int64(n), err
+ finalOff = offset
+ // Update file size for regular files.
+ if fd.seekable {
+ finalOff += int64(n)
+ // d.metadataMu is already locked at this point.
+ if uint64(finalOff) > d.size {
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ atomic.StoreUint64(&d.size, uint64(finalOff))
+ }
+ }
+ return int64(n), finalOff, err
}
// Write implements vfs.FileDescriptionImpl.Write.
@@ -216,8 +256,8 @@ func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts
}
fd.mu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.mu.Unlock()
return n, err
}
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index 44a09d87a..e86fbe2d5 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -22,6 +22,7 @@ go_library(
"//pkg/context",
"//pkg/fdnotifier",
"//pkg/fspath",
+ "//pkg/iovec",
"//pkg/log",
"//pkg/refs",
"//pkg/safemem",
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 1cd2982cb..c894f2ca0 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -259,7 +259,7 @@ func (i *inode) Mode() linux.FileMode {
}
// Stat implements kernfs.Inode.
-func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
if opts.Mask&linux.STATX__RESERVED != 0 {
return linux.Statx{}, syserror.EINVAL
}
@@ -373,7 +373,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) {
// SetStat implements kernfs.Inode.
func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
- s := opts.Stat
+ s := &opts.Stat
m := s.Mask
if m == 0 {
@@ -386,7 +386,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
if err := syscall.Fstat(i.hostFD, &hostStat); err != nil {
return err
}
- if err := vfs.CheckSetStat(ctx, creds, &s, linux.FileMode(hostStat.Mode&linux.PermissionsMask), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &opts, linux.FileMode(hostStat.Mode), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
return err
}
@@ -396,6 +396,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
}
if m&linux.STATX_SIZE != 0 {
+ if hostStat.Mode&linux.S_IFMT != linux.S_IFREG {
+ return syserror.EINVAL
+ }
if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
return err
}
@@ -534,8 +537,8 @@ func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
}
// Stat implements vfs.FileDescriptionImpl.
-func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) {
- return f.inode.Stat(f.vfsfd.Mount().Filesystem(), opts)
+func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts)
}
// Release implements vfs.FileDescriptionImpl.
diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go
index 584c247d2..fc0d5fd38 100644
--- a/pkg/sentry/fsimpl/host/socket_iovec.go
+++ b/pkg/sentry/fsimpl/host/socket_iovec.go
@@ -17,13 +17,10 @@ package host
import (
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/syserror"
)
-// maxIovs is the maximum number of iovecs to pass to the host.
-var maxIovs = linux.UIO_MAXIOV
-
// copyToMulti copies as many bytes from src to dst as possible.
func copyToMulti(dst [][]byte, src []byte) {
for _, d := range dst {
@@ -74,7 +71,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
}
}
- if iovsRequired > maxIovs {
+ if iovsRequired > iovec.MaxIovs {
// The kernel will reject our call if we pass this many iovs.
// Use a single intermediate buffer instead.
b := make([]byte, stopLen)
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 179df6c1e..3835557fe 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -70,6 +70,6 @@ go_test(
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/usermem",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 6886b0876..c6c4472e7 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -127,7 +127,7 @@ func (fd *DynamicBytesFD) Release() {}
// Stat implements vfs.FileDescriptionImpl.Stat.
func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
- return fd.inode.Stat(fs, opts)
+ return fd.inode.Stat(ctx, fs, opts)
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index ca8b8c63b..1d37ccb98 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -112,7 +112,7 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence
return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts)
}
-// Release implements vfs.FileDecriptionImpl.Release.
+// Release implements vfs.FileDescriptionImpl.Release.
func (fd *GenericDirectoryFD) Release() {}
func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
@@ -123,7 +123,7 @@ func (fd *GenericDirectoryFD) inode() Inode {
return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
}
-// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds
// o.mu when calling cb.
func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
fd.mu.Lock()
@@ -132,7 +132,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
opts := vfs.StatOptions{Mask: linux.STATX_INO}
// Handle ".".
if fd.off == 0 {
- stat, err := fd.inode().Stat(fd.filesystem(), opts)
+ stat, err := fd.inode().Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
}
@@ -152,7 +152,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
if fd.off == 1 {
vfsd := fd.vfsfd.VirtualDentry().Dentry()
parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode
- stat, err := parentInode.Stat(fd.filesystem(), opts)
+ stat, err := parentInode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
}
@@ -176,7 +176,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
childIdx := fd.off - 2
for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
inode := it.Dentry.Impl().(*Dentry).inode
- stat, err := inode.Stat(fd.filesystem(), opts)
+ stat, err := inode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
}
@@ -198,7 +198,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
return err
}
-// Seek implements vfs.FileDecriptionImpl.Seek.
+// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
fd.mu.Lock()
defer fd.mu.Unlock()
@@ -226,7 +226,7 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int
func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
fs := fd.filesystem()
inode := fd.inode()
- return inode.Stat(fs, opts)
+ return inode.Stat(ctx, fs, opts)
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 8939871c1..61a36cff9 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -684,7 +684,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
- return inode.Stat(fs.VFSFilesystem(), opts)
+ return inode.Stat(ctx, fs.VFSFilesystem(), opts)
}
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 4cb885d87..579e627f0 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -243,7 +243,7 @@ func (a *InodeAttrs) Mode() linux.FileMode {
// Stat partially implements Inode.Stat. Note that this function doesn't provide
// all the stat fields, and the embedder should consider extending the result
// with filesystem-specific fields.
-func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
+func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
stat.DevMajor = a.devMajor
@@ -267,7 +267,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 596de1edf..46f207664 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -346,7 +346,7 @@ type inodeMetadata interface {
// Stat returns the metadata for this inode. This corresponds to
// vfs.FilesystemImpl.StatAt.
- Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error)
+ Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error)
// SetStat updates the metadata for this inode. This corresponds to
// vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index ff82e1f20..6b705e955 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -1104,7 +1104,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
mnt := rp.Mount()
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go
index a3c1f7a8d..c0749e711 100644
--- a/pkg/sentry/fsimpl/overlay/non_directory.go
+++ b/pkg/sentry/fsimpl/overlay/non_directory.go
@@ -151,7 +151,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux
func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
d := fd.dentry()
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
mnt := fd.vfsfd.Mount()
@@ -176,7 +176,7 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return nil
}
-// StatFS implements vfs.FileDesciptionImpl.StatFS.
+// StatFS implements vfs.FileDescriptionImpl.StatFS.
func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) {
return fd.filesystem().statFS(ctx)
}
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index dd7eaf4a8..811f80a5f 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -115,7 +115,7 @@ func (i *inode) Mode() linux.FileMode {
}
// Stat implements kernfs.Inode.Stat.
-func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+func (i *inode) Stat(_ context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds())
return linux.Statx{
Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index 36a89540c..79c2725f3 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -128,7 +128,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac
return fd.GenericDirectoryFD.IterDirents(ctx, cb)
}
-// Seek implements vfs.FileDecriptionImpl.Seek.
+// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
if fd.task.ExitState() >= kernel.TaskExitZombie {
return 0, syserror.ENOENT
@@ -165,8 +165,8 @@ func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *v
}
// Stat implements kernfs.Inode.
-func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- stat, err := i.InodeAttrs.Stat(vsfs, opts)
+func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts)
if err != nil {
return linux.Statx{}, err
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 8bb2b0ce1..a5c7aa470 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -156,8 +156,8 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.
}
// Stat implements kernfs.Inode.
-func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- stat, err := i.Inode.Stat(fs, opts)
+func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.Inode.Stat(ctx, fs, opts)
if err != nil {
return linux.Statx{}, err
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 9af43b859..859b7d727 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -876,7 +876,7 @@ var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil)
// Stat implements FileDescriptionImpl.
func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
- return fd.inode.Stat(vfs, opts)
+ return fd.inode.Stat(ctx, vfs, opts)
}
// SetStat implements FileDescriptionImpl.
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 2f214d0c2..6d2b90a8b 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -206,8 +206,8 @@ func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.
return fd.VFSFileDescription(), nil
}
-func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
- stat, err := i.InodeAttrs.Stat(vsfs, opts)
+func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts)
if err != nil {
return linux.Statx{}, err
}
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index a741e2bb6..1b548ccd4 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -29,6 +29,6 @@ go_test(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index 0e4053a46..400a97996 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -32,6 +32,6 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/sync",
"//pkg/usermem",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index c16a36cdb..e743e8114 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -62,6 +62,7 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("creating platform: %v", err)
}
+ kernel.VFS2Enabled = true
k := &kernel.Kernel{
Platform: plat,
}
@@ -73,7 +74,7 @@ func Boot() (*kernel.Kernel, error) {
k.SetMemoryFile(mf)
// Pass k as the platform since it is savable, unlike the actual platform.
- vdso, err := loader.PrepareVDSO(nil, k)
+ vdso, err := loader.PrepareVDSO(k)
if err != nil {
return nil, fmt.Errorf("creating vdso: %v", err)
}
@@ -103,11 +104,6 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("initializing kernel: %v", err)
}
- kernel.VFS2Enabled = true
-
- if err := k.VFS().Init(); err != nil {
- return nil, fmt.Errorf("VFS init: %v", err)
- }
k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
AllowUserList: true,
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index ed40f6b52..ef210a69b 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -277,7 +277,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
creds := rp.Credentials()
var childInode *inode
switch opts.Mode.FileType() {
- case 0, linux.S_IFREG:
+ case linux.S_IFREG:
childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
case linux.S_IFIFO:
childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
@@ -649,7 +649,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.mu.RUnlock()
return err
}
- if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil {
+ if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil {
fs.mu.RUnlock()
return err
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 1cdb46e6f..abbaa5d60 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -325,8 +325,15 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset and error. The
+// final offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
// Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
@@ -334,40 +341,44 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
//
// TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
- return 0, syserror.EOPNOTSUPP
+ return 0, offset, syserror.EOPNOTSUPP
}
srclen := src.NumBytes()
if srclen == 0 {
- return 0, nil
+ return 0, offset, nil
}
f := fd.inode().impl.(*regularFile)
+ f.inode.mu.Lock()
+ defer f.inode.mu.Unlock()
+ // If the file is opened with O_APPEND, update offset to file size.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Locking f.inode.mu is sufficient for reading f.size.
+ offset = int64(f.size)
+ }
if end := offset + srclen; end < offset {
// Overflow.
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
- var err error
srclen, err = vfs.CheckLimit(ctx, offset, srclen)
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(srclen)
- f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
- fd.inode().touchCMtimeLocked()
- f.inode.mu.Unlock()
+ f.inode.touchCMtimeLocked()
putRegularFileReadWriter(rw)
- return n, err
+ return n, n + offset, err
}
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
fd.offMu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.offMu.Unlock()
return n, err
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index d7f4f0779..2545d88e9 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -452,7 +452,8 @@ func (i *inode) statTo(stat *linux.Statx) {
}
}
-func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error {
+func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions) error {
+ stat := &opts.Stat
if stat.Mask == 0 {
return nil
}
@@ -460,7 +461,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu
return syserror.EPERM
}
mode := linux.FileMode(atomic.LoadUint32(&i.mode))
- if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
return err
}
i.mu.Lock()
@@ -695,7 +696,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
creds := auth.CredentialsFromContext(ctx)
d := fd.dentry()
- if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil {
+ if err := d.inode.setStat(ctx, creds, &opts); err != nil {
return err
}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 25fe1921b..f6886a758 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -132,6 +132,7 @@ go_library(
"task_stop.go",
"task_syscall.go",
"task_usermem.go",
+ "task_work.go",
"thread_group.go",
"threads.go",
"timekeeper.go",
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index 732e66da4..bcc1b29a8 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -717,10 +717,10 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3
}
}
-// UnlockPI unlock the futex following the Priority-inheritance futex
-// rules. The address provided must contain the caller's TID. If there are
-// waiters, TID of the next waiter (FIFO) is set to the given address, and the
-// waiter woken up. If there are no waiters, 0 is set to the address.
+// UnlockPI unlocks the futex following the Priority-inheritance futex rules.
+// The address provided must contain the caller's TID. If there are waiters,
+// TID of the next waiter (FIFO) is set to the given address, and the waiter
+// woken up. If there are no waiters, 0 is set to the address.
func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error {
k, err := getKey(t, addr, private)
if err != nil {
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 2177b785a..15dae0f5b 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -81,6 +81,10 @@ import (
// easy access everywhere. To be removed once VFS2 becomes the default.
var VFS2Enabled = false
+// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow
+// easy access everywhere. To be removed once FUSE is completed.
+var FUSEEnabled = false
+
// Kernel represents an emulated Linux kernel. It must be initialized by calling
// Init() or LoadFrom().
//
@@ -1465,6 +1469,11 @@ func (k *Kernel) NowMonotonic() int64 {
return now
}
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ return ktime.TcpipAfterFunc(k.realtimeClock, d, f)
+}
+
// SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or
// LoadFrom.
func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) {
diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go
index 4607cde2f..a83ce219c 100644
--- a/pkg/sentry/kernel/syslog.go
+++ b/pkg/sentry/kernel/syslog.go
@@ -98,6 +98,15 @@ func (s *syslog) Log() []byte {
s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, selectMessage()))...)
}
+ if VFS2Enabled {
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up VFS2..."))...)
+ if FUSEEnabled {
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up FUSE..."))...)
+ }
+ }
+
time += rand.Float64() / 2
s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Ready!"))...)
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index f48247c94..c4db05bd8 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -68,6 +68,21 @@ type Task struct {
// runState is exclusive to the task goroutine.
runState taskRunState
+ // taskWorkCount represents the current size of the task work queue. It is
+ // used to avoid acquiring taskWorkMu when the queue is empty.
+ //
+ // Must accessed with atomic memory operations.
+ taskWorkCount int32
+
+ // taskWorkMu protects taskWork.
+ taskWorkMu sync.Mutex `state:"nosave"`
+
+ // taskWork is a queue of work to be executed before resuming user execution.
+ // It is similar to the task_work mechanism in Linux.
+ //
+ // taskWork is exclusive to the task goroutine.
+ taskWork []TaskWorker
+
// haveSyscallReturn is true if tc.Arch().Return() represents a value
// returned by a syscall (or set by ptrace after a syscall).
//
@@ -550,6 +565,10 @@ type Task struct {
// futexWaiter is exclusive to the task goroutine.
futexWaiter *futex.Waiter `state:"nosave"`
+ // robustList is a pointer to the head of the tasks's robust futex
+ // list.
+ robustList usermem.Addr
+
// startTime is the real time at which the task started. It is set when
// a Task is created or invokes execve(2).
//
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 9b69f3cbe..7803b98d0 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -207,6 +207,9 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
return flags.CloseOnExec
})
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// NOTE(b/30815691): We currently do not implement privileged
// executables (set-user/group-ID bits and file capabilities). This
// allows us to unconditionally enable user dumpability on the new mm.
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index c4ade6e8e..231ac548a 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -253,6 +253,9 @@ func (*runExitMain) execute(t *Task) taskRunState {
}
}
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// Deactivate the address space and update max RSS before releasing the
// task's MM.
t.Deactivate()
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
index a53e77c9f..4b535c949 100644
--- a/pkg/sentry/kernel/task_futex.go
+++ b/pkg/sentry/kernel/task_futex.go
@@ -15,6 +15,7 @@
package kernel
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -52,3 +53,127 @@ func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) {
func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) {
return t.MemoryManager().GetSharedFutexKey(t, addr)
}
+
+// GetRobustList sets the robust futex list for the task.
+func (t *Task) GetRobustList() usermem.Addr {
+ t.mu.Lock()
+ addr := t.robustList
+ t.mu.Unlock()
+ return addr
+}
+
+// SetRobustList sets the robust futex list for the task.
+func (t *Task) SetRobustList(addr usermem.Addr) {
+ t.mu.Lock()
+ t.robustList = addr
+ t.mu.Unlock()
+}
+
+// exitRobustList walks the robust futex list, marking locks dead and notifying
+// wakers. It corresponds to Linux's exit_robust_list(). Following Linux,
+// errors are silently ignored.
+func (t *Task) exitRobustList() {
+ t.mu.Lock()
+ addr := t.robustList
+ t.robustList = 0
+ t.mu.Unlock()
+
+ if addr == 0 {
+ return
+ }
+
+ var rl linux.RobustListHead
+ if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil {
+ return
+ }
+
+ next := rl.List
+ done := 0
+ var pendingLockAddr usermem.Addr
+ if rl.ListOpPending != 0 {
+ pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset)
+ }
+
+ // Wake up normal elements.
+ for usermem.Addr(next) != addr {
+ // We traverse to the next element of the list before we
+ // actually wake anything. This prevents the race where waking
+ // this futex causes a modification of the list.
+ thisLockAddr := usermem.Addr(next + rl.FutexOffset)
+
+ // Try to decode the next element in the list before waking the
+ // current futex. But don't check the error until after we've
+ // woken the current futex. Linux does it in this order too
+ _, nextErr := t.CopyIn(usermem.Addr(next), &next)
+
+ // Wakeup the current futex if it's not pending.
+ if thisLockAddr != pendingLockAddr {
+ t.wakeRobustListOne(thisLockAddr)
+ }
+
+ // If there was an error copying the next futex, we must bail.
+ if nextErr != nil {
+ break
+ }
+
+ // This is a user structure, so it could be a massive list, or
+ // even contain a loop if they are trying to mess with us. We
+ // cap traversal to prevent that.
+ done++
+ if done >= linux.ROBUST_LIST_LIMIT {
+ break
+ }
+ }
+
+ // Is there a pending entry to wake?
+ if pendingLockAddr != 0 {
+ t.wakeRobustListOne(pendingLockAddr)
+ }
+}
+
+// wakeRobustListOne wakes a single futex from the robust list.
+func (t *Task) wakeRobustListOne(addr usermem.Addr) {
+ // Bit 0 in address signals PI futex.
+ pi := addr&1 == 1
+ addr = addr &^ 1
+
+ // Load the futex.
+ f, err := t.LoadUint32(addr)
+ if err != nil {
+ // Can't read this single value? Ignore the problem.
+ // We can wake the other futexes in the list.
+ return
+ }
+
+ tid := uint32(t.ThreadID())
+ for {
+ // Is this held by someone else?
+ if f&linux.FUTEX_TID_MASK != tid {
+ return
+ }
+
+ // This thread is dying and it's holding this futex. We need to
+ // set the owner died bit and wake up any waiters.
+ newF := (f & linux.FUTEX_WAITERS) | linux.FUTEX_OWNER_DIED
+ if curF, err := t.CompareAndSwapUint32(addr, f, newF); err != nil {
+ return
+ } else if curF != f {
+ // Futex changed out from under us. Try again...
+ f = curF
+ continue
+ }
+
+ // Wake waiters if there are any.
+ if f&linux.FUTEX_WAITERS != 0 {
+ private := f&linux.FUTEX_PRIVATE_FLAG != 0
+ if pi {
+ t.Futex().UnlockPI(t, addr, tid, private)
+ return
+ }
+ t.Futex().Wake(t, addr, private, linux.FUTEX_BITSET_MATCH_ANY, 1)
+ }
+
+ // Done.
+ return
+ }
+}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index d654dd997..7d4f44caf 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -167,7 +167,22 @@ func (app *runApp) execute(t *Task) taskRunState {
return (*runInterrupt)(nil)
}
- // We're about to switch to the application again. If there's still a
+ // Execute any task work callbacks before returning to user space.
+ if atomic.LoadInt32(&t.taskWorkCount) > 0 {
+ t.taskWorkMu.Lock()
+ queue := t.taskWork
+ t.taskWork = nil
+ atomic.StoreInt32(&t.taskWorkCount, 0)
+ t.taskWorkMu.Unlock()
+
+ // Do not hold taskWorkMu while executing task work, which may register
+ // more work.
+ for _, work := range queue {
+ work.TaskWork(t)
+ }
+ }
+
+ // We're about to switch to the application again. If there's still an
// unhandled SyscallRestartErrno that wasn't translated to an EINTR,
// restart the syscall that was interrupted. If there's a saved signal
// mask, restore it. (Note that restoring the saved signal mask may unblock
diff --git a/pkg/sentry/kernel/task_work.go b/pkg/sentry/kernel/task_work.go
new file mode 100644
index 000000000..dda5a433a
--- /dev/null
+++ b/pkg/sentry/kernel/task_work.go
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "sync/atomic"
+
+// TaskWorker is a deferred task.
+//
+// This must be savable.
+type TaskWorker interface {
+ // TaskWork will be executed prior to returning to user space. Note that
+ // TaskWork may call RegisterWork again, but this will not be executed until
+ // the next return to user space, unlike in Linux. This effectively allows
+ // registration of indefinite user return hooks, but not by default.
+ TaskWork(t *Task)
+}
+
+// RegisterWork can be used to register additional task work that will be
+// performed prior to returning to user space. See TaskWorker.TaskWork for
+// semantics regarding registration.
+func (t *Task) RegisterWork(work TaskWorker) {
+ t.taskWorkMu.Lock()
+ defer t.taskWorkMu.Unlock()
+ atomic.AddInt32(&t.taskWorkCount, 1)
+ t.taskWork = append(t.taskWork, work)
+}
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 7ba7dc50c..2817aa3ba 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "time",
srcs = [
"context.go",
+ "tcpip.go",
"time.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go
new file mode 100644
index 000000000..c4474c0cf
--- /dev/null
+++ b/pkg/sentry/kernel/time/tcpip.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "sync"
+ "time"
+)
+
+// TcpipAfterFunc waits for duration to elapse according to clock then runs fn.
+// The timer is started immediately and will fire exactly once.
+func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer {
+ timer := &TcpipTimer{
+ clock: clock,
+ }
+ timer.notifier = functionNotifier{
+ fn: func() {
+ // tcpip.Timer.Stop() explicitly states that the function is called in a
+ // separate goroutine that Stop() does not synchronize with.
+ // Timer.Destroy() synchronizes with calls to TimerListener.Notify().
+ // This is semantically meaningful because, in the former case, it's
+ // legal to call tcpip.Timer.Stop() while holding locks that may also be
+ // taken by the function, but this isn't so in the latter case. Most
+ // immediately, Timer calls TimerListener.Notify() while holding
+ // Timer.mu. A deadlock occurs without spawning a goroutine:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify()
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock!
+ //
+ // Spawning a goroutine avoids the deadlock:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify() <- Launches T2
+ // T2:
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, blocks
+ // T1:
+ // => (returns) <- Timer.mu.Unlock() called
+ // T2:
+ // => (continues) <- No deadlock!
+ go func() {
+ timer.Stop()
+ fn()
+ }()
+ },
+ }
+ timer.Reset(duration)
+ return timer
+}
+
+// TcpipTimer is a resettable timer with variable duration expirations.
+// Implements tcpip.Timer, which does not define a Destroy method; instead, all
+// resources are released after timer expiration and calls to Timer.Stop.
+//
+// Must be created by AfterFunc.
+type TcpipTimer struct {
+ // clock is the time source. clock is immutable.
+ clock Clock
+
+ // notifier is called when the Timer expires. notifier is immutable.
+ notifier functionNotifier
+
+ // mu protects t.
+ mu sync.Mutex
+
+ // t stores the latest running Timer. This is replaced whenever Reset is
+ // called since Timer cannot be restarted once it has been Destroyed by Stop.
+ //
+ // This field is nil iff Stop has been called.
+ t *Timer
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (r *TcpipTimer) Stop() bool {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ return false
+ }
+ _, lastSetting := r.t.Swap(Setting{})
+ r.t.Destroy()
+ r.t = nil
+ return lastSetting.Enabled
+}
+
+// Reset implements tcpip.Timer.Reset.
+func (r *TcpipTimer) Reset(d time.Duration) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ r.t = NewTimer(r.clock, &r.notifier)
+ }
+
+ r.t.Swap(Setting{
+ Enabled: true,
+ Period: 0,
+ Next: r.clock.Now().Add(d),
+ })
+}
+
+// functionNotifier is a TimerListener that runs a function.
+//
+// functionNotifier cannot be saved or loaded.
+type functionNotifier struct {
+ fn func()
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) {
+ f.fn()
+ return Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (f *functionNotifier) Destroy() {}
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
index 0adf25691..5f3908d8b 100644
--- a/pkg/sentry/kernel/timekeeper.go
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -210,9 +210,6 @@ func (t *Timekeeper) startUpdater() {
p.realtimeBaseRef = int64(realtimeParams.BaseRef)
p.realtimeFrequency = realtimeParams.Frequency
}
-
- log.Debugf("Updating VDSO parameters: %+v", p)
-
return p
}); err != nil {
log.Warningf("Unable to update VDSO parameter page: %v", err)
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index c6aa65f28..34bdb0b69 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -30,9 +30,6 @@ go_library(
"//pkg/rand",
"//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/fs",
- "//pkg/sentry/fs/anon",
- "//pkg/sentry/fs/fsutil",
"//pkg/sentry/fsbridge",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/limits",
@@ -45,6 +42,5 @@ go_library(
"//pkg/syserr",
"//pkg/syserror",
"//pkg/usermem",
- "//pkg/waiter",
],
)
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index 616fafa2c..ddeaff3db 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -90,14 +90,23 @@ type elfInfo struct {
sharedObject bool
}
+// fullReader interface extracts the ReadFull method from fsbridge.File so that
+// client code does not need to define an entire fsbridge.File when only read
+// functionality is needed.
+//
+// TODO(gvisor.dev/issue/1035): Once VFS2 ships, rewrite this to wrap
+// vfs.FileDescription's PRead/Read instead.
+type fullReader interface {
+ // ReadFull is the same as fsbridge.File.ReadFull.
+ ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error)
+}
+
// parseHeader parse the ELF header, verifying that this is a supported ELF
// file and returning the ELF program headers.
//
// This is similar to elf.NewFile, except that it is more strict about what it
// accepts from the ELF, and it doesn't parse unnecessary parts of the file.
-//
-// ctx may be nil if f does not need it.
-func parseHeader(ctx context.Context, f fsbridge.File) (elfInfo, error) {
+func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
// Check ident first; it will tell us the endianness of the rest of the
// structs.
var ident [elf.EI_NIDENT]byte
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 88449fe95..986c7fb4d 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -80,22 +79,6 @@ type LoadArgs struct {
Features *cpuid.FeatureSet
}
-// readFull behaves like io.ReadFull for an *fs.File.
-func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- var total int64
- for dst.NumBytes() > 0 {
- n, err := f.Preadv(ctx, dst, offset+total)
- total += n
- if err == io.EOF && total != 0 {
- return total, io.ErrUnexpectedEOF
- } else if err != nil {
- return total, err
- }
- dst = dst.DropFirst64(n)
- }
- return total, nil
-}
-
// openPath opens args.Filename and checks that it is valid for loading.
//
// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not
@@ -238,14 +221,14 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
// Load the executable itself.
loaded, ac, file, newArgv, err := loadExecutable(ctx, args)
if err != nil {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
}
defer file.DecRef()
// Load the VDSO.
vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded)
if err != nil {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux())
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("error loading VDSO: %v", err), syserr.FromError(err).ToLinux())
}
// Setup the heap. brk starts at the next page after the end of the
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 165869028..05a294fe6 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -26,10 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/anon"
- "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -37,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
)
const vdsoPrelink = 0xffffffffff700000
@@ -55,52 +50,11 @@ func (f *fileContext) Value(key interface{}) interface{} {
}
}
-// byteReader implements fs.FileOperations for reading from a []byte source.
-type byteReader struct {
- fsutil.FileNoFsync `state:"nosave"`
- fsutil.FileNoIoctl `state:"nosave"`
- fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopFlush `state:"nosave"`
- fsutil.FileNoopRelease `state:"nosave"`
- fsutil.FileNotDirReaddir `state:"nosave"`
- fsutil.FilePipeSeek `state:"nosave"`
- fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- waiter.AlwaysReady `state:"nosave"`
-
+type byteFullReader struct {
data []byte
}
-var _ fs.FileOperations = (*byteReader)(nil)
-
-// newByteReaderFile creates a fake file to read data from.
-//
-// TODO(gvisor.dev/issue/2921): Convert to VFS2.
-func newByteReaderFile(ctx context.Context, data []byte) *fs.File {
- // Create a fake inode.
- inode := fs.NewInode(
- ctx,
- &fsutil.SimpleFileInode{},
- fs.NewPseudoMountSource(ctx),
- fs.StableAttr{
- Type: fs.Anonymous,
- DeviceID: anon.PseudoDevice.DeviceID(),
- InodeID: anon.PseudoDevice.NextIno(),
- BlockSize: usermem.PageSize,
- })
-
- // Use the fake inode to create a fake dirent.
- dirent := fs.NewTransientDirent(inode)
- defer dirent.DecRef()
-
- // Use the fake dirent to make a fake file.
- flags := fs.FileFlags{Read: true, Pread: true}
- return fs.NewFile(&fileContext{Context: context.Background()}, dirent, flags, &byteReader{
- data: data,
- })
-}
-
-func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
return 0, syserror.EINVAL
}
@@ -111,10 +65,6 @@ func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequ
return int64(n), err
}
-func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
- panic("Write not supported")
-}
-
// validateVDSO checks that the VDSO can be loaded by loadVDSO.
//
// VDSOs are special (see below). Since we are going to map the VDSO directly
@@ -130,7 +80,7 @@ func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSeq
// * PT_LOAD segments don't extend beyond the end of the file.
//
// ctx may be nil if f does not need it.
-func validateVDSO(ctx context.Context, f fsbridge.File, size uint64) (elfInfo, error) {
+func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, error) {
info, err := parseHeader(ctx, f)
if err != nil {
log.Infof("Unable to parse VDSO header: %v", err)
@@ -248,13 +198,12 @@ func getSymbolValueFromVDSO(symbol string) (uint64, error) {
// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
// param page for updating by the kernel.
-func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
- vdsoFile := fsbridge.NewFSFile(newByteReaderFile(ctx, vdsoBin))
+func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
+ vdsoFile := &byteFullReader{data: vdsoBin}
// First make sure the VDSO is valid. vdsoFile does not use ctx, so a
// nil context can be passed.
info, err := validateVDSO(nil, vdsoFile, uint64(len(vdsoBin)))
- vdsoFile.DecRef()
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 4792454c4..10a10bfe2 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -60,6 +60,7 @@ go_library(
go_test(
name = "kvm_test",
srcs = [
+ "kvm_amd64_test.go",
"kvm_test.go",
"virtual_map_test.go",
],
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
new file mode 100644
index 000000000..c0b4fd374
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -0,0 +1,51 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+)
+
+func TestSegments(t *testing.T) {
+ applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ testutil.SetTestSegments(regs)
+ for {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }, &si); err == platform.ErrContextInterrupt {
+ continue // Retry.
+ } else if err != nil {
+ t.Errorf("application segment check with full restore got unexpected error: %v", err)
+ }
+ if err := testutil.CheckTestSegments(regs); err != nil {
+ t.Errorf("application segment check with full restore failed: %v", err)
+ }
+ break // Done.
+ }
+ return false
+ })
+}
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 6c8f4fa28..45b3180f1 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -262,30 +262,6 @@ func TestRegistersFault(t *testing.T) {
})
}
-func TestSegments(t *testing.T) {
- applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
- testutil.SetTestSegments(regs)
- for {
- var si arch.SignalInfo
- if _, err := c.SwitchToUser(ring0.SwitchOpts{
- Registers: regs,
- FloatingPointState: dummyFPState,
- PageTables: pt,
- FullRestore: true,
- }, &si); err == platform.ErrContextInterrupt {
- continue // Retry.
- } else if err != nil {
- t.Errorf("application segment check with full restore got unexpected error: %v", err)
- }
- if err := testutil.CheckTestSegments(regs); err != nil {
- t.Errorf("application segment check with full restore failed: %v", err)
- }
- break // Done.
- }
- return false
- })
-}
-
func TestBounce(t *testing.T) {
applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
go func() {
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 8bed34922..3de309c1a 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -78,19 +78,6 @@ func (c *vCPU) initArchState() error {
return err
}
- // sctlr_el1
- regGet.id = _KVM_ARM64_REGS_SCTLR_EL1
- if err := c.getOneRegister(&regGet); err != nil {
- return err
- }
-
- dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I)
- data = dataGet
- reg.id = _KVM_ARM64_REGS_SCTLR_EL1
- if err := c.setOneRegister(&reg); err != nil {
- return err
- }
-
// tcr_el1
data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS
reg.id = _KVM_ARM64_REGS_TCR_EL1
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
index 0bebee852..07658144e 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
@@ -104,3 +104,9 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0
TWIDDLE_REGS()
SVC
RET // never reached
+
+TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
+ TWIDDLE_REGS()
+ // Branch to Register branches unconditionally to an address in <Rn>.
+ JMP (R4) // <=> br x4, must fault
+ RET // never reached
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 2bc5f3ecd..6ed73699b 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -40,6 +40,14 @@
#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT)
+// sctlr_el1: system control register el1.
+#define SCTLR_M 1 << 0
+#define SCTLR_C 1 << 2
+#define SCTLR_I 1 << 12
+#define SCTLR_UCT 1 << 15
+
+#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT)
+
// Saves a register set.
//
// This is a macro because it may need to executed in contents where a stack is
@@ -496,6 +504,11 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
// Start is the CPU entrypoint.
TEXT ·Start(SB),NOSPLIT,$0
IRQ_DISABLE
+
+ // Init.
+ MOVD $SCTLR_EL1_DEFAULT, R1
+ MSR R1, SCTLR_EL1
+
MOVD R8, RSV_REG
ORR $0xffff000000000000, RSV_REG, RSV_REG
WORD $0xd518d092 //MSR R18, TPIDR_EL1
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index ccacaea6b..fca3a5478 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -58,7 +58,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Pstate &= ^uint64(UserFlagsClear)
regs.Pstate |= UserFlagsSet
+
+ SetTLS(regs.TPIDR_EL0)
+
kernelExitToEl0()
+
+ regs.TPIDR_EL0 = GetTLS()
+
vector = c.vecCode
// Perform the switch.
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index c40c6d673..c0fd3425b 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,5 +20,6 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index ff81ea6e6..e76e498de 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -40,6 +40,8 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index fda19e7bb..532a1ea5d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -36,6 +36,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const (
@@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
if err != nil {
return nil, syserr.FromError(err)
}
- return opt, nil
+ optP := primitive.ByteSlice(opt)
+ return &optP, nil
}
// SetSockOpt implements socket.Socket.SetSockOpt.
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 8f192c62f..8a1d52ebf 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -71,6 +71,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in
DenyPWrite: true,
UseDentryMetadata: true,
}); err != nil {
+ fdnotifier.RemoveFD(int32(s.fd))
return nil, syserr.FromError(err)
}
return vfsfd, nil
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index f7abe77d3..d9394055d 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -145,7 +145,7 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
// Each rule corresponds to an entry.
entry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
+ Entry: linux.IPTEntry{
IP: linux.IPTIP{
Protocol: uint16(rule.Filter.Protocol),
},
@@ -153,20 +153,20 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
TargetOffset: linux.SizeOfIPTEntry,
},
}
- copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst)
- copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask)
- copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src)
- copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask)
- copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface)
- copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ copy(entry.Entry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.Entry.IP.Src[:], rule.Filter.Src)
+ copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
if rule.Filter.DstInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP
}
if rule.Filter.SrcInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP
}
if rule.Filter.OutputInterfaceInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
}
for _, matcher := range rule.Matchers {
@@ -178,8 +178,8 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
- entry.TargetOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
+ entry.Entry.TargetOffset += uint16(len(serialized))
}
// Serialize and append the target.
@@ -188,11 +188,11 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
nflog("convert to binary: adding entry: %+v", entry)
- entries.Size += uint32(entry.NextOffset)
+ entries.Size += uint32(entry.Entry.NextOffset)
entries.Entrytable = append(entries.Entrytable, entry)
info.NumEntries++
}
@@ -342,10 +342,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
- case stack.TablenameFilter:
+ case stack.FilterTable:
table = stack.EmptyFilterTable()
- case stack.TablenameNat:
- table = stack.EmptyNatTable()
+ case stack.NATTable:
+ table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
@@ -431,6 +431,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
for hook, _ := range replace.HookEntry {
if table.ValidHooks()&(1<<hook) != 0 {
hk := hookFromLinux(hook)
+ table.BuiltinChains[hk] = stack.HookUnset
+ table.Underflows[hk] = stack.HookUnset
for offset, ruleIdx := range offsets {
if offset == replace.HookEntry[hook] {
table.BuiltinChains[hk] = ruleIdx
@@ -456,8 +458,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Add the user chains.
for ruleIdx, rule := range table.Rules {
- target, ok := rule.Target.(stack.UserChainTarget)
- if !ok {
+ if _, ok := rule.Target.(stack.UserChainTarget); !ok {
continue
}
@@ -473,7 +474,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
nflog("user chain's first node must have no matchers")
return syserr.ErrInvalidArgument
}
- table.UserChains[target.Name] = ruleIdx + 1
}
// Set each jump to point to the appropriate rule. Right now they hold byte
@@ -499,7 +499,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
// make sure all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook == stack.Forward || hook == stack.Postrouting {
+ if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
+ if ruleIdx == stack.HookUnset {
+ continue
+ }
if !isUnconditionalAccept(table.Rules[ruleIdx]) {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
@@ -512,9 +515,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- stk.IPTables().ReplaceTable(replace.Name.String(), table)
-
- return nil
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table))
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index d5ca3ac56..0546801bf 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -36,6 +36,8 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 81f34c5a2..98ca7add0 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -38,6 +38,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const sizeOfInt32 int = 4
@@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
}
s.mu.Lock()
defer s.mu.Unlock()
- return int32(s.sendBufferSize), nil
+ sendBufferSizeP := primitive.Int32(s.sendBufferSize)
+ return &sendBufferSizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
// We don't have limit on receiving size.
- return int32(math.MaxInt32), nil
+ recvBufferSizeP := primitive.Int32(math.MaxInt32)
+ return &recvBufferSizeP, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var passcred int32
+ var passcred primitive.Int32
if s.Passcred() {
passcred = 1
}
- return passcred, nil
+ return &passcred, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index ea6ebd0e2..1fb777a6c 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -51,6 +51,8 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e7d2c83d7..9856ab8c5 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -26,6 +26,7 @@ package netstack
import (
"bytes"
+ "fmt"
"io"
"math"
"reflect"
@@ -61,6 +62,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
@@ -192,6 +195,7 @@ var Metrics = tcpip.Stats{
PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."),
+ InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."),
},
}
@@ -296,8 +300,9 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
- sender tcpip.FullAddress
+ readCM tcpip.ControlMessages
+ sender tcpip.FullAddress
+ linkPacketInfo tcpip.LinkPacketInfo
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
// of returned messages can be returned via control messages. When
@@ -446,8 +451,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = nil
s.sender = tcpip.FullAddress{}
+ s.linkPacketInfo = tcpip.LinkPacketInfo{}
- v, cms, err := s.Endpoint.Read(&s.sender)
+ var v buffer.View
+ var cms tcpip.ControlMessages
+ var err *tcpip.Error
+
+ switch e := s.Endpoint.(type) {
+ // The ordering of these interfaces matters. The most specific
+ // interfaces must be specified before the more generic Endpoint
+ // interface.
+ case tcpip.PacketEndpoint:
+ v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo)
+ case tcpip.Endpoint:
+ v, cms, err = e.Read(&s.sender)
+ }
if err != nil {
atomic.StoreUint32(&s.readViewHasData, 0)
return syserr.TranslateNetstackError(err)
@@ -894,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -904,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -940,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -955,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
@@ -965,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -998,7 +1016,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -1009,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// Get the last error and convert it.
err := ep.GetSockOpt(tcpip.ErrorOption{})
if err == nil {
- return int32(0), nil
+ optP := primitive.Int32(0)
+ return &optP, nil
}
- return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number())
+ return &optP, nil
case linux.SO_PEERCRED:
if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
@@ -1019,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
tcred := t.Credentials()
- return syscall.Ucred{
- Pid: int32(t.ThreadGroup().ID()),
- Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
- Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
- }, nil
+ creds := linux.ControlMessageCredentials{
+ PID: int32(t.ThreadGroup().ID()),
+ UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }
+ return &creds, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
@@ -1034,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1050,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
@@ -1066,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_REUSEADDR:
if outLen < sizeOfInt32 {
@@ -1077,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
@@ -1088,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1096,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
if v == 0 {
- return []byte{}, nil
+ var b primitive.ByteSlice
+ return &b, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
@@ -1111,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// interface was removed.
return nil, syserr.ErrUnknownDevice
}
- return append([]byte(nic.Name), 0), nil
+
+ name := primitive.ByteSlice(append([]byte(nic.Name), 0))
+ return &name, nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1122,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
@@ -1133,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- return linux.Linger{}, nil
+
+ linger := linux.Linger{}
+ return &linger, nil
case linux.SO_SNDTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1147,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.SendTimeout()), nil
+ sendTimeout := linux.NsecToTimeval(s.SendTimeout())
+ return &sendTimeout, nil
case linux.SO_RCVTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1155,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.RecvTimeout()), nil
+ recvTimeout := linux.NsecToTimeval(s.RecvTimeout())
+ return &recvTimeout, nil
case linux.SO_OOBINLINE:
if outLen < sizeOfInt32 {
@@ -1167,7 +1207,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
@@ -1178,7 +1219,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1187,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
@@ -1198,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(!v), nil
+
+ vP := primitive.Int32(boolToInt32(!v))
+ return &vP, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
@@ -1209,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
@@ -1220,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1231,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_KEEPIDLE:
if outLen < sizeOfInt32 {
@@ -1243,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveIdle, nil
case linux.TCP_KEEPINTVL:
if outLen < sizeOfInt32 {
@@ -1255,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveInterval, nil
case linux.TCP_KEEPCNT:
if outLen < sizeOfInt32 {
@@ -1267,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_USER_TIMEOUT:
if outLen < sizeOfInt32 {
@@ -1279,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Millisecond), nil
+ tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond)
+ return &tcpUserTimeout, nil
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
@@ -1293,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
info := linux.TCPInfo{}
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &info)
- if len(ib) > outLen {
- ib = ib[:outLen]
+ buf := t.CopyScratchBuffer(info.SizeBytes())
+ info.MarshalUnsafe(buf)
+ if len(buf) > outLen {
+ buf = buf[:outLen]
}
-
- return ib, nil
+ bufP := primitive.ByteSlice(buf)
+ return &bufP, nil
case linux.TCP_CC_INFO,
linux.TCP_NOTSENT_LOWAT,
@@ -1328,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
b := make([]byte, toCopy)
copy(b, v)
- return b, nil
+
+ bP := primitive.ByteSlice(b)
+ return &bP, nil
case linux.TCP_LINGER2:
if outLen < sizeOfInt32 {
@@ -1340,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ return &lingerTimeout, nil
case linux.TCP_DEFER_ACCEPT:
if outLen < sizeOfInt32 {
@@ -1352,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second)
+ return &tcpDeferAccept, nil
case linux.TCP_SYNCNT:
if outLen < sizeOfInt32 {
@@ -1363,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_WINDOW_CLAMP:
if outLen < sizeOfInt32 {
@@ -1375,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1384,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
@@ -1395,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1403,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
case linux.IPV6_TCLASS:
// Length handling for parity with Linux.
if outLen == 0 {
- return make([]byte, 0), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- uintv := uint32(v)
+ uintv := primitive.Uint32(v)
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ ib := t.CopyScratchBuffer(uintv.SizeBytes())
+ uintv.MarshalUnsafe(ib)
// Handle cases where outLen is lesser than sizeOfInt32.
if len(ib) > outLen {
ib = ib[:outLen]
}
- return ib, nil
+ ibP := primitive.ByteSlice(ib)
+ return &ibP, nil
case linux.IPV6_RECVTCLASS:
if outLen < sizeOfInt32 {
@@ -1428,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIPv6(t, name)
@@ -1437,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1450,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
// Fill in the default value, if needed.
- if v == 0 {
- v = DefaultTTL
+ vP := primitive.Int32(v)
+ if vP == 0 {
+ vP = DefaultTTL
}
- return int32(v), nil
+ return &vP, nil
case linux.IP_MULTICAST_TTL:
if outLen < sizeOfInt32 {
@@ -1466,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_MULTICAST_IF:
if outLen < len(linux.InetAddr{}) {
@@ -1480,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
- return a.(*linux.SockAddrInet).Addr, nil
+ return &a.(*linux.SockAddrInet).Addr, nil
case linux.IP_MULTICAST_LOOP:
if outLen < sizeOfInt32 {
@@ -1491,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
- return []byte(nil), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
- return uint8(v), nil
+ vP := primitive.Uint8(v)
+ return &vP, nil
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_RECVTOS:
if outLen < sizeOfInt32 {
@@ -1516,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
@@ -1527,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIP(t, name)
@@ -1753,6 +1824,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
return nil
+ case linux.SO_DETACH_FILTER:
+ // optval is ignored.
+ var v tcpip.SocketDetachFilterOption
+ return syserr.TranslateNetstackError(ep.SetSockOpt(v))
+
default:
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
@@ -2112,13 +2188,22 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+ case linux.IP_HDRINCL:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
linux.IP_CHECKSUM,
linux.IP_DROP_SOURCE_MEMBERSHIP,
linux.IP_FREEBIND,
- linux.IP_HDRINCL,
linux.IP_IPSEC_POLICY,
linux.IP_MINTTL,
linux.IP_MSFILTER,
@@ -2439,6 +2524,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
}
+func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
+ switch pktType {
+ case tcpip.PacketHost:
+ return linux.PACKET_HOST
+ case tcpip.PacketOtherHost:
+ return linux.PACKET_OTHERHOST
+ case tcpip.PacketOutgoing:
+ return linux.PACKET_OUTGOING
+ case tcpip.PacketBroadcast:
+ return linux.PACKET_BROADCAST
+ case tcpip.PacketMulticast:
+ return linux.PACKET_MULTICAST
+ default:
+ panic(fmt.Sprintf("unknown packet type: %d", pktType))
+ }
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
@@ -2494,6 +2596,11 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
var addrLen uint32
if isPacket && senderRequested {
addr, addrLen = ConvertAddress(s.family, s.sender)
+ switch v := addr.(type) {
+ case *linux.SockAddrLink:
+ v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
+ }
}
if peek {
@@ -2732,7 +2839,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
switch args[1].Int() {
- case syscall.SIOCGSTAMP:
+ case linux.SIOCGSTAMP:
s.readMu.Lock()
defer s.readMu.Unlock()
if !s.timestampValid {
@@ -2773,18 +2880,19 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
// Ioctl performs a socket ioctl.
func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch arg := int(args[1].Int()); arg {
- case syscall.SIOCGIFFLAGS,
- syscall.SIOCGIFADDR,
- syscall.SIOCGIFBRDADDR,
- syscall.SIOCGIFDSTADDR,
- syscall.SIOCGIFHWADDR,
- syscall.SIOCGIFINDEX,
- syscall.SIOCGIFMAP,
- syscall.SIOCGIFMETRIC,
- syscall.SIOCGIFMTU,
- syscall.SIOCGIFNAME,
- syscall.SIOCGIFNETMASK,
- syscall.SIOCGIFTXQLEN:
+ case linux.SIOCGIFFLAGS,
+ linux.SIOCGIFADDR,
+ linux.SIOCGIFBRDADDR,
+ linux.SIOCGIFDSTADDR,
+ linux.SIOCGIFHWADDR,
+ linux.SIOCGIFINDEX,
+ linux.SIOCGIFMAP,
+ linux.SIOCGIFMETRIC,
+ linux.SIOCGIFMTU,
+ linux.SIOCGIFNAME,
+ linux.SIOCGIFNETMASK,
+ linux.SIOCGIFTXQLEN,
+ linux.SIOCETHTOOL:
var ifr linux.IFReq
if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
@@ -2800,7 +2908,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
})
return 0, err
- case syscall.SIOCGIFCONF:
+ case linux.SIOCGIFCONF:
// Return a list of interface addresses or the buffer size
// necessary to hold the list.
var ifc linux.IFConf
@@ -2874,7 +2982,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to
// identify a device.
- if arg == syscall.SIOCGIFNAME {
+ if arg == linux.SIOCGIFNAME {
// Gets the name of the interface given the interface index
// stored in ifr_ifindex.
index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4]))
@@ -2897,21 +3005,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
switch arg {
- case syscall.SIOCGIFINDEX:
+ case linux.SIOCGIFINDEX:
// Copy out the index to the data.
usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
- case syscall.SIOCGIFHWADDR:
+ case linux.SIOCGIFHWADDR:
// Copy the hardware address out.
- ifr.Data[0] = 6 // IEEE802.2 arp type.
- ifr.Data[1] = 0
+ //
+ // Refer: https://linux.die.net/man/7/netdevice
+ // SIOCGIFHWADDR, SIOCSIFHWADDR
+ //
+ // Get or set the hardware address of a device using
+ // ifr_hwaddr. The hardware address is specified in a struct
+ // sockaddr. sa_family contains the ARPHRD_* device type,
+ // sa_data the L2 hardware address starting from byte 0. Setting
+ // the hardware address is a privileged operation.
+ usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType)
n := copy(ifr.Data[2:], iface.Addr)
for i := 2 + n; i < len(ifr.Data); i++ {
ifr.Data[i] = 0 // Clear padding.
}
- usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n))
- case syscall.SIOCGIFFLAGS:
+ case linux.SIOCGIFFLAGS:
f, err := interfaceStatusFlags(stack, iface.Name)
if err != nil {
return err
@@ -2920,7 +3035,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// matches Linux behavior.
usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f))
- case syscall.SIOCGIFADDR:
+ case linux.SIOCGIFADDR:
// Copy the IPv4 address out.
for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
@@ -2931,32 +3046,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
break
}
- case syscall.SIOCGIFMETRIC:
+ case linux.SIOCGIFMETRIC:
// Gets the metric of the device. As per netdevice(7), this
// always just sets ifr_metric to 0.
usermem.ByteOrder.PutUint32(ifr.Data[:4], 0)
- case syscall.SIOCGIFMTU:
+ case linux.SIOCGIFMTU:
// Gets the MTU of the device.
usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU)
- case syscall.SIOCGIFMAP:
+ case linux.SIOCGIFMAP:
// Gets the hardware parameters of the device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFTXQLEN:
+ case linux.SIOCGIFTXQLEN:
// Gets the transmit queue length of the device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFDSTADDR:
+ case linux.SIOCGIFDSTADDR:
// Gets the destination address of a point-to-point device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFBRDADDR:
+ case linux.SIOCGIFBRDADDR:
// Gets the broadcast address of a device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFNETMASK:
+ case linux.SIOCGIFNETMASK:
// Gets the network mask of a device.
for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
@@ -2973,6 +3088,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
break
}
+ case linux.SIOCETHTOOL:
+ // Stubbed out for now, Ideally we should implement the required
+ // sub-commands for ETHTOOL
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c
+ return syserr.ErrEndpointOperation
+
default:
// Not a valid call.
return syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index d65a89316..a9025b0ec 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -31,6 +31,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// SocketVFS2 encapsulates all the state needed to represent a network stack
@@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketVFS2 rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 548442b96..67737ae87 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -15,6 +15,8 @@
package netstack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
@@ -40,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool {
return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber)
}
+// Converts Netstack's ARPHardwareType to equivalent linux constants.
+func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 {
+ switch t {
+ case header.ARPHardwareNone:
+ return linux.ARPHRD_NONE
+ case header.ARPHardwareLoopback:
+ return linux.ARPHRD_LOOPBACK
+ case header.ARPHardwareEther:
+ return linux.ARPHRD_ETHER
+ default:
+ panic(fmt.Sprintf("unknown ARPHRD type: %d", t))
+ }
+}
+
// Interfaces implements inet.Stack.Interfaces.
func (s *Stack) Interfaces() map[int32]inet.Interface {
is := make(map[int32]inet.Interface)
for id, ni := range s.Stack.NICInfo() {
- var devType uint16
- if ni.Flags.Loopback {
- devType = linux.ARPHRD_LOOPBACK
- }
is[int32(id)] = inet.Interface{
Name: ni.Name,
Addr: []byte(ni.LinkAddress),
Flags: uint32(nicStateFlagsToLinux(ni.Flags)),
- DeviceType: devType,
+ DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType),
MTU: ni.MTU,
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fcd7f9d7f..d112757fb 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// ControlMessages represents the union of unix control messages and tcpip
@@ -86,7 +87,7 @@ type SocketOps interface {
Shutdown(t *kernel.Task, how int) *syserr.Error
// GetSockOpt implements the getsockopt(2) linux syscall.
- GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error)
// SetSockOpt implements the setsockopt(2) linux syscall.
SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cca5e70f1..061a689a9 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -35,5 +35,6 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4bb2b6ff4..0482d33cf 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -40,6 +40,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketOperations is a Unix socket. It is similar to a netstack socket,
@@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index ff2149250..05c16fcfe 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketVFS2 implements socket.SocketVFS2 (and by extension,
@@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 217fcfef2..4a9b04fd0 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -99,5 +99,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index ea4f9b1a7..80c65164a 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -325,8 +325,8 @@ var AMD64 = &kernel.SyscallTable{
270: syscalls.Supported("pselect", Pselect),
271: syscalls.Supported("ppoll", Ppoll),
272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
- 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 273: syscalls.Supported("set_robust_list", SetRobustList),
+ 274: syscalls.Supported("get_robust_list", GetRobustList),
275: syscalls.Supported("splice", Splice),
276: syscalls.Supported("tee", Tee),
277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index b68261f72..f04d78856 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -198,7 +198,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
switch cmd {
case linux.FUTEX_WAIT:
// WAIT uses a relative timeout.
- mask = ^uint32(0)
+ mask = linux.FUTEX_BITSET_MATCH_ANY
var timeoutDur time.Duration
if !forever {
timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond
@@ -286,3 +286,49 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, syserror.ENOSYS
}
}
+
+// SetRobustList implements linux syscall set_robust_list(2).
+func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ head := args[0].Pointer()
+ length := args[1].SizeT()
+
+ if length != uint(linux.SizeOfRobustListHead) {
+ return 0, nil, syserror.EINVAL
+ }
+ t.SetRobustList(head)
+ return 0, nil, nil
+}
+
+// GetRobustList implements linux syscall get_robust_list(2).
+func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ tid := args[0].Int()
+ head := args[1].Pointer()
+ size := args[2].Pointer()
+
+ if tid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ot := t
+ if tid != 0 {
+ if ot = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid)); ot == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ // Copy out head pointer.
+ if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil {
+ return 0, nil, err
+ }
+
+ // Copy out size, which is a constant.
+ if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 0760af77b..414fce8e3 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -29,6 +29,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
@@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 0c740335b..64696b438 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -72,5 +72,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
index b12b5967b..6b14c2bef 100644
--- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -107,7 +107,7 @@ func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
addr := args[0].Pointer()
mode := args[1].ModeT()
dev := args[2].Uint()
- return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev)
+ return 0, nil, mknodat(t, linux.AT_FDCWD, addr, linux.FileMode(mode), dev)
}
// Mknodat implements Linux syscall mknodat(2).
@@ -116,10 +116,10 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
addr := args[1].Pointer()
mode := args[2].ModeT()
dev := args[3].Uint()
- return 0, nil, mknodat(t, dirfd, addr, mode, dev)
+ return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev)
}
-func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error {
+func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error {
path, err := copyInPath(t, addr)
if err != nil {
return err
@@ -129,9 +129,14 @@ func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint
return err
}
defer tpop.Release()
+
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ if mode.FileType() == 0 {
+ mode |= linux.ModeRegular
+ }
major, minor := linux.DecodeDeviceID(dev)
return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{
- Mode: linux.FileMode(mode &^ t.FSContext().Umask()),
+ Mode: mode &^ linux.FileMode(t.FSContext().Umask()),
DevMajor: uint32(major),
DevMinor: minor,
})
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
index adeaa39cc..ea337de7c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mount.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -77,8 +77,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Silently allow MS_NOSUID, since we don't implement set-id bits
// anyway.
- const unsupportedFlags = linux.MS_NODEV |
- linux.MS_NODIRATIME | linux.MS_STRICTATIME
+ const unsupportedFlags = linux.MS_NODIRATIME | linux.MS_STRICTATIME
// Linux just allows passing any flags to mount(2) - it won't fail when
// unknown or unsupported flags are passed. Since we don't implement
@@ -94,6 +93,12 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
opts.Flags.NoExec = true
}
+ if flags&linux.MS_NODEV == linux.MS_NODEV {
+ opts.Flags.NoDev = true
+ }
+ if flags&linux.MS_NOSUID == linux.MS_NOSUID {
+ opts.Flags.NoSUID = true
+ }
if flags&linux.MS_RDONLY == linux.MS_RDONLY {
opts.ReadOnly = true
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 09ecfed26..6daedd173 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -178,6 +178,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
Mask: linux.STATX_SIZE,
Size: uint64(length),
},
+ NeedWritePerm: true,
})
return 0, nil, handleSetSizeError(t, err)
}
@@ -197,6 +198,10 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
defer file.DecRef()
+ if !file.IsWritable() {
+ return 0, nil, syserror.EINVAL
+ }
+
err := file.SetStat(t, vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_SIZE,
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 10b668477..8096a8f9c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -30,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// minListenBacklog is the minimum reasonable backlog for listening sockets.
@@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 945a364a7..63ab11f8c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -15,12 +15,15 @@
package vfs2
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -110,16 +113,20 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Move data.
var (
- n int64
- err error
- inCh chan struct{}
- outCh chan struct{}
+ n int64
+ err error
)
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
for {
// If both input and output are pipes, delegate to the pipe
- // implementation. Otherwise, exactly one end is a pipe, which we
- // ensure is consistently ordered after the non-pipe FD's locks by
- // passing the pipe FD as usermem.IO to the non-pipe end.
+ // implementation. Otherwise, exactly one end is a pipe, which
+ // we ensure is consistently ordered after the non-pipe FD's
+ // locks by passing the pipe FD as usermem.IO to the non-pipe
+ // end.
switch {
case inIsPipe && outIsPipe:
n, err = pipe.Splice(t, outPipeFD, inPipeFD, count)
@@ -137,38 +144,15 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
} else {
n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
}
+ default:
+ panic("not possible")
}
+
if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
break
}
-
- // Note that the blocking behavior here is a bit different than the
- // normal pattern. Because we need to have both data to read and data
- // to write simultaneously, we actually explicitly block on both of
- // these cases in turn before returning to the splice operation.
- if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
- if inCh == nil {
- inCh = make(chan struct{}, 1)
- inW, _ := waiter.NewChannelEntry(inCh)
- inFile.EventRegister(&inW, eventMaskRead)
- defer inFile.EventUnregister(&inW)
- continue // Need to refresh readiness.
- }
- if err = t.Block(inCh); err != nil {
- break
- }
- }
- if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
- if outCh == nil {
- outCh = make(chan struct{}, 1)
- outW, _ := waiter.NewChannelEntry(outCh)
- outFile.EventRegister(&outW, eventMaskWrite)
- defer outFile.EventUnregister(&outW)
- continue // Need to refresh readiness.
- }
- if err = t.Block(outCh); err != nil {
- break
- }
+ if err = dw.waitForBoth(t); err != nil {
+ break
}
}
@@ -247,45 +231,256 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
// Copy data.
var (
- inCh chan struct{}
- outCh chan struct{}
+ n int64
+ err error
)
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
for {
- n, err := pipe.Tee(t, outPipeFD, inPipeFD, count)
- if n != 0 {
- return uintptr(n), nil, nil
+ n, err = pipe.Tee(t, outPipeFD, inPipeFD, count)
+ if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ break
+ }
+ if err = dw.waitForBoth(t); err != nil {
+ break
+ }
+ }
+ if n == 0 {
+ return 0, nil, err
+ }
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// Sendfile implements linux system call sendfile(2).
+func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ outFD := args[0].Int()
+ inFD := args[1].Int()
+ offsetAddr := args[2].Pointer()
+ count := int64(args[3].SizeT())
+
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ if !inFile.IsReadable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+ if !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Verify that the outFile Append flag is not set.
+ if outFile.StatusFlags()&linux.O_APPEND != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Verify that inFile is a regular file or block device. This is a
+ // requirement; the same check appears in Linux
+ // (fs/splice.c:splice_direct_to_actor).
+ if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil {
+ return 0, nil, err
+ } else if stat.Mask&linux.STATX_TYPE == 0 ||
+ (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy offset if it exists.
+ offset := int64(-1)
+ if offsetAddr != 0 {
+ if inFile.Options().DenyPRead {
+ return 0, nil, syserror.ESPIPE
}
- if err != syserror.ErrWouldBlock || nonBlock {
+ if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
return 0, nil, err
}
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if offset+count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ // Validate count. This must come after offset checks.
+ if count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
- // Note that the blocking behavior here is a bit different than the
- // normal pattern. Because we need to have both data to read and data
- // to write simultaneously, we actually explicitly block on both of
- // these cases in turn before returning to the tee operation.
- if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
- if inCh == nil {
- inCh = make(chan struct{}, 1)
- inW, _ := waiter.NewChannelEntry(inCh)
- inFile.EventRegister(&inW, eventMaskRead)
- defer inFile.EventUnregister(&inW)
- continue // Need to refresh readiness.
+ // Copy data.
+ var (
+ n int64
+ err error
+ )
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ // Reading from input file should never block, since it is regular or
+ // block device. We only need to check if writing to the output file
+ // can block.
+ nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0
+ if outIsPipe {
+ for n < count {
+ var spliceN int64
+ if offset != -1 {
+ spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{})
+ offset += spliceN
+ } else {
+ spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
}
- if err := t.Block(inCh); err != nil {
- return 0, nil, err
+ n += spliceN
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
+ }
+ if err != nil {
+ break
}
}
- if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
- if outCh == nil {
- outCh = make(chan struct{}, 1)
- outW, _ := waiter.NewChannelEntry(outCh)
- outFile.EventRegister(&outW, eventMaskWrite)
- defer outFile.EventUnregister(&outW)
- continue // Need to refresh readiness.
+ } else {
+ // Read inFile to buffer, then write the contents to outFile.
+ buf := make([]byte, count)
+ for n < count {
+ var readN int64
+ if offset != -1 {
+ readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{})
+ offset += readN
+ } else {
+ readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ }
+ if readN == 0 && err == io.EOF {
+ // We reached the end of the file. Eat the
+ // error and exit the loop.
+ err = nil
+ break
}
- if err := t.Block(outCh); err != nil {
- return 0, nil, err
+ n += readN
+ if err != nil {
+ break
+ }
+
+ // Write all of the bytes that we read. This may need
+ // multiple write calls to complete.
+ wbuf := buf[:n]
+ for len(wbuf) > 0 {
+ var writeN int64
+ writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{})
+ wbuf = wbuf[writeN:]
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForOut(t)
+ }
+ if err != nil {
+ // We didn't complete the write. Only
+ // report the bytes that were actually
+ // written, and rewind the offset.
+ notWritten := int64(len(wbuf))
+ n -= notWritten
+ if offset != -1 {
+ offset -= notWritten
+ }
+ break
+ }
+ }
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
}
+ if err != nil {
+ break
+ }
+ }
+ }
+
+ if offsetAddr != 0 {
+ // Copy out the new offset.
+ if _, err := t.CopyOut(offsetAddr, offset); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if n == 0 {
+ return 0, nil, err
+ }
+
+ inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not
+// thread-safe, and does not take a reference on the vfs.FileDescriptions.
+//
+// Users must call destroy() when finished.
+type dualWaiter struct {
+ inFile *vfs.FileDescription
+ outFile *vfs.FileDescription
+
+ inW waiter.Entry
+ inCh chan struct{}
+ outW waiter.Entry
+ outCh chan struct{}
+}
+
+// waitForBoth waits for both dw.inFile and dw.outFile to be ready.
+func (dw *dualWaiter) waitForBoth(t *kernel.Task) error {
+ if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
+ if dw.inCh == nil {
+ dw.inW, dw.inCh = waiter.NewChannelEntry(nil)
+ dw.inFile.EventRegister(&dw.inW, eventMaskRead)
+ // We might be ready now. Try again before blocking.
+ return nil
+ }
+ if err := t.Block(dw.inCh); err != nil {
+ return err
+ }
+ }
+ return dw.waitForOut(t)
+}
+
+// waitForOut waits for dw.outfile to be read.
+func (dw *dualWaiter) waitForOut(t *kernel.Task) error {
+ if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
+ if dw.outCh == nil {
+ dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
+ dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
+ // We might be ready now. Try again before blocking.
+ return nil
}
+ if err := t.Block(dw.outCh); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// destroy cleans up resources help by dw. No more calls to wait* can occur
+// after destroy is called.
+func (dw *dualWaiter) destroy() {
+ if dw.inCh != nil {
+ dw.inFile.EventUnregister(&dw.inW)
+ dw.inCh = nil
+ }
+ if dw.outCh != nil {
+ dw.outFile.EventUnregister(&dw.outW)
+ dw.outCh = nil
}
+ dw.inFile = nil
+ dw.outFile = nil
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index 8f497ecc7..1b2cfad7d 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -44,7 +44,7 @@ func Override() {
s.Table[23] = syscalls.Supported("select", Select)
s.Table[32] = syscalls.Supported("dup", Dup)
s.Table[33] = syscalls.Supported("dup2", Dup2)
- delete(s.Table, 40) // sendfile
+ s.Table[40] = syscalls.Supported("sendfile", Sendfile)
s.Table[41] = syscalls.Supported("socket", Socket)
s.Table[42] = syscalls.Supported("connect", Connect)
s.Table[43] = syscalls.Supported("accept", Accept)
diff --git a/pkg/sentry/time/parameters.go b/pkg/sentry/time/parameters.go
index 65868cb26..cd1b95117 100644
--- a/pkg/sentry/time/parameters.go
+++ b/pkg/sentry/time/parameters.go
@@ -228,11 +228,15 @@ func errorAdjust(prevParams Parameters, newParams Parameters, now TSCValue) (Par
//
// The log level is determined by the error severity.
func logErrorAdjustment(clock ClockID, errorNS ReferenceNS, orig, adjusted Parameters) {
- fn := log.Debugf
- if int64(errorNS.Magnitude()) > time.Millisecond.Nanoseconds() {
+ magNS := int64(errorNS.Magnitude())
+ if magNS <= 10*time.Microsecond.Nanoseconds() {
+ // Don't log small errors.
+ return
+ }
+ fn := log.Infof
+ if magNS > time.Millisecond.Nanoseconds() {
+ // Upgrade large errors to warning.
fn = log.Warningf
- } else if int64(errorNS.Magnitude()) > 10*time.Microsecond.Nanoseconds() {
- fn = log.Infof
}
fn("Clock(%v): error: %v ns, adjusted frequency from %v Hz to %v Hz", clock, errorNS, orig.Frequency, adjusted.Frequency)
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index c2e21ac5f..167b731ac 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -179,12 +179,12 @@ func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask {
return mask & ready
}
-// PRead implements FileDescriptionImpl.
+// PRead implements FileDescriptionImpl.PRead.
func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
return 0, syserror.ESPIPE
}
-// PWrite implements FileDescriptionImpl.
+// PWrite implements FileDescriptionImpl.PWrite.
func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
return 0, syserror.ESPIPE
}
@@ -243,7 +243,7 @@ func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOpt
return writeLen, nil
}
-// Ioctl implements fs.FileOperations.Ioctl.
+// Ioctl implements FileDescriptionImpl.Ioctl.
func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch args[1].Int() {
case linux.FIONREAD:
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index f223aeda8..dfc8573fd 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -79,6 +79,17 @@ type MountFlags struct {
// NoATime is equivalent to MS_NOATIME and indicates that the
// filesystem should not update access time in-place.
NoATime bool
+
+ // NoDev is equivalent to MS_NODEV and indicates that the
+ // filesystem should not allow access to devices (special files).
+ // TODO(gVisor.dev/issue/3186): respect this flag in non FUSE
+ // filesystems.
+ NoDev bool
+
+ // NoSUID is equivalent to MS_NOSUID and indicates that the
+ // filesystem should not honor set-user-ID and set-group-ID bits or
+ // file capabilities when executing programs.
+ NoSUID bool
}
// MountOptions contains options to VirtualFilesystem.MountAt().
@@ -153,6 +164,12 @@ type SetStatOptions struct {
// == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask
// instead).
Stat linux.Statx
+
+ // NeedWritePerm indicates that write permission on the file is needed for
+ // this operation. This is needed for truncate(2) (note that ftruncate(2)
+ // does not require the same check--instead, it checks that the fd is
+ // writable).
+ NeedWritePerm bool
}
// BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt()
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index 9cb050597..33389c1df 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -183,7 +183,8 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
// CheckSetStat checks that creds has permission to change the metadata of a
// file with the given permissions, UID, and GID as specified by stat, subject
// to the rules of Linux's fs/attr.c:setattr_prepare().
-func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOptions, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ stat := &opts.Stat
if stat.Mask&linux.STATX_SIZE != 0 {
limit, err := CheckLimit(ctx, 0, int64(stat.Size))
if err != nil {
@@ -215,6 +216,11 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat
return syserror.EPERM
}
}
+ if opts.NeedWritePerm && !creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
+ return err
+ }
+ }
if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 {
if !CanActAsOwner(creds, kuid) {
if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) ||
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 58c7ad778..522e27475 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -123,6 +123,9 @@ type VirtualFilesystem struct {
// Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes.
func (vfs *VirtualFilesystem) Init() error {
+ if vfs.mountpoints != nil {
+ panic("VFS already initialized")
+ }
vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{})
vfs.devices = make(map[devTuple]*registeredDevice)
vfs.anonBlockDevMinorNext = 1
diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD
new file mode 100644
index 000000000..f08599ebd
--- /dev/null
+++ b/pkg/shim/runsc/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "runsc",
+ srcs = [
+ "runsc.go",
+ "utils.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "@com_github_containerd_go_runc//:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
+)
diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go
new file mode 100644
index 000000000..c5cf68efa
--- /dev/null
+++ b/pkg/shim/runsc/runsc.go
@@ -0,0 +1,514 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runsc
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "syscall"
+ "time"
+
+ runc "github.com/containerd/go-runc"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+var Monitor runc.ProcessMonitor = runc.Monitor
+
+// DefaultCommand is the default command for Runsc.
+const DefaultCommand = "runsc"
+
+// Runsc is the client to the runsc cli.
+type Runsc struct {
+ Command string
+ PdeathSignal syscall.Signal
+ Setpgid bool
+ Root string
+ Log string
+ LogFormat runc.Format
+ Config map[string]string
+}
+
+// List returns all containers created inside the provided runsc root directory.
+func (r *Runsc) List(context context.Context) ([]*runc.Container, error) {
+ data, err := cmdOutput(r.command(context, "list", "--format=json"), false)
+ if err != nil {
+ return nil, err
+ }
+ var out []*runc.Container
+ if err := json.Unmarshal(data, &out); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// State returns the state for the container provided by id.
+func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) {
+ data, err := cmdOutput(r.command(context, "state", id), true)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s", err, data)
+ }
+ var c runc.Container
+ if err := json.Unmarshal(data, &c); err != nil {
+ return nil, err
+ }
+ return &c, nil
+}
+
+type CreateOpts struct {
+ runc.IO
+ ConsoleSocket runc.ConsoleSocket
+
+ // PidFile is a path to where a pid file should be created.
+ PidFile string
+
+ // UserLog is a path to where runsc user log should be generated.
+ UserLog string
+}
+
+func (o *CreateOpts) args() (out []string, err error) {
+ if o.PidFile != "" {
+ abs, err := filepath.Abs(o.PidFile)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, "--pid-file", abs)
+ }
+ if o.ConsoleSocket != nil {
+ out = append(out, "--console-socket", o.ConsoleSocket.Path())
+ }
+ if o.UserLog != "" {
+ out = append(out, "--user-log", o.UserLog)
+ }
+ return out, nil
+}
+
+// Create creates a new container and returns its pid if it was created successfully.
+func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateOpts) error {
+ args := []string{"create", "--bundle", bundle}
+ if opts != nil {
+ oargs, err := opts.args()
+ if err != nil {
+ return err
+ }
+ args = append(args, oargs...)
+ }
+ cmd := r.command(context, append(args, id)...)
+ if opts != nil && opts.IO != nil {
+ opts.Set(cmd)
+ }
+
+ if cmd.Stdout == nil && cmd.Stderr == nil {
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ if opts != nil && opts.IO != nil {
+ if c, ok := opts.IO.(runc.StartCloser); ok {
+ if err := c.CloseAfterStart(); err != nil {
+ return err
+ }
+ }
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+
+ return err
+}
+
+// Start will start an already created container.
+func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error {
+ cmd := r.command(context, "start", id)
+ if cio != nil {
+ cio.Set(cmd)
+ }
+
+ if cmd.Stdout == nil && cmd.Stderr == nil {
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+ }
+
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ if cio != nil {
+ if c, ok := cio.(runc.StartCloser); ok {
+ if err := c.CloseAfterStart(); err != nil {
+ return err
+ }
+ }
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+
+ return err
+}
+
+type waitResult struct {
+ ID string `json:"id"`
+ ExitStatus int `json:"exitStatus"`
+}
+
+// Wait will wait for a running container, and return its exit status.
+//
+// TODO(random-liu): Add exec process support.
+func (r *Runsc) Wait(context context.Context, id string) (int, error) {
+ data, err := cmdOutput(r.command(context, "wait", id), true)
+ if err != nil {
+ return 0, fmt.Errorf("%s: %s", err, data)
+ }
+ var res waitResult
+ if err := json.Unmarshal(data, &res); err != nil {
+ return 0, err
+ }
+ return res.ExitStatus, nil
+}
+
+type ExecOpts struct {
+ runc.IO
+ PidFile string
+ InternalPidFile string
+ ConsoleSocket runc.ConsoleSocket
+ Detach bool
+}
+
+func (o *ExecOpts) args() (out []string, err error) {
+ if o.ConsoleSocket != nil {
+ out = append(out, "--console-socket", o.ConsoleSocket.Path())
+ }
+ if o.Detach {
+ out = append(out, "--detach")
+ }
+ if o.PidFile != "" {
+ abs, err := filepath.Abs(o.PidFile)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, "--pid-file", abs)
+ }
+ if o.InternalPidFile != "" {
+ abs, err := filepath.Abs(o.InternalPidFile)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, "--internal-pid-file", abs)
+ }
+ return out, nil
+}
+
+// Exec executes an additional process inside the container based on a full OCI
+// Process specification.
+func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opts *ExecOpts) error {
+ f, err := ioutil.TempFile(os.Getenv("XDG_RUNTIME_DIR"), "runsc-process")
+ if err != nil {
+ return err
+ }
+ defer os.Remove(f.Name())
+ err = json.NewEncoder(f).Encode(spec)
+ f.Close()
+ if err != nil {
+ return err
+ }
+ args := []string{"exec", "--process", f.Name()}
+ if opts != nil {
+ oargs, err := opts.args()
+ if err != nil {
+ return err
+ }
+ args = append(args, oargs...)
+ }
+ cmd := r.command(context, append(args, id)...)
+ if opts != nil && opts.IO != nil {
+ opts.Set(cmd)
+ }
+ if cmd.Stdout == nil && cmd.Stderr == nil {
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ if opts != nil && opts.IO != nil {
+ if c, ok := opts.IO.(runc.StartCloser); ok {
+ if err := c.CloseAfterStart(); err != nil {
+ return err
+ }
+ }
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+ return err
+}
+
+// Run runs the create, start, delete lifecycle of the container and returns
+// its exit status after it has exited.
+func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts) (int, error) {
+ args := []string{"run", "--bundle", bundle}
+ if opts != nil {
+ oargs, err := opts.args()
+ if err != nil {
+ return -1, err
+ }
+ args = append(args, oargs...)
+ }
+ cmd := r.command(context, append(args, id)...)
+ if opts != nil && opts.IO != nil {
+ opts.Set(cmd)
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return -1, err
+ }
+ return Monitor.Wait(cmd, ec)
+}
+
+type DeleteOpts struct {
+ Force bool
+}
+
+func (o *DeleteOpts) args() (out []string) {
+ if o.Force {
+ out = append(out, "--force")
+ }
+ return out
+}
+
+// Delete deletes the container.
+func (r *Runsc) Delete(context context.Context, id string, opts *DeleteOpts) error {
+ args := []string{"delete"}
+ if opts != nil {
+ args = append(args, opts.args()...)
+ }
+ return r.runOrError(r.command(context, append(args, id)...))
+}
+
+// KillOpts specifies options for killing a container and its processes.
+type KillOpts struct {
+ All bool
+ Pid int
+}
+
+func (o *KillOpts) args() (out []string) {
+ if o.All {
+ out = append(out, "--all")
+ }
+ if o.Pid != 0 {
+ out = append(out, "--pid", strconv.Itoa(o.Pid))
+ }
+ return out
+}
+
+// Kill sends the specified signal to the container.
+func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts) error {
+ args := []string{
+ "kill",
+ }
+ if opts != nil {
+ args = append(args, opts.args()...)
+ }
+ return r.runOrError(r.command(context, append(args, id, strconv.Itoa(sig))...))
+}
+
+// Stats return the stats for a container like cpu, memory, and I/O.
+func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) {
+ cmd := r.command(context, "events", "--stats", id)
+ rd, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, err
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ rd.Close()
+ Monitor.Wait(cmd, ec)
+ }()
+ var e runc.Event
+ if err := json.NewDecoder(rd).Decode(&e); err != nil {
+ return nil, err
+ }
+ return e.Stats, nil
+}
+
+// Events returns an event stream from runsc for a container with stats and OOM notifications.
+func (r *Runsc) Events(context context.Context, id string, interval time.Duration) (chan *runc.Event, error) {
+ cmd := r.command(context, "events", fmt.Sprintf("--interval=%ds", int(interval.Seconds())), id)
+ rd, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, err
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ rd.Close()
+ return nil, err
+ }
+ var (
+ dec = json.NewDecoder(rd)
+ c = make(chan *runc.Event, 128)
+ )
+ go func() {
+ defer func() {
+ close(c)
+ rd.Close()
+ Monitor.Wait(cmd, ec)
+ }()
+ for {
+ var e runc.Event
+ if err := dec.Decode(&e); err != nil {
+ if err == io.EOF {
+ return
+ }
+ e = runc.Event{
+ Type: "error",
+ Err: err,
+ }
+ }
+ c <- &e
+ }
+ }()
+ return c, nil
+}
+
+// Ps lists all the processes inside the container returning their pids.
+func (r *Runsc) Ps(context context.Context, id string) ([]int, error) {
+ data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s", err, data)
+ }
+ var pids []int
+ if err := json.Unmarshal(data, &pids); err != nil {
+ return nil, err
+ }
+ return pids, nil
+}
+
+// Top lists all the processes inside the container returning the full ps data.
+func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) {
+ data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s", err, data)
+ }
+
+ topResults, err := runc.ParsePSOutput(data)
+ if err != nil {
+ return nil, fmt.Errorf("%s: ", err)
+ }
+ return topResults, nil
+}
+
+func (r *Runsc) args() []string {
+ var args []string
+ if r.Root != "" {
+ args = append(args, fmt.Sprintf("--root=%s", r.Root))
+ }
+ if r.Log != "" {
+ args = append(args, fmt.Sprintf("--log=%s", r.Log))
+ }
+ if r.LogFormat != "" {
+ args = append(args, fmt.Sprintf("--log-format=%s", r.LogFormat))
+ }
+ for k, v := range r.Config {
+ args = append(args, fmt.Sprintf("--%s=%s", k, v))
+ }
+ return args
+}
+
+// runOrError will run the provided command.
+//
+// If an error is encountered and neither Stdout or Stderr was set the error
+// will be returned in the format of <error>: <stderr>.
+func (r *Runsc) runOrError(cmd *exec.Cmd) error {
+ if cmd.Stdout != nil || cmd.Stderr != nil {
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+ return err
+ }
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+}
+
+func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd {
+ command := r.Command
+ if command == "" {
+ command = DefaultCommand
+ }
+ cmd := exec.CommandContext(context, command, append(r.args(), args...)...)
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setpgid: r.Setpgid,
+ }
+ if r.PdeathSignal != 0 {
+ cmd.SysProcAttr.Pdeathsig = r.PdeathSignal
+ }
+
+ return cmd
+}
+
+func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) {
+ b := getBuf()
+ defer putBuf(b)
+
+ cmd.Stdout = b
+ if combined {
+ cmd.Stderr = b
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return nil, err
+ }
+
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+
+ return b.Bytes(), err
+}
diff --git a/pkg/shim/runsc/utils.go b/pkg/shim/runsc/utils.go
new file mode 100644
index 000000000..c514b3bc7
--- /dev/null
+++ b/pkg/shim/runsc/utils.go
@@ -0,0 +1,44 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runsc
+
+import (
+ "bytes"
+ "strings"
+ "sync"
+)
+
+var bytesBufferPool = sync.Pool{
+ New: func() interface{} {
+ return bytes.NewBuffer(nil)
+ },
+}
+
+func getBuf() *bytes.Buffer {
+ return bytesBufferPool.Get().(*bytes.Buffer)
+}
+
+func putBuf(b *bytes.Buffer) {
+ b.Reset()
+ bytesBufferPool.Put(b)
+}
+
+// FormatLogPath parses runsc config, and fill in %ID% in the log path.
+func FormatLogPath(id string, config map[string]string) {
+ if path, ok := config["debug-log"]; ok {
+ config["debug-log"] = strings.Replace(path, "%ID%", id, -1)
+ }
+}
diff --git a/pkg/shim/v1/proc/BUILD b/pkg/shim/v1/proc/BUILD
new file mode 100644
index 000000000..4377306af
--- /dev/null
+++ b/pkg/shim/v1/proc/BUILD
@@ -0,0 +1,36 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "proc",
+ srcs = [
+ "deleted_state.go",
+ "exec.go",
+ "exec_state.go",
+ "init.go",
+ "init_state.go",
+ "io.go",
+ "process.go",
+ "types.go",
+ "utils.go",
+ ],
+ visibility = [
+ "//pkg/shim:__subpackages__",
+ "//shim:__subpackages__",
+ ],
+ deps = [
+ "//pkg/shim/runsc",
+ "@com_github_containerd_console//:go_default_library",
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_containerd_containerd//log:go_default_library",
+ "@com_github_containerd_containerd//mount:go_default_library",
+ "@com_github_containerd_containerd//pkg/process:go_default_library",
+ "@com_github_containerd_containerd//pkg/stdio:go_default_library",
+ "@com_github_containerd_fifo//:go_default_library",
+ "@com_github_containerd_go_runc//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v1/proc/deleted_state.go b/pkg/shim/v1/proc/deleted_state.go
new file mode 100644
index 000000000..d9b970c4d
--- /dev/null
+++ b/pkg/shim/v1/proc/deleted_state.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/pkg/process"
+)
+
+type deletedState struct{}
+
+func (*deletedState) Resize(ws console.WinSize) error {
+ return fmt.Errorf("cannot resize a deleted process.ss")
+}
+
+func (*deletedState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a deleted process.ss")
+}
+
+func (*deletedState) Delete(ctx context.Context) error {
+ return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound)
+}
+
+func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound)
+}
+
+func (*deletedState) SetExited(status int) {}
+
+func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return nil, fmt.Errorf("cannot exec in a deleted state")
+}
diff --git a/pkg/shim/v1/proc/exec.go b/pkg/shim/v1/proc/exec.go
new file mode 100644
index 000000000..1d1d90488
--- /dev/null
+++ b/pkg/shim/v1/proc/exec.go
@@ -0,0 +1,281 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/fifo"
+ runc "github.com/containerd/go-runc"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+)
+
+type execProcess struct {
+ wg sync.WaitGroup
+
+ execState execState
+
+ mu sync.Mutex
+ id string
+ console console.Console
+ io runc.IO
+ status int
+ exited time.Time
+ pid int
+ internalPid int
+ closers []io.Closer
+ stdin io.Closer
+ stdio stdio.Stdio
+ path string
+ spec specs.Process
+
+ parent *Init
+ waitBlock chan struct{}
+}
+
+func (e *execProcess) Wait() {
+ <-e.waitBlock
+}
+
+func (e *execProcess) ID() string {
+ return e.id
+}
+
+func (e *execProcess) Pid() int {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.pid
+}
+
+func (e *execProcess) ExitStatus() int {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.status
+}
+
+func (e *execProcess) ExitedAt() time.Time {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.exited
+}
+
+func (e *execProcess) SetExited(status int) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.execState.SetExited(status)
+}
+
+func (e *execProcess) setExited(status int) {
+ e.status = status
+ e.exited = time.Now()
+ e.parent.Platform.ShutdownConsole(context.Background(), e.console)
+ close(e.waitBlock)
+}
+
+func (e *execProcess) Delete(ctx context.Context) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Delete(ctx)
+}
+
+func (e *execProcess) delete(ctx context.Context) error {
+ e.wg.Wait()
+ if e.io != nil {
+ for _, c := range e.closers {
+ c.Close()
+ }
+ e.io.Close()
+ }
+ pidfile := filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id))
+ // silently ignore error
+ os.Remove(pidfile)
+ internalPidfile := filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id))
+ // silently ignore error
+ os.Remove(internalPidfile)
+ return nil
+}
+
+func (e *execProcess) Resize(ws console.WinSize) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Resize(ws)
+}
+
+func (e *execProcess) resize(ws console.WinSize) error {
+ if e.console == nil {
+ return nil
+ }
+ return e.console.Resize(ws)
+}
+
+func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Kill(ctx, sig, false)
+}
+
+func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error {
+ internalPid := e.internalPid
+ if internalPid != 0 {
+ if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{
+ Pid: internalPid,
+ }); err != nil {
+ // If this returns error, consider the process has
+ // already stopped.
+ //
+ // TODO: Fix after signal handling is fixed.
+ return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound)
+ }
+ }
+ return nil
+}
+
+func (e *execProcess) Stdin() io.Closer {
+ return e.stdin
+}
+
+func (e *execProcess) Stdio() stdio.Stdio {
+ return e.stdio
+}
+
+func (e *execProcess) Start(ctx context.Context) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Start(ctx)
+}
+
+func (e *execProcess) start(ctx context.Context) (err error) {
+ var (
+ socket *runc.Socket
+ pidfile = filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id))
+ internalPidfile = filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id))
+ )
+ if e.stdio.Terminal {
+ if socket, err = runc.NewTempConsoleSocket(); err != nil {
+ return fmt.Errorf("failed to create runc console socket: %w", err)
+ }
+ defer socket.Close()
+ } else if e.stdio.IsNull() {
+ if e.io, err = runc.NewNullIO(); err != nil {
+ return fmt.Errorf("creating new NULL IO: %w", err)
+ }
+ } else {
+ if e.io, err = runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)); err != nil {
+ return fmt.Errorf("failed to create runc io pipes: %w", err)
+ }
+ }
+ opts := &runsc.ExecOpts{
+ PidFile: pidfile,
+ InternalPidFile: internalPidfile,
+ IO: e.io,
+ Detach: true,
+ }
+ if socket != nil {
+ opts.ConsoleSocket = socket
+ }
+ eventCh := e.parent.Monitor.Subscribe()
+ defer func() {
+ // Unsubscribe if an error is returned.
+ if err != nil {
+ e.parent.Monitor.Unsubscribe(eventCh)
+ }
+ }()
+ if err := e.parent.runtime.Exec(ctx, e.parent.id, e.spec, opts); err != nil {
+ close(e.waitBlock)
+ return e.parent.runtimeError(err, "OCI runtime exec failed")
+ }
+ if e.stdio.Stdin != "" {
+ sc, err := fifo.OpenFifo(context.Background(), e.stdio.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return fmt.Errorf("failed to open stdin fifo %s: %w", e.stdio.Stdin, err)
+ }
+ e.closers = append(e.closers, sc)
+ e.stdin = sc
+ }
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+ if socket != nil {
+ console, err := socket.ReceiveMaster()
+ if err != nil {
+ return fmt.Errorf("failed to retrieve console master: %w", err)
+ }
+ if e.console, err = e.parent.Platform.CopyConsole(ctx, console, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil {
+ return fmt.Errorf("failed to start console copy: %w", err)
+ }
+ } else if !e.stdio.IsNull() {
+ if err := copyPipes(ctx, e.io, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil {
+ return fmt.Errorf("failed to start io pipe copy: %w", err)
+ }
+ }
+ pid, err := runc.ReadPidFile(opts.PidFile)
+ if err != nil {
+ return fmt.Errorf("failed to retrieve OCI runtime exec pid: %w", err)
+ }
+ e.pid = pid
+ internalPid, err := runc.ReadPidFile(opts.InternalPidFile)
+ if err != nil {
+ return fmt.Errorf("failed to retrieve OCI runtime exec internal pid: %w", err)
+ }
+ e.internalPid = internalPid
+ go func() {
+ defer e.parent.Monitor.Unsubscribe(eventCh)
+ for event := range eventCh {
+ if event.Pid == e.pid {
+ ExitCh <- Exit{
+ Timestamp: event.Timestamp,
+ ID: e.id,
+ Status: event.Status,
+ }
+ break
+ }
+ }
+ }()
+ return nil
+}
+
+func (e *execProcess) Status(ctx context.Context) (string, error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ // if we don't have a pid then the exec process has just been created
+ if e.pid == 0 {
+ return "created", nil
+ }
+ // if we have a pid and it can be signaled, the process is running
+ // TODO(random-liu): Use `runsc kill --pid`.
+ if err := unix.Kill(e.pid, 0); err == nil {
+ return "running", nil
+ }
+ // else if we have a pid but it can nolonger be signaled, it has stopped
+ return "stopped", nil
+}
diff --git a/pkg/shim/v1/proc/exec_state.go b/pkg/shim/v1/proc/exec_state.go
new file mode 100644
index 000000000..4dcda8b44
--- /dev/null
+++ b/pkg/shim/v1/proc/exec_state.go
@@ -0,0 +1,154 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/containerd/console"
+)
+
+type execState interface {
+ Resize(console.WinSize) error
+ Start(context.Context) error
+ Delete(context.Context) error
+ Kill(context.Context, uint32, bool) error
+ SetExited(int)
+}
+
+type execCreatedState struct {
+ p *execProcess
+}
+
+func (s *execCreatedState) transition(name string) error {
+ switch name {
+ case "running":
+ s.p.execState = &execRunningState{p: s.p}
+ case "stopped":
+ s.p.execState = &execStoppedState{p: s.p}
+ case "deleted":
+ s.p.execState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *execCreatedState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *execCreatedState) Start(ctx context.Context) error {
+ if err := s.p.start(ctx); err != nil {
+ return err
+ }
+ return s.transition("running")
+}
+
+func (s *execCreatedState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *execCreatedState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+type execRunningState struct {
+ p *execProcess
+}
+
+func (s *execRunningState) transition(name string) error {
+ switch name {
+ case "stopped":
+ s.p.execState = &execStoppedState{p: s.p}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *execRunningState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *execRunningState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a running process")
+}
+
+func (s *execRunningState) Delete(ctx context.Context) error {
+ return fmt.Errorf("cannot delete a running process")
+}
+
+func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *execRunningState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+type execStoppedState struct {
+ p *execProcess
+}
+
+func (s *execStoppedState) transition(name string) error {
+ switch name {
+ case "deleted":
+ s.p.execState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *execStoppedState) Resize(ws console.WinSize) error {
+ return fmt.Errorf("cannot resize a stopped container")
+}
+
+func (s *execStoppedState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a stopped process")
+}
+
+func (s *execStoppedState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *execStoppedState) SetExited(status int) {
+ // no op
+}
diff --git a/pkg/shim/v1/proc/init.go b/pkg/shim/v1/proc/init.go
new file mode 100644
index 000000000..dab3123d6
--- /dev/null
+++ b/pkg/shim/v1/proc/init.go
@@ -0,0 +1,460 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/containerd/mount"
+ "github.com/containerd/containerd/pkg/process"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/fifo"
+ runc "github.com/containerd/go-runc"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+)
+
+// InitPidFile name of the file that contains the init pid.
+const InitPidFile = "init.pid"
+
+// Init represents an initial process for a container.
+type Init struct {
+ wg sync.WaitGroup
+ initState initState
+
+ // mu is used to ensure that `Start()` and `Exited()` calls return in
+ // the right order when invoked in separate go routines. This is the
+ // case within the shim implementation as it makes use of the reaper
+ // interface.
+ mu sync.Mutex
+
+ waitBlock chan struct{}
+
+ WorkDir string
+
+ id string
+ Bundle string
+ console console.Console
+ Platform stdio.Platform
+ io runc.IO
+ runtime *runsc.Runsc
+ status int
+ exited time.Time
+ pid int
+ closers []io.Closer
+ stdin io.Closer
+ stdio stdio.Stdio
+ Rootfs string
+ IoUID int
+ IoGID int
+ Sandbox bool
+ UserLog string
+ Monitor ProcessMonitor
+}
+
+// NewRunsc returns a new runsc instance for a process.
+func NewRunsc(root, path, namespace, runtime string, config map[string]string) *runsc.Runsc {
+ if root == "" {
+ root = RunscRoot
+ }
+ return &runsc.Runsc{
+ Command: runtime,
+ PdeathSignal: syscall.SIGKILL,
+ Log: filepath.Join(path, "log.json"),
+ LogFormat: runc.JSON,
+ Root: filepath.Join(root, namespace),
+ Config: config,
+ }
+}
+
+// New returns a new init process.
+func New(id string, runtime *runsc.Runsc, stdio stdio.Stdio) *Init {
+ p := &Init{
+ id: id,
+ runtime: runtime,
+ stdio: stdio,
+ status: 0,
+ waitBlock: make(chan struct{}),
+ }
+ p.initState = &createdState{p: p}
+ return p
+}
+
+// Create the process with the provided config.
+func (p *Init) Create(ctx context.Context, r *CreateConfig) (err error) {
+ var socket *runc.Socket
+ if r.Terminal {
+ if socket, err = runc.NewTempConsoleSocket(); err != nil {
+ return fmt.Errorf("failed to create OCI runtime console socket: %w", err)
+ }
+ defer socket.Close()
+ } else if hasNoIO(r) {
+ if p.io, err = runc.NewNullIO(); err != nil {
+ return fmt.Errorf("creating new NULL IO: %w", err)
+ }
+ } else {
+ if p.io, err = runc.NewPipeIO(p.IoUID, p.IoGID, withConditionalIO(p.stdio)); err != nil {
+ return fmt.Errorf("failed to create OCI runtime io pipes: %w", err)
+ }
+ }
+ pidFile := filepath.Join(p.Bundle, InitPidFile)
+ opts := &runsc.CreateOpts{
+ PidFile: pidFile,
+ }
+ if socket != nil {
+ opts.ConsoleSocket = socket
+ }
+ if p.Sandbox {
+ opts.IO = p.io
+ // UserLog is only useful for sandbox.
+ opts.UserLog = p.UserLog
+ }
+ if err := p.runtime.Create(ctx, r.ID, r.Bundle, opts); err != nil {
+ return p.runtimeError(err, "OCI runtime create failed")
+ }
+ if r.Stdin != "" {
+ sc, err := fifo.OpenFifo(context.Background(), r.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return fmt.Errorf("failed to open stdin fifo %s: %w", r.Stdin, err)
+ }
+ p.stdin = sc
+ p.closers = append(p.closers, sc)
+ }
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+ if socket != nil {
+ console, err := socket.ReceiveMaster()
+ if err != nil {
+ return fmt.Errorf("failed to retrieve console master: %w", err)
+ }
+ console, err = p.Platform.CopyConsole(ctx, console, r.Stdin, r.Stdout, r.Stderr, &p.wg)
+ if err != nil {
+ return fmt.Errorf("failed to start console copy: %w", err)
+ }
+ p.console = console
+ } else if !hasNoIO(r) {
+ if err := copyPipes(ctx, p.io, r.Stdin, r.Stdout, r.Stderr, &p.wg); err != nil {
+ return fmt.Errorf("failed to start io pipe copy: %w", err)
+ }
+ }
+ pid, err := runc.ReadPidFile(pidFile)
+ if err != nil {
+ return fmt.Errorf("failed to retrieve OCI runtime container pid: %w", err)
+ }
+ p.pid = pid
+ return nil
+}
+
+// Wait waits for the process to exit.
+func (p *Init) Wait() {
+ <-p.waitBlock
+}
+
+// ID returns the ID of the process.
+func (p *Init) ID() string {
+ return p.id
+}
+
+// Pid returns the PID of the process.
+func (p *Init) Pid() int {
+ return p.pid
+}
+
+// ExitStatus returns the exit status of the process.
+func (p *Init) ExitStatus() int {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.status
+}
+
+// ExitedAt returns the time when the process exited.
+func (p *Init) ExitedAt() time.Time {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.exited
+}
+
+// Status returns the status of the process.
+func (p *Init) Status(ctx context.Context) (string, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ c, err := p.runtime.State(ctx, p.id)
+ if err != nil {
+ if strings.Contains(err.Error(), "does not exist") {
+ return "stopped", nil
+ }
+ return "", p.runtimeError(err, "OCI runtime state failed")
+ }
+ return p.convertStatus(c.Status), nil
+}
+
+// Start starts the init process.
+func (p *Init) Start(ctx context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Start(ctx)
+}
+
+func (p *Init) start(ctx context.Context) error {
+ var cio runc.IO
+ if !p.Sandbox {
+ cio = p.io
+ }
+ if err := p.runtime.Start(ctx, p.id, cio); err != nil {
+ return p.runtimeError(err, "OCI runtime start failed")
+ }
+ go func() {
+ status, err := p.runtime.Wait(context.Background(), p.id)
+ if err != nil {
+ log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id)
+ // TODO(random-liu): Handle runsc kill error.
+ if err := p.killAll(ctx); err != nil {
+ log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id)
+ }
+ status = internalErrorCode
+ }
+ ExitCh <- Exit{
+ Timestamp: time.Now(),
+ ID: p.id,
+ Status: status,
+ }
+ }()
+ return nil
+}
+
+// SetExited set the exit stauts of the init process.
+func (p *Init) SetExited(status int) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.initState.SetExited(status)
+}
+
+func (p *Init) setExited(status int) {
+ p.exited = time.Now()
+ p.status = status
+ p.Platform.ShutdownConsole(context.Background(), p.console)
+ close(p.waitBlock)
+}
+
+// Delete deletes the init process.
+func (p *Init) Delete(ctx context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Delete(ctx)
+}
+
+func (p *Init) delete(ctx context.Context) error {
+ p.killAll(ctx)
+ p.wg.Wait()
+ err := p.runtime.Delete(ctx, p.id, nil)
+ // ignore errors if a runtime has already deleted the process
+ // but we still hold metadata and pipes
+ //
+ // this is common during a checkpoint, runc will delete the container state
+ // after a checkpoint and the container will no longer exist within runc
+ if err != nil {
+ if strings.Contains(err.Error(), "does not exist") {
+ err = nil
+ } else {
+ err = p.runtimeError(err, "failed to delete task")
+ }
+ }
+ if p.io != nil {
+ for _, c := range p.closers {
+ c.Close()
+ }
+ p.io.Close()
+ }
+ if err2 := mount.UnmountAll(p.Rootfs, 0); err2 != nil {
+ log.G(ctx).WithError(err2).Warn("failed to cleanup rootfs mount")
+ if err == nil {
+ err = fmt.Errorf("failed rootfs umount: %w", err2)
+ }
+ }
+ return err
+}
+
+// Resize resizes the init processes console.
+func (p *Init) Resize(ws console.WinSize) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.console == nil {
+ return nil
+ }
+ return p.console.Resize(ws)
+}
+
+func (p *Init) resize(ws console.WinSize) error {
+ if p.console == nil {
+ return nil
+ }
+ return p.console.Resize(ws)
+}
+
+// Kill kills the init process.
+func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Kill(ctx, signal, all)
+}
+
+func (p *Init) kill(context context.Context, signal uint32, all bool) error {
+ var (
+ killErr error
+ backoff = 100 * time.Millisecond
+ )
+ timeout := 1 * time.Second
+ for start := time.Now(); time.Now().Sub(start) < timeout; {
+ c, err := p.runtime.State(context, p.id)
+ if err != nil {
+ if strings.Contains(err.Error(), "does not exist") {
+ return fmt.Errorf("no such process: %w", errdefs.ErrNotFound)
+ }
+ return p.runtimeError(err, "OCI runtime state failed")
+ }
+ // For runsc, signal only works when container is running state.
+ // If the container is not in running state, directly return
+ // "no such process"
+ if p.convertStatus(c.Status) == "stopped" {
+ return fmt.Errorf("no such process: %w", errdefs.ErrNotFound)
+ }
+ killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{
+ All: all,
+ })
+ if killErr == nil {
+ return nil
+ }
+ time.Sleep(backoff)
+ backoff *= 2
+ }
+ return p.runtimeError(killErr, "kill timeout")
+}
+
+// KillAll kills all processes belonging to the init process.
+func (p *Init) KillAll(context context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.killAll(context)
+}
+
+func (p *Init) killAll(context context.Context) error {
+ p.runtime.Kill(context, p.id, int(syscall.SIGKILL), &runsc.KillOpts{
+ All: true,
+ })
+ // Ignore error handling for `runsc kill --all` for now.
+ // * If it doesn't return error, it is good;
+ // * If it returns error, consider the container has already stopped.
+ // TODO: Fix `runsc kill --all` error handling.
+ return nil
+}
+
+// Stdin returns the stdin of the process.
+func (p *Init) Stdin() io.Closer {
+ return p.stdin
+}
+
+// Runtime returns the OCI runtime configured for the init process.
+func (p *Init) Runtime() *runsc.Runsc {
+ return p.runtime
+}
+
+// Exec returns a new child process.
+func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Exec(ctx, path, r)
+}
+
+// exec returns a new exec'd process.
+func (p *Init) exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ // process exec request
+ var spec specs.Process
+ if err := json.Unmarshal(r.Spec.Value, &spec); err != nil {
+ return nil, err
+ }
+ spec.Terminal = r.Terminal
+
+ e := &execProcess{
+ id: r.ID,
+ path: path,
+ parent: p,
+ spec: spec,
+ stdio: stdio.Stdio{
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Terminal: r.Terminal,
+ },
+ waitBlock: make(chan struct{}),
+ }
+ e.execState = &execCreatedState{p: e}
+ return e, nil
+}
+
+// Stdio returns the stdio of the process.
+func (p *Init) Stdio() stdio.Stdio {
+ return p.stdio
+}
+
+func (p *Init) runtimeError(rErr error, msg string) error {
+ if rErr == nil {
+ return nil
+ }
+
+ rMsg, err := getLastRuntimeError(p.runtime)
+ switch {
+ case err != nil:
+ return fmt.Errorf("%s: %w (unable to retrieve OCI runtime error: %v)", msg, rErr, err)
+ case rMsg == "":
+ return fmt.Errorf("%s: %w", msg, rErr)
+ default:
+ return fmt.Errorf("%s: %s", msg, rMsg)
+ }
+}
+
+func (p *Init) convertStatus(status string) string {
+ if status == "created" && !p.Sandbox && p.status == internalErrorCode {
+ // Treat start failure state for non-root container as stopped.
+ return "stopped"
+ }
+ return status
+}
+
+func withConditionalIO(c stdio.Stdio) runc.IOOpt {
+ return func(o *runc.IOOption) {
+ o.OpenStdin = c.Stdin != ""
+ o.OpenStdout = c.Stdout != ""
+ o.OpenStderr = c.Stderr != ""
+ }
+}
diff --git a/pkg/shim/v1/proc/init_state.go b/pkg/shim/v1/proc/init_state.go
new file mode 100644
index 000000000..9233ecc85
--- /dev/null
+++ b/pkg/shim/v1/proc/init_state.go
@@ -0,0 +1,182 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/pkg/process"
+)
+
+type initState interface {
+ Resize(console.WinSize) error
+ Start(context.Context) error
+ Delete(context.Context) error
+ Exec(context.Context, string, *ExecConfig) (process.Process, error)
+ Kill(context.Context, uint32, bool) error
+ SetExited(int)
+}
+
+type createdState struct {
+ p *Init
+}
+
+func (s *createdState) transition(name string) error {
+ switch name {
+ case "running":
+ s.p.initState = &runningState{p: s.p}
+ case "stopped":
+ s.p.initState = &stoppedState{p: s.p}
+ case "deleted":
+ s.p.initState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *createdState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *createdState) Start(ctx context.Context) error {
+ if err := s.p.start(ctx); err != nil {
+ // Containerd doesn't allow deleting container in created state.
+ // However, for gvisor, a non-root container in created state can
+ // only go to running state. If the container can't be started,
+ // it can only stay in created state, and never be deleted.
+ // To work around that, we treat non-root container in start failure
+ // state as stopped.
+ if !s.p.Sandbox {
+ s.p.io.Close()
+ s.p.setExited(internalErrorCode)
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+ }
+ return err
+ }
+ return s.transition("running")
+}
+
+func (s *createdState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *createdState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return s.p.exec(ctx, path, r)
+}
+
+type runningState struct {
+ p *Init
+}
+
+func (s *runningState) transition(name string) error {
+ switch name {
+ case "stopped":
+ s.p.initState = &stoppedState{p: s.p}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *runningState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *runningState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a running process.ss")
+}
+
+func (s *runningState) Delete(ctx context.Context) error {
+ return fmt.Errorf("cannot delete a running process.ss")
+}
+
+func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *runningState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return s.p.exec(ctx, path, r)
+}
+
+type stoppedState struct {
+ p *Init
+}
+
+func (s *stoppedState) transition(name string) error {
+ switch name {
+ case "deleted":
+ s.p.initState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *stoppedState) Resize(ws console.WinSize) error {
+ return fmt.Errorf("cannot resize a stopped container")
+}
+
+func (s *stoppedState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a stopped process.ss")
+}
+
+func (s *stoppedState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id)
+}
+
+func (s *stoppedState) SetExited(status int) {
+ // no op
+}
+
+func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return nil, fmt.Errorf("cannot exec in a stopped state")
+}
diff --git a/pkg/shim/v1/proc/io.go b/pkg/shim/v1/proc/io.go
new file mode 100644
index 000000000..34d825fb7
--- /dev/null
+++ b/pkg/shim/v1/proc/io.go
@@ -0,0 +1,162 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/fifo"
+ runc "github.com/containerd/go-runc"
+)
+
+// TODO(random-liu): This file can be a util.
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ buffer := make([]byte, 32<<10)
+ return &buffer
+ },
+}
+
+func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, wg *sync.WaitGroup) error {
+ var sameFile *countingWriteCloser
+ for _, i := range []struct {
+ name string
+ dest func(wc io.WriteCloser, rc io.Closer)
+ }{
+ {
+ name: stdout,
+ dest: func(wc io.WriteCloser, rc io.Closer) {
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ if _, err := io.CopyBuffer(wc, rio.Stdout(), *p); err != nil {
+ log.G(ctx).Warn("error copying stdout")
+ }
+ wg.Done()
+ wc.Close()
+ if rc != nil {
+ rc.Close()
+ }
+ }()
+ },
+ }, {
+ name: stderr,
+ dest: func(wc io.WriteCloser, rc io.Closer) {
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ if _, err := io.CopyBuffer(wc, rio.Stderr(), *p); err != nil {
+ log.G(ctx).Warn("error copying stderr")
+ }
+ wg.Done()
+ wc.Close()
+ if rc != nil {
+ rc.Close()
+ }
+ }()
+ },
+ },
+ } {
+ ok, err := isFifo(i.name)
+ if err != nil {
+ return err
+ }
+ var (
+ fw io.WriteCloser
+ fr io.Closer
+ )
+ if ok {
+ if fw, err = fifo.OpenFifo(ctx, i.name, syscall.O_WRONLY, 0); err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err)
+ }
+ if fr, err = fifo.OpenFifo(ctx, i.name, syscall.O_RDONLY, 0); err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err)
+ }
+ } else {
+ if sameFile != nil {
+ sameFile.count++
+ i.dest(sameFile, nil)
+ continue
+ }
+ if fw, err = os.OpenFile(i.name, syscall.O_WRONLY|syscall.O_APPEND, 0); err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err)
+ }
+ if stdout == stderr {
+ sameFile = &countingWriteCloser{
+ WriteCloser: fw,
+ count: 1,
+ }
+ }
+ }
+ i.dest(fw, fr)
+ }
+ if stdin == "" {
+ return nil
+ }
+ f, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", stdin, err)
+ }
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+
+ io.CopyBuffer(rio.Stdin(), f, *p)
+ rio.Stdin().Close()
+ f.Close()
+ }()
+ return nil
+}
+
+// countingWriteCloser masks io.Closer() until close has been invoked a certain number of times.
+type countingWriteCloser struct {
+ io.WriteCloser
+ count int64
+}
+
+func (c *countingWriteCloser) Close() error {
+ if atomic.AddInt64(&c.count, -1) > 0 {
+ return nil
+ }
+ return c.WriteCloser.Close()
+}
+
+// isFifo checks if a file is a fifo.
+//
+// If the file does not exist then it returns false.
+func isFifo(path string) (bool, error) {
+ stat, err := os.Stat(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, err
+ }
+ if stat.Mode()&os.ModeNamedPipe == os.ModeNamedPipe {
+ return true, nil
+ }
+ return false, nil
+}
diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go
new file mode 100644
index 000000000..d462c3eef
--- /dev/null
+++ b/pkg/shim/v1/proc/process.go
@@ -0,0 +1,37 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+)
+
+// RunscRoot is the path to the root runsc state directory.
+const RunscRoot = "/run/containerd/runsc"
+
+func stateName(v interface{}) string {
+ switch v.(type) {
+ case *runningState, *execRunningState:
+ return "running"
+ case *createdState, *execCreatedState:
+ return "created"
+ case *deletedState:
+ return "deleted"
+ case *stoppedState:
+ return "stopped"
+ }
+ panic(fmt.Errorf("invalid state %v", v))
+}
diff --git a/pkg/shim/v1/proc/types.go b/pkg/shim/v1/proc/types.go
new file mode 100644
index 000000000..2b0df4663
--- /dev/null
+++ b/pkg/shim/v1/proc/types.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "time"
+
+ runc "github.com/containerd/go-runc"
+ "github.com/gogo/protobuf/types"
+)
+
+// Mount holds filesystem mount configuration.
+type Mount struct {
+ Type string
+ Source string
+ Target string
+ Options []string
+}
+
+// CreateConfig hold task creation configuration.
+type CreateConfig struct {
+ ID string
+ Bundle string
+ Runtime string
+ Rootfs []Mount
+ Terminal bool
+ Stdin string
+ Stdout string
+ Stderr string
+ Options *types.Any
+}
+
+// ExecConfig holds exec creation configuration.
+type ExecConfig struct {
+ ID string
+ Terminal bool
+ Stdin string
+ Stdout string
+ Stderr string
+ Spec *types.Any
+}
+
+// Exit is the type of exit events.
+type Exit struct {
+ Timestamp time.Time
+ ID string
+ Status int
+}
+
+// ProcessMonitor monitors process exit changes.
+type ProcessMonitor interface {
+ // Subscribe to process exit changes
+ Subscribe() chan runc.Exit
+ // Unsubscribe to process exit changes
+ Unsubscribe(c chan runc.Exit)
+}
diff --git a/pkg/shim/v1/proc/utils.go b/pkg/shim/v1/proc/utils.go
new file mode 100644
index 000000000..716de2f59
--- /dev/null
+++ b/pkg/shim/v1/proc/utils.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "encoding/json"
+ "io"
+ "os"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+)
+
+const (
+ internalErrorCode = 128
+ bufferSize = 32
+)
+
+// ExitCh is the exit events channel for containers and exec processes
+// inside the sandbox.
+var ExitCh = make(chan Exit, bufferSize)
+
+// TODO(mlaventure): move to runc package?
+func getLastRuntimeError(r *runsc.Runsc) (string, error) {
+ if r.Log == "" {
+ return "", nil
+ }
+
+ f, err := os.OpenFile(r.Log, os.O_RDONLY, 0400)
+ if err != nil {
+ return "", err
+ }
+
+ var (
+ errMsg string
+ log struct {
+ Level string
+ Msg string
+ Time time.Time
+ }
+ )
+
+ dec := json.NewDecoder(f)
+ for err = nil; err == nil; {
+ if err = dec.Decode(&log); err != nil && err != io.EOF {
+ return "", err
+ }
+ if log.Level == "error" {
+ errMsg = strings.TrimSpace(log.Msg)
+ }
+ }
+
+ return errMsg, nil
+}
+
+func copyFile(to, from string) error {
+ ff, err := os.Open(from)
+ if err != nil {
+ return err
+ }
+ defer ff.Close()
+ tt, err := os.Create(to)
+ if err != nil {
+ return err
+ }
+ defer tt.Close()
+
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ _, err = io.CopyBuffer(tt, ff, *p)
+ return err
+}
+
+func hasNoIO(r *CreateConfig) bool {
+ return r.Stdin == "" && r.Stdout == "" && r.Stderr == ""
+}
diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD
new file mode 100644
index 000000000..05c595bc9
--- /dev/null
+++ b/pkg/shim/v1/shim/BUILD
@@ -0,0 +1,40 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "shim",
+ srcs = [
+ "api.go",
+ "platform.go",
+ "service.go",
+ ],
+ visibility = [
+ "//pkg/shim:__subpackages__",
+ "//shim:__subpackages__",
+ ],
+ deps = [
+ "//pkg/shim/runsc",
+ "//pkg/shim/v1/proc",
+ "//pkg/shim/v1/utils",
+ "@com_github_containerd_console//:go_default_library",
+ "@com_github_containerd_containerd//api/events:go_default_library",
+ "@com_github_containerd_containerd//api/types/task:go_default_library",
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_containerd_containerd//events:go_default_library",
+ "@com_github_containerd_containerd//log:go_default_library",
+ "@com_github_containerd_containerd//mount:go_default_library",
+ "@com_github_containerd_containerd//namespaces:go_default_library",
+ "@com_github_containerd_containerd//pkg/process:go_default_library",
+ "@com_github_containerd_containerd//pkg/stdio:go_default_library",
+ "@com_github_containerd_containerd//runtime:go_default_library",
+ "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library",
+ "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library",
+ "@com_github_containerd_containerd//sys/reaper:go_default_library",
+ "@com_github_containerd_fifo//:go_default_library",
+ "@com_github_containerd_typeurl//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@org_golang_google_grpc//codes:go_default_library",
+ "@org_golang_google_grpc//status:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v1/shim/api.go b/pkg/shim/v1/shim/api.go
new file mode 100644
index 000000000..5dd8ff172
--- /dev/null
+++ b/pkg/shim/v1/shim/api.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "github.com/containerd/containerd/api/events"
+)
+
+type TaskCreate = events.TaskCreate
+type TaskStart = events.TaskStart
+type TaskOOM = events.TaskOOM
+type TaskExit = events.TaskExit
+type TaskDelete = events.TaskDelete
+type TaskExecAdded = events.TaskExecAdded
+type TaskExecStarted = events.TaskExecStarted
diff --git a/pkg/shim/v1/shim/platform.go b/pkg/shim/v1/shim/platform.go
new file mode 100644
index 000000000..f590f80ef
--- /dev/null
+++ b/pkg/shim/v1/shim/platform.go
@@ -0,0 +1,106 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "sync"
+ "syscall"
+
+ "github.com/containerd/console"
+ "github.com/containerd/fifo"
+)
+
+type linuxPlatform struct {
+ epoller *console.Epoller
+}
+
+func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) {
+ if p.epoller == nil {
+ return nil, fmt.Errorf("uninitialized epoller")
+ }
+
+ epollConsole, err := p.epoller.Add(console)
+ if err != nil {
+ return nil, err
+ }
+
+ if stdin != "" {
+ in, err := fifo.OpenFifo(ctx, stdin, syscall.O_RDONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(epollConsole, in, *p)
+ }()
+ }
+
+ outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(outw, epollConsole, *p)
+ epollConsole.Close()
+ outr.Close()
+ outw.Close()
+ wg.Done()
+ }()
+ return epollConsole, nil
+}
+
+func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error {
+ if p.epoller == nil {
+ return fmt.Errorf("uninitialized epoller")
+ }
+ epollConsole, ok := cons.(*console.EpollConsole)
+ if !ok {
+ return fmt.Errorf("expected EpollConsole, got %#v", cons)
+ }
+ return epollConsole.Shutdown(p.epoller.CloseConsole)
+}
+
+func (p *linuxPlatform) Close() error {
+ return p.epoller.Close()
+}
+
+// initialize a single epoll fd to manage our consoles. `initPlatform` should
+// only be called once.
+func (s *Service) initPlatform() error {
+ if s.platform != nil {
+ return nil
+ }
+ epoller, err := console.NewEpoller()
+ if err != nil {
+ return fmt.Errorf("failed to initialize epoller: %w", err)
+ }
+ s.platform = &linuxPlatform{
+ epoller: epoller,
+ }
+ go epoller.Wait()
+ return nil
+}
diff --git a/pkg/shim/v1/shim/service.go b/pkg/shim/v1/shim/service.go
new file mode 100644
index 000000000..84a810cb2
--- /dev/null
+++ b/pkg/shim/v1/shim/service.go
@@ -0,0 +1,573 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sync"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/api/types/task"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/events"
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/containerd/mount"
+ "github.com/containerd/containerd/namespaces"
+ "github.com/containerd/containerd/pkg/process"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/containerd/runtime"
+ "github.com/containerd/containerd/runtime/linux/runctypes"
+ shim "github.com/containerd/containerd/runtime/v1/shim/v1"
+ "github.com/containerd/containerd/sys/reaper"
+ "github.com/containerd/typeurl"
+ "github.com/gogo/protobuf/types"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+ "gvisor.dev/gvisor/pkg/shim/v1/proc"
+ "gvisor.dev/gvisor/pkg/shim/v1/utils"
+)
+
+var (
+ empty = &types.Empty{}
+ bufPool = sync.Pool{
+ New: func() interface{} {
+ buffer := make([]byte, 32<<10)
+ return &buffer
+ },
+ }
+)
+
+// Config contains shim specific configuration.
+type Config struct {
+ Path string
+ Namespace string
+ WorkDir string
+ RuntimeRoot string
+ RunscConfig map[string]string
+}
+
+// NewService returns a new shim service that can be used via GRPC.
+func NewService(config Config, publisher events.Publisher) (*Service, error) {
+ if config.Namespace == "" {
+ return nil, fmt.Errorf("shim namespace cannot be empty")
+ }
+ ctx := namespaces.WithNamespace(context.Background(), config.Namespace)
+ s := &Service{
+ config: config,
+ context: ctx,
+ processes: make(map[string]process.Process),
+ events: make(chan interface{}, 128),
+ ec: proc.ExitCh,
+ }
+ go s.processExits()
+ if err := s.initPlatform(); err != nil {
+ return nil, fmt.Errorf("failed to initialized platform behavior: %w", err)
+ }
+ go s.forward(publisher)
+ return s, nil
+}
+
+// Service is the shim implementation of a remote shim over GRPC.
+type Service struct {
+ mu sync.Mutex
+
+ config Config
+ context context.Context
+ processes map[string]process.Process
+ events chan interface{}
+ platform stdio.Platform
+ ec chan proc.Exit
+
+ // Filled by Create()
+ id string
+ bundle string
+}
+
+// Create creates a new initial process and container with the underlying OCI runtime.
+func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shim.CreateTaskResponse, err error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var mounts []proc.Mount
+ for _, m := range r.Rootfs {
+ mounts = append(mounts, proc.Mount{
+ Type: m.Type,
+ Source: m.Source,
+ Target: m.Target,
+ Options: m.Options,
+ })
+ }
+
+ rootfs := filepath.Join(r.Bundle, "rootfs")
+ if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) {
+ return nil, err
+ }
+
+ config := &proc.CreateConfig{
+ ID: r.ID,
+ Bundle: r.Bundle,
+ Runtime: r.Runtime,
+ Rootfs: mounts,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Options: r.Options,
+ }
+ defer func() {
+ if err != nil {
+ if err2 := mount.UnmountAll(rootfs, 0); err2 != nil {
+ log.G(ctx).WithError(err2).Warn("Failed to cleanup rootfs mount")
+ }
+ }
+ }()
+ for _, rm := range mounts {
+ m := &mount.Mount{
+ Type: rm.Type,
+ Source: rm.Source,
+ Options: rm.Options,
+ }
+ if err := m.Mount(rootfs); err != nil {
+ return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err)
+ }
+ }
+ process, err := newInit(
+ ctx,
+ s.config.Path,
+ s.config.WorkDir,
+ s.config.RuntimeRoot,
+ s.config.Namespace,
+ s.config.RunscConfig,
+ s.platform,
+ config,
+ )
+ if err := process.Create(ctx, config); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ // Save the main task id and bundle to the shim for additional
+ // requests.
+ s.id = r.ID
+ s.bundle = r.Bundle
+ pid := process.Pid()
+ s.processes[r.ID] = process
+ return &shim.CreateTaskResponse{
+ Pid: uint32(pid),
+ }, nil
+}
+
+// Start starts a process.
+func (s *Service) Start(ctx context.Context, r *shim.StartRequest) (*shim.StartResponse, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Start(ctx); err != nil {
+ return nil, err
+ }
+ return &shim.StartResponse{
+ ID: p.ID(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Delete deletes the initial process and container.
+func (s *Service) Delete(ctx context.Context, r *types.Empty) (*shim.DeleteResponse, error) {
+ p, err := s.getInitProcess()
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Delete(ctx); err != nil {
+ return nil, err
+ }
+ s.mu.Lock()
+ delete(s.processes, s.id)
+ s.mu.Unlock()
+ s.platform.Close()
+ return &shim.DeleteResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// DeleteProcess deletes an exec'd process.
+func (s *Service) DeleteProcess(ctx context.Context, r *shim.DeleteProcessRequest) (*shim.DeleteResponse, error) {
+ if r.ID == s.id {
+ return nil, status.Errorf(codes.InvalidArgument, "cannot delete init process with DeleteProcess")
+ }
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Delete(ctx); err != nil {
+ return nil, err
+ }
+ s.mu.Lock()
+ delete(s.processes, r.ID)
+ s.mu.Unlock()
+ return &shim.DeleteResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Exec spawns an additional process inside the container.
+func (s *Service) Exec(ctx context.Context, r *shim.ExecProcessRequest) (*types.Empty, error) {
+ s.mu.Lock()
+
+ if p := s.processes[r.ID]; p != nil {
+ s.mu.Unlock()
+ return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ID)
+ }
+
+ p := s.processes[s.id]
+ s.mu.Unlock()
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+
+ process, err := p.(*proc.Init).Exec(ctx, s.config.Path, &proc.ExecConfig{
+ ID: r.ID,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Spec: r.Spec,
+ })
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ s.mu.Lock()
+ s.processes[r.ID] = process
+ s.mu.Unlock()
+ return empty, nil
+}
+
+// ResizePty resises the terminal of a process.
+func (s *Service) ResizePty(ctx context.Context, r *shim.ResizePtyRequest) (*types.Empty, error) {
+ if r.ID == "" {
+ return nil, errdefs.ToGRPCf(errdefs.ErrInvalidArgument, "id not provided")
+ }
+ ws := console.WinSize{
+ Width: uint16(r.Width),
+ Height: uint16(r.Height),
+ }
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Resize(ws); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// State returns runtime state information for a process.
+func (s *Service) State(ctx context.Context, r *shim.StateRequest) (*shim.StateResponse, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ st, err := p.Status(ctx)
+ if err != nil {
+ return nil, err
+ }
+ status := task.StatusUnknown
+ switch st {
+ case "created":
+ status = task.StatusCreated
+ case "running":
+ status = task.StatusRunning
+ case "stopped":
+ status = task.StatusStopped
+ }
+ sio := p.Stdio()
+ return &shim.StateResponse{
+ ID: p.ID(),
+ Bundle: s.bundle,
+ Pid: uint32(p.Pid()),
+ Status: status,
+ Stdin: sio.Stdin,
+ Stdout: sio.Stdout,
+ Stderr: sio.Stderr,
+ Terminal: sio.Terminal,
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+// Pause pauses the container.
+func (s *Service) Pause(ctx context.Context, r *types.Empty) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Resume resumes the container.
+func (s *Service) Resume(ctx context.Context, r *types.Empty) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Kill kills a process with the provided signal.
+func (s *Service) Kill(ctx context.Context, r *shim.KillRequest) (*types.Empty, error) {
+ if r.ID == "" {
+ p, err := s.getInitProcess()
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Kill(ctx, r.Signal, r.All); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+ }
+
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Kill(ctx, r.Signal, r.All); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// ListPids returns all pids inside the container.
+func (s *Service) ListPids(ctx context.Context, r *shim.ListPidsRequest) (*shim.ListPidsResponse, error) {
+ pids, err := s.getContainerPids(ctx, r.ID)
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ var processes []*task.ProcessInfo
+ for _, pid := range pids {
+ pInfo := task.ProcessInfo{
+ Pid: pid,
+ }
+ for _, p := range s.processes {
+ if p.Pid() == int(pid) {
+ d := &runctypes.ProcessDetails{
+ ExecID: p.ID(),
+ }
+ a, err := typeurl.MarshalAny(d)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err)
+ }
+ pInfo.Info = a
+ break
+ }
+ }
+ processes = append(processes, &pInfo)
+ }
+ return &shim.ListPidsResponse{
+ Processes: processes,
+ }, nil
+}
+
+// CloseIO closes the I/O context of a process.
+func (s *Service) CloseIO(ctx context.Context, r *shim.CloseIORequest) (*types.Empty, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if stdin := p.Stdin(); stdin != nil {
+ if err := stdin.Close(); err != nil {
+ return nil, fmt.Errorf("close stdin: %w", err)
+ }
+ }
+ return empty, nil
+}
+
+// Checkpoint checkpoints the container.
+func (s *Service) Checkpoint(ctx context.Context, r *shim.CheckpointTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// ShimInfo returns shim information such as the shim's pid.
+func (s *Service) ShimInfo(ctx context.Context, r *types.Empty) (*shim.ShimInfoResponse, error) {
+ return &shim.ShimInfoResponse{
+ ShimPid: uint32(os.Getpid()),
+ }, nil
+}
+
+// Update updates a running container.
+func (s *Service) Update(ctx context.Context, r *shim.UpdateTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Wait waits for a process to exit.
+func (s *Service) Wait(ctx context.Context, r *shim.WaitRequest) (*shim.WaitResponse, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ p.Wait()
+
+ return &shim.WaitResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+func (s *Service) processExits() {
+ for e := range s.ec {
+ s.checkProcesses(e)
+ }
+}
+
+func (s *Service) allProcesses() []process.Process {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ res := make([]process.Process, 0, len(s.processes))
+ for _, p := range s.processes {
+ res = append(res, p)
+ }
+ return res
+}
+
+func (s *Service) checkProcesses(e proc.Exit) {
+ for _, p := range s.allProcesses() {
+ if p.ID() == e.ID {
+ if ip, ok := p.(*proc.Init); ok {
+ // Ensure all children are killed.
+ if err := ip.KillAll(s.context); err != nil {
+ log.G(s.context).WithError(err).WithField("id", ip.ID()).
+ Error("failed to kill init's children")
+ }
+ }
+ p.SetExited(e.Status)
+ s.events <- &TaskExit{
+ ContainerID: s.id,
+ ID: p.ID(),
+ Pid: uint32(p.Pid()),
+ ExitStatus: uint32(e.Status),
+ ExitedAt: p.ExitedAt(),
+ }
+ return
+ }
+ }
+}
+
+func (s *Service) getContainerPids(ctx context.Context, id string) ([]uint32, error) {
+ p, err := s.getInitProcess()
+ if err != nil {
+ return nil, err
+ }
+
+ ps, err := p.(*proc.Init).Runtime().Ps(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ pids := make([]uint32, 0, len(ps))
+ for _, pid := range ps {
+ pids = append(pids, uint32(pid))
+ }
+ return pids, nil
+}
+
+func (s *Service) forward(publisher events.Publisher) {
+ for e := range s.events {
+ if err := publisher.Publish(s.context, getTopic(s.context, e), e); err != nil {
+ log.G(s.context).WithError(err).Error("post event")
+ }
+ }
+}
+
+// getInitProcess returns the init process.
+func (s *Service) getInitProcess() (process.Process, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ p := s.processes[s.id]
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ return p, nil
+}
+
+// getExecProcess returns the given exec process.
+func (s *Service) getExecProcess(id string) (process.Process, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ p := s.processes[id]
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process %s does not exist", id)
+ }
+ return p, nil
+}
+
+func getTopic(ctx context.Context, e interface{}) string {
+ switch e.(type) {
+ case *TaskCreate:
+ return runtime.TaskCreateEventTopic
+ case *TaskStart:
+ return runtime.TaskStartEventTopic
+ case *TaskOOM:
+ return runtime.TaskOOMEventTopic
+ case *TaskExit:
+ return runtime.TaskExitEventTopic
+ case *TaskDelete:
+ return runtime.TaskDeleteEventTopic
+ case *TaskExecAdded:
+ return runtime.TaskExecAddedEventTopic
+ case *TaskExecStarted:
+ return runtime.TaskExecStartedEventTopic
+ default:
+ log.L.Printf("no topic for type %#v", e)
+ }
+ return runtime.TaskUnknownTopic
+}
+
+func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig) (*proc.Init, error) {
+ var options runctypes.CreateOptions
+ if r.Options != nil {
+ v, err := typeurl.UnmarshalAny(r.Options)
+ if err != nil {
+ return nil, err
+ }
+ options = *v.(*runctypes.CreateOptions)
+ }
+
+ spec, err := utils.ReadSpec(r.Bundle)
+ if err != nil {
+ return nil, fmt.Errorf("read oci spec: %w", err)
+ }
+ if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil {
+ return nil, fmt.Errorf("update volume annotations: %w", err)
+ }
+
+ runsc.FormatLogPath(r.ID, config)
+ rootfs := filepath.Join(path, "rootfs")
+ runtime := proc.NewRunsc(runtimeRoot, path, namespace, r.Runtime, config)
+ p := proc.New(r.ID, runtime, stdio.Stdio{
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Terminal: r.Terminal,
+ })
+ p.Bundle = r.Bundle
+ p.Platform = platform
+ p.Rootfs = rootfs
+ p.WorkDir = workDir
+ p.IoUID = int(options.IoUid)
+ p.IoGID = int(options.IoGid)
+ p.Sandbox = utils.IsSandbox(spec)
+ p.UserLog = utils.UserLogPath(spec)
+ p.Monitor = reaper.Default
+ return p, nil
+}
diff --git a/pkg/shim/v1/utils/BUILD b/pkg/shim/v1/utils/BUILD
new file mode 100644
index 000000000..54a0aabb7
--- /dev/null
+++ b/pkg/shim/v1/utils/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "utils",
+ srcs = [
+ "annotations.go",
+ "utils.go",
+ "volumes.go",
+ ],
+ visibility = [
+ "//pkg/shim:__subpackages__",
+ "//shim:__subpackages__",
+ ],
+ deps = [
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
+)
+
+go_test(
+ name = "utils_test",
+ size = "small",
+ srcs = ["volumes_test.go"],
+ library = ":utils",
+ deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"],
+)
diff --git a/pkg/shim/v1/utils/annotations.go b/pkg/shim/v1/utils/annotations.go
new file mode 100644
index 000000000..1e9d3f365
--- /dev/null
+++ b/pkg/shim/v1/utils/annotations.go
@@ -0,0 +1,25 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+// Annotations from the CRI annotations package.
+//
+// These are vendor due to import conflicts.
+const (
+ sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory"
+ containerTypeAnnotation = "io.kubernetes.cri.container-type"
+ containerTypeSandbox = "sandbox"
+ containerTypeContainer = "container"
+)
diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go
new file mode 100644
index 000000000..07e346654
--- /dev/null
+++ b/pkg/shim/v1/utils/utils.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+// ReadSpec reads OCI spec from the bundle directory.
+func ReadSpec(bundle string) (*specs.Spec, error) {
+ f, err := os.Open(filepath.Join(bundle, "config.json"))
+ if err != nil {
+ return nil, err
+ }
+ b, err := ioutil.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+ var spec specs.Spec
+ if err := json.Unmarshal(b, &spec); err != nil {
+ return nil, err
+ }
+ return &spec, nil
+}
+
+// IsSandbox checks whether a container is a sandbox container.
+func IsSandbox(spec *specs.Spec) bool {
+ t, ok := spec.Annotations[containerTypeAnnotation]
+ return !ok || t == containerTypeSandbox
+}
+
+// UserLogPath gets user log path from OCI annotation.
+func UserLogPath(spec *specs.Spec) string {
+ sandboxLogDir := spec.Annotations[sandboxLogDirAnnotation]
+ if sandboxLogDir == "" {
+ return ""
+ }
+ return filepath.Join(sandboxLogDir, "gvisor.log")
+}
diff --git a/pkg/shim/v1/utils/volumes.go b/pkg/shim/v1/utils/volumes.go
new file mode 100644
index 000000000..52a428179
--- /dev/null
+++ b/pkg/shim/v1/utils/volumes.go
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "path/filepath"
+ "strings"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+const volumeKeyPrefix = "dev.gvisor.spec.mount."
+
+var kubeletPodsDir = "/var/lib/kubelet/pods"
+
+// volumeName gets volume name from volume annotation key, example:
+// dev.gvisor.spec.mount.NAME.share
+func volumeName(k string) string {
+ return strings.SplitN(strings.TrimPrefix(k, volumeKeyPrefix), ".", 2)[0]
+}
+
+// volumeFieldName gets volume field name from volume annotation key, example:
+// `type` is the field of dev.gvisor.spec.mount.NAME.type
+func volumeFieldName(k string) string {
+ parts := strings.Split(strings.TrimPrefix(k, volumeKeyPrefix), ".")
+ return parts[len(parts)-1]
+}
+
+// podUID gets pod UID from the pod log path.
+func podUID(s *specs.Spec) (string, error) {
+ sandboxLogDir := s.Annotations[sandboxLogDirAnnotation]
+ if sandboxLogDir == "" {
+ return "", fmt.Errorf("no sandbox log path annotation")
+ }
+ fields := strings.Split(filepath.Base(sandboxLogDir), "_")
+ switch len(fields) {
+ case 1: // This is the old CRI logging path.
+ return fields[0], nil
+ case 3: // This is the new CRI logging path.
+ return fields[2], nil
+ }
+ return "", fmt.Errorf("unexpected sandbox log path %q", sandboxLogDir)
+}
+
+// isVolumeKey checks whether an annotation key is for volume.
+func isVolumeKey(k string) bool {
+ return strings.HasPrefix(k, volumeKeyPrefix)
+}
+
+// volumeSourceKey constructs the annotation key for volume source.
+func volumeSourceKey(volume string) string {
+ return volumeKeyPrefix + volume + ".source"
+}
+
+// volumePath searches the volume path in the kubelet pod directory.
+func volumePath(volume, uid string) (string, error) {
+ // TODO: Support subpath when gvisor supports pod volume bind mount.
+ volumeSearchPath := fmt.Sprintf("%s/%s/volumes/*/%s", kubeletPodsDir, uid, volume)
+ dirs, err := filepath.Glob(volumeSearchPath)
+ if err != nil {
+ return "", err
+ }
+ if len(dirs) != 1 {
+ return "", fmt.Errorf("unexpected matched volume list %v", dirs)
+ }
+ return dirs[0], nil
+}
+
+// isVolumePath checks whether a string is the volume path.
+func isVolumePath(volume, path string) (bool, error) {
+ // TODO: Support subpath when gvisor supports pod volume bind mount.
+ volumeSearchPath := fmt.Sprintf("%s/*/volumes/*/%s", kubeletPodsDir, volume)
+ return filepath.Match(volumeSearchPath, path)
+}
+
+// UpdateVolumeAnnotations add necessary OCI annotations for gvisor
+// volume optimization.
+func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
+ var (
+ uid string
+ err error
+ )
+ if IsSandbox(s) {
+ uid, err = podUID(s)
+ if err != nil {
+ // Skip if we can't get pod UID, because this doesn't work
+ // for containerd 1.1.
+ return nil
+ }
+ }
+ var updated bool
+ for k, v := range s.Annotations {
+ if !isVolumeKey(k) {
+ continue
+ }
+ if volumeFieldName(k) != "type" {
+ continue
+ }
+ volume := volumeName(k)
+ if uid != "" {
+ // This is a sandbox.
+ path, err := volumePath(volume, uid)
+ if err != nil {
+ return fmt.Errorf("get volume path for %q: %w", volume, err)
+ }
+ s.Annotations[volumeSourceKey(volume)] = path
+ updated = true
+ } else {
+ // This is a container.
+ for i := range s.Mounts {
+ // An error is returned for sandbox if source
+ // annotation is not successfully applied, so
+ // it is guaranteed that the source annotation
+ // for sandbox has already been successfully
+ // applied at this point.
+ //
+ // The volume name is unique inside a pod, so
+ // matching without podUID is fine here.
+ //
+ // TODO: Pass podUID down to shim for containers to do
+ // more accurate matching.
+ if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes {
+ // gVisor requires the container mount type to match
+ // sandbox mount type.
+ s.Mounts[i].Type = v
+ updated = true
+ }
+ }
+ }
+ }
+ if !updated {
+ return nil
+ }
+ // Update bundle.
+ b, err := json.Marshal(s)
+ if err != nil {
+ return err
+ }
+ return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666)
+}
diff --git a/pkg/shim/v1/utils/volumes_test.go b/pkg/shim/v1/utils/volumes_test.go
new file mode 100644
index 000000000..3e02c6151
--- /dev/null
+++ b/pkg/shim/v1/utils/volumes_test.go
@@ -0,0 +1,308 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "reflect"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+func TestUpdateVolumeAnnotations(t *testing.T) {
+ dir, err := ioutil.TempDir("", "test-update-volume-annotations")
+ if err != nil {
+ t.Fatalf("create tempdir: %v", err)
+ }
+ defer os.RemoveAll(dir)
+ kubeletPodsDir = dir
+
+ const (
+ testPodUID = "testuid"
+ testVolumeName = "testvolume"
+ testLogDirPath = "/var/log/pods/testns_testname_" + testPodUID
+ testLegacyLogDirPath = "/var/log/pods/" + testPodUID
+ )
+ testVolumePath := fmt.Sprintf("%s/%s/volumes/kubernetes.io~empty-dir/%s", dir, testPodUID, testVolumeName)
+
+ if err := os.MkdirAll(testVolumePath, 0755); err != nil {
+ t.Fatalf("Create test volume: %v", err)
+ }
+
+ for _, test := range []struct {
+ desc string
+ spec *specs.Spec
+ expected *specs.Spec
+ expectErr bool
+ expectUpdate bool
+ }{
+ {
+ desc: "volume annotations for sandbox",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath,
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "volume annotations for sandbox with legacy log path",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLegacyLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLegacyLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath,
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "tmpfs: volume annotations for container",
+ spec: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "tmpfs",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "bind: volume annotations for container",
+ spec: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "container",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "container",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "should not return error without pod log directory",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ },
+ {
+ desc: "should return error if volume path does not exist",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount.notexist.share": "pod",
+ "dev.gvisor.spec.mount.notexist.type": "tmpfs",
+ "dev.gvisor.spec.mount.notexist.options": "ro",
+ },
+ },
+ expectErr: true,
+ },
+ {
+ desc: "no volume annotations for sandbox",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ },
+ },
+ },
+ {
+ desc: "no volume annotations for container",
+ spec: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: "/test",
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ },
+ },
+ expected: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: "/test",
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ },
+ },
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ bundle, err := ioutil.TempDir(dir, "test-bundle")
+ if err != nil {
+ t.Fatalf("Create test bundle: %v", err)
+ }
+ err = UpdateVolumeAnnotations(bundle, test.spec)
+ if test.expectErr {
+ if err == nil {
+ t.Fatal("Expected error, but got nil")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ if !reflect.DeepEqual(test.expected, test.spec) {
+ t.Fatalf("Expected %+v, got %+v", test.expected, test.spec)
+ }
+ if test.expectUpdate {
+ b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json"))
+ if err != nil {
+ t.Fatalf("Read spec from bundle: %v", err)
+ }
+ var spec specs.Spec
+ if err := json.Unmarshal(b, &spec); err != nil {
+ t.Fatalf("Unmarshal spec: %v", err)
+ }
+ if !reflect.DeepEqual(test.expected, &spec) {
+ t.Fatalf("Expected %+v, got %+v", test.expected, &spec)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD
new file mode 100644
index 000000000..7e0a114a0
--- /dev/null
+++ b/pkg/shim/v2/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "v2",
+ srcs = [
+ "api.go",
+ "epoll.go",
+ "service.go",
+ "service_linux.go",
+ ],
+ visibility = ["//shim:__subpackages__"],
+ deps = [
+ "//pkg/shim/runsc",
+ "//pkg/shim/v1/proc",
+ "//pkg/shim/v1/utils",
+ "//pkg/shim/v2/options",
+ "//pkg/shim/v2/runtimeoptions",
+ "//runsc/specutils",
+ "@com_github_burntsushi_toml//:go_default_library",
+ "@com_github_containerd_cgroups//:go_default_library",
+ "@com_github_containerd_console//:go_default_library",
+ "@com_github_containerd_containerd//api/events:go_default_library",
+ "@com_github_containerd_containerd//api/types/task:go_default_library",
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_containerd_containerd//events:go_default_library",
+ "@com_github_containerd_containerd//log:go_default_library",
+ "@com_github_containerd_containerd//mount:go_default_library",
+ "@com_github_containerd_containerd//namespaces:go_default_library",
+ "@com_github_containerd_containerd//pkg/process:go_default_library",
+ "@com_github_containerd_containerd//pkg/stdio:go_default_library",
+ "@com_github_containerd_containerd//runtime:go_default_library",
+ "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library",
+ "@com_github_containerd_containerd//runtime/v2/shim:go_default_library",
+ "@com_github_containerd_containerd//runtime/v2/task:go_default_library",
+ "@com_github_containerd_containerd//sys/reaper:go_default_library",
+ "@com_github_containerd_fifo//:go_default_library",
+ "@com_github_containerd_typeurl//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v2/api.go b/pkg/shim/v2/api.go
new file mode 100644
index 000000000..dbe5c59f6
--- /dev/null
+++ b/pkg/shim/v2/api.go
@@ -0,0 +1,22 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package v2
+
+import (
+ "github.com/containerd/containerd/api/events"
+)
+
+type TaskOOM = events.TaskOOM
diff --git a/pkg/shim/v2/epoll.go b/pkg/shim/v2/epoll.go
new file mode 100644
index 000000000..41232cca8
--- /dev/null
+++ b/pkg/shim/v2/epoll.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package v2
+
+import (
+ "context"
+ "fmt"
+ "sync"
+
+ "github.com/containerd/cgroups"
+ "github.com/containerd/containerd/events"
+ "github.com/containerd/containerd/runtime"
+ "golang.org/x/sys/unix"
+)
+
+func newOOMEpoller(publisher events.Publisher) (*epoller, error) {
+ fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
+ if err != nil {
+ return nil, err
+ }
+ return &epoller{
+ fd: fd,
+ publisher: publisher,
+ set: make(map[uintptr]*item),
+ }, nil
+}
+
+type epoller struct {
+ mu sync.Mutex
+
+ fd int
+ publisher events.Publisher
+ set map[uintptr]*item
+}
+
+type item struct {
+ id string
+ cg cgroups.Cgroup
+}
+
+func (e *epoller) Close() error {
+ return unix.Close(e.fd)
+}
+
+func (e *epoller) run(ctx context.Context) {
+ var events [128]unix.EpollEvent
+ for {
+ select {
+ case <-ctx.Done():
+ e.Close()
+ return
+ default:
+ n, err := unix.EpollWait(e.fd, events[:], -1)
+ if err != nil {
+ if err == unix.EINTR || err == unix.EAGAIN {
+ continue
+ }
+ // Should not happen.
+ panic(fmt.Errorf("cgroups: epoll wait: %w", err))
+ }
+ for i := 0; i < n; i++ {
+ e.process(ctx, uintptr(events[i].Fd))
+ }
+ }
+ }
+}
+
+func (e *epoller) add(id string, cg cgroups.Cgroup) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ fd, err := cg.OOMEventFD()
+ if err != nil {
+ return err
+ }
+ e.set[fd] = &item{
+ id: id,
+ cg: cg,
+ }
+ event := unix.EpollEvent{
+ Fd: int32(fd),
+ Events: unix.EPOLLHUP | unix.EPOLLIN | unix.EPOLLERR,
+ }
+ return unix.EpollCtl(e.fd, unix.EPOLL_CTL_ADD, int(fd), &event)
+}
+
+func (e *epoller) process(ctx context.Context, fd uintptr) {
+ flush(fd)
+ e.mu.Lock()
+ i, ok := e.set[fd]
+ if !ok {
+ e.mu.Unlock()
+ return
+ }
+ e.mu.Unlock()
+ if i.cg.State() == cgroups.Deleted {
+ e.mu.Lock()
+ delete(e.set, fd)
+ e.mu.Unlock()
+ unix.Close(int(fd))
+ return
+ }
+ if err := e.publisher.Publish(ctx, runtime.TaskOOMEventTopic, &TaskOOM{
+ ContainerID: i.id,
+ }); err != nil {
+ // Should not happen.
+ panic(fmt.Errorf("publish OOM event: %w", err))
+ }
+}
+
+func flush(fd uintptr) error {
+ var buf [8]byte
+ _, err := unix.Read(int(fd), buf[:])
+ return err
+}
diff --git a/pkg/shim/v2/options/BUILD b/pkg/shim/v2/options/BUILD
new file mode 100644
index 000000000..ca212e874
--- /dev/null
+++ b/pkg/shim/v2/options/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "options",
+ srcs = [
+ "options.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/shim/v2/options/options.go b/pkg/shim/v2/options/options.go
new file mode 100644
index 000000000..de09f2f79
--- /dev/null
+++ b/pkg/shim/v2/options/options.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package options
+
+const OptionType = "io.containerd.runsc.v1.options"
+
+// Options is runtime options for io.containerd.runsc.v1.
+type Options struct {
+ // ShimCgroup is the cgroup the shim should be in.
+ ShimCgroup string `toml:"shim_cgroup"`
+ // IoUid is the I/O's pipes uid.
+ IoUid uint32 `toml:"io_uid"`
+ // IoUid is the I/O's pipes gid.
+ IoGid uint32 `toml:"io_gid"`
+ // BinaryName is the binary name of the runsc binary.
+ BinaryName string `toml:"binary_name"`
+ // Root is the runsc root directory.
+ Root string `toml:"root"`
+ // RunscConfig is a key/value map of all runsc flags.
+ RunscConfig map[string]string `toml:"runsc_config"`
+}
diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD
new file mode 100644
index 000000000..01716034c
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library", "proto_library")
+
+package(licenses = ["notice"])
+
+proto_library(
+ name = "api",
+ srcs = [
+ "runtimeoptions.proto",
+ ],
+)
+
+go_library(
+ name = "runtimeoptions",
+ srcs = ["runtimeoptions.go"],
+ visibility = ["//pkg/shim/v2:__pkg__"],
+ deps = [
+ "//pkg/shim/v2/runtimeoptions:api_go_proto",
+ "@com_github_gogo_protobuf//proto:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go
new file mode 100644
index 000000000..1c1a0c5d1
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runtimeoptions
+
+import (
+ proto "github.com/gogo/protobuf/proto"
+ pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto"
+)
+
+type Options = pb.Options
+
+func init() {
+ proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options")
+}
diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.proto b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto
new file mode 100644
index 000000000..edb19020a
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package runtimeoptions;
+
+// This is a version of the runtimeoptions CRI API that is vendored.
+//
+// Imported the full CRI package is a nightmare.
+message Options {
+ string type_url = 1;
+ string config_path = 2;
+}
diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go
new file mode 100644
index 000000000..1534152fc
--- /dev/null
+++ b/pkg/shim/v2/service.go
@@ -0,0 +1,824 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package v2
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/BurntSushi/toml"
+ "github.com/containerd/cgroups"
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/api/events"
+ "github.com/containerd/containerd/api/types/task"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/containerd/mount"
+ "github.com/containerd/containerd/namespaces"
+ "github.com/containerd/containerd/pkg/process"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/containerd/runtime"
+ "github.com/containerd/containerd/runtime/linux/runctypes"
+ "github.com/containerd/containerd/runtime/v2/shim"
+ taskAPI "github.com/containerd/containerd/runtime/v2/task"
+ "github.com/containerd/containerd/sys/reaper"
+ "github.com/containerd/typeurl"
+ "github.com/gogo/protobuf/types"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+ "gvisor.dev/gvisor/pkg/shim/v1/proc"
+ "gvisor.dev/gvisor/pkg/shim/v1/utils"
+ "gvisor.dev/gvisor/pkg/shim/v2/options"
+ "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+var (
+ empty = &types.Empty{}
+ bufPool = sync.Pool{
+ New: func() interface{} {
+ buffer := make([]byte, 32<<10)
+ return &buffer
+ },
+ }
+)
+
+var _ = (taskAPI.TaskService)(&service{})
+
+// configFile is the default config file name. For containerd 1.2,
+// we assume that a config.toml should exist in the runtime root.
+const configFile = "config.toml"
+
+// New returns a new shim service that can be used via GRPC.
+func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) {
+ ep, err := newOOMEpoller(publisher)
+ if err != nil {
+ return nil, err
+ }
+ go ep.run(ctx)
+ s := &service{
+ id: id,
+ context: ctx,
+ processes: make(map[string]process.Process),
+ events: make(chan interface{}, 128),
+ ec: proc.ExitCh,
+ oomPoller: ep,
+ cancel: cancel,
+ }
+ go s.processExits()
+ runsc.Monitor = reaper.Default
+ if err := s.initPlatform(); err != nil {
+ cancel()
+ return nil, fmt.Errorf("failed to initialized platform behavior: %w", err)
+ }
+ go s.forward(publisher)
+ return s, nil
+}
+
+// service is the shim implementation of a remote shim over GRPC.
+type service struct {
+ mu sync.Mutex
+
+ context context.Context
+ task process.Process
+ processes map[string]process.Process
+ events chan interface{}
+ platform stdio.Platform
+ opts options.Options
+ ec chan proc.Exit
+ oomPoller *epoller
+
+ id string
+ bundle string
+ cancel func()
+}
+
+func newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) {
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, err
+ }
+ self, err := os.Executable()
+ if err != nil {
+ return nil, err
+ }
+ cwd, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ args := []string{
+ "-namespace", ns,
+ "-address", containerdAddress,
+ "-publish-binary", containerdBinary,
+ }
+ cmd := exec.Command(self, args...)
+ cmd.Dir = cwd
+ cmd.Env = append(os.Environ(), "GOMAXPROCS=2")
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setpgid: true,
+ }
+ return cmd, nil
+}
+
+func (s *service) StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error) {
+ cmd, err := newCommand(ctx, containerdBinary, containerdAddress)
+ if err != nil {
+ return "", err
+ }
+ address, err := shim.SocketAddress(ctx, id)
+ if err != nil {
+ return "", err
+ }
+ socket, err := shim.NewSocket(address)
+ if err != nil {
+ return "", err
+ }
+ defer socket.Close()
+ f, err := socket.File()
+ if err != nil {
+ return "", err
+ }
+ defer f.Close()
+
+ cmd.ExtraFiles = append(cmd.ExtraFiles, f)
+
+ if err := cmd.Start(); err != nil {
+ return "", err
+ }
+ defer func() {
+ if err != nil {
+ cmd.Process.Kill()
+ }
+ }()
+ // make sure to wait after start
+ go cmd.Wait()
+ if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil {
+ return "", err
+ }
+ if err := shim.WriteAddress("address", address); err != nil {
+ return "", err
+ }
+ if err := shim.SetScore(cmd.Process.Pid); err != nil {
+ return "", fmt.Errorf("failed to set OOM Score on shim: %w", err)
+ }
+ return address, nil
+}
+
+func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) {
+ path, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, err
+ }
+ runtime, err := s.readRuntime(path)
+ if err != nil {
+ return nil, err
+ }
+ r := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil)
+ if err := r.Delete(ctx, s.id, &runsc.DeleteOpts{
+ Force: true,
+ }); err != nil {
+ log.L.Printf("failed to remove runc container: %v", err)
+ }
+ if err := mount.UnmountAll(filepath.Join(path, "rootfs"), 0); err != nil {
+ log.L.Printf("failed to cleanup rootfs mount: %v", err)
+ }
+ return &taskAPI.DeleteResponse{
+ ExitedAt: time.Now(),
+ ExitStatus: 128 + uint32(unix.SIGKILL),
+ }, nil
+}
+
+func (s *service) readRuntime(path string) (string, error) {
+ data, err := ioutil.ReadFile(filepath.Join(path, "runtime"))
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func (s *service) writeRuntime(path, runtime string) error {
+ return ioutil.WriteFile(filepath.Join(path, "runtime"), []byte(runtime), 0600)
+}
+
+// Create creates a new initial process and container with the underlying OCI
+// runtime.
+func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *taskAPI.CreateTaskResponse, err error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("create namespace: %w", err)
+ }
+
+ // Read from root for now.
+ var opts options.Options
+ if r.Options != nil {
+ v, err := typeurl.UnmarshalAny(r.Options)
+ if err != nil {
+ return nil, err
+ }
+ var path string
+ switch o := v.(type) {
+ case *runctypes.CreateOptions: // containerd 1.2.x
+ opts.IoUid = o.IoUid
+ opts.IoGid = o.IoGid
+ opts.ShimCgroup = o.ShimCgroup
+ case *runctypes.RuncOptions: // containerd 1.2.x
+ root := proc.RunscRoot
+ if o.RuntimeRoot != "" {
+ root = o.RuntimeRoot
+ }
+
+ opts.BinaryName = o.Runtime
+
+ path = filepath.Join(root, configFile)
+ if _, err := os.Stat(path); err != nil {
+ if !os.IsNotExist(err) {
+ return nil, fmt.Errorf("stat config file %q: %w", path, err)
+ }
+ // A config file in runtime root is not required.
+ path = ""
+ }
+ case *runtimeoptions.Options: // containerd 1.3.x+
+ if o.ConfigPath == "" {
+ break
+ }
+ if o.TypeUrl != options.OptionType {
+ return nil, fmt.Errorf("unsupported option type %q", o.TypeUrl)
+ }
+ path = o.ConfigPath
+ default:
+ return nil, fmt.Errorf("unsupported option type %q", r.Options.TypeUrl)
+ }
+ if path != "" {
+ if _, err = toml.DecodeFile(path, &opts); err != nil {
+ return nil, fmt.Errorf("decode config file %q: %w", path, err)
+ }
+ }
+ }
+
+ var mounts []proc.Mount
+ for _, m := range r.Rootfs {
+ mounts = append(mounts, proc.Mount{
+ Type: m.Type,
+ Source: m.Source,
+ Target: m.Target,
+ Options: m.Options,
+ })
+ }
+
+ rootfs := filepath.Join(r.Bundle, "rootfs")
+ if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) {
+ return nil, err
+ }
+
+ config := &proc.CreateConfig{
+ ID: r.ID,
+ Bundle: r.Bundle,
+ Runtime: opts.BinaryName,
+ Rootfs: mounts,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Options: r.Options,
+ }
+ if err := s.writeRuntime(r.Bundle, opts.BinaryName); err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err != nil {
+ if err := mount.UnmountAll(rootfs, 0); err != nil {
+ log.L.Printf("failed to cleanup rootfs mount: %v", err)
+ }
+ }
+ }()
+ for _, rm := range mounts {
+ m := &mount.Mount{
+ Type: rm.Type,
+ Source: rm.Source,
+ Options: rm.Options,
+ }
+ if err := m.Mount(rootfs); err != nil {
+ return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err)
+ }
+ }
+ process, err := newInit(
+ ctx,
+ r.Bundle,
+ filepath.Join(r.Bundle, "work"),
+ ns,
+ s.platform,
+ config,
+ &opts,
+ rootfs,
+ )
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ if err := process.Create(ctx, config); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ // Save the main task id and bundle to the shim for additional
+ // requests.
+ s.id = r.ID
+ s.bundle = r.Bundle
+
+ // Set up OOM notification on the sandbox's cgroup. This is done on
+ // sandbox create since the sandbox process will be created here.
+ pid := process.Pid()
+ if pid > 0 {
+ cg, err := cgroups.Load(cgroups.V1, cgroups.PidPath(pid))
+ if err != nil {
+ return nil, fmt.Errorf("loading cgroup for %d: %w", pid, err)
+ }
+ if err := s.oomPoller.add(s.id, cg); err != nil {
+ return nil, fmt.Errorf("add cg to OOM monitor: %w", err)
+ }
+ }
+ s.task = process
+ s.opts = opts
+ return &taskAPI.CreateTaskResponse{
+ Pid: uint32(process.Pid()),
+ }, nil
+
+}
+
+// Start starts a process.
+func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Start(ctx); err != nil {
+ return nil, err
+ }
+ // TODO: Set the cgroup and oom notifications on restore.
+ // https://github.com/google/gvisor-containerd-shim/issues/58
+ return &taskAPI.StartResponse{
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Delete deletes the initial process and container.
+func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ if err := p.Delete(ctx); err != nil {
+ return nil, err
+ }
+ isTask := r.ExecID == ""
+ if !isTask {
+ s.mu.Lock()
+ delete(s.processes, r.ExecID)
+ s.mu.Unlock()
+ }
+ if isTask && s.platform != nil {
+ s.platform.Close()
+ }
+ return &taskAPI.DeleteResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Exec spawns an additional process inside the container.
+func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) {
+ s.mu.Lock()
+ p := s.processes[r.ExecID]
+ s.mu.Unlock()
+ if p != nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID)
+ }
+ p = s.task
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ process, err := p.(*proc.Init).Exec(ctx, s.bundle, &proc.ExecConfig{
+ ID: r.ExecID,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Spec: r.Spec,
+ })
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ s.mu.Lock()
+ s.processes[r.ExecID] = process
+ s.mu.Unlock()
+ return empty, nil
+}
+
+// ResizePty resizes the terminal of a process.
+func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ ws := console.WinSize{
+ Width: uint16(r.Width),
+ Height: uint16(r.Height),
+ }
+ if err := p.Resize(ws); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// State returns runtime state information for a process.
+func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ st, err := p.Status(ctx)
+ if err != nil {
+ return nil, err
+ }
+ status := task.StatusUnknown
+ switch st {
+ case "created":
+ status = task.StatusCreated
+ case "running":
+ status = task.StatusRunning
+ case "stopped":
+ status = task.StatusStopped
+ }
+ sio := p.Stdio()
+ return &taskAPI.StateResponse{
+ ID: p.ID(),
+ Bundle: s.bundle,
+ Pid: uint32(p.Pid()),
+ Status: status,
+ Stdin: sio.Stdin,
+ Stdout: sio.Stdout,
+ Stderr: sio.Stderr,
+ Terminal: sio.Terminal,
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+// Pause the container.
+func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Resume the container.
+func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Kill a process with the provided signal.
+func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ if err := p.Kill(ctx, r.Signal, r.All); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// Pids returns all pids inside the container.
+func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) {
+ pids, err := s.getContainerPids(ctx, r.ID)
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ var processes []*task.ProcessInfo
+ for _, pid := range pids {
+ pInfo := task.ProcessInfo{
+ Pid: pid,
+ }
+ for _, p := range s.processes {
+ if p.Pid() == int(pid) {
+ d := &runctypes.ProcessDetails{
+ ExecID: p.ID(),
+ }
+ a, err := typeurl.MarshalAny(d)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err)
+ }
+ pInfo.Info = a
+ break
+ }
+ }
+ processes = append(processes, &pInfo)
+ }
+ return &taskAPI.PidsResponse{
+ Processes: processes,
+ }, nil
+}
+
+// CloseIO closes the I/O context of a process.
+func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if stdin := p.Stdin(); stdin != nil {
+ if err := stdin.Close(); err != nil {
+ return nil, fmt.Errorf("close stdin: %w", err)
+ }
+ }
+ return empty, nil
+}
+
+// Checkpoint checkpoints the container.
+func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Connect returns shim information such as the shim's pid.
+func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) {
+ var pid int
+ if s.task != nil {
+ pid = s.task.Pid()
+ }
+ return &taskAPI.ConnectResponse{
+ ShimPid: uint32(os.Getpid()),
+ TaskPid: uint32(pid),
+ }, nil
+}
+
+func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) {
+ s.cancel()
+ os.Exit(0)
+ return empty, nil
+}
+
+func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) {
+ path, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, err
+ }
+ runtime, err := s.readRuntime(path)
+ if err != nil {
+ return nil, err
+ }
+ rs := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil)
+ stats, err := rs.Stats(ctx, s.id)
+ if err != nil {
+ return nil, err
+ }
+
+ // gvisor currently (as of 2020-03-03) only returns the total memory
+ // usage and current PID value[0]. However, we copy the common fields here
+ // so that future updates will propagate correct information. We're
+ // using the cgroups.Metrics structure so we're returning the same type
+ // as runc.
+ //
+ // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81
+ data, err := typeurl.MarshalAny(&cgroups.Metrics{
+ CPU: &cgroups.CPUStat{
+ Usage: &cgroups.CPUUsage{
+ Total: stats.Cpu.Usage.Total,
+ Kernel: stats.Cpu.Usage.Kernel,
+ User: stats.Cpu.Usage.User,
+ PerCPU: stats.Cpu.Usage.Percpu,
+ },
+ Throttling: &cgroups.Throttle{
+ Periods: stats.Cpu.Throttling.Periods,
+ ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods,
+ ThrottledTime: stats.Cpu.Throttling.ThrottledTime,
+ },
+ },
+ Memory: &cgroups.MemoryStat{
+ Cache: stats.Memory.Cache,
+ Usage: &cgroups.MemoryEntry{
+ Limit: stats.Memory.Usage.Limit,
+ Usage: stats.Memory.Usage.Usage,
+ Max: stats.Memory.Usage.Max,
+ Failcnt: stats.Memory.Usage.Failcnt,
+ },
+ Swap: &cgroups.MemoryEntry{
+ Limit: stats.Memory.Swap.Limit,
+ Usage: stats.Memory.Swap.Usage,
+ Max: stats.Memory.Swap.Max,
+ Failcnt: stats.Memory.Swap.Failcnt,
+ },
+ Kernel: &cgroups.MemoryEntry{
+ Limit: stats.Memory.Kernel.Limit,
+ Usage: stats.Memory.Kernel.Usage,
+ Max: stats.Memory.Kernel.Max,
+ Failcnt: stats.Memory.Kernel.Failcnt,
+ },
+ KernelTCP: &cgroups.MemoryEntry{
+ Limit: stats.Memory.KernelTCP.Limit,
+ Usage: stats.Memory.KernelTCP.Usage,
+ Max: stats.Memory.KernelTCP.Max,
+ Failcnt: stats.Memory.KernelTCP.Failcnt,
+ },
+ },
+ Pids: &cgroups.PidsStat{
+ Current: stats.Pids.Current,
+ Limit: stats.Pids.Limit,
+ },
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &taskAPI.StatsResponse{
+ Stats: data,
+ }, nil
+}
+
+// Update updates a running container.
+func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Wait waits for a process to exit.
+func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ p.Wait()
+
+ return &taskAPI.WaitResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+func (s *service) processExits() {
+ for e := range s.ec {
+ s.checkProcesses(e)
+ }
+}
+
+func (s *service) checkProcesses(e proc.Exit) {
+ // TODO(random-liu): Add `shouldKillAll` logic if container pid
+ // namespace is supported.
+ for _, p := range s.allProcesses() {
+ if p.ID() == e.ID {
+ if ip, ok := p.(*proc.Init); ok {
+ // Ensure all children are killed.
+ if err := ip.KillAll(s.context); err != nil {
+ log.G(s.context).WithError(err).WithField("id", ip.ID()).
+ Error("failed to kill init's children")
+ }
+ }
+ p.SetExited(e.Status)
+ s.events <- &events.TaskExit{
+ ContainerID: s.id,
+ ID: p.ID(),
+ Pid: uint32(p.Pid()),
+ ExitStatus: uint32(e.Status),
+ ExitedAt: p.ExitedAt(),
+ }
+ return
+ }
+ }
+}
+
+func (s *service) allProcesses() (o []process.Process) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for _, p := range s.processes {
+ o = append(o, p)
+ }
+ if s.task != nil {
+ o = append(o, s.task)
+ }
+ return o
+}
+
+func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, error) {
+ s.mu.Lock()
+ p := s.task
+ s.mu.Unlock()
+ if p == nil {
+ return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition)
+ }
+ ps, err := p.(*proc.Init).Runtime().Ps(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ pids := make([]uint32, 0, len(ps))
+ for _, pid := range ps {
+ pids = append(pids, uint32(pid))
+ }
+ return pids, nil
+}
+
+func (s *service) forward(publisher shim.Publisher) {
+ for e := range s.events {
+ ctx, cancel := context.WithTimeout(s.context, 5*time.Second)
+ err := publisher.Publish(ctx, getTopic(e), e)
+ cancel()
+ if err != nil {
+ // Should not happen.
+ panic(fmt.Errorf("post event: %w", err))
+ }
+ }
+}
+
+func (s *service) getProcess(execID string) (process.Process, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if execID == "" {
+ return s.task, nil
+ }
+ p := s.processes[execID]
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID)
+ }
+ return p, nil
+}
+
+func getTopic(e interface{}) string {
+ switch e.(type) {
+ case *events.TaskCreate:
+ return runtime.TaskCreateEventTopic
+ case *events.TaskStart:
+ return runtime.TaskStartEventTopic
+ case *events.TaskOOM:
+ return runtime.TaskOOMEventTopic
+ case *events.TaskExit:
+ return runtime.TaskExitEventTopic
+ case *events.TaskDelete:
+ return runtime.TaskDeleteEventTopic
+ case *events.TaskExecAdded:
+ return runtime.TaskExecAddedEventTopic
+ case *events.TaskExecStarted:
+ return runtime.TaskExecStartedEventTopic
+ default:
+ log.L.Printf("no topic for type %#v", e)
+ }
+ return runtime.TaskUnknownTopic
+}
+
+func newInit(ctx context.Context, path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options.Options, rootfs string) (*proc.Init, error) {
+ spec, err := utils.ReadSpec(r.Bundle)
+ if err != nil {
+ return nil, fmt.Errorf("read oci spec: %w", err)
+ }
+ if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil {
+ return nil, fmt.Errorf("update volume annotations: %w", err)
+ }
+ runsc.FormatLogPath(r.ID, options.RunscConfig)
+ runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig)
+ p := proc.New(r.ID, runtime, stdio.Stdio{
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Terminal: r.Terminal,
+ })
+ p.Bundle = r.Bundle
+ p.Platform = platform
+ p.Rootfs = rootfs
+ p.WorkDir = workDir
+ p.IoUID = int(options.IoUid)
+ p.IoGID = int(options.IoGid)
+ p.Sandbox = specutils.SpecContainerType(spec) == specutils.ContainerTypeSandbox
+ p.UserLog = utils.UserLogPath(spec)
+ p.Monitor = reaper.Default
+ return p, nil
+}
diff --git a/pkg/shim/v2/service_linux.go b/pkg/shim/v2/service_linux.go
new file mode 100644
index 000000000..1800ab90b
--- /dev/null
+++ b/pkg/shim/v2/service_linux.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package v2
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "sync"
+ "syscall"
+
+ "github.com/containerd/console"
+ "github.com/containerd/fifo"
+)
+
+type linuxPlatform struct {
+ epoller *console.Epoller
+}
+
+func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) {
+ if p.epoller == nil {
+ return nil, fmt.Errorf("uninitialized epoller")
+ }
+
+ epollConsole, err := p.epoller.Add(console)
+ if err != nil {
+ return nil, err
+ }
+
+ if stdin != "" {
+ in, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(epollConsole, in, *p)
+ }()
+ }
+
+ outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(outw, epollConsole, *p)
+ epollConsole.Close()
+ outr.Close()
+ outw.Close()
+ wg.Done()
+ }()
+ return epollConsole, nil
+}
+
+func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error {
+ if p.epoller == nil {
+ return fmt.Errorf("uninitialized epoller")
+ }
+ epollConsole, ok := cons.(*console.EpollConsole)
+ if !ok {
+ return fmt.Errorf("expected EpollConsole, got %#v", cons)
+ }
+ return epollConsole.Shutdown(p.epoller.CloseConsole)
+}
+
+func (p *linuxPlatform) Close() error {
+ return p.epoller.Close()
+}
+
+// initialize a single epoll fd to manage our consoles. `initPlatform` should
+// only be called once.
+func (s *service) initPlatform() error {
+ if s.platform != nil {
+ return nil
+ }
+ epoller, err := console.NewEpoller()
+ if err != nil {
+ return fmt.Errorf("failed to initialize epoller: %w", err)
+ }
+ s.platform = &linuxPlatform{
+ epoller: epoller,
+ }
+ go epoller.Wait()
+ return nil
+}
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index e131455f7..ae0fe1522 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -12,6 +12,7 @@ go_library(
"sleep_unsafe.go",
],
visibility = ["//:sandbox"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go
index af47e2ba1..1dd11707d 100644
--- a/pkg/sleep/sleep_test.go
+++ b/pkg/sleep/sleep_test.go
@@ -379,10 +379,7 @@ func TestRace(t *testing.T) {
// TestRaceInOrder tests that multiple wakers can continuously send wake requests to
// the sleeper and that the wakers are retrieved in the order asserted.
func TestRaceInOrder(t *testing.T) {
- const wakers = 100
- const wakeRequests = 10000
-
- w := make([]Waker, wakers)
+ w := make([]Waker, 10000)
s := Sleeper{}
// Associate each waker and start goroutines that will assert them.
@@ -390,19 +387,16 @@ func TestRaceInOrder(t *testing.T) {
s.AddWaker(&w[i], i)
}
go func() {
- n := 0
- for n < wakeRequests {
- wk := w[n%len(w)]
- wk.Assert()
- n++
+ for i := range w {
+ w[i].Assert()
}
}()
// Wait for all wake up notifications from all wakers.
- for i := 0; i < wakeRequests; i++ {
- v, _ := s.Fetch(true)
- if got, want := v, i%wakers; got != want {
- t.Fatalf("got %d want %d", got, want)
+ for want := range w {
+ got, _ := s.Fetch(true)
+ if got != want {
+ t.Fatalf("got %d want %d", got, want)
}
}
}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index f68c12620..118805492 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -75,6 +75,8 @@ package sleep
import (
"sync/atomic"
"unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
@@ -323,7 +325,12 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
//
// This struct is thread-safe, that is, its methods can be called concurrently
// by multiple goroutines.
+//
+// Note, it is not safe to copy a Waker as its fields are modified by value
+// (the pointer fields are individually modified with atomic operations).
type Waker struct {
+ _ sync.NoCopy
+
// s is the sleeper that this waker can wake up. Only one sleeper at a
// time is allowed. This field can have three classes of values:
// nil -- the waker is not asserted: it either is not associated with
diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD
index d0d77e19c..4d47207f7 100644
--- a/pkg/sync/BUILD
+++ b/pkg/sync/BUILD
@@ -33,6 +33,7 @@ go_library(
"aliases.go",
"memmove_unsafe.go",
"mutex_unsafe.go",
+ "nocopy.go",
"norace_unsafe.go",
"race_unsafe.go",
"rwmutex_unsafe.go",
diff --git a/pkg/sync/nocopy.go b/pkg/sync/nocopy.go
new file mode 100644
index 000000000..722b29501
--- /dev/null
+++ b/pkg/sync/nocopy.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sync
+
+// NoCopy may be embedded into structs which must not be copied
+// after the first use.
+//
+// See https://golang.org/issues/8005#issuecomment-190753527
+// for details.
+type NoCopy struct{}
+
+// Lock is a no-op used by -copylocks checker from `go vet`.
+func (*NoCopy) Lock() {}
+
+// Unlock is a no-op used by -copylocks checker from `go vet`.
+func (*NoCopy) Unlock() {}
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 8ff922c69..5ae10939d 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -22,7 +22,7 @@ import (
// Mapping for tcpip.Error types.
var (
ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL)
- ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL)
+ ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV)
ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV)
ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT)
ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST)
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 0cde694dc..d87797617 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -48,7 +48,7 @@ go_test(
"//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
@@ -64,6 +64,6 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go
index 718a4720a..83189676e 100644
--- a/pkg/tcpip/header/arp.go
+++ b/pkg/tcpip/header/arp.go
@@ -14,14 +14,33 @@
package header
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
const (
// ARPProtocolNumber is the ARP network protocol number.
ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
// ARPSize is the size of an IPv4-over-Ethernet ARP packet.
- ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
+ ARPSize = 28
+)
+
+// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header.
+type ARPHardwareType uint16
+
+// Typical ARP HardwareType values. Some of the constants have to be specific
+// values as they are egressed on the wire in the HTYPE field of an ARP header.
+const (
+ ARPHardwareNone ARPHardwareType = 0
+ // ARPHardwareEther specifically is the HTYPE for Ethernet as specified
+ // in the IANA list here:
+ //
+ // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2
+ ARPHardwareEther ARPHardwareType = 1
+ ARPHardwareLoopback ARPHardwareType = 2
)
// ARPOp is an ARP opcode.
@@ -36,54 +55,64 @@ const (
// ARP is an ARP packet stored in a byte array as described in RFC 826.
type ARP []byte
-func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
-func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
-func (a ARP) hardwareAddressSize() int { return int(a[4]) }
-func (a ARP) protocolAddressSize() int { return int(a[5]) }
+const (
+ hTypeOffset = 0
+ protocolOffset = 2
+ haAddressSizeOffset = 4
+ protoAddressSizeOffset = 5
+ opCodeOffset = 6
+ senderHAAddressOffset = 8
+ senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize
+ targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize
+ targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize
+)
+
+func (a ARP) hardwareAddressType() ARPHardwareType {
+ return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:]))
+}
+
+func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) }
+func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) }
+func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) }
// Op is the ARP opcode.
-func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
+func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) }
// SetOp sets the ARP opcode.
func (a ARP) SetOp(op ARPOp) {
- a[6] = uint8(op >> 8)
- a[7] = uint8(op)
+ binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op))
}
// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
func (a ARP) SetIPv4OverEthernet() {
- a[0], a[1] = 0, 1 // htypeEthernet
- a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
- a[4] = 6 // macSize
- a[5] = uint8(IPv4AddressSize)
+ binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther))
+ binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber))
+ a[haAddressSizeOffset] = EthernetAddressSize
+ a[protoAddressSizeOffset] = uint8(IPv4AddressSize)
}
// HardwareAddressSender is the link address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressSender() []byte {
- const s = 8
- return a[s : s+6]
+ return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressSender is the protocol address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressSender() []byte {
- const s = 8 + 6
- return a[s : s+4]
+ return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize]
}
// HardwareAddressTarget is the link address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressTarget() []byte {
- const s = 8 + 6 + 4
- return a[s : s+6]
+ return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressTarget is the protocol address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressTarget() []byte {
- const s = 8 + 6 + 4 + 6
- return a[s : s+4]
+ return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize]
}
// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
@@ -91,10 +120,8 @@ func (a ARP) IsValid() bool {
if len(a) < ARPSize {
return false
}
- const htypeEthernet = 1
- const macSize = 6
- return a.hardwareAddressSpace() == htypeEthernet &&
+ return a.hardwareAddressType() == ARPHardwareEther &&
a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
- a.hardwareAddressSize() == macSize &&
+ a.hardwareAddressSize() == EthernetAddressSize &&
a.protocolAddressSize() == IPv4AddressSize
}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index b1e92d2d7..eaface8cb 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -53,6 +53,10 @@ const (
// (all bits set to 0).
unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+ // EthernetBroadcastAddress is an ethernet address that addresses every node
+ // on a local link.
+ EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff")
+
// unicastMulticastFlagMask is the mask of the least significant bit in
// the first octet (in network byte order) of an ethernet address that
// determines whether the ethernet address is a unicast or multicast. If
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 7908c5744..1a631b31a 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -72,6 +72,7 @@ const (
// Values for ICMP code as defined in RFC 792.
const (
ICMPv4TTLExceeded = 0
+ ICMPv4HostUnreachable = 1
ICMPv4PortUnreachable = 3
ICMPv4FragmentationNeeded = 4
)
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index c7ee2de57..a13b4b809 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -110,9 +110,16 @@ const (
ICMPv6RedirectMsg ICMPv6Type = 137
)
-// Values for ICMP code as defined in RFC 4443.
+// Values for ICMP destination unreachable code as defined in RFC 4443 section
+// 3.1.
const (
- ICMPv6PortUnreachable = 4
+ ICMPv6NetworkUnreachable = 0
+ ICMPv6Prohibited = 1
+ ICMPv6BeyondScope = 2
+ ICMPv6AddressUnreachable = 3
+ ICMPv6PortUnreachable = 4
+ ICMPv6Policy = 5
+ ICMPv6RejectRoute = 6
)
// Type is the ICMP type field.
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index b8b93e78e..39ca774ef 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 20b183da0..e12a5929b 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -296,3 +297,12 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
e.q.RemoveNotify(handle)
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index aa6db9aea..507b44abc 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -15,6 +15,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/binary",
+ "//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index f34082e1a..c18bb91fb 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -45,6 +45,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -385,26 +386,35 @@ const (
_VIRTIO_NET_HDR_GSO_TCPV6 = 4
)
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
// Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
}
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if e.hdrSize > 0 {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
+
+ var builder iovec.Builder
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
@@ -430,47 +440,28 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
}
vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
- return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView())
+ builder.Add(vnetHdrBuf)
}
- if pkt.Data.Size() == 0 {
- return rawfile.NonBlockingWrite(fd, pkt.Header.View())
- }
- if pkt.Header.UsedLength() == 0 {
- return rawfile.NonBlockingWrite(fd, pkt.Data.ToView())
+ builder.Add(pkt.Header.View())
+ for _, v := range pkt.Data.Views() {
+ builder.Add(v)
}
- return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil)
+ return rawfile.NonBlockingWriteIovec(fd, builder.Build())
}
func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) {
// Send a batch of packets through batchFD.
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
- var ethHdrBuf []byte
- iovLen := 0
if e.hdrSize > 0 {
- // Add ethernet header if needed.
- ethHdrBuf = make([]byte, header.EthernetMinimumSize)
- eth := header.Ethernet(ethHdrBuf)
- ethHdr := &header.EthernetFields{
- DstAddr: pkt.EgressRoute.RemoteLinkAddress,
- Type: pkt.NetworkProtocolNumber,
- }
-
- // Preserve the src address if it's set in the route.
- if pkt.EgressRoute.LocalLinkAddress != "" {
- ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
- iovLen++
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
- vnetHdr := virtioNetHdr{}
var vnetHdrBuf []byte
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ vnetHdr := virtioNetHdr{}
if pkt.GSOOptions != nil {
vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
if pkt.GSOOptions.NeedsCsum {
@@ -491,45 +482,19 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
}
}
vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
- iovLen++
}
- iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views()))
+ var builder iovec.Builder
+ builder.Add(vnetHdrBuf)
+ builder.Add(pkt.Header.View())
+ for _, v := range pkt.Data.Views() {
+ builder.Add(v)
+ }
+ iovecs := builder.Build()
+
var mmsgHdr rawfile.MMsgHdr
mmsgHdr.Msg.Iov = &iovecs[0]
- iovecIdx := 0
- if vnetHdrBuf != nil {
- v := &iovecs[iovecIdx]
- v.Base = &vnetHdrBuf[0]
- v.Len = uint64(len(vnetHdrBuf))
- iovecIdx++
- }
- if ethHdrBuf != nil {
- v := &iovecs[iovecIdx]
- v.Base = &ethHdrBuf[0]
- v.Len = uint64(len(ethHdrBuf))
- iovecIdx++
- }
- pktSize := uint64(0)
- // Encode L3 Header
- v := &iovecs[iovecIdx]
- hdr := &pkt.Header
- hdrView := hdr.View()
- v.Base = &hdrView[0]
- v.Len = uint64(len(hdrView))
- pktSize += v.Len
- iovecIdx++
-
- // Now encode the Transport Payload.
- pktViews := pkt.Data.Views()
- for i := range pktViews {
- vec := &iovecs[iovecIdx]
- iovecIdx++
- vec.Base = &pktViews[i][0]
- vec.Len = uint64(len(pktViews[i]))
- pktSize += vec.Len
- }
- mmsgHdr.Msg.Iovlen = uint64(iovecIdx)
+ mmsgHdr.Msg.Iovlen = uint64(len(iovecs))
mmsgHdrs = append(mmsgHdrs, mmsgHdr)
}
@@ -626,6 +591,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
return e.gsoMaxSize
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
+
// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes
// to the FD, but does not read from it. All reads come from injected packets.
type InjectableEndpoint struct {
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index eaee7e5d7..7b995b85a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin
c.ch <- packetInfo{remote, protocol, pkt}
}
+func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNoEthernetProperties(t *testing.T) {
c := newContext(t, &Options{MTU: mtu})
defer c.cleanup()
@@ -500,3 +504,80 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) {
}
}
+
+// fakeNetworkDispatcher delivers packets to pkts.
+type fakeNetworkDispatcher struct {
+ pkts []*stack.PacketBuffer
+}
+
+func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ d.pkts = append(d.pkts, pkt)
+}
+
+func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
+func TestDispatchPacketFormat(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ newDispatcher func(fd int, e *endpoint) (linkDispatcher, error)
+ }{
+ {
+ name: "readVDispatcher",
+ newDispatcher: newReadVDispatcher,
+ },
+ {
+ name: "recvMMsgDispatcher",
+ newDispatcher: newRecvMMsgDispatcher,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ // Create a socket pair to send/recv.
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer syscall.Close(fds[0])
+ defer syscall.Close(fds[1])
+
+ data := []byte{
+ // Ethernet header.
+ 1, 2, 3, 4, 5, 60,
+ 1, 2, 3, 4, 5, 61,
+ 8, 0,
+ // Mock network header.
+ 40, 41, 42, 43,
+ }
+ err = syscall.Sendmsg(fds[1], data, nil, nil, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create and run dispatcher once.
+ sink := &fakeNetworkDispatcher{}
+ d, err := test.newDispatcher(fds[0], &endpoint{
+ hdrSize: header.EthernetMinimumSize,
+ dispatcher: sink,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if ok, err := d.dispatch(); !ok || err != nil {
+ t.Fatalf("d.dispatch() = %v, %v", ok, err)
+ }
+
+ // Verify packet.
+ if got, want := len(sink.pkts), 1; got != want {
+ t.Fatalf("len(sink.pkts) = %d, want %d", got, want)
+ }
+ pkt := sink.pkts[0]
+ if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want {
+ t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want)
+ }
+ if got, want := pkt.Data.Size(), 4; got != want {
+ t.Errorf("pkt.Data.Size() = %d, want %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index f04738cfb..d8f2504b3 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -278,7 +278,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(d.views[k][0])
+ eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize])
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 568c6874f..781cdd317 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -113,3 +113,11 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
return nil
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareLoopback
+}
+
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index 82b441b79..e7493e5c5 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -9,6 +9,7 @@ go_library(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index c69d6b7e9..56a611825 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -18,6 +18,7 @@ package muxed
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -129,6 +130,15 @@ func (m *InjectableEndpoint) Wait() {
}
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unsupported operation")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
index bdd5276ad..2cdb23475 100644
--- a/pkg/tcpip/link/nested/BUILD
+++ b/pkg/tcpip/link/nested/BUILD
@@ -12,6 +12,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 2998f9c4f..d40de54df 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -60,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
}
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverOutboundPacket(remote, local, protocol, pkt)
+ }
+}
+
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
@@ -129,3 +140,13 @@ func (e *Endpoint) GSOMaxSize() uint32 {
}
return 0
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.child.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.child.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
index c1a219f02..7d9249c1c 100644
--- a/pkg/tcpip/link/nested/nested_test.go
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd
d.count++
}
+func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNestedLinkEndpoint(t *testing.T) {
const emptyAddress = tcpip.LinkAddress("")
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
new file mode 100644
index 000000000..6fff160ce
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "packetsocket",
+ srcs = ["endpoint.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
new file mode 100644
index 000000000..3922c2a04
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packetsocket provides a link layer endpoint that provides the ability
+// to loop outbound packets to any AF_PACKET sockets that may be interested in
+// the outgoing packet.
+package packetsocket
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ nested.Endpoint
+}
+
+// New creates a new packetsocket LinkEndpoint.
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ e := &endpoint{}
+ e.Endpoint.Init(lower, e)
+ return e
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
index 054c213bc..1d0079bd6 100644
--- a/pkg/tcpip/link/qdisc/fifo/BUILD
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index b5dfb7850..467083239 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -106,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
@@ -193,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3267/): Queue these packets as well once
+ // WriteRawPacket takes PacketBuffer instead of VectorisedView.
return e.lower.WriteRawPacket(vv)
}
@@ -207,3 +215,13 @@ func (e *endpoint) Wait() {
e.wg.Wait()
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 44e25d475..f4c32c2da 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -66,39 +66,14 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
return nil
}
-// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a
-// single syscall. It fails if partial data is written.
-func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
- // If the is no second buffer, issue a regular write.
- if len(b2) == 0 {
- return NonBlockingWrite(fd, b1)
- }
-
- // We have two buffers. Build the iovec that represents them and issue
- // a writev syscall.
- iovec := [3]syscall.Iovec{
- {
- Base: &b1[0],
- Len: uint64(len(b1)),
- },
- {
- Base: &b2[0],
- Len: uint64(len(b2)),
- },
- }
- iovecLen := uintptr(2)
-
- if len(b3) > 0 {
- iovecLen++
- iovec[2].Base = &b3[0]
- iovec[2].Len = uint64(len(b3))
- }
-
+// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall.
+// It fails if partial data is written.
+func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error {
+ iovecLen := uintptr(len(iovec))
_, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
if e != 0 {
return TranslateErrno(e)
}
-
return nil
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 0374a2441..507c76b76 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Add the ethernet header here.
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
v := pkt.Data.ToView()
// Transmit the packet.
@@ -287,3 +294,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
e.completed.Done()
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 28a2e88ba..8f3cd9449 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L
c.packetCh <- struct{}{}
}
+func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (c *testContext) cleanup() {
c.ep.Close()
closeFDs(&c.txCfg)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index d9cd4e83a..509076643 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -123,6 +123,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 6bc9033d0..04ae58e59 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -139,6 +139,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
stack: s,
nicID: id,
name: name,
+ isTap: prefix == "tap",
}
endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
@@ -271,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader == nil {
- hdr := &header.EthernetFields{
- SrcAddr: info.Route.LocalLinkAddress,
- DstAddr: info.Route.RemoteLinkAddress,
- Type: info.Proto,
- }
- if hdr.SrcAddr == "" {
- hdr.SrcAddr = d.endpoint.LinkAddress()
- }
-
- eth := make(header.Ethernet, header.EthernetMinimumSize)
- eth.Encode(hdr)
- vv.AppendView(buffer.View(eth))
- } else {
- vv.AppendView(info.Pkt.LinkHeader)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
}
+ vv.AppendView(info.Pkt.LinkHeader)
}
// Append upper headers.
@@ -348,6 +337,7 @@ type tunEndpoint struct {
stack *stack.Stack
nicID tcpip.NICID
name string
+ isTap bool
}
// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
@@ -356,3 +346,38 @@ func (e *tunEndpoint) DecRef() {
e.stack.RemoveNIC(e.nicID)
})
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.isTap {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.isTap {
+ return
+ }
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
+ hdr := &header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: protocol,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = e.LinkAddress()
+ }
+
+ eth.Encode(hdr)
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header.
+func (e *tunEndpoint) MaxHeaderLength() uint16 {
+ if e.isTap {
+ return header.EthernetMinimumSize
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 0956d2c65..ee84c3d96 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -12,6 +12,7 @@ go_library(
"//pkg/gate",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
@@ -25,6 +26,7 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 949b3f2b2..b152a0f26 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -59,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatchGate.Leave()
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
// registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
@@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() {
// Wait implements stack.LinkEndpoint.Wait.
func (e *Endpoint) Wait() {}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 63bf40562..c448a888f 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -39,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress,
e.dispatchCount++
}
+func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.attachCount++
e.dispatcher = dispatcher
@@ -81,9 +86,19 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
return nil
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unimplemented")
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7f27a840d..b0f57040c 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -162,7 +162,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
- RemoteLinkAddress: broadcastMAC,
+ RemoteLinkAddress: header.EthernetBroadcastAddress,
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
@@ -181,7 +181,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
- return broadcastMAC, true
+ return header.EthernetBroadcastAddress, true
}
if header.IsV4MulticastAddress(addr) {
return header.EthernetAddressFromMulticastIPv4Address(addr), true
@@ -216,8 +216,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return 0, false, true
}
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-
// NewProtocol returns an ARP network protocol.
func NewProtocol() stack.NetworkProtocol {
return &protocol{}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 7c8fb3e0a..615bae648 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -172,14 +172,24 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
-func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testObject) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 78420d6e6..d142b4ffa 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -34,6 +34,6 @@ go_test(
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 1b67aa066..83e71cb8c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -129,6 +129,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
+ case header.ICMPv4HostUnreachable:
+ e.handleControl(stack.ControlNoRoute, 0, pkt)
+
case header.ICMPv4PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 7e9f16c90..b1776e5ee 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -225,12 +225,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
length := uint16(hdr.UsedLength() + payloadSize)
- id := uint32(0)
- if length > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
- }
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
+ // datagrams. Since the DF bit is never being set here, all datagrams
+ // are non-atomic and need an ID.
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: length,
@@ -376,13 +374,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// Set the packet ID when zero.
if ip.ID() == 0 {
- id := uint32(0)
- if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for
+ // non-atomic datagrams, so assign an ID to all such datagrams
+ // according to the definition given in RFC 6864 section 4.
+ if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
- ip.SetID(uint16(id))
}
// Always set the checksum.
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 3f71fc520..feada63dc 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -39,6 +39,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 2ff7eedf4..ff1cb53dd 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -128,6 +128,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6NetworkUnreachable:
+ e.handleControl(stack.ControlNetworkUnreachable, 0, pkt)
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
@@ -494,8 +496,6 @@ const (
icmpV6LengthOffset = 25
)
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 794ddb5c8..6b9a6b316 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -27,6 +27,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "tuple_list",
+ out = "tuple_list.go",
+ package = "stack",
+ prefix = "tuple",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*tuple",
+ "Linker": "*tuple",
+ },
+)
+
go_library(
name = "stack",
srcs = [
@@ -35,6 +47,7 @@ go_library(
"forwarder.go",
"icmp_rate_limit.go",
"iptables.go",
+ "iptables_state.go",
"iptables_targets.go",
"iptables_types.go",
"linkaddrcache.go",
@@ -50,6 +63,7 @@ go_library(
"stack_global_state.go",
"stack_options.go",
"transport_demuxer.go",
+ "tuple_list.go",
],
visibility = ["//visibility:public"],
deps = [
@@ -79,6 +93,7 @@ go_test(
"transport_demuxer_test.go",
"transport_test.go",
],
+ shard_count = 20,
deps = [
":stack",
"//pkg/rand",
@@ -94,7 +109,7 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index af9c325ca..559a1c4dd 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -15,9 +15,12 @@
package stack
import (
+ "encoding/binary"
"sync"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
)
@@ -30,6 +33,10 @@ import (
//
// Currently, only TCP tracking is supported.
+// Our hash table has 16K buckets.
+// TODO(gvisor.dev/issue/170): These should be tunable.
+const numBuckets = 1 << 14
+
// Direction of the tuple.
type direction int
@@ -42,13 +49,19 @@ const (
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
// tuple holds a connection's identifying and manipulating data in one
// direction. It is immutable.
+//
+// +stateify savable
type tuple struct {
+ // tupleEntry is used to build an intrusive list of tuples.
+ tupleEntry
+
tupleID
// conn is the connection tracking entry this tuple belongs to.
@@ -61,6 +74,8 @@ type tuple struct {
// tupleID uniquely identifies a connection in one direction. It currently
// contains enough information to distinguish between any TCP or UDP
// connection, and will need to be extended to support other protocols.
+//
+// +stateify savable
type tupleID struct {
srcAddr tcpip.Address
srcPort uint16
@@ -83,6 +98,8 @@ func (ti tupleID) reply() tupleID {
}
// conn is a tracked connection.
+//
+// +stateify savable
type conn struct {
// original is the tuple in original direction. It is immutable.
original tuple
@@ -97,23 +114,84 @@ type conn struct {
// update the state of tcb. It is immutable.
tcbHook Hook
- // mu protects tcb.
- mu sync.Mutex
-
+ // mu protects all mutable state.
+ mu sync.Mutex `state:"nosave"`
// tcb is TCB control block. It is used to keep track of states
// of tcp connection and is protected by mu.
tcb tcpconntrack.TCB
+ // lastUsed is the last time the connection saw a relevant packet, and
+ // is updated by each packet on the connection. It is protected by mu.
+ lastUsed time.Time `state:".(unixTime)"`
+}
+
+// timedOut returns whether the connection timed out based on its state.
+func (cn *conn) timedOut(now time.Time) bool {
+ const establishedTimeout = 5 * 24 * time.Hour
+ const defaultTimeout = 120 * time.Second
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+ if cn.tcb.State() == tcpconntrack.ResultAlive {
+ // Use the same default as Linux, which doesn't delete
+ // established connections for 5(!) days.
+ return now.Sub(cn.lastUsed) > establishedTimeout
+ }
+ // Use the same default as Linux, which lets connections in most states
+ // other than established remain for <= 120 seconds.
+ return now.Sub(cn.lastUsed) > defaultTimeout
+}
+
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
}
// ConnTrack tracks all connections created for NAT rules. Most users are
-// expected to only call handlePacket and createConnFor.
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
+//
+// ConnTrack keeps all connections in a slice of buckets, each of which holds a
+// linked list of tuples. This gives us some desirable properties:
+// - Each bucket has its own lock, lessening lock contention.
+// - The slice is large enough that lists stay short (<10 elements on average).
+// Thus traversal is fast.
+// - During linked list traversal we reap expired connections. This amortizes
+// the cost of reaping them and makes reapUnused faster.
+//
+// Locks are ordered by their location in the buckets slice. That is, a
+// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
+//
+// +stateify savable
type ConnTrack struct {
- // mu protects conns.
- mu sync.RWMutex
+ // seed is a one-time random value initialized at stack startup
+ // and is used in the calculation of hash keys for the list of buckets.
+ // It is immutable.
+ seed uint32
- // conns maintains a map of tuples needed for connection tracking for
- // iptables NAT rules. It is protected by mu.
- conns map[tupleID]tuple
+ // mu protects the buckets slice, but not buckets' contents. Only take
+ // the write lock if you are modifying the slice or saving for S/R.
+ mu sync.RWMutex `state:"nosave"`
+
+ // buckets is protected by mu.
+ buckets []bucket
+}
+
+// +stateify savable
+type bucket struct {
+ // mu protects tuples.
+ mu sync.Mutex `state:"nosave"`
+ tuples tupleList
}
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
@@ -143,8 +221,9 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
// newConn creates new connection.
func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
conn := conn{
- manip: manip,
- tcbHook: hook,
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
}
conn.original = tuple{conn: &conn, tupleID: orig}
conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
@@ -162,18 +241,31 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
return nil, dirOriginal
}
- ct.mu.Lock()
- defer ct.mu.Unlock()
-
- tuple, ok := ct.conns[tid]
- if !ok {
- return nil, dirOriginal
+ bucket := ct.bucket(tid)
+ now := time.Now()
+
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ ct.buckets[bucket].mu.Lock()
+ defer ct.buckets[bucket].mu.Unlock()
+
+ // Iterate over the tuples in a bucket, cleaning up any unused
+ // connections we find.
+ for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
+ // Clean up any timed-out connections we happen to find.
+ if ct.reapTupleLocked(other, bucket, now) {
+ // The tuple expired.
+ continue
+ }
+ if tid == other.tupleID {
+ return other.conn, other.direction
+ }
}
- return tuple.conn, tuple.direction
+
+ return nil, dirOriginal
}
-// createConnFor creates a new conn for pkt.
-func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -196,21 +288,59 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
manip = manipDstOutput
}
conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
+}
- // Add the changed tuple to the map.
- // TODO(gvisor.dev/issue/170): Need to support collisions using linked
- // list.
- ct.mu.Lock()
- defer ct.mu.Unlock()
- ct.conns[tid] = conn.original
- ct.conns[replyTID] = conn.reply
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
+ // Lock the buckets in the correct order.
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ if tupleBucket < replyBucket {
+ ct.buckets[tupleBucket].mu.Lock()
+ ct.buckets[replyBucket].mu.Lock()
+ } else if tupleBucket > replyBucket {
+ ct.buckets[replyBucket].mu.Lock()
+ ct.buckets[tupleBucket].mu.Lock()
+ } else {
+ // Both tuples are in the same bucket.
+ ct.buckets[tupleBucket].mu.Lock()
+ }
- return conn
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
+
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ }
+
+ // Unlocking can happen in any order.
+ ct.buckets[tupleBucket].mu.Unlock()
+ if tupleBucket != replyBucket {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -228,12 +358,22 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
netHeader.SetSourceAddress(conn.original.dstAddr)
}
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
// handlePacketOutput manipulates ports for packets in Output hook.
func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -268,20 +408,31 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
}
// handlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
+ }
+
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
if conn == nil {
- // Connection not found for the packet or the packet is invalid.
- return
+ return true
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if tcpHeader == nil {
+ return false
}
switch hook {
@@ -297,35 +448,159 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
// other tcp states.
conn.mu.Lock()
defer conn.mu.Unlock()
- var st tcpconntrack.Result
- tcpHeader := header.TCP(pkt.TransportHeader)
- if conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else {
- switch hook {
- case conn.tcbHook:
- st = conn.tcb.UpdateStateOutbound(tcpHeader)
- default:
- st = conn.tcb.UpdateStateInbound(tcpHeader)
- }
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ conn.lastUsed = time.Now()
+ // Update connection state.
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
}
- // Delete conn if tcp connection is closed.
- if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
- ct.deleteConn(conn)
+ // We only track TCP connections.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return
}
-}
-// deleteConn deletes the connection.
-func (ct *ConnTrack) deleteConn(conn *conn) {
- if conn == nil {
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
return
}
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+ ct.insertConn(conn)
+}
+
+// bucket gets the conntrack bucket for a tupleID.
+func (ct *ConnTrack) bucket(id tupleID) int {
+ h := jenkins.Sum32(ct.seed)
+ h.Write([]byte(id.srcAddr))
+ h.Write([]byte(id.dstAddr))
+ shortBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(shortBuf, id.srcPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, id.dstPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
+ h.Write([]byte(shortBuf))
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ return int(h.Sum32()) % len(ct.buckets)
+}
+
+// reapUnused deletes timed out entries from the conntrack map. The rules for
+// reaping are:
+// - Most reaping occurs in connFor, which is called on each packet. connFor
+// cleans up the bucket the packet's connection maps to. Thus calls to
+// reapUnused should be fast.
+// - Each call to reapUnused traverses a fraction of the conntrack table.
+// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
+// - After reaping, reapUnused decides when it should next run based on the
+// ratio of expired connections to examined connections. If the ratio is
+// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
+// slightly increases the interval between runs.
+// - maxFullTraversal caps the time it takes to traverse the entire table.
+//
+// reapUnused returns the next bucket that should be checked and the time after
+// which it should be called again.
+func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
+ // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
+ // it is in Linux via sysctl.
+ const fractionPerReaping = 128
+ const maxExpiredPct = 50
+ const maxFullTraversal = 60 * time.Second
+ const minInterval = 10 * time.Millisecond
+ const maxInterval = maxFullTraversal / fractionPerReaping
+
+ now := time.Now()
+ checked := 0
+ expired := 0
+ var idx int
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
+ idx = (i + start) % len(ct.buckets)
+ ct.buckets[idx].mu.Lock()
+ for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ checked++
+ if ct.reapTupleLocked(tuple, idx, now) {
+ expired++
+ }
+ }
+ ct.buckets[idx].mu.Unlock()
+ }
+ // We already checked buckets[idx].
+ idx++
+
+ // If half or more of the connections are expired, the table has gotten
+ // stale. Reschedule quickly.
+ expiredPct := 0
+ if checked != 0 {
+ expiredPct = expired * 100 / checked
+ }
+ if expiredPct > maxExpiredPct {
+ return idx, minInterval
+ }
+ if interval := prevInterval + minInterval; interval <= maxInterval {
+ // Increment the interval between runs.
+ return idx, interval
+ }
+ // We've hit the maximum interval.
+ return idx, maxInterval
+}
- ct.mu.Lock()
- defer ct.mu.Unlock()
+// reapTupleLocked tries to remove tuple and its reply from the table. It
+// returns whether the tuple's connection has timed out.
+//
+// Preconditions: ct.mu is locked for reading and bucket is locked.
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+ if !tuple.conn.timedOut(now) {
+ return false
+ }
+
+ // To maintain lock order, we can only reap these tuples if the reply
+ // appears later in the table.
+ replyBucket := ct.bucket(tuple.reply())
+ if bucket > replyBucket {
+ return true
+ }
+
+ // Don't re-lock if both tuples are in the same bucket.
+ differentBuckets := bucket != replyBucket
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Lock()
+ }
+
+ // We have the buckets locked and can remove both tuples.
+ if tuple.direction == dirOriginal {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ } else {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
+ }
+ ct.buckets[bucket].tuples.Remove(tuple)
+
+ // Don't re-unlock if both tuples are in the same bucket.
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
- delete(ct.conns, conn.original.tupleID)
- delete(ct.conns, conn.reply.tupleID)
+ return true
}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index a6546cef0..bca1d940b 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -301,6 +302,16 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er
// Wait implements stack.LinkEndpoint.Wait.
func (*fwdTestLinkEndpoint) Wait() {}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 974d77c36..cbbae4224 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -16,39 +16,49 @@ package stack
import (
"fmt"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// Table names.
+// tableID is an index into IPTables.tables.
+type tableID int
+
const (
- TablenameNat = "nat"
- TablenameMangle = "mangle"
- TablenameFilter = "filter"
+ natID tableID = iota
+ mangleID
+ filterID
+ numTables
)
-// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+// Table names.
const (
- ChainNamePrerouting = "PREROUTING"
- ChainNameInput = "INPUT"
- ChainNameForward = "FORWARD"
- ChainNameOutput = "OUTPUT"
- ChainNamePostrouting = "POSTROUTING"
+ NATTable = "nat"
+ MangleTable = "mangle"
+ FilterTable = "filter"
)
+// nameToID is immutable.
+var nameToID = map[string]tableID{
+ NATTable: natID,
+ MangleTable: mangleID,
+ FilterTable: filterID,
+}
+
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
+// reaperDelay is how long to wait before starting to reap connections.
+const reaperDelay = 5 * time.Second
+
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables() *IPTables {
- // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
- // iotas.
return &IPTables{
- tables: map[string]Table{
- TablenameNat: Table{
+ tables: [numTables]Table{
+ natID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
@@ -56,64 +66,71 @@ func DefaultTables() *IPTables {
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- Underflows: map[Hook]int{
+ Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- UserChains: map[string]int{},
},
- TablenameMangle: Table{
+ mangleID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
- Underflows: map[Hook]int{
- Prerouting: 0,
- Output: 1,
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
- TablenameFilter: Table{
+ filterID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
},
- priorities: map[Hook][]string{
- Input: []string{TablenameNat, TablenameFilter},
- Prerouting: []string{TablenameMangle, TablenameNat},
- Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ priorities: [NumHooks][]tableID{
+ Prerouting: []tableID{mangleID, natID},
+ Input: []tableID{natID, filterID},
+ Output: []tableID{mangleID, natID, filterID},
},
connections: ConnTrack{
- conns: make(map[tupleID]tuple),
+ seed: generateRandUint32(),
},
+ reaperDone: make(chan struct{}, 1),
}
}
@@ -122,62 +139,59 @@ func DefaultTables() *IPTables {
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// EmptyNatTable returns a Table with no rules and the filter table chains
+// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
-func EmptyNatTable() Table {
+func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Forward: HookUnset,
},
- Underflows: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ Underflows: [NumHooks]int{
+ Forward: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// GetTable returns table by name.
+// GetTable returns a table by name.
func (it *IPTables) GetTable(name string) (Table, bool) {
+ id, ok := nameToID[name]
+ if !ok {
+ return Table{}, false
+ }
it.mu.RLock()
defer it.mu.RUnlock()
- t, ok := it.tables[name]
- return t, ok
+ return it.tables[id], true
}
// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table) {
+func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+ id, ok := nameToID[name]
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
it.mu.Lock()
defer it.mu.Unlock()
+ // If iptables is being enabled, initialize the conntrack table and
+ // reaper.
+ if !it.modified {
+ it.connections.buckets = make([]bucket, numBuckets)
+ it.startReaper(reaperDelay)
+ }
it.modified = true
- it.tables[name] = table
-}
-
-// GetPriorities returns slice of priorities associated with hook.
-func (it *IPTables) GetPriorities(hook Hook) []string {
- it.mu.RLock()
- defer it.mu.RUnlock()
- return it.priorities[hook]
+ it.tables[id] = table
+ return nil
}
// A chainVerdict is what a table decides should be done with a packet.
@@ -212,11 +226,19 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
// Packets are manipulated only if connection and matching
// NAT rule exists.
- it.connections.handlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.GetPriorities(hook) {
- table, _ := it.GetTable(tablename)
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ priorities := it.priorities[hook]
+ for _, tableID := range priorities {
+ // If handlePacket already NATed the packet, we don't need to
+ // check the NAT table.
+ if tableID == natID && pkt.NatDone {
+ continue
+ }
+ table := it.tables[tableID]
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
@@ -245,17 +267,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
}
}
+ // If this connection should be tracked, try to add an entry for it. If
+ // traversing the nat table didn't end in adding an entry,
+ // maybeInsertNoop will add a no-op entry for the connection. This is
+ // needeed when establishing connections so that the SYN/ACK reply to an
+ // outgoing SYN is delivered to the correct endpoint rather than being
+ // redirected by a prerouting rule.
+ //
+ // From the iptables documentation: "If there is no rule, a `null'
+ // binding is created: this usually does not map the packet, but exists
+ // to ensure we don't map another stream over an existing one."
+ if shouldTrack {
+ it.connections.maybeInsertNoop(pkt, hook)
+ }
+
// Every table returned Accept.
return true
}
+// beforeSave is invoked by stateify.
+func (it *IPTables) beforeSave() {
+ // Ensure the reaper exits cleanly.
+ it.reaperDone <- struct{}{}
+ // Prevent others from modifying the connection table.
+ it.connections.mu.Lock()
+}
+
+// afterLoad is invoked by stateify.
+func (it *IPTables) afterLoad() {
+ it.startReaper(reaperDelay)
+}
+
+// startReaper starts a goroutine that wakes up periodically to reap timed out
+// connections.
+func (it *IPTables) startReaper(interval time.Duration) {
+ go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved.
+ bucket := 0
+ for {
+ select {
+ case <-it.reaperDone:
+ return
+ case <-time.After(interval):
+ bucket, interval = it.connections.reapUnused(bucket, interval)
+ }
+ }
+ }()
+}
+
// CheckPackets runs pkts through the rules for hook and returns a map of packets that
// should not go forward.
//
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-//
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
@@ -279,9 +343,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
return drop, natPkts
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
@@ -326,23 +390,12 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
return chainDrop
}
-// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
-// precondition.
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
- // If pkt.NetworkHeader hasn't been set yet, it will be contained in
- // pkt.Data.
- if pkt.NetworkHeader == nil {
- var ok bool
- pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- // Precondition has been violated.
- panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize))
- }
- }
-
// Check whether the packet matches the IP header filter.
if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) {
// Continue on to the next rule.
diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go
new file mode 100644
index 000000000..529e02a07
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "time"
+)
+
+// +stateify savable
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastUsed is invoked by stateify.
+func (cn *conn) saveLastUsed() unixTime {
+ return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
+}
+
+// loadLastUsed is invoked by stateify.
+func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.lastUsed = time.Unix(unix.second, unix.nano)
+}
+
+// beforeSave is invoked by stateify.
+func (ct *ConnTrack) beforeSave() {
+ ct.mu.Lock()
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index d43f60c67..dc88033c7 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -153,7 +153,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.createConnFor(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index c528ec381..73274ada9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -78,18 +78,20 @@ const (
)
// IPTables holds all the tables for a netstack.
+//
+// +stateify savable
type IPTables struct {
// mu protects tables, priorities, and modified.
mu sync.RWMutex
- // tables maps table names to tables. User tables have arbitrary names.
- // mu needs to be locked for accessing.
- tables map[string]Table
+ // tables maps tableIDs to tables. Holds builtin tables only, not user
+ // tables. mu must be locked for accessing.
+ tables [numTables]Table
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. mu needs to be locked for accessing.
- priorities map[Hook][]string
+ priorities [NumHooks][]tableID
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
@@ -97,31 +99,34 @@ type IPTables struct {
modified bool
connections ConnTrack
+
+ // reaperDone can be signalled to stop the reaper goroutine.
+ reaperDone chan struct{}
}
// A Table defines a set of chains and hooks into the network stack. It is
// really just a list of rules.
+//
+// +stateify savable
type Table struct {
// Rules holds the rules that make up the table.
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
- BuiltinChains map[Hook]int
+ BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
- Underflows map[Hook]int
-
- // UserChains holds user-defined chains for the keyed by name. Users
- // can give their chains arbitrary names.
- UserChains map[string]int
+ Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook := range table.BuiltinChains {
- hooks |= 1 << hook
+ for hook, ruleIdx := range table.BuiltinChains {
+ if ruleIdx != HookUnset {
+ hooks |= 1 << hook
+ }
}
return hooks
}
@@ -130,6 +135,8 @@ func (table *Table) ValidHooks() uint32 {
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
// applies to any packet.
+//
+// +stateify savable
type Rule struct {
// Filter holds basic IP filtering fields common to every rule.
Filter IPHeaderFilter
@@ -142,6 +149,8 @@ type Rule struct {
}
// IPHeaderFilter holds basic IP filtering data common to every rule.
+//
+// +stateify savable
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index e28c23d66..9dce11a97 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -469,7 +469,7 @@ type ndpState struct {
rtrSolicit struct {
// The timer used to send the next router solicitation message.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the Router Solicitation timer know that it has been stopped.
//
@@ -503,7 +503,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the DAD timer know that it has been stopped.
//
@@ -515,38 +515,38 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- // Timer to invalidate the default router.
+ // Job to invalidate the default router.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- // Timer to invalidate the on-link prefix.
+ // Job to invalidate the on-link prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// tempSLAACAddrState holds state associated with a temporary SLAAC address.
type tempSLAACAddrState struct {
- // Timer to deprecate the temporary SLAAC address.
+ // Job to deprecate the temporary SLAAC address.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the temporary SLAAC address.
+ // Job to invalidate the temporary SLAAC address.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
- // Timer to regenerate the temporary SLAAC address.
+ // Job to regenerate the temporary SLAAC address.
//
// Must not be nil.
- regenTimer *tcpip.CancellableTimer
+ regenJob *tcpip.Job
createdAt time.Time
@@ -561,15 +561,15 @@ type tempSLAACAddrState struct {
// slaacPrefixState holds state associated with a SLAAC prefix.
type slaacPrefixState struct {
- // Timer to deprecate the prefix.
+ // Job to deprecate the prefix.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the prefix.
+ // Job to invalidate the prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
// Nonzero only when the address is not valid forever.
validUntil time.Time
@@ -651,12 +651,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
var done bool
- var timer *time.Timer
+ var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
// cannot be done while holding the NIC's lock. This is effectively the same
// as starting a goroutine but we use a timer that fires immediately so we can
// reset it for the next DAD iteration.
- timer = time.AfterFunc(0, func() {
+ timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
ndp.nic.mu.Lock()
defer ndp.nic.mu.Unlock()
@@ -871,9 +871,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
- // the invalidation timer.
- rtr.invalidationTimer.StopLocked()
- rtr.invalidationTimer.Reset(rl)
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
@@ -950,7 +950,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.StopLocked()
+ rtr.invalidationJob.Cancel()
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
@@ -979,12 +979,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
state := defaultRouterState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
- state.invalidationTimer.Reset(rl)
+ state.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = state
}
@@ -1009,13 +1009,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
state := onLinkPrefixState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
if l < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(l)
+ state.invalidationJob.Schedule(l)
}
ndp.onLinkPrefixes[prefix] = state
@@ -1033,7 +1033,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- s.invalidationTimer.StopLocked()
+ s.invalidationJob.Cancel()
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
@@ -1082,14 +1082,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
//
- // Update the invalidation timer.
+ // Update the invalidation job.
- prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationJob.Cancel()
if vl < header.NDPInfiniteLifetime {
- // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
// the new valid lifetime.
- prefixState.invalidationTimer.Reset(vl)
+ prefixState.invalidationJob.Schedule(vl)
}
ndp.onLinkPrefixes[prefix] = prefixState
@@ -1154,7 +1154,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1162,7 +1162,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1184,19 +1184,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if !ndp.generateSLAACAddr(prefix, &state) {
// We were unable to generate an address for the prefix, we do not nothing
- // further as there is no reason to maintain state or timers for a prefix we
+ // further as there is no reason to maintain state or jobs for a prefix we
// do not have an address for.
return
}
- // Setup the initial timers to deprecate and invalidate prefix.
+ // Setup the initial jobs to deprecate and invalidate prefix.
if pl < header.NDPInfiniteLifetime && pl != 0 {
- state.deprecationTimer.Reset(pl)
+ state.deprecationJob.Schedule(pl)
}
if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
+ state.invalidationJob.Schedule(vl)
state.validUntil = now.Add(vl)
}
@@ -1428,7 +1428,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1441,7 +1441,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1454,7 +1454,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1481,9 +1481,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ref: ref,
}
- state.deprecationTimer.Reset(pl)
- state.invalidationTimer.Reset(vl)
- state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
prefixState.generationAttempts++
prefixState.tempAddrs[generatedAddr.Address] = state
@@ -1518,16 +1518,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
prefixState.stableAddr.ref.deprecated = false
}
- // If prefix was preferred for some finite lifetime before, stop the
- // deprecation timer so it can be reset.
- prefixState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
now := time.Now()
- // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
if !deprecated {
- prefixState.deprecationTimer.Reset(pl)
+ prefixState.deprecationJob.Schedule(pl)
}
prefixState.preferredUntil = now.Add(pl)
} else {
@@ -1546,9 +1546,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
if vl >= header.NDPInfiniteLifetime {
- // Handle the infinite valid lifetime separately as we do not keep a timer
- // in this case.
- prefixState.invalidationTimer.StopLocked()
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
prefixState.validUntil = time.Time{}
} else {
var effectiveVl time.Duration
@@ -1569,8 +1569,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
if effectiveVl != 0 {
- prefixState.invalidationTimer.StopLocked()
- prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
prefixState.validUntil = now.Add(effectiveVl)
}
}
@@ -1582,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// Note, we do not need to update the entries in the temporary address map
- // after updating the timers because the timers are held as pointers.
+ // after updating the jobs because the jobs are held as pointers.
var regenForAddr tcpip.Address
allAddressesRegenerated := true
for tempAddr, tempAddrState := range prefixState.tempAddrs {
@@ -1596,14 +1596,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
- // reset the invalidation timer.
+ // reset the invalidation job.
newValidLifetime := validUntil.Sub(now)
if newValidLifetime <= 0 {
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
continue
}
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.invalidationTimer.Reset(newValidLifetime)
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
// As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
// address is the lower of the preferred lifetime of the stable address or
@@ -1616,17 +1616,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer preferred, deprecate it immediately.
- // Otherwise, reset the deprecation timer.
+ // Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
- tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.ref)
} else {
tempAddrState.ref.deprecated = false
- tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.regenJob.Cancel()
if tempAddrState.regenerated {
} else {
allAddressesRegenerated = false
@@ -1637,7 +1637,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// immediately after we finish iterating over the temporary addresses.
regenForAddr = tempAddr
} else {
- tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
}
}
}
@@ -1717,7 +1717,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
//
// Panics if the SLAAC prefix is not known.
//
@@ -1729,8 +1729,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
}
state.stableAddr.ref = nil
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
}
@@ -1775,13 +1775,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
}
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
-// timers and entry.
+// jobs and entry.
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
- tempAddrState.deprecationTimer.StopLocked()
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
@@ -1860,7 +1860,7 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = time.AfterFunc(delay, func() {
+ ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
ndp.nic.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6f86abc98..644ba7c33 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1254,7 +1254,7 @@ func TestRouterDiscovery(t *testing.T) {
default:
}
- // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1271,7 +1271,7 @@ func TestRouterDiscovery(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
expectRouterEvent(llAddr2, false)
- // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1502,7 +1502,7 @@ func TestPrefixDiscovery(t *testing.T) {
default:
}
- // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
select {
case e := <-ndpDisp.prefixC:
@@ -2395,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
for _, addr := range tempAddrs {
// Wait for a deprecation then invalidation event, or just an invalidation
// event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation timers could fire in any
+ // cases because the deprecation and invalidation jobs could execute in any
// order.
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2432,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
}
-// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
-// regeneration timer gets updated when refreshing the address's lifetimes.
-func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
+// regeneration job gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
nicID = 1
regenAfter = 2 * time.Second
@@ -2533,7 +2533,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
//
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
- // - this regeneration does not depend on a timer.
+ // - this regeneration does not depend on a job.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(tempAddr2, newAddr)
@@ -2559,11 +2559,11 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
}
// Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration timer gets reset.
+ // RA, the regeneration job gets scheduled again.
//
// The maximum lifetime is the sum of the minimum lifetimes for temporary
// addresses + the time that has already passed since the last address was
- // generated so that the regeneration timer is needed to generate the next
+ // generated so that the regeneration job is needed to generate the next
// address.
newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
@@ -2993,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr(addr2)
}
-// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
// when its preferred lifetime expires.
-func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
@@ -3513,8 +3513,8 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
expectAutoGenAddrEvent(addr, invalidatedAddr)
- // Wait for the original valid lifetime to make sure the original timer
- // got stopped/cleaned up.
+ // Wait for the original valid lifetime to make sure the original job got
+ // cancelled/cleaned up.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index afb7dfeaf..fea0ce7e8 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// Are any packet sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
- }
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
@@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
+ }
+}
+
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
// TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
@@ -1358,16 +1374,19 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// TransportHeader is nil only when pkt is an ICMP packet or was reassembled
// from fragments.
if pkt.TransportHeader == nil {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their
- // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
// full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
+ // ICMP packets may be longer, but until icmp.Parse is implemented, here
+ // we parse it using the minimum size.
transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
if !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
pkt.TransportHeader = transHeader
+ pkt.Data.TrimFront(len(pkt.TransportHeader))
} else {
// This is either a bad packet or was re-assembled from fragments.
transProto.Parse(pkt)
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 31f865260..c477e31d8 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -84,6 +84,16 @@ func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
// An IPv6 NetworkEndpoint that throws away outgoing packets.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 1b5da6017..5d6865e35 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -14,6 +14,7 @@
package stack
import (
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -24,7 +25,7 @@ import (
// multiple endpoints. Clone() should be called in such cases so that
// modifications to the Data field do not affect other copies.
type PacketBuffer struct {
- _ noCopy
+ _ sync.NoCopy
// PacketBufferEntry is used to build an intrusive list of
// PacketBuffers.
@@ -78,6 +79,10 @@ type PacketBuffer struct {
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
@@ -102,14 +107,3 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
NatDone: pk.NatDone,
}
}
-
-// noCopy may be embedded into structs which must not be copied
-// after the first use.
-//
-// See https://golang.org/issues/8005#issuecomment-190753527
-// for details.
-type noCopy struct{}
-
-// Lock is a no-op used by -copylocks checker from `go vet`.
-func (*noCopy) Lock() {}
-func (*noCopy) Unlock() {}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 5cbc946b6..9e1b2d25f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -51,8 +52,11 @@ type TransportEndpointID struct {
type ControlType int
// The following are the allowed values for ControlType values.
+// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlPacketTooBig ControlType = iota
+ ControlNetworkUnreachable ControlType = iota
+ ControlNoRoute
+ ControlPacketTooBig
ControlPortUnreachable
ControlUnknown
)
@@ -329,8 +333,7 @@ type NetworkProtocol interface {
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing.
@@ -341,6 +344,16 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -436,6 +449,15 @@ type LinkEndpoint interface {
// Wait will not block if the endpoint hasn't started any goroutines
// yet, even if it might later.
Wait()
+
+ // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint.
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
+ ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index cdcfb8321..a6faa22c2 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -425,6 +425,7 @@ type Stack struct {
handleLocal bool
// tables are the iptables packet filtering and manipulation rules.
+ // TODO(gvisor.dev/issue/170): S/R this field.
tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
@@ -727,6 +728,11 @@ func New(opts Options) *Stack {
return s
}
+// newJob returns a tcpip.Job using the Stack clock.
+func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
+
// UniqueID returns a unique identifier.
func (s *Stack) UniqueID() uint64 {
return s.uniqueIDGenerator.UniqueID()
@@ -800,9 +806,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f
}
}
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (s *Stack) NowNanoseconds() int64 {
- return s.clock.NowNanoseconds()
+// Clock returns the Stack's clock for retrieving the current time and
+// scheduling work.
+func (s *Stack) Clock() tcpip.Clock {
+ return s.clock
}
// Stats returns a mutable copy of the current stats.
@@ -1094,6 +1101,11 @@ type NICInfo struct {
// Context is user-supplied data optionally supplied in CreateNICWithOptions.
// See type NICOptions for more details.
Context NICContext
+
+ // ARPHardwareType holds the ARP Hardware type of the NIC. This is the
+ // value sent in haType field of an ARP Request sent by this NIC and the
+ // value expected in the haType field of an ARP response.
+ ARPHardwareType header.ARPHardwareType
}
// HasNIC returns true if the NICID is defined in the stack.
@@ -1125,6 +1137,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
Context: nic.context,
+ ARPHardwareType: nic.linkEP.ARPHardwareType(),
}
}
return nics
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 2be1c107a..21aafb0a2 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -192,7 +192,7 @@ func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
-// A Clock provides the current time.
+// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
@@ -203,6 +203,31 @@ type Clock interface {
// NowMonotonic returns a monotonic time value.
NowMonotonic() int64
+
+ // AfterFunc waits for the duration to elapse and then calls f in its own
+ // goroutine. It returns a Timer that can be used to cancel the call using
+ // its Stop method.
+ AfterFunc(d time.Duration, f func()) Timer
+}
+
+// Timer represents a single event. A Timer must be created with
+// Clock.AfterFunc.
+type Timer interface {
+ // Stop prevents the Timer from firing. It returns true if the call stops the
+ // timer, false if the timer has already expired or been stopped.
+ //
+ // If Stop returns false, then the timer has already expired and the function
+ // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop
+ // does not wait for f to complete before returning. If the caller needs to
+ // know whether f is completed, it must coordinate with f explicitly.
+ Stop() bool
+
+ // Reset changes the timer to expire after duration d.
+ //
+ // Reset should be invoked only on stopped or expired timers. If the timer is
+ // known to have expired, Reset can be used directly. Otherwise, the caller
+ // must coordinate with the function f of Clock.AfterFunc(d, f).
+ Reset(d time.Duration)
}
// Address is a byte slice cast as a string that represents the address of a
@@ -316,6 +341,28 @@ const (
ShutdownWrite
)
+// PacketType is used to indicate the destination of the packet.
+type PacketType uint8
+
+const (
+ // PacketHost indicates a packet addressed to the local host.
+ PacketHost PacketType = iota
+
+ // PacketOtherHost indicates an outgoing packet addressed to
+ // another host caught by a NIC in promiscuous mode.
+ PacketOtherHost
+
+ // PacketOutgoing for a packet originating from the local host
+ // that is looped back to a packet socket.
+ PacketOutgoing
+
+ // PacketBroadcast indicates a link layer broadcast packet.
+ PacketBroadcast
+
+ // PacketMulticast indicates a link layer multicast packet.
+ PacketMulticast
+)
+
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
//
@@ -549,6 +596,28 @@ type Endpoint interface {
SetOwner(owner PacketOwner)
}
+// LinkPacketInfo holds Link layer information for a received packet.
+//
+// +stateify savable
+type LinkPacketInfo struct {
+ // Protocol is the NetworkProtocolNumber for the packet.
+ Protocol NetworkProtocolNumber
+
+ // PktType is used to indicate the destination of the packet.
+ PktType PacketType
+}
+
+// PacketEndpoint are additional methods that are only implemented by Packet
+// endpoints.
+type PacketEndpoint interface {
+ // ReadPacket reads a datagram/packet from the endpoint and optionally
+ // returns the sender and additional LinkPacketInfo.
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
+}
+
// EndpointInfo is the interface implemented by each endpoint info struct.
type EndpointInfo interface {
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@@ -648,6 +717,11 @@ const (
// whether an IPv6 socket is to be restricted to sending and receiving
// IPv6 packets only.
V6OnlyOption
+
+ // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
+ // endpoint that all packets being written have an IP header and the
+ // endpoint should not attach an IP header.
+ IPHdrIncludedOption
)
// SockOptInt represents socket options which values have the int type.
@@ -673,6 +747,13 @@ const (
// TCP_MAXSEG option.
MaxSegOption
+ // MTUDiscoverOption is used to set/get the path MTU discovery setting.
+ //
+ // NOTE: Setting this option to any other value than PMTUDiscoveryDont
+ // is not supported and will fail as such, and getting this option will
+ // always return PMTUDiscoveryDont.
+ MTUDiscoverOption
+
// MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control
// the default TTL value for multicast messages. The default is 1.
MulticastTTLOption
@@ -714,6 +795,24 @@ const (
TCPWindowClampOption
)
+const (
+ // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use
+ // per-route settings.
+ PMTUDiscoveryWant int = iota
+
+ // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable
+ // path MTU discovery.
+ PMTUDiscoveryDont
+
+ // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do
+ // path MTU discovery.
+ PMTUDiscoveryDo
+
+ // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF
+ // but ignore path MTU.
+ PMTUDiscoveryProbe
+)
+
// ErrorOption is used in GetSockOpt to specify that the last error reported by
// the endpoint should be cleared and returned.
type ErrorOption struct{}
@@ -752,7 +851,7 @@ type CongestionControlOption string
// control algorithms.
type AvailableCongestionControlOption string
-// buffer moderation.
+// ModerateReceiveBufferOption is used by buffer moderation.
type ModerateReceiveBufferOption bool
// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
@@ -825,7 +924,10 @@ type OutOfBandInlineOption int
// a default TTL.
type DefaultTTLOption uint8
-//
+// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
+// classic BPF filter on a given endpoint.
+type SocketDetachFilterOption int
+
// IPPacketInfo is the message structure for IP_PKTINFO.
//
// +stateify savable
@@ -1214,6 +1316,9 @@ type UDPStats struct {
// ChecksumErrors is the number of datagrams dropped due to bad checksums.
ChecksumErrors *StatCounter
+
+ // InvalidSourceAddress is the number of invalid sourced datagrams dropped.
+ InvalidSourceAddress *StatCounter
}
// Stats holds statistics about the networking stack.
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 7f172f978..f32d58091 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -20,7 +20,7 @@
package tcpip
import (
- _ "time" // Used with go:linkname.
+ "time" // Used with go:linkname.
_ "unsafe" // Required for go:linkname.
)
@@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 {
_, _, mono := now()
return mono
}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
index 59f3b391f..f1dd7c310 100644
--- a/pkg/tcpip/timer.go
+++ b/pkg/tcpip/timer.go
@@ -15,54 +15,54 @@
package tcpip
import (
- "sync"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
-// cancellableTimerInstance is a specific instance of CancellableTimer.
+// jobInstance is a specific instance of Job.
//
-// Different instances are created each time CancellableTimer is Reset so each
-// timer has its own earlyReturn signal. This is to address a bug when a
-// CancellableTimer is stopped and reset in quick succession resulting in a
-// timer instance's earlyReturn signal being affected or seen by another timer
-// instance.
+// Different instances are created each time Job is scheduled so each timer has
+// its own earlyReturn signal. This is to address a bug when a Job is stopped
+// and reset in quick succession resulting in a timer instance's earlyReturn
+// signal being affected or seen by another timer instance.
//
// Consider the following sceneario where timer instances share a common
// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
// (B), third (C), and fourth (D) instance of the timer firing, respectively):
// T1: Obtain L
-// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T1: Create a new Job w/ lock L (create instance A)
// T2: instance A fires, blocked trying to obtain L.
// T1: Attempt to stop instance A (set earlyReturn = true)
-// T1: Reset timer (create instance B)
+// T1: Schedule timer (create instance B)
// T3: instance B fires, blocked trying to obtain L.
// T1: Attempt to stop instance B (set earlyReturn = true)
-// T1: Reset timer (create instance C)
+// T1: Schedule timer (create instance C)
// T4: instance C fires, blocked trying to obtain L.
// T1: Attempt to stop instance C (set earlyReturn = true)
-// T1: Reset timer (create instance D)
+// T1: Schedule timer (create instance D)
// T5: instance D fires, blocked trying to obtain L.
// T1: Release L
//
-// Now that T1 has released L, any of the 4 timer instances can take L and check
-// earlyReturn. If the timers simply check earlyReturn and then do nothing
-// further, then instance D will never early return even though it was not
-// requested to stop. If the timers reset earlyReturn before early returning,
-// then all but one of the timers will do work when only one was expected to.
-// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// Now that T1 has released L, any of the 4 timer instances can take L and
+// check earlyReturn. If the timers simply check earlyReturn and then do
+// nothing further, then instance D will never early return even though it was
+// not requested to stop. If the timers reset earlyReturn before early
+// returning, then all but one of the timers will do work when only one was
+// expected to. If Job resets earlyReturn when resetting, then all the timers
// will fire (again, when only one was expected to).
//
// To address the above concerns the simplest solution was to give each timer
// its own earlyReturn signal.
-type cancellableTimerInstance struct {
- timer *time.Timer
+type jobInstance struct {
+ timer Timer
// Used to inform the timer to early return when it gets stopped while the
// lock the timer tries to obtain when fired is held (T1 is a goroutine that
// tries to cancel the timer and T2 is the goroutine that handles the timer
// firing):
- // T1: Obtain the lock, then call StopLocked()
+ // T1: Obtain the lock, then call Cancel()
// T2: timer fires, and gets blocked on obtaining the lock
// T1: Releases lock
// T2: Obtains lock does unintended work
@@ -73,27 +73,33 @@ type cancellableTimerInstance struct {
earlyReturn *bool
}
-// stop stops the timer instance t from firing if it hasn't fired already. If it
+// stop stops the job instance j from firing if it hasn't fired already. If it
// has fired and is blocked at obtaining the lock, earlyReturn will be set to
// true so that it will early return when it obtains the lock.
-func (t *cancellableTimerInstance) stop() {
- if t.timer != nil {
- t.timer.Stop()
- *t.earlyReturn = true
+func (j *jobInstance) stop() {
+ if j.timer != nil {
+ j.timer.Stop()
+ *j.earlyReturn = true
}
}
-// CancellableTimer is a timer that does some work and can be safely cancelled
-// when it fires at the same time some "related work" is being done.
+// Job represents some work that can be scheduled for execution. The work can
+// be safely cancelled when it fires at the same time some "related work" is
+// being done.
//
// The term "related work" is defined as some work that needs to be done while
// holding some lock that the timer must also hold while doing some work.
//
-// Note, it is not safe to copy a CancellableTimer as its timer instance creates
-// a closure over the address of the CancellableTimer.
-type CancellableTimer struct {
+// Note, it is not safe to copy a Job as its timer instance creates
+// a closure over the address of the Job.
+type Job struct {
+ _ sync.NoCopy
+
+ // The clock used to schedule the backing timer
+ clock Clock
+
// The active instance of a cancellable timer.
- instance cancellableTimerInstance
+ instance jobInstance
// locker is the lock taken by the timer immediately after it fires and must
// be held when attempting to stop the timer.
@@ -110,75 +116,91 @@ type CancellableTimer struct {
fn func()
}
-// StopLocked prevents the Timer from firing if it has not fired already.
+// Cancel prevents the Job from executing if it has not executed already.
//
-// If the timer is blocked on obtaining the t.locker lock when StopLocked is
-// called, it will early return instead of calling t.fn.
+// Cancel requires appropriate locking to be in place for any resources managed
+// by the Job. If the Job is blocked on obtaining the lock when Cancel is
+// called, it will early return.
//
// Note, t will be modified.
//
-// t.locker MUST be locked.
-func (t *CancellableTimer) StopLocked() {
- t.instance.stop()
+// j.locker MUST be locked.
+func (j *Job) Cancel() {
+ j.instance.stop()
// Nothing to do with the stopped instance anymore.
- t.instance = cancellableTimerInstance{}
+ j.instance = jobInstance{}
}
-// Reset changes the timer to expire after duration d.
+// Schedule schedules the Job for execution after duration d. This can be
+// called on cancelled or completed Jobs to schedule them again.
//
-// Note, t will be modified.
+// Schedule should be invoked only on unscheduled, cancelled, or completed
+// Jobs. To be safe, callers should always call Cancel before calling Schedule.
//
-// Reset should only be called on stopped or expired timers. To be safe, callers
-// should always call StopLocked before calling Reset.
-func (t *CancellableTimer) Reset(d time.Duration) {
+// Note, j will be modified.
+func (j *Job) Schedule(d time.Duration) {
// Create a new instance.
earlyReturn := false
// Capture the locker so that updating the timer does not cause a data race
// when a timer fires and tries to obtain the lock (read the timer's locker).
- locker := t.locker
- t.instance = cancellableTimerInstance{
- timer: time.AfterFunc(d, func() {
+ locker := j.locker
+ j.instance = jobInstance{
+ timer: j.clock.AfterFunc(d, func() {
locker.Lock()
defer locker.Unlock()
if earlyReturn {
// If we reach this point, it means that the timer fired while another
- // goroutine called StopLocked while it had the lock. Simply return
- // here and do nothing further.
+ // goroutine called Cancel while it had the lock. Simply return here
+ // and do nothing further.
earlyReturn = false
return
}
- t.fn()
+ j.fn()
}),
earlyReturn: &earlyReturn,
}
}
-// Lock is a no-op used by the copylocks checker from go vet.
-//
-// See CancellableTimer for details about why it shouldn't be copied.
-//
-// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more
-// details about the copylocks checker.
-func (*CancellableTimer) Lock() {}
-
-// Unlock is a no-op used by the copylocks checker from go vet.
-//
-// See CancellableTimer for details about why it shouldn't be copied.
-//
-// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more
-// details about the copylocks checker.
-func (*CancellableTimer) Unlock() {}
-
-// NewCancellableTimer returns an unscheduled CancellableTimer with the given
-// locker and fn.
-//
-// fn MUST NOT attempt to lock locker.
-//
-// Callers must call Reset to schedule the timer to fire.
-func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer {
- return &CancellableTimer{locker: locker, fn: fn}
+// NewJob returns a new Job that can be used to schedule f to run in its own
+// gorountine. l will be locked before calling f then unlocked after f returns.
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// message = "bar"
+// mu.Unlock()
+//
+// // Output: bar
+//
+// f MUST NOT attempt to lock l.
+//
+// l MUST be locked prior to calling the returned job's Cancel().
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// job.Cancel()
+// mu.Unlock()
+func NewJob(c Clock, l sync.Locker, f func()) *Job {
+ return &Job{
+ clock: c,
+ locker: l,
+ fn: f,
+ }
}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index b4940e397..a82384c49 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -28,8 +28,8 @@ const (
longDuration = 1 * time.Second
)
-func TestCancellableTimerReassignment(t *testing.T) {
- var timer tcpip.CancellableTimer
+func TestJobReschedule(t *testing.T) {
+ var clock tcpip.StdClock
var wg sync.WaitGroup
var lock sync.Mutex
@@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) {
// that has an active timer (even if it has been stopped as a stopped
// timer may be blocked on a lock before it can check if it has been
// stopped while another goroutine holds the same lock).
- timer = *tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
wg.Done()
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
lock.Unlock()
}()
}
wg.Wait()
}
-func TestCancellableTimerFire(t *testing.T) {
+func TestJobExecution(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
ch <- struct{}{}
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) {
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(middleDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(middleDuration)
lock.Lock()
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
}
}
-func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
// Wait for timer to fire if it wasn't correctly stopped.
@@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
case <-time.After(middleDuration):
}
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
}
}
-func TestCancellableTimerImmediatelyStop(t *testing.T) {
+func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
for i := 0; i < 1000; i++ {
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
}
@@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) {
}
}
-func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
for i := 0; i < 10; i++ {
- timer.Reset(middleDuration)
+ job.Schedule(middleDuration)
lock.Lock()
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration * 2)
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
}
@@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration)
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
@@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
}
}
-func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 8ce294002..4612be4e7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+ }
return nil
}
@@ -744,15 +748,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
- if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
+ h := header.ICMPv4(pkt.TransportHeader)
+ if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
- if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
+ h := header.ICMPv6(pkt.TransportHeader)
+ if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
@@ -786,12 +790,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
},
}
- packet.data = pkt.Data
+ // ICMP socket's data includes ICMP header.
+ packet.data = pkt.TransportHeader.ToVectorisedView()
+ packet.data.Append(pkt.Data)
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index baf08eda6..0e46e6355 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,6 +25,8 @@
package packet
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -43,6 +45,9 @@ type packet struct {
timestampNS int64
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ // packetInfo holds additional information like the protocol
+ // of the packet etc.
+ packetInfo tcpip.LinkPacketInfo
}
// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
@@ -71,11 +76,17 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
+ boundNIC tcpip.NICID
+
+ // lastErrorMu protects lastError.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
}
// NewEndpoint returns a new packet endpoint.
@@ -92,6 +103,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
sndBufSize: 32 * 1024,
}
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ ep.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.rcvBufSizeMax = rs.Default
+ }
+
if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
return nil, err
}
@@ -132,8 +154,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.PacketEndpoint.ReadPacket.
+func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -158,9 +180,18 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
*addr = packet.senderAddr
}
+ if info != nil {
+ *info = packet.packetInfo
+ }
+
return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
}
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return ep.ReadPacket(addr, nil)
+}
+
func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// TODO(b/129292371): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -215,12 +246,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound {
- return tcpip.ErrAlreadyBound
+ if ep.bound && ep.boundNIC == addr.NIC {
+ // If the NIC being bound is the same then just return success.
+ return nil
}
// Unregister endpoint with all the nics.
ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.bound = false
// Bind endpoint to receive packets from specific interface.
if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
@@ -228,6 +261,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
ep.bound = true
+ ep.boundNIC = addr.NIC
return nil
}
@@ -264,7 +298,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
@@ -274,11 +314,63 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := ep.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ ep.mu.Lock()
+ ep.sndBufSizeMax = v
+ ep.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := ep.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ ep.rcvMu.Lock()
+ ep.rcvBufSizeMax = v
+ ep.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (ep *endpoint) takeLastError() *tcpip.Error {
+ ep.lastErrorMu.Lock()
+ defer ep.lastErrorMu.Unlock()
+
+ err := ep.lastError
+ ep.lastError = nil
+ return err
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return ep.takeLastError()
+ }
return tcpip.ErrNotSupported
}
@@ -289,7 +381,32 @@ func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() {
+ p := ep.rcvList.Front()
+ v = p.data.Size()
+ }
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSizeMax
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
}
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
@@ -323,40 +440,66 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(localAddr),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header from the Header.
+ pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):])
+ combinedVV := pkt.Header.View().ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ var combinedVV buffer.VectorisedView
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ combinedVV = linkHeader.ToVectorisedView()
+ }
+ if pkt.PktType == tcpip.PacketOutgoing {
+ // For outgoing packets the Link, Network and Transport
+ // headers are in the pkt.Header fields normally unless
+ // a Raw socket is in use in which case pkt.Header could
+ // be nil.
+ combinedVV.AppendView(pkt.Header.View())
}
- combinedVV := linkHeader.ToVectorisedView()
combinedVV.Append(pkt.Data)
packet.data = combinedVV
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 9b88f17e4..e2fa96d17 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() {
panic(*err)
}
}
+
+// saveLastError is invoked by stateify.
+func (ep *endpoint) saveLastError() string {
+ if ep.lastError == nil {
+ return ""
+ }
+
+ return ep.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (ep *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ ep.lastError = tcpip.StringToError(s)
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 766c7648e..f85a68554 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -63,6 +63,7 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
+ hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -108,6 +109,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSizeMax: 32 * 1024,
associated: associated,
+ hdrIncluded: !associated,
}
// Override with stack defaults.
@@ -182,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !e.associated {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
- }
-
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -263,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !e.associated {
+ if e.hdrIncluded {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -353,7 +351,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- if !e.associated {
+ if e.hdrIncluded {
if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
Data: buffer.View(payloadBytes).ToVectorisedView(),
}); err != nil {
@@ -458,7 +456,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
defer e.mu.Unlock()
// If a local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -508,11 +506,24 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ e.hdrIncluded = v
+ e.mu.Unlock()
+ return nil
+ }
return tcpip.ErrUnknownProtocolOption
}
@@ -577,6 +588,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.KeepaliveEnabledOption:
return false, nil
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ v := e.hdrIncluded
+ e.mu.Unlock()
+ return v, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -616,8 +633,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
- // Drop the packet if our buffer is currently full.
- if e.rcvClosed {
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
@@ -676,7 +700,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
}
combinedVV.Append(pkt.Data)
packet.data = combinedVV
- packet.timestampNS = e.stack.NowNanoseconds()
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 3601207be..18ff89ffc 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -58,7 +58,6 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
- "//pkg/binary",
"//pkg/log",
"//pkg/rand",
"//pkg/sleep",
@@ -87,6 +86,7 @@ go_test(
"tcp_test.go",
"tcp_timestamp_test.go",
],
+ shard_count = 10,
deps = [
":tcp",
"//pkg/sync",
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 81b740115..1798510bc 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
}
// Wait for notification.
@@ -616,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 43b76bee5..98aecab9e 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -15,7 +15,8 @@
package tcp
import (
- "gvisor.dev/gvisor/pkg/binary"
+ "encoding/binary"
+
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index bd3ec5a8d..0f7487963 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1209,6 +1209,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1589,6 +1597,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.UnlockUser()
e.notifyProtocolGoroutine(notifyMSSChanged)
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if attempting to set this option to
+ // anything other than path MTU discovery disabled.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1785,6 +1800,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.deferAccept = time.Duration(v)
e.UnlockUser()
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
default:
return nil
}
@@ -1896,6 +1914,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
v := header.TCPDefaultMSS
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // Always return the path MTU discovery disabled setting since
+ // it's the only one supported.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1941,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
+ return e.takeLastError()
case *tcpip.BindToDeviceOption:
e.LockUser()
@@ -2531,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 169adb16b..e67ec42b1 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -3095,6 +3095,63 @@ func TestMaxRTO(t *testing.T) {
}
}
+// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
+// unique on retransmits.
+func TestRetransmitIPv4IDUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ size int
+ }{
+ {"1Byte", 1},
+ {"512Bytes", 512},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ // Disabling PMTU discovery causes all packets sent from this socket to
+ // have DF=0. This needs to be done because the IPv4 ID uniqueness
+ // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
+ // Section 4, and datagrams with DF=0 are non-atomic.
+ if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
+ t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}}
+ // Expect two retransmitted packets, and that all packets received have
+ // unique IPv4 ID values.
+ for i := 0; i <= 2; i++ {
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ id := header.IPv4(pkt).ID()
+ if _, exists := idSet[id]; exists {
+ t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
+ }
+ idSet[id] = struct{}{}
+ }
+ })
+ }
+}
+
func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 12bc1b5b5..558b06df0 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
return st
}
+// State returns the current state of the TCB.
+func (t *TCB) State() Result {
+ return t.state
+}
+
// IsAlive returns true as long as the connection is established(Alive)
// or connecting state.
func (t *TCB) IsAlive() bool {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index cae29fbff..6e692da07 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -612,6 +612,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
@@ -809,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
e.bindToDevice = id
e.mu.Unlock()
+
+ case tcpip.SocketDetachFilterOption:
+ return nil
}
return nil
}
@@ -906,6 +916,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.MulticastTTLOption:
e.mu.Lock()
v := int(e.multicastTTL)
@@ -1366,6 +1380,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
+ // Never receive from a multicast address.
+ if header.IsV4MulticastAddress(id.RemoteAddress) ||
+ header.IsV6MulticastAddress(id.RemoteAddress) {
+ e.stack.Stats().UDP.InvalidSourceAddress.Increment()
+ e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ return
+ }
+
// Verify checksum unless RX checksum offload is enabled.
// On IPv4, UDP checksum is optional, and a zero value means
// the transmitter omitted the checksum generation (RFC768).
@@ -1384,10 +1407,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
}
}
- e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed {
e.rcvMu.Unlock()
@@ -1428,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index db59eb5a0..90781cf49 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -83,16 +83,18 @@ type header4Tuple struct {
type testFlow int
const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ reverseMulticast4 // V4 multicast src. Must fail.
+ reverseMulticast6 // V6 multicast src. Must fail.
)
func (flow testFlow) String() string {
@@ -117,6 +119,10 @@ func (flow testFlow) String() string {
return "broadcast"
case broadcastIn6:
return "broadcastIn6"
+ case reverseMulticast4:
+ return "reverseMulticast4"
+ case reverseMulticast6:
+ return "reverseMulticast6"
default:
return "unknown"
}
@@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
h.dstAddr.Addr = multicastV6Addr
}
}
+ if flow.isReverseMulticast() {
+ h.srcAddr.Addr = flow.getMcastAddr()
+ }
return h
}
@@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
// endpoint for this flow.
func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast:
+ case unicastV4, multicastV4, broadcast, reverseMulticast4:
return ipv4.ProtocolNumber
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool {
switch flow {
case unicastV6Only, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool {
switch flow {
case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool {
switch flow {
case broadcast, broadcastIn6:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool {
switch flow {
case unicastV4in6, multicastV4in6, broadcastIn6:
return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
}
}
+func (flow testFlow) isReverseMulticast() bool {
+ switch flow {
+ case reverseMulticast4, reverseMulticast6:
+ return true
+ default:
+ return false
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -872,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
}
}
+// TestReadFromMulticast checks that an endpoint will NOT receive a packet
+// that was sent with multicast SOURCE address.
+func TestReadFromMulticast(t *testing.T) {
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ testFailingRead(c, flow, false /* expectReadError */)
+ })
+ }
+}
+
+// TestReadFromMulticaststats checks that a discarded packet
+// that that was sent with multicast SOURCE address increments
+// the correct counters and that a regular packet does not.
+func TestReadFromMulticastStats(t *testing.T) {
+ t.Helper()
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ c.injectPacket(flow, payload)
+
+ var want uint64 = 0
+ if flow.isReverseMulticast() {
+ want = 1
+ }
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want {
+ t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want {
+ t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
@@ -1721,9 +1793,11 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
payload := newPayload()
h := unicastV6.header4Tuple(incoming)
buf := c.buildV6Packet(payload, &h)
- // Invalidate the packet length field in the UDP header by adding one.
+
+ // Invalidate the UDP header length field.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.SetLength(u.Length() + 1)
+
c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -1803,9 +1877,16 @@ func TestIncrementChecksumErrorsV4(t *testing.T) {
payload := newPayload()
h := unicastV4.header4Tuple(incoming)
buf := c.buildV4Packet(payload, &h)
- // Invalidate the checksum field in the UDP header by adding one.
- u := header.UDP(buf[header.IPv4MinimumSize:])
- u.SetChecksum(u.Checksum() + 1)
+
+ // Invalidate the UDP header checksum field, taking care to avoid
+ // overflow to zero, which would disable checksum validation.
+ for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
+ u.SetChecksum(u.Checksum() + 1)
+ if u.Checksum() != 0 {
+ break
+ }
+ }
+
c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -1834,9 +1915,11 @@ func TestIncrementChecksumErrorsV6(t *testing.T) {
payload := newPayload()
h := unicastV6.header4Tuple(incoming)
buf := c.buildV6Packet(payload, &h)
- // Invalidate the checksum field in the UDP header by adding one.
+
+ // Invalidate the UDP header checksum field.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.SetChecksum(u.Checksum() + 1)
+
c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go
index 8fed29ff5..70945f234 100644
--- a/pkg/test/criutil/criutil.go
+++ b/pkg/test/criutil/criutil.go
@@ -22,6 +22,9 @@ import (
"fmt"
"os"
"os/exec"
+ "path"
+ "regexp"
+ "strconv"
"strings"
"time"
@@ -33,28 +36,44 @@ import (
type Crictl struct {
logger testutil.Logger
endpoint string
+ runpArgs []string
cleanup []func()
}
-// resolvePath attempts to find binary paths. It may set the path to invalid,
+// ResolvePath attempts to find binary paths. It may set the path to invalid,
// which will cause the execution to fail with a sensible error.
-func resolvePath(executable string) string {
+func ResolvePath(executable string) string {
+ runtime, err := dockerutil.RuntimePath()
+ if err == nil {
+ // Check first the directory of the runtime itself.
+ if dir := path.Dir(runtime); dir != "" && dir != "." {
+ guess := path.Join(dir, executable)
+ if fi, err := os.Stat(guess); err == nil && (fi.Mode()&0111) != 0 {
+ return guess
+ }
+ }
+ }
+
+ // Try to find via the path.
guess, err := exec.LookPath(executable)
- if err != nil {
- guess = fmt.Sprintf("/usr/local/bin/%s", executable)
+ if err == nil {
+ return guess
}
- return guess
+
+ // Return a default path.
+ return fmt.Sprintf("/usr/local/bin/%s", executable)
}
// NewCrictl returns a Crictl configured with a timeout and an endpoint over
// which it will talk to containerd.
-func NewCrictl(logger testutil.Logger, endpoint string) *Crictl {
+func NewCrictl(logger testutil.Logger, endpoint string, runpArgs []string) *Crictl {
// Attempt to find the executable, but don't bother propagating the
// error at this point. The first command executed will return with a
// binary not found error.
return &Crictl{
logger: logger,
endpoint: endpoint,
+ runpArgs: runpArgs,
}
}
@@ -67,8 +86,8 @@ func (cc *Crictl) CleanUp() {
}
// RunPod creates a sandbox. It corresponds to `crictl runp`.
-func (cc *Crictl) RunPod(sbSpecFile string) (string, error) {
- podID, err := cc.run("runp", sbSpecFile)
+func (cc *Crictl) RunPod(runtime, sbSpecFile string) (string, error) {
+ podID, err := cc.run("runp", "--runtime", runtime, sbSpecFile)
if err != nil {
return "", fmt.Errorf("runp failed: %v", err)
}
@@ -79,10 +98,42 @@ func (cc *Crictl) RunPod(sbSpecFile string) (string, error) {
// Create creates a container within a sandbox. It corresponds to `crictl
// create`.
func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) {
- podID, err := cc.run("create", podID, contSpecFile, sbSpecFile)
+ // In version 1.16.0, crictl annoying starting attempting to pull the
+ // container, even if it was already available locally. We therefore
+ // need to parse the version and add an appropriate --no-pull argument
+ // since the image has already been loaded locally.
+ out, err := cc.run("-v")
+ if err != nil {
+ return "", err
+ }
+ r := regexp.MustCompile("crictl version ([0-9]+)\\.([0-9]+)\\.([0-9+])")
+ vs := r.FindStringSubmatch(out)
+ if len(vs) != 4 {
+ return "", fmt.Errorf("crictl -v had unexpected output: %s", out)
+ }
+ major, err := strconv.ParseUint(vs[1], 10, 64)
if err != nil {
+ return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out)
+ }
+ minor, err := strconv.ParseUint(vs[2], 10, 64)
+ if err != nil {
+ return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out)
+ }
+
+ args := []string{"create"}
+ if (major == 1 && minor >= 16) || major > 1 {
+ args = append(args, "--no-pull")
+ }
+ args = append(args, podID)
+ args = append(args, contSpecFile)
+ args = append(args, sbSpecFile)
+
+ podID, err = cc.run(args...)
+ if err != nil {
+ time.Sleep(10 * time.Minute) // XXX
return "", fmt.Errorf("create failed: %v", err)
}
+
// Strip the trailing newline from crictl output.
return strings.TrimSpace(podID), nil
}
@@ -179,7 +230,7 @@ func (cc *Crictl) Import(image string) error {
// be pushing a lot of bytes in order to import the image. The connect
// timeout stays the same and is inherited from the Crictl instance.
cmd := testutil.Command(cc.logger,
- resolvePath("ctr"),
+ ResolvePath("ctr"),
fmt.Sprintf("--connect-timeout=%s", 30*time.Second),
fmt.Sprintf("--address=%s", cc.endpoint),
"-n", "k8s.io", "images", "import", "-")
@@ -260,7 +311,7 @@ func (cc *Crictl) StopContainer(contID string) error {
// StartPodAndContainer starts a sandbox and container in that sandbox. It
// returns the pod ID and container ID.
-func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) {
+func (cc *Crictl) StartPodAndContainer(runtime, image, sbSpec, contSpec string) (string, string, error) {
if err := cc.Import(image); err != nil {
return "", "", err
}
@@ -277,7 +328,7 @@ func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string,
}
cc.cleanup = append(cc.cleanup, cleanup)
- podID, err := cc.RunPod(sbSpecFile)
+ podID, err := cc.RunPod(runtime, sbSpecFile)
if err != nil {
return "", "", err
}
@@ -307,7 +358,7 @@ func (cc *Crictl) StopPodAndContainer(podID, contID string) error {
// run runs crictl with the given args.
func (cc *Crictl) run(args ...string) (string, error) {
defaultArgs := []string{
- resolvePath("crictl"),
+ ResolvePath("crictl"),
"--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
"--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
}
diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD
index 7c8758e35..83b80c8bc 100644
--- a/pkg/test/dockerutil/BUILD
+++ b/pkg/test/dockerutil/BUILD
@@ -5,10 +5,21 @@ package(licenses = ["notice"])
go_library(
name = "dockerutil",
testonly = 1,
- srcs = ["dockerutil.go"],
+ srcs = [
+ "container.go",
+ "dockerutil.go",
+ "exec.go",
+ "network.go",
+ ],
visibility = ["//:sandbox"],
deps = [
"//pkg/test/testutil",
- "@com_github_kr_pty//:go_default_library",
+ "@com_github_docker_docker//api/types:go_default_library",
+ "@com_github_docker_docker//api/types/container:go_default_library",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
+ "@com_github_docker_docker//api/types/network:go_default_library",
+ "@com_github_docker_docker//client:go_default_library",
+ "@com_github_docker_docker//pkg/stdcopy:go_default_library",
+ "@com_github_docker_go_connections//nat:go_default_library",
],
)
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
new file mode 100644
index 000000000..17acdaf6f
--- /dev/null
+++ b/pkg/test/dockerutil/container.go
@@ -0,0 +1,501 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/container"
+ "github.com/docker/docker/api/types/mount"
+ "github.com/docker/docker/api/types/network"
+ "github.com/docker/docker/client"
+ "github.com/docker/docker/pkg/stdcopy"
+ "github.com/docker/go-connections/nat"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Container represents a Docker Container allowing
+// user to configure and control as one would with the 'docker'
+// client. Container is backed by the offical golang docker API.
+// See: https://pkg.go.dev/github.com/docker/docker.
+type Container struct {
+ Name string
+ Runtime string
+
+ logger testutil.Logger
+ client *client.Client
+ id string
+ mounts []mount.Mount
+ links []string
+ cleanups []func()
+ copyErr error
+
+ // Stores streams attached to the container. Used by WaitForOutputSubmatch.
+ streams types.HijackedResponse
+
+ // stores previously read data from the attached streams.
+ streamBuf bytes.Buffer
+}
+
+// RunOpts are options for running a container.
+type RunOpts struct {
+ // Image is the image relative to images/. This will be mangled
+ // appropriately, to ensure that only first-party images are used.
+ Image string
+
+ // Memory is the memory limit in bytes.
+ Memory int
+
+ // Cpus in which to allow execution. ("0", "1", "0-2").
+ CpusetCpus string
+
+ // Ports are the ports to be allocated.
+ Ports []int
+
+ // WorkDir sets the working directory.
+ WorkDir string
+
+ // ReadOnly sets the read-only flag.
+ ReadOnly bool
+
+ // Env are additional environment variables.
+ Env []string
+
+ // User is the user to use.
+ User string
+
+ // Privileged enables privileged mode.
+ Privileged bool
+
+ // CapAdd are the extra set of capabilities to add.
+ CapAdd []string
+
+ // CapDrop are the extra set of capabilities to drop.
+ CapDrop []string
+
+ // Mounts is the list of directories/files to be mounted inside the container.
+ Mounts []mount.Mount
+
+ // Links is the list of containers to be connected to the container.
+ Links []string
+}
+
+// MakeContainer sets up the struct for a Docker container.
+//
+// Names of containers will be unique.
+func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ // Slashes are not allowed in container names.
+ name := testutil.RandomID(logger.Name())
+ name = strings.ReplaceAll(name, "/", "-")
+ client, err := client.NewClientWithOpts(client.FromEnv)
+ if err != nil {
+ return nil
+ }
+
+ client.NegotiateAPIVersion(ctx)
+
+ return &Container{
+ logger: logger,
+ Name: name,
+ Runtime: *runtime,
+ client: client,
+ }
+}
+
+// Spawn is analogous to 'docker run -d'.
+func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error {
+ if err := c.create(ctx, r, args); err != nil {
+ return err
+ }
+ return c.Start(ctx)
+}
+
+// SpawnProcess is analogous to 'docker run -it'. It returns a process
+// which represents the root process.
+func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) (Process, error) {
+ config, hostconf, netconf := c.ConfigsFrom(r, args...)
+ config.Tty = true
+ config.OpenStdin = true
+
+ if err := c.CreateFrom(ctx, config, hostconf, netconf); err != nil {
+ return Process{}, err
+ }
+
+ if err := c.Start(ctx); err != nil {
+ return Process{}, err
+ }
+
+ return Process{container: c, conn: c.streams}, nil
+}
+
+// Run is analogous to 'docker run'.
+func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) {
+ if err := c.create(ctx, r, args); err != nil {
+ return "", err
+ }
+
+ if err := c.Start(ctx); err != nil {
+ return "", err
+ }
+
+ if err := c.Wait(ctx); err != nil {
+ return "", err
+ }
+
+ return c.Logs(ctx)
+}
+
+// ConfigsFrom returns container configs from RunOpts and args. The caller should call 'CreateFrom'
+// and Start.
+func (c *Container) ConfigsFrom(r RunOpts, args ...string) (*container.Config, *container.HostConfig, *network.NetworkingConfig) {
+ return c.config(r, args), c.hostConfig(r), &network.NetworkingConfig{}
+}
+
+// MakeLink formats a link to add to a RunOpts.
+func (c *Container) MakeLink(target string) string {
+ return fmt.Sprintf("%s:%s", c.Name, target)
+}
+
+// CreateFrom creates a container from the given configs.
+func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
+ cont, err := c.client.ContainerCreate(ctx, conf, hostconf, netconf, c.Name)
+ if err != nil {
+ return err
+ }
+ c.id = cont.ID
+ return nil
+}
+
+// Create is analogous to 'docker create'.
+func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error {
+ return c.create(ctx, r, args)
+}
+
+func (c *Container) create(ctx context.Context, r RunOpts, args []string) error {
+ conf := c.config(r, args)
+ hostconf := c.hostConfig(r)
+ cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name)
+ if err != nil {
+ return err
+ }
+ c.id = cont.ID
+ return nil
+}
+
+func (c *Container) config(r RunOpts, args []string) *container.Config {
+ ports := nat.PortSet{}
+ for _, p := range r.Ports {
+ port := nat.Port(fmt.Sprintf("%d", p))
+ ports[port] = struct{}{}
+ }
+ env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name))
+
+ return &container.Config{
+ Image: testutil.ImageByName(r.Image),
+ Cmd: args,
+ ExposedPorts: ports,
+ Env: env,
+ WorkingDir: r.WorkDir,
+ User: r.User,
+ }
+}
+
+func (c *Container) hostConfig(r RunOpts) *container.HostConfig {
+ c.mounts = append(c.mounts, r.Mounts...)
+
+ return &container.HostConfig{
+ Runtime: c.Runtime,
+ Mounts: c.mounts,
+ PublishAllPorts: true,
+ Links: r.Links,
+ CapAdd: r.CapAdd,
+ CapDrop: r.CapDrop,
+ Privileged: r.Privileged,
+ ReadonlyRootfs: r.ReadOnly,
+ Resources: container.Resources{
+ Memory: int64(r.Memory), // In bytes.
+ CpusetCpus: r.CpusetCpus,
+ },
+ }
+}
+
+// Start is analogous to 'docker start'.
+func (c *Container) Start(ctx context.Context) error {
+
+ // Open a connection to the container for parsing logs and for TTY.
+ streams, err := c.client.ContainerAttach(ctx, c.id,
+ types.ContainerAttachOptions{
+ Stream: true,
+ Stdin: true,
+ Stdout: true,
+ Stderr: true,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to connect to container: %v", err)
+ }
+
+ c.streams = streams
+ c.cleanups = append(c.cleanups, func() {
+ c.streams.Close()
+ })
+
+ return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{})
+}
+
+// Stop is analogous to 'docker stop'.
+func (c *Container) Stop(ctx context.Context) error {
+ return c.client.ContainerStop(ctx, c.id, nil)
+}
+
+// Pause is analogous to'docker pause'.
+func (c *Container) Pause(ctx context.Context) error {
+ return c.client.ContainerPause(ctx, c.id)
+}
+
+// Unpause is analogous to 'docker unpause'.
+func (c *Container) Unpause(ctx context.Context) error {
+ return c.client.ContainerUnpause(ctx, c.id)
+}
+
+// Checkpoint is analogous to 'docker checkpoint'.
+func (c *Container) Checkpoint(ctx context.Context, name string) error {
+ return c.client.CheckpointCreate(ctx, c.Name, types.CheckpointCreateOptions{CheckpointID: name, Exit: true})
+}
+
+// Restore is analogous to 'docker start --checkname [name]'.
+func (c *Container) Restore(ctx context.Context, name string) error {
+ return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{CheckpointID: name})
+}
+
+// Logs is analogous 'docker logs'.
+func (c *Container) Logs(ctx context.Context) (string, error) {
+ var out bytes.Buffer
+ err := c.logs(ctx, &out, &out)
+ return out.String(), err
+}
+
+func (c *Container) logs(ctx context.Context, stdout, stderr *bytes.Buffer) error {
+ opts := types.ContainerLogsOptions{ShowStdout: true, ShowStderr: true}
+ writer, err := c.client.ContainerLogs(ctx, c.id, opts)
+ if err != nil {
+ return err
+ }
+ defer writer.Close()
+ _, err = stdcopy.StdCopy(stdout, stderr, writer)
+
+ return err
+}
+
+// ID returns the container id.
+func (c *Container) ID() string {
+ return c.id
+}
+
+// SandboxPid returns the container's pid.
+func (c *Container) SandboxPid(ctx context.Context) (int, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return -1, err
+ }
+ return resp.ContainerJSONBase.State.Pid, nil
+}
+
+// FindIP returns the IP address of the container.
+func (c *Container) FindIP(ctx context.Context) (net.IP, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return nil, err
+ }
+
+ ip := net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress)
+ if ip == nil {
+ return net.IP{}, fmt.Errorf("invalid IP: %q", ip)
+ }
+ return ip, nil
+}
+
+// FindPort returns the host port that is mapped to 'sandboxPort'.
+func (c *Container) FindPort(ctx context.Context, sandboxPort int) (int, error) {
+ desc, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return -1, fmt.Errorf("error retrieving port: %v", err)
+ }
+
+ format := fmt.Sprintf("%d/tcp", sandboxPort)
+ ports, ok := desc.NetworkSettings.Ports[nat.Port(format)]
+ if !ok {
+ return -1, fmt.Errorf("error retrieving port: %v", err)
+
+ }
+
+ port, err := strconv.Atoi(ports[0].HostPort)
+ if err != nil {
+ return -1, fmt.Errorf("error parsing port %q: %v", port, err)
+ }
+ return port, nil
+}
+
+// CopyFiles copies in and mounts the given files. They are always ReadOnly.
+func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) {
+ dir, err := ioutil.TempDir("", c.Name)
+ if err != nil {
+ c.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err)
+ return
+ }
+ c.cleanups = append(c.cleanups, func() { os.RemoveAll(dir) })
+ if err := os.Chmod(dir, 0755); err != nil {
+ c.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err)
+ return
+ }
+ for _, name := range sources {
+ src, err := testutil.FindFile(name)
+ if err != nil {
+ c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err)
+ return
+ }
+ dst := path.Join(dir, path.Base(name))
+ if err := testutil.Copy(src, dst); err != nil {
+ c.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
+ return
+ }
+ c.logger.Logf("copy: %s -> %s", src, dst)
+ }
+ opts.Mounts = append(opts.Mounts, mount.Mount{
+ Type: mount.TypeBind,
+ Source: dir,
+ Target: target,
+ ReadOnly: false,
+ })
+}
+
+// Status inspects the container returns its status.
+func (c *Container) Status(ctx context.Context) (types.ContainerState, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return types.ContainerState{}, err
+ }
+ return *resp.State, err
+}
+
+// Wait waits for the container to exit.
+func (c *Container) Wait(ctx context.Context) error {
+ statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning)
+ select {
+ case err := <-errChan:
+ return err
+ case <-statusChan:
+ return nil
+ }
+}
+
+// WaitTimeout waits for the container to exit with a timeout.
+func (c *Container) WaitTimeout(ctx context.Context, timeout time.Duration) error {
+ timeoutChan := time.After(timeout)
+ statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning)
+ select {
+ case err := <-errChan:
+ return err
+ case <-statusChan:
+ return nil
+ case <-timeoutChan:
+ return fmt.Errorf("container %s timed out after %v seconds", c.Name, timeout.Seconds())
+ }
+}
+
+// WaitForOutput searches container logs for pattern and returns or timesout.
+func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout time.Duration) (string, error) {
+ matches, err := c.WaitForOutputSubmatch(ctx, pattern, timeout)
+ if err != nil {
+ return "", err
+ }
+ if len(matches) == 0 {
+ return "", fmt.Errorf("didn't find pattern %s logs", pattern)
+ }
+ return matches[0], nil
+}
+
+// WaitForOutputSubmatch searches container logs for the given
+// pattern or times out. It returns any regexp submatches as well.
+func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) {
+ re := regexp.MustCompile(pattern)
+ if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil {
+ return matches, nil
+ }
+
+ for exp := time.Now().Add(timeout); time.Now().Before(exp); {
+ c.streams.Conn.SetDeadline(time.Now().Add(50 * time.Millisecond))
+ _, err := stdcopy.StdCopy(&c.streamBuf, &c.streamBuf, c.streams.Reader)
+
+ if err != nil {
+ // check that it wasn't a timeout
+ if nerr, ok := err.(net.Error); !ok || !nerr.Timeout() {
+ return nil, err
+ }
+ }
+
+ if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil {
+ return matches, nil
+ }
+ }
+
+ return nil, fmt.Errorf("timeout waiting for output %q: out: %s", re.String(), c.streamBuf.String())
+}
+
+// Kill kills the container.
+func (c *Container) Kill(ctx context.Context) error {
+ return c.client.ContainerKill(ctx, c.id, "")
+}
+
+// Remove is analogous to 'docker rm'.
+func (c *Container) Remove(ctx context.Context) error {
+ // Remove the image.
+ remove := types.ContainerRemoveOptions{
+ RemoveVolumes: c.mounts != nil,
+ RemoveLinks: c.links != nil,
+ Force: true,
+ }
+ return c.client.ContainerRemove(ctx, c.Name, remove)
+}
+
+// CleanUp kills and deletes the container (best effort).
+func (c *Container) CleanUp(ctx context.Context) {
+ // Kill the container.
+ if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") {
+ // Just log; can't do anything here.
+ c.logger.Logf("error killing container %q: %v", c.Name, err)
+ }
+ // Remove the image.
+ if err := c.Remove(ctx); err != nil {
+ c.logger.Logf("error removing container %q: %v", c.Name, err)
+ }
+ // Forget all mounts.
+ c.mounts = nil
+ // Execute all cleanups.
+ for _, c := range c.cleanups {
+ c()
+ }
+ c.cleanups = nil
+}
diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go
index 819dd0a59..df09babf3 100644
--- a/pkg/test/dockerutil/dockerutil.go
+++ b/pkg/test/dockerutil/dockerutil.go
@@ -22,17 +22,10 @@ import (
"io"
"io/ioutil"
"log"
- "net"
- "os"
"os/exec"
- "path"
"regexp"
"strconv"
- "strings"
- "syscall"
- "time"
- "github.com/kr/pty"
"gvisor.dev/gvisor/pkg/test/testutil"
)
@@ -127,595 +120,7 @@ func Save(logger testutil.Logger, image string, w io.Writer) error {
return cmd.Run()
}
-// MountMode describes if the mount should be ro or rw.
-type MountMode int
-
-const (
- // ReadOnly is what the name says.
- ReadOnly MountMode = iota
- // ReadWrite is what the name says.
- ReadWrite
-)
-
-// String returns the mount mode argument for this MountMode.
-func (m MountMode) String() string {
- switch m {
- case ReadOnly:
- return "ro"
- case ReadWrite:
- return "rw"
- }
- panic(fmt.Sprintf("invalid mode: %d", m))
-}
-
-// DockerNetwork contains the name of a docker network.
-type DockerNetwork struct {
- logger testutil.Logger
- Name string
- Subnet *net.IPNet
- containers []*Docker
-}
-
-// NewDockerNetwork sets up the struct for a Docker network. Names of networks
-// will be unique.
-func NewDockerNetwork(logger testutil.Logger) *DockerNetwork {
- return &DockerNetwork{
- logger: logger,
- Name: testutil.RandomID(logger.Name()),
- }
-}
-
-// Create calls 'docker network create'.
-func (n *DockerNetwork) Create(args ...string) error {
- a := []string{"docker", "network", "create"}
- if n.Subnet != nil {
- a = append(a, fmt.Sprintf("--subnet=%s", n.Subnet))
- }
- a = append(a, args...)
- a = append(a, n.Name)
- return testutil.Command(n.logger, a...).Run()
-}
-
-// Connect calls 'docker network connect' with the arguments provided.
-func (n *DockerNetwork) Connect(container *Docker, args ...string) error {
- a := []string{"docker", "network", "connect"}
- a = append(a, args...)
- a = append(a, n.Name, container.Name)
- if err := testutil.Command(n.logger, a...).Run(); err != nil {
- return err
- }
- n.containers = append(n.containers, container)
- return nil
-}
-
-// Cleanup cleans up the docker network and all the containers attached to it.
-func (n *DockerNetwork) Cleanup() error {
- for _, c := range n.containers {
- // Don't propagate the error, it might be that the container
- // was already cleaned up.
- if err := c.Kill(); err != nil {
- n.logger.Logf("unable to kill container during cleanup: %s", err)
- }
- }
-
- if err := testutil.Command(n.logger, "docker", "network", "rm", n.Name).Run(); err != nil {
- return err
- }
- return nil
-}
-
-// Docker contains the name and the runtime of a docker container.
-type Docker struct {
- logger testutil.Logger
- Runtime string
- Name string
- copyErr error
- cleanups []func()
-}
-
-// MakeDocker sets up the struct for a Docker container.
-//
-// Names of containers will be unique.
-func MakeDocker(logger testutil.Logger) *Docker {
- // Slashes are not allowed in container names.
- name := testutil.RandomID(logger.Name())
- name = strings.ReplaceAll(name, "/", "-")
-
- return &Docker{
- logger: logger,
- Name: name,
- Runtime: *runtime,
- }
-}
-
-// CopyFiles copies in and mounts the given files. They are always ReadOnly.
-func (d *Docker) CopyFiles(opts *RunOpts, targetDir string, sources ...string) {
- dir, err := ioutil.TempDir("", d.Name)
- if err != nil {
- d.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err)
- return
- }
- d.cleanups = append(d.cleanups, func() { os.RemoveAll(dir) })
- if err := os.Chmod(dir, 0755); err != nil {
- d.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err)
- return
- }
- for _, name := range sources {
- src, err := testutil.FindFile(name)
- if err != nil {
- d.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err)
- return
- }
- dst := path.Join(dir, path.Base(name))
- if err := testutil.Copy(src, dst); err != nil {
- d.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
- return
- }
- d.logger.Logf("copy: %s -> %s", src, dst)
- }
- opts.Mounts = append(opts.Mounts, Mount{
- Source: dir,
- Target: targetDir,
- Mode: ReadOnly,
- })
-}
-
-// Mount describes a mount point inside the container.
-type Mount struct {
- // Source is the path outside the container.
- Source string
-
- // Target is the path inside the container.
- Target string
-
- // Mode tells whether the mount inside the container should be readonly.
- Mode MountMode
-}
-
-// Link informs dockers that a given container needs to be made accessible from
-// the container being configured.
-type Link struct {
- // Source is the container to connect to.
- Source *Docker
-
- // Target is the alias for the container.
- Target string
-}
-
-// RunOpts are options for running a container.
-type RunOpts struct {
- // Image is the image relative to images/. This will be mangled
- // appropriately, to ensure that only first-party images are used.
- Image string
-
- // Memory is the memory limit in kB.
- Memory int
-
- // Ports are the ports to be allocated.
- Ports []int
-
- // WorkDir sets the working directory.
- WorkDir string
-
- // ReadOnly sets the read-only flag.
- ReadOnly bool
-
- // Env are additional environment variables.
- Env []string
-
- // User is the user to use.
- User string
-
- // Privileged enables privileged mode.
- Privileged bool
-
- // CapAdd are the extra set of capabilities to add.
- CapAdd []string
-
- // CapDrop are the extra set of capabilities to drop.
- CapDrop []string
-
- // Pty indicates that a pty will be allocated. If this is non-nil, then
- // this will run after start-up with the *exec.Command and Pty file
- // passed in to the function.
- Pty func(*exec.Cmd, *os.File)
-
- // Foreground indicates that the container should be run in the
- // foreground. If this is true, then the output will be available as a
- // return value from the Run function.
- Foreground bool
-
- // Mounts is the list of directories/files to be mounted inside the container.
- Mounts []Mount
-
- // Links is the list of containers to be connected to the container.
- Links []Link
-
- // Extra are extra arguments that may be passed.
- Extra []string
-}
-
-// args returns common arguments.
-//
-// Note that this does not define the complete behavior.
-func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) {
- isExec := command == "exec"
- isRun := command == "run"
-
- if isRun || isExec {
- rv = append(rv, "-i")
- }
- if r.Pty != nil {
- rv = append(rv, "-t")
- }
- if r.User != "" {
- rv = append(rv, fmt.Sprintf("--user=%s", r.User))
- }
- if r.Privileged {
- rv = append(rv, "--privileged")
- }
- for _, c := range r.CapAdd {
- rv = append(rv, fmt.Sprintf("--cap-add=%s", c))
- }
- for _, c := range r.CapDrop {
- rv = append(rv, fmt.Sprintf("--cap-drop=%s", c))
- }
- for _, e := range r.Env {
- rv = append(rv, fmt.Sprintf("--env=%s", e))
- }
- if r.WorkDir != "" {
- rv = append(rv, fmt.Sprintf("--workdir=%s", r.WorkDir))
- }
- if !isExec {
- if r.Memory != 0 {
- rv = append(rv, fmt.Sprintf("--memory=%dk", r.Memory))
- }
- for _, p := range r.Ports {
- rv = append(rv, fmt.Sprintf("--publish=%d", p))
- }
- if r.ReadOnly {
- rv = append(rv, fmt.Sprintf("--read-only"))
- }
- if len(p) > 0 {
- rv = append(rv, "--entrypoint=")
- }
- }
-
- // Always attach the test environment & Extra.
- rv = append(rv, fmt.Sprintf("--env=RUNSC_TEST_NAME=%s", d.Name))
- rv = append(rv, r.Extra...)
-
- // Attach necessary bits.
- if isExec {
- rv = append(rv, d.Name)
- } else {
- for _, m := range r.Mounts {
- rv = append(rv, fmt.Sprintf("-v=%s:%s:%v", m.Source, m.Target, m.Mode))
- }
- for _, l := range r.Links {
- rv = append(rv, fmt.Sprintf("--link=%s:%s", l.Source.Name, l.Target))
- }
-
- if len(d.Runtime) > 0 {
- rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime))
- }
- rv = append(rv, fmt.Sprintf("--name=%s", d.Name))
- rv = append(rv, testutil.ImageByName(r.Image))
- }
-
- // Attach other arguments.
- rv = append(rv, p...)
- return rv
-}
-
-// run runs a complete command.
-func (d *Docker) run(r RunOpts, command string, p ...string) (string, error) {
- if d.copyErr != nil {
- return "", d.copyErr
- }
- basicArgs := []string{"docker"}
- if command == "spawn" {
- command = "run"
- basicArgs = append(basicArgs, command)
- basicArgs = append(basicArgs, "-d")
- } else {
- basicArgs = append(basicArgs, command)
- }
- customArgs := d.argsFor(&r, command, p)
- cmd := testutil.Command(d.logger, append(basicArgs, customArgs...)...)
- if r.Pty != nil {
- // If allocating a terminal, then we just ignore the output
- // from the command.
- ptmx, err := pty.Start(cmd.Cmd)
- if err != nil {
- return "", err
- }
- defer cmd.Wait() // Best effort.
- r.Pty(cmd.Cmd, ptmx)
- } else {
- // Can't support PTY or streaming.
- out, err := cmd.CombinedOutput()
- return string(out), err
- }
- return "", nil
-}
-
-// Create calls 'docker create' with the arguments provided.
-func (d *Docker) Create(r RunOpts, args ...string) error {
- out, err := d.run(r, "create", args...)
- if strings.Contains(out, "Unable to find image") {
- return fmt.Errorf("unable to find image, did you remember to `make load-%s`: %w", r.Image, err)
- }
- return err
-}
-
-// Start calls 'docker start'.
-func (d *Docker) Start() error {
- return testutil.Command(d.logger, "docker", "start", d.Name).Run()
-}
-
-// Stop calls 'docker stop'.
-func (d *Docker) Stop() error {
- return testutil.Command(d.logger, "docker", "stop", d.Name).Run()
-}
-
-// Run calls 'docker run' with the arguments provided.
-func (d *Docker) Run(r RunOpts, args ...string) (string, error) {
- return d.run(r, "run", args...)
-}
-
-// Spawn starts the container and detaches.
-func (d *Docker) Spawn(r RunOpts, args ...string) error {
- _, err := d.run(r, "spawn", args...)
- return err
-}
-
-// Logs calls 'docker logs'.
-func (d *Docker) Logs() (string, error) {
- // Don't capture the output; since it will swamp the logs.
- out, err := exec.Command("docker", "logs", d.Name).CombinedOutput()
- return string(out), err
-}
-
-// Exec calls 'docker exec' with the arguments provided.
-func (d *Docker) Exec(r RunOpts, args ...string) (string, error) {
- return d.run(r, "exec", args...)
-}
-
-// Pause calls 'docker pause'.
-func (d *Docker) Pause() error {
- return testutil.Command(d.logger, "docker", "pause", d.Name).Run()
-}
-
-// Unpause calls 'docker pause'.
-func (d *Docker) Unpause() error {
- return testutil.Command(d.logger, "docker", "unpause", d.Name).Run()
-}
-
-// Checkpoint calls 'docker checkpoint'.
-func (d *Docker) Checkpoint(name string) error {
- return testutil.Command(d.logger, "docker", "checkpoint", "create", d.Name, name).Run()
-}
-
-// Restore calls 'docker start --checkname [name]'.
-func (d *Docker) Restore(name string) error {
- return testutil.Command(d.logger, "docker", "start", fmt.Sprintf("--checkpoint=%s", name), d.Name).Run()
-}
-
-// Kill calls 'docker kill'.
-func (d *Docker) Kill() error {
- // Skip logging this command, it will likely be an error.
- out, err := exec.Command("docker", "kill", d.Name).CombinedOutput()
- if err != nil && !strings.Contains(string(out), "is not running") {
- return err
- }
- return nil
-}
-
-// Remove calls 'docker rm'.
-func (d *Docker) Remove() error {
- return testutil.Command(d.logger, "docker", "rm", d.Name).Run()
-}
-
-// CleanUp kills and deletes the container (best effort).
-func (d *Docker) CleanUp() {
- // Kill the container.
- if err := d.Kill(); err != nil {
- // Just log; can't do anything here.
- d.logger.Logf("error killing container %q: %v", d.Name, err)
- }
- // Remove the image.
- if err := d.Remove(); err != nil {
- d.logger.Logf("error removing container %q: %v", d.Name, err)
- }
- // Execute all cleanups.
- for _, c := range d.cleanups {
- c()
- }
- d.cleanups = nil
-}
-
-// FindPort returns the host port that is mapped to 'sandboxPort'. This calls
-// docker to allocate a free port in the host and prevent conflicts.
-func (d *Docker) FindPort(sandboxPort int) (int, error) {
- format := fmt.Sprintf(`{{ (index (index .NetworkSettings.Ports "%d/tcp") 0).HostPort }}`, sandboxPort)
- out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput()
- if err != nil {
- return -1, fmt.Errorf("error retrieving port: %v", err)
- }
- port, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
- if err != nil {
- return -1, fmt.Errorf("error parsing port %q: %v", out, err)
- }
- return port, nil
-}
-
-// FindIP returns the IP address of the container.
-func (d *Docker) FindIP() (net.IP, error) {
- const format = `{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}`
- out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput()
- if err != nil {
- return net.IP{}, fmt.Errorf("error retrieving IP: %v", err)
- }
- ip := net.ParseIP(strings.TrimSpace(string(out)))
- if ip == nil {
- return net.IP{}, fmt.Errorf("invalid IP: %q", string(out))
- }
- return ip, nil
-}
-
-// A NetworkInterface is container's network interface information.
-type NetworkInterface struct {
- IPv4 net.IP
- MAC net.HardwareAddr
-}
-
-// ListNetworks returns the network interfaces of the container, keyed by
-// Docker network name.
-func (d *Docker) ListNetworks() (map[string]NetworkInterface, error) {
- const format = `{{json .NetworkSettings.Networks}}`
- out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput()
- if err != nil {
- return nil, fmt.Errorf("error network interfaces: %q: %w", string(out), err)
- }
-
- networks := map[string]map[string]string{}
- if err := json.Unmarshal(out, &networks); err != nil {
- return nil, fmt.Errorf("error decoding network interfaces: %w", err)
- }
-
- interfaces := map[string]NetworkInterface{}
- for name, iface := range networks {
- var netface NetworkInterface
-
- rawIP := strings.TrimSpace(iface["IPAddress"])
- if rawIP != "" {
- ip := net.ParseIP(rawIP)
- if ip == nil {
- return nil, fmt.Errorf("invalid IP: %q", rawIP)
- }
- // Docker's IPAddress field is IPv4. The IPv6 address
- // is stored in the GlobalIPv6Address field.
- netface.IPv4 = ip
- }
-
- rawMAC := strings.TrimSpace(iface["MacAddress"])
- if rawMAC != "" {
- mac, err := net.ParseMAC(rawMAC)
- if err != nil {
- return nil, fmt.Errorf("invalid MAC: %q: %w", rawMAC, err)
- }
- netface.MAC = mac
- }
-
- interfaces[name] = netface
- }
-
- return interfaces, nil
-}
-
-// SandboxPid returns the PID to the sandbox process.
-func (d *Docker) SandboxPid() (int, error) {
- out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput()
- if err != nil {
- return -1, fmt.Errorf("error retrieving pid: %v", err)
- }
- pid, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
- if err != nil {
- return -1, fmt.Errorf("error parsing pid %q: %v", out, err)
- }
- return pid, nil
-}
-
-// ID returns the container ID.
-func (d *Docker) ID() (string, error) {
- out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.Id}}", d.Name).CombinedOutput()
- if err != nil {
- return "", fmt.Errorf("error retrieving ID: %v", err)
- }
- return strings.TrimSpace(string(out)), nil
-}
-
-// Wait waits for container to exit, up to the given timeout. Returns error if
-// wait fails or timeout is hit. Returns the application return code otherwise.
-// Note that the application may have failed even if err == nil, always check
-// the exit code.
-func (d *Docker) Wait(timeout time.Duration) (syscall.WaitStatus, error) {
- timeoutChan := time.After(timeout)
- waitChan := make(chan (syscall.WaitStatus))
- errChan := make(chan (error))
-
- go func() {
- out, err := testutil.Command(d.logger, "docker", "wait", d.Name).CombinedOutput()
- if err != nil {
- errChan <- fmt.Errorf("error waiting for container %q: %v", d.Name, err)
- }
- exit, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
- if err != nil {
- errChan <- fmt.Errorf("error parsing exit code %q: %v", out, err)
- }
- waitChan <- syscall.WaitStatus(uint32(exit))
- }()
-
- select {
- case ws := <-waitChan:
- return ws, nil
- case err := <-errChan:
- return syscall.WaitStatus(1), err
- case <-timeoutChan:
- return syscall.WaitStatus(1), fmt.Errorf("timeout waiting for container %q", d.Name)
- }
-}
-
-// WaitForOutput calls 'docker logs' to retrieve containers output and searches
-// for the given pattern.
-func (d *Docker) WaitForOutput(pattern string, timeout time.Duration) (string, error) {
- matches, err := d.WaitForOutputSubmatch(pattern, timeout)
- if err != nil {
- return "", err
- }
- if len(matches) == 0 {
- return "", nil
- }
- return matches[0], nil
-}
-
-// WaitForOutputSubmatch calls 'docker logs' to retrieve containers output and
-// searches for the given pattern. It returns any regexp submatches as well.
-func (d *Docker) WaitForOutputSubmatch(pattern string, timeout time.Duration) ([]string, error) {
- re := regexp.MustCompile(pattern)
- var (
- lastOut string
- stopped bool
- )
- for exp := time.Now().Add(timeout); time.Now().Before(exp); {
- out, err := d.Logs()
- if err != nil {
- return nil, err
- }
- if out != lastOut {
- if lastOut == "" {
- d.logger.Logf("output (start): %s", out)
- } else if strings.HasPrefix(out, lastOut) {
- d.logger.Logf("output (contn): %s", out[len(lastOut):])
- } else {
- d.logger.Logf("output (trunc): %s", out)
- }
- lastOut = out // Save for future.
- if matches := re.FindStringSubmatch(lastOut); matches != nil {
- return matches, nil // Success!
- }
- } else if stopped {
- // The sandbox stopped and we looked at the
- // logs at least once since determining that.
- return nil, fmt.Errorf("no longer running: %v", err)
- } else if pid, err := d.SandboxPid(); pid == 0 || err != nil {
- // The sandbox may have stopped, but it's
- // possible that it has emitted the terminal
- // line between the last call to Logs and here.
- stopped = true
- }
- time.Sleep(100 * time.Millisecond)
- }
- return nil, fmt.Errorf("timeout waiting for output %q: %s", re.String(), lastOut)
+// Runtime returns the value of the flag runtime.
+func Runtime() string {
+ return *runtime
}
diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go
new file mode 100644
index 000000000..4c739c9e9
--- /dev/null
+++ b/pkg/test/dockerutil/exec.go
@@ -0,0 +1,193 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/pkg/stdcopy"
+)
+
+// ExecOpts holds arguments for Exec calls.
+type ExecOpts struct {
+ // Env are additional environment variables.
+ Env []string
+
+ // Privileged enables privileged mode.
+ Privileged bool
+
+ // User is the user to use.
+ User string
+
+ // Enables Tty and stdin for the created process.
+ UseTTY bool
+
+ // WorkDir is the working directory of the process.
+ WorkDir string
+}
+
+// Exec creates a process inside the container.
+func (c *Container) Exec(ctx context.Context, opts ExecOpts, args ...string) (string, error) {
+ p, err := c.doExec(ctx, opts, args)
+ if err != nil {
+ return "", err
+ }
+
+ if exitStatus, err := p.WaitExitStatus(ctx); err != nil {
+ return "", err
+ } else if exitStatus != 0 {
+ out, _ := p.Logs()
+ return out, fmt.Errorf("process terminated with status: %d", exitStatus)
+ }
+
+ return p.Logs()
+}
+
+// ExecProcess creates a process inside the container and returns a process struct
+// for the caller to use.
+func (c *Container) ExecProcess(ctx context.Context, opts ExecOpts, args ...string) (Process, error) {
+ return c.doExec(ctx, opts, args)
+}
+
+func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Process, error) {
+ config := c.execConfig(r, args)
+ resp, err := c.client.ContainerExecCreate(ctx, c.id, config)
+ if err != nil {
+ return Process{}, fmt.Errorf("exec create failed with err: %v", err)
+ }
+
+ hijack, err := c.client.ContainerExecAttach(ctx, resp.ID, types.ExecStartCheck{})
+ if err != nil {
+ return Process{}, fmt.Errorf("exec attach failed with err: %v", err)
+ }
+
+ if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil {
+ hijack.Close()
+ return Process{}, fmt.Errorf("exec start failed with err: %v", err)
+ }
+
+ return Process{
+ container: c,
+ execid: resp.ID,
+ conn: hijack,
+ }, nil
+}
+
+func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig {
+ env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name))
+ return types.ExecConfig{
+ AttachStdin: r.UseTTY,
+ AttachStderr: true,
+ AttachStdout: true,
+ Cmd: cmd,
+ Privileged: r.Privileged,
+ WorkingDir: r.WorkDir,
+ Env: env,
+ Tty: r.UseTTY,
+ User: r.User,
+ }
+
+}
+
+// Process represents a containerized process.
+type Process struct {
+ container *Container
+ execid string
+ conn types.HijackedResponse
+}
+
+// Write writes buf to the process's stdin.
+func (p *Process) Write(timeout time.Duration, buf []byte) (int, error) {
+ p.conn.Conn.SetDeadline(time.Now().Add(timeout))
+ return p.conn.Conn.Write(buf)
+}
+
+// Read returns process's stdout and stderr.
+func (p *Process) Read() (string, string, error) {
+ var stdout, stderr bytes.Buffer
+ if err := p.read(&stdout, &stderr); err != nil {
+ return "", "", err
+ }
+ return stdout.String(), stderr.String(), nil
+}
+
+// Logs returns combined stdout/stderr from the process.
+func (p *Process) Logs() (string, error) {
+ var out bytes.Buffer
+ if err := p.read(&out, &out); err != nil {
+ return "", err
+ }
+ return out.String(), nil
+}
+
+func (p *Process) read(stdout, stderr *bytes.Buffer) error {
+ _, err := stdcopy.StdCopy(stdout, stderr, p.conn.Reader)
+ return err
+}
+
+// ExitCode returns the process's exit code.
+func (p *Process) ExitCode(ctx context.Context) (int, error) {
+ _, exitCode, err := p.runningExitCode(ctx)
+ return exitCode, err
+}
+
+// IsRunning checks if the process is running.
+func (p *Process) IsRunning(ctx context.Context) (bool, error) {
+ running, _, err := p.runningExitCode(ctx)
+ return running, err
+}
+
+// WaitExitStatus until process completes and returns exit status.
+func (p *Process) WaitExitStatus(ctx context.Context) (int, error) {
+ waitChan := make(chan (int))
+ errChan := make(chan (error))
+
+ go func() {
+ for {
+ running, exitcode, err := p.runningExitCode(ctx)
+ if err != nil {
+ errChan <- fmt.Errorf("error waiting process %s: container %v", p.execid, p.container.Name)
+ }
+ if !running {
+ waitChan <- exitcode
+ }
+ time.Sleep(time.Millisecond * 500)
+ }
+ }()
+
+ select {
+ case ws := <-waitChan:
+ return ws, nil
+ case err := <-errChan:
+ return -1, err
+ }
+}
+
+// runningExitCode collects if the process is running and the exit code.
+// The exit code is only valid if the process has exited.
+func (p *Process) runningExitCode(ctx context.Context) (bool, int, error) {
+ // If execid is not empty, this is a execed process.
+ if p.execid != "" {
+ status, err := p.container.client.ContainerExecInspect(ctx, p.execid)
+ return status.Running, status.ExitCode, err
+ }
+ // else this is the root process.
+ status, err := p.container.Status(ctx)
+ return status.Running, status.ExitCode, err
+}
diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go
new file mode 100644
index 000000000..047091e75
--- /dev/null
+++ b/pkg/test/dockerutil/network.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "net"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/network"
+ "github.com/docker/docker/client"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Network is a docker network.
+type Network struct {
+ client *client.Client
+ id string
+ logger testutil.Logger
+ Name string
+ containers []*Container
+ Subnet *net.IPNet
+}
+
+// NewNetwork sets up the struct for a Docker network. Names of networks
+// will be unique.
+func NewNetwork(ctx context.Context, logger testutil.Logger) *Network {
+ client, err := client.NewClientWithOpts(client.FromEnv)
+ if err != nil {
+ logger.Logf("create client failed with: %v", err)
+ return nil
+ }
+ client.NegotiateAPIVersion(ctx)
+
+ return &Network{
+ logger: logger,
+ Name: testutil.RandomID(logger.Name()),
+ client: client,
+ }
+}
+
+func (n *Network) networkCreate() types.NetworkCreate {
+
+ var subnet string
+ if n.Subnet != nil {
+ subnet = n.Subnet.String()
+ }
+
+ ipam := network.IPAM{
+ Config: []network.IPAMConfig{{
+ Subnet: subnet,
+ }},
+ }
+
+ return types.NetworkCreate{
+ CheckDuplicate: true,
+ IPAM: &ipam,
+ }
+}
+
+// Create is analogous to 'docker network create'.
+func (n *Network) Create(ctx context.Context) error {
+
+ opts := n.networkCreate()
+ resp, err := n.client.NetworkCreate(ctx, n.Name, opts)
+ if err != nil {
+ return err
+ }
+ n.id = resp.ID
+ return nil
+}
+
+// Connect is analogous to 'docker network connect' with the arguments provided.
+func (n *Network) Connect(ctx context.Context, container *Container, ipv4, ipv6 string) error {
+ settings := network.EndpointSettings{
+ IPAMConfig: &network.EndpointIPAMConfig{
+ IPv4Address: ipv4,
+ IPv6Address: ipv6,
+ },
+ }
+ err := n.client.NetworkConnect(ctx, n.id, container.id, &settings)
+ if err == nil {
+ n.containers = append(n.containers, container)
+ }
+ return err
+}
+
+// Inspect returns this network's info.
+func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) {
+ return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true})
+}
+
+// Cleanup cleans up the docker network and all the containers attached to it.
+func (n *Network) Cleanup(ctx context.Context) error {
+ for _, c := range n.containers {
+ c.CleanUp(ctx)
+ }
+ n.containers = nil
+
+ return n.client.NetworkRemove(ctx, n.id)
+}
diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD
index 03b1b4677..2d8f56bc0 100644
--- a/pkg/test/testutil/BUILD
+++ b/pkg/test/testutil/BUILD
@@ -15,6 +15,6 @@ go_library(
"//runsc/boot",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
],
)
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index f21d6769a..64c292698 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -482,6 +482,21 @@ func IsStatic(filename string) (bool, error) {
return true, nil
}
+// TouchShardStatusFile indicates to Bazel that the test runner supports
+// sharding by creating or updating the last modified date of the file
+// specified by TEST_SHARD_STATUS_FILE.
+//
+// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner.
+func TouchShardStatusFile() error {
+ if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" {
+ cmd := exec.Command("touch", statusFile)
+ if b, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error())
+ }
+ }
+ return nil
+}
+
// TestIndicesForShard returns indices for this test shard based on the
// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars.
//