summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD4
-rw-r--r--pkg/abi/linux/fuse.go143
-rw-r--r--pkg/abi/linux/futex.go18
-rw-r--r--pkg/abi/linux/netdevice.go4
-rw-r--r--pkg/abi/linux/netfilter.go146
-rw-r--r--pkg/abi/linux/socket.go17
-rw-r--r--pkg/p9/messages.go2
-rw-r--r--pkg/sentry/arch/arch_aarch64.go23
-rw-r--r--pkg/sentry/arch/arch_amd64.go4
-rw-r--r--pkg/sentry/arch/arch_arm64.go4
-rw-r--r--pkg/sentry/arch/arch_x86.go16
-rw-r--r--pkg/sentry/fs/file_operations.go1
-rw-r--r--pkg/sentry/fs/fsutil/BUILD7
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set.go7
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go15
-rw-r--r--pkg/sentry/fs/fsutil/frame_ref_set.go10
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go5
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go19
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go25
-rw-r--r--pkg/sentry/fsimpl/devpts/slave.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD45
-rw-r--r--pkg/sentry/fsimpl/fuse/connection.go255
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go317
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go429
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go224
-rw-r--r--pkg/sentry/fsimpl/fuse/register.go42
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go19
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go29
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go68
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go56
-rw-r--r--pkg/sentry/fsimpl/host/BUILD1
-rw-r--r--pkg/sentry/fsimpl/host/host.go7
-rw-r--r--pkg/sentry/fsimpl/host/mmap.go21
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/non_directory.go4
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go35
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go7
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/futex/futex.go8
-rw-r--r--pkg/sentry/kernel/kernel.go5
-rw-r--r--pkg/sentry/kernel/shm/BUILD1
-rw-r--r--pkg/sentry/kernel/shm/shm.go3
-rw-r--r--pkg/sentry/kernel/task.go19
-rw-r--r--pkg/sentry/kernel/task_exec.go3
-rw-r--r--pkg/sentry/kernel/task_exit.go3
-rw-r--r--pkg/sentry/kernel/task_futex.go125
-rw-r--r--pkg/sentry/kernel/task_run.go17
-rw-r--r--pkg/sentry/kernel/task_work.go38
-rw-r--r--pkg/sentry/kernel/time/BUILD1
-rw-r--r--pkg/sentry/kernel/time/tcpip.go131
-rw-r--r--pkg/sentry/kernel/timekeeper.go4
-rw-r--r--pkg/sentry/kernel/vdso.go6
-rw-r--r--pkg/sentry/memmap/BUILD14
-rw-r--r--pkg/sentry/memmap/memmap.go60
-rw-r--r--pkg/sentry/mm/BUILD4
-rw-r--r--pkg/sentry/mm/aio_context.go3
-rw-r--r--pkg/sentry/mm/mm.go10
-rw-r--r--pkg/sentry/mm/pma.go25
-rw-r--r--pkg/sentry/mm/special_mappable.go7
-rw-r--r--pkg/sentry/pgalloc/BUILD10
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go66
-rw-r--r--pkg/sentry/platform/BUILD20
-rw-r--r--pkg/sentry/platform/kvm/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/address_space.go3
-rw-r--r--pkg/sentry/platform/platform.go50
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go3
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go6
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/BUILD2
-rw-r--r--pkg/sentry/socket/hostinet/socket.go9
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go53
-rw-r--r--pkg/sentry/socket/netlink/BUILD2
-rw-r--r--pkg/sentry/socket/netlink/socket.go14
-rw-r--r--pkg/sentry/socket/netstack/BUILD2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go278
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go16
-rw-r--r--pkg/sentry/socket/socket.go3
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/unix.go3
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go3
-rw-r--r--pkg/sentry/syscalls/linux/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go9
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go321
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go104
-rw-r--r--pkg/sentry/vfs/options.go17
-rw-r--r--pkg/sentry/vfs/permissions.go8
-rw-r--r--pkg/tcpip/header/icmpv4.go1
-rw-r--r--pkg/tcpip/header/icmpv6.go11
-rw-r--r--pkg/tcpip/link/channel/channel.go4
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go36
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go8
-rw-r--r--pkg/tcpip/link/loopback/loopback.go3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go4
-rw-r--r--pkg/tcpip/link/nested/nested.go15
-rw-r--r--pkg/tcpip/link/nested/nested_test.go4
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go12
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go21
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go4
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go5
-rw-r--r--pkg/tcpip/link/tun/device.go43
-rw-r--r--pkg/tcpip/link/waitable/waitable.go14
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go9
-rw-r--r--pkg/tcpip/network/arp/arp.go7
-rw-r--r--pkg/tcpip/network/arp/arp_test.go58
-rw-r--r--pkg/tcpip/network/ip_test.go5
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go3
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go10
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go52
-rw-r--r--pkg/tcpip/stack/conntrack.go136
-rw-r--r--pkg/tcpip/stack/forwarder_test.go23
-rw-r--r--pkg/tcpip/stack/iptables.go173
-rw-r--r--pkg/tcpip/stack/iptables_targets.go2
-rw-r--r--pkg/tcpip/stack/iptables_types.go22
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go2
-rw-r--r--pkg/tcpip/stack/ndp.go140
-rw-r--r--pkg/tcpip/stack/ndp_test.go28
-rw-r--r--pkg/tcpip/stack/nic.go30
-rw-r--r--pkg/tcpip/stack/nic_test.go7
-rw-r--r--pkg/tcpip/stack/packet_buffer.go4
-rw-r--r--pkg/tcpip/stack/registration.go28
-rw-r--r--pkg/tcpip/stack/stack.go12
-rw-r--r--pkg/tcpip/tcpip.go52
-rw-r--r--pkg/tcpip/time_unsafe.go30
-rw-r--r--pkg/tcpip/timer.go147
-rw-r--r--pkg/tcpip/timer_test.go91
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go54
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/connect.go6
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go26
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go4
-rw-r--r--pkg/tcpip/transport/tcp/segment.go6
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go23
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go2
-rw-r--r--pkg/test/criutil/criutil.go8
-rw-r--r--pkg/test/dockerutil/BUILD19
-rw-r--r--pkg/test/dockerutil/README.md86
-rw-r--r--pkg/test/dockerutil/container.go82
-rw-r--r--pkg/test/dockerutil/dockerutil.go21
-rw-r--r--pkg/test/dockerutil/exec.go1
-rw-r--r--pkg/test/dockerutil/profile.go152
-rw-r--r--pkg/test/dockerutil/profile_test.go117
159 files changed, 4577 insertions, 1003 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 2b789c4ec..05ca5342f 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -29,6 +29,7 @@ go_library(
"file_amd64.go",
"file_arm64.go",
"fs.go",
+ "fuse.go",
"futex.go",
"inotify.go",
"ioctl.go",
@@ -72,6 +73,9 @@ go_library(
"//pkg/abi",
"//pkg/binary",
"//pkg/bits",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go
new file mode 100644
index 000000000..d3ebbccc4
--- /dev/null
+++ b/pkg/abi/linux/fuse.go
@@ -0,0 +1,143 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// +marshal
+type FUSEOpcode uint32
+
+// +marshal
+type FUSEOpID uint64
+
+// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h.
+const (
+ FUSE_LOOKUP FUSEOpcode = 1
+ FUSE_FORGET = 2 /* no reply */
+ FUSE_GETATTR = 3
+ FUSE_SETATTR = 4
+ FUSE_READLINK = 5
+ FUSE_SYMLINK = 6
+ _
+ FUSE_MKNOD = 8
+ FUSE_MKDIR = 9
+ FUSE_UNLINK = 10
+ FUSE_RMDIR = 11
+ FUSE_RENAME = 12
+ FUSE_LINK = 13
+ FUSE_OPEN = 14
+ FUSE_READ = 15
+ FUSE_WRITE = 16
+ FUSE_STATFS = 17
+ FUSE_RELEASE = 18
+ _
+ FUSE_FSYNC = 20
+ FUSE_SETXATTR = 21
+ FUSE_GETXATTR = 22
+ FUSE_LISTXATTR = 23
+ FUSE_REMOVEXATTR = 24
+ FUSE_FLUSH = 25
+ FUSE_INIT = 26
+ FUSE_OPENDIR = 27
+ FUSE_READDIR = 28
+ FUSE_RELEASEDIR = 29
+ FUSE_FSYNCDIR = 30
+ FUSE_GETLK = 31
+ FUSE_SETLK = 32
+ FUSE_SETLKW = 33
+ FUSE_ACCESS = 34
+ FUSE_CREATE = 35
+ FUSE_INTERRUPT = 36
+ FUSE_BMAP = 37
+ FUSE_DESTROY = 38
+ FUSE_IOCTL = 39
+ FUSE_POLL = 40
+ FUSE_NOTIFY_REPLY = 41
+ FUSE_BATCH_FORGET = 42
+)
+
+const (
+ // FUSE_MIN_READ_BUFFER is the minimum size the read can be for any FUSE filesystem.
+ // This is the minimum size Linux supports. See linux.fuse.h.
+ FUSE_MIN_READ_BUFFER uint32 = 8192
+)
+
+// FUSEHeaderIn is the header read by the daemon with each request.
+//
+// +marshal
+type FUSEHeaderIn struct {
+ // Len specifies the total length of the data, including this header.
+ Len uint32
+
+ // Opcode specifies the kind of operation of the request.
+ Opcode FUSEOpcode
+
+ // Unique specifies the unique identifier for this request.
+ Unique FUSEOpID
+
+ // NodeID is the ID of the filesystem object being operated on.
+ NodeID uint64
+
+ // UID is the UID of the requesting process.
+ UID uint32
+
+ // GID is the GID of the requesting process.
+ GID uint32
+
+ // PID is the PID of the requesting process.
+ PID uint32
+
+ _ uint32
+}
+
+// FUSEHeaderOut is the header written by the daemon when it processes
+// a request and wants to send a reply (almost all operations require a
+// reply; if they do not, this will be explicitly documented).
+//
+// +marshal
+type FUSEHeaderOut struct {
+ // Len specifies the total length of the data, including this header.
+ Len uint32
+
+ // Error specifies the error that occurred (0 if none).
+ Error int32
+
+ // Unique specifies the unique identifier of the corresponding request.
+ Unique FUSEOpID
+}
+
+// FUSEWriteIn is the header written by a daemon when it makes a
+// write request to the FUSE filesystem.
+//
+// +marshal
+type FUSEWriteIn struct {
+ // Fh specifies the file handle that is being written to.
+ Fh uint64
+
+ // Offset is the offset of the write.
+ Offset uint64
+
+ // Size is the size of data being written.
+ Size uint32
+
+ // WriteFlags is the flags used during the write.
+ WriteFlags uint32
+
+ // LockOwner is the ID of the lock owner.
+ LockOwner uint64
+
+ // Flags is the flags for the request.
+ Flags uint32
+
+ _ uint32
+}
diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go
index 08bfde3b5..8138088a6 100644
--- a/pkg/abi/linux/futex.go
+++ b/pkg/abi/linux/futex.go
@@ -60,3 +60,21 @@ const (
FUTEX_WAITERS = 0x80000000
FUTEX_OWNER_DIED = 0x40000000
)
+
+// FUTEX_BITSET_MATCH_ANY has all bits set.
+const FUTEX_BITSET_MATCH_ANY = 0xffffffff
+
+// ROBUST_LIST_LIMIT protects against a deliberately circular list.
+const ROBUST_LIST_LIMIT = 2048
+
+// RobustListHead corresponds to Linux's struct robust_list_head.
+//
+// +marshal
+type RobustListHead struct {
+ List uint64
+ FutexOffset uint64
+ ListOpPending uint64
+}
+
+// SizeOfRobustListHead is the size of a RobustListHead struct.
+var SizeOfRobustListHead = (*RobustListHead)(nil).SizeBytes()
diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go
index 7866352b4..0faf015c7 100644
--- a/pkg/abi/linux/netdevice.go
+++ b/pkg/abi/linux/netdevice.go
@@ -22,6 +22,8 @@ const (
)
// IFReq is an interface request.
+//
+// +marshal
type IFReq struct {
// IFName is an encoded name, normally null-terminated. This should be
// accessed via the Name and SetName functions.
@@ -79,6 +81,8 @@ type IFMap struct {
// IFConf is used to return a list of interfaces and their addresses. See
// netdevice(7) and struct ifconf for more detail on its use.
+//
+// +marshal
type IFConf struct {
Len int32
_ [4]byte // Pad to sizeof(struct ifconf).
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 46d8b0b42..a91f9f018 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -14,6 +14,14 @@
package linux
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
+)
+
// This file contains structures required to support netfilter, specifically
// the iptables tool.
@@ -76,6 +84,8 @@ const (
// IPTEntry is an iptable rule. It corresponds to struct ipt_entry in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTEntry struct {
// IP is used to filter packets based on the IP header.
IP IPTIP
@@ -112,21 +122,41 @@ type IPTEntry struct {
// SizeOfIPTEntry is the size of an IPTEntry.
const SizeOfIPTEntry = 112
-// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This
-// struct marshaled via the binary package to write an IPTEntry to userspace.
+// KernelIPTEntry is identical to IPTEntry, but includes the Elems field.
+// KernelIPTEntry itself is not Marshallable but it implements some methods of
+// marshal.Marshallable that help in other implementations of Marshallable.
type KernelIPTEntry struct {
- IPTEntry
+ Entry IPTEntry
// Elems holds the data for all this rule's matches followed by the
// target. It is variable length -- users have to iterate over any
// matches and use TargetOffset and NextOffset to make sense of the
// data.
- Elems []byte
+ Elems primitive.ByteSlice
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTEntry) SizeBytes() int {
+ return ke.Entry.SizeBytes() + ke.Elems.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTEntry) MarshalBytes(dst []byte) {
+ ke.Entry.MarshalBytes(dst)
+ ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) {
+ ke.Entry.UnmarshalBytes(src)
+ ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
// IPTIP contains information for matching a packet's IP header.
// It corresponds to struct ipt_ip in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTIP struct {
// Src is the source IP address.
Src InetAddr
@@ -189,6 +219,8 @@ const SizeOfIPTIP = 84
// XTCounters holds packet and byte counts for a rule. It corresponds to struct
// xt_counters in include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTCounters struct {
// Pcnt is the packet count.
Pcnt uint64
@@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetinfo struct {
Name TableName
ValidHooks uint32
@@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84
// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
// corresponds to struct ipt_get_entries in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetEntries struct {
Name TableName
Size uint32
@@ -350,13 +386,103 @@ type IPTGetEntries struct {
const SizeOfIPTGetEntries = 40
// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
-// Entrytable field. This struct marshaled via the binary package to write an
-// KernelIPTGetEntries to userspace.
+// Entrytable field. This has been manually made marshal.Marshallable since it
+// is dynamically sized.
type KernelIPTGetEntries struct {
IPTGetEntries
Entrytable []KernelIPTEntry
}
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTGetEntries) SizeBytes() int {
+ res := ke.IPTGetEntries.SizeBytes()
+ for _, entry := range ke.Entrytable {
+ res += entry.SizeBytes()
+ }
+ return res
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
+ ke.IPTGetEntries.MarshalBytes(dst)
+ marshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
+ marshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) {
+ ke.IPTGetEntries.UnmarshalBytes(src)
+ unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := 0; i < len(ke.Entrytable); i++ {
+ ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
+ unmarshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (ke *KernelIPTGetEntries) Packed() bool {
+ // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an
+ // indirection to the actual data we want to marshal (the slice data
+ // pointer), and the memory for KernelIPTGetEntries contains the slice
+ // header which we don't want to marshal.
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) {
+ // Fall back to safe Marshal because the type in not packed.
+ ke.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) {
+ // Fall back to safe Unmarshal because the type in not packed.
+ ke.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results in a
+ // partially unmarshalled struct.
+ ke.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task))
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+}
+
+func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte {
+ buf := task.CopyScratchBuffer(ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ return buf
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) {
+ buf := make([]byte, ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ length, err := w.Write(buf)
+ return int64(length), err
+}
+
+var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
+
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
@@ -374,12 +500,6 @@ type IPTReplace struct {
// Entries [0]IPTEntry
}
-// KernelIPTReplace is identical to IPTReplace, but includes the Entries field.
-type KernelIPTReplace struct {
- IPTReplace
- Entries [0]IPTEntry
-}
-
// SizeOfIPTReplace is the size of an IPTReplace.
const SizeOfIPTReplace = 96
@@ -392,6 +512,8 @@ func (en ExtensionName) String() string {
}
// TableName holds the name of a netfilter table.
+//
+// +marshal
type TableName [XT_TABLE_MAXNAMELEN]byte
// String implements fmt.Stringer.
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 4a14ef691..c24a8216e 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -134,6 +134,15 @@ const (
SHUT_RDWR = 2
)
+// Packet types from <linux/if_packet.h>
+const (
+ PACKET_HOST = 0 // To us
+ PACKET_BROADCAST = 1 // To all
+ PACKET_MULTICAST = 2 // To group
+ PACKET_OTHERHOST = 3 // To someone else
+ PACKET_OUTGOING = 4 // Outgoing of any type
+)
+
// Socket options from socket.h.
const (
SO_DEBUG = 1
@@ -225,6 +234,8 @@ const (
const SockAddrMax = 128
// InetAddr is struct in_addr, from uapi/linux/in.h.
+//
+// +marshal
type InetAddr [4]byte
// SockAddrInet is struct sockaddr_in, from uapi/linux/in.h.
@@ -294,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {}
func (s *SockAddrNetlink) implementsSockAddr() {}
// Linger is struct linger, from include/linux/socket.h.
+//
+// +marshal
type Linger struct {
OnOff int32
Linger int32
@@ -308,6 +321,8 @@ const SizeOfLinger = 8
// the end of this struct or within existing unusued space, so its size grows
// over time. The current iteration is based on linux v4.17. New versions are
// always backwards compatible.
+//
+// +marshal
type TCPInfo struct {
State uint8
CaState uint8
@@ -405,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
//
// ControlMessageCredentials represents struct ucred from linux/socket.h.
+//
+// +marshal
type ControlMessageCredentials struct {
PID int32
UID uint32
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index 57b89ad7d..2cb59f934 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -2506,7 +2506,7 @@ type msgFactory struct {
var msgRegistry registry
type registry struct {
- factories [math.MaxUint8]msgFactory
+ factories [math.MaxUint8 + 1]msgFactory
// largestFixedSize is computed so that given some message size M, you can
// compute the maximum payload size (e.g. for Twrite, Rread) with
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index daba8b172..fd95eb2d2 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -28,7 +28,14 @@ import (
)
// Registers represents the CPU registers for this architecture.
-type Registers = linux.PtraceRegs
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+
+ // TPIDR_EL0 is the EL0 Read/Write Software Thread ID Register.
+ TPIDR_EL0 uint64
+}
const (
// SyscallWidth is the width of insturctions.
@@ -101,9 +108,6 @@ type State struct {
// Our floating point state.
aarch64FPState `state:"wait"`
- // TLS pointer
- TPValue uint64
-
// FeatureSet is a pointer to the currently active feature set.
FeatureSet *cpuid.FeatureSet
@@ -157,7 +161,6 @@ func (s *State) Fork() State {
return State{
Regs: s.Regs,
aarch64FPState: s.aarch64FPState.fork(),
- TPValue: s.TPValue,
FeatureSet: s.FeatureSet,
OrigR0: s.OrigR0,
}
@@ -241,18 +244,18 @@ func (s *State) ptraceGetRegs() Registers {
return s.Regs
}
-var registersSize = (*Registers)(nil).SizeBytes()
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
var regs Registers
- buf := make([]byte, registersSize)
+ buf := make([]byte, ptraceRegistersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
regs.UnmarshalUnsafe(buf)
s.Regs = regs
- return registersSize, nil
+ return ptraceRegistersSize, nil
}
// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
@@ -278,7 +281,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -291,7 +294,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 3b3a0a272..1c3e3c14c 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -300,7 +300,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
// PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and
// u_debugreg, returning 0 or silently no-oping for other fields
// respectively.
- if addr < uintptr(registersSize) {
+ if addr < uintptr(ptraceRegistersSize) {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
@@ -315,7 +315,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
if addr&7 != 0 || addr >= userStructSize {
return syscall.EIO
}
- if addr < uintptr(registersSize) {
+ if addr < uintptr(ptraceRegistersSize) {
regs := c.ptraceGetRegs()
buf := make([]byte, regs.SizeBytes())
regs.MarshalUnsafe(buf)
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index ada7ac7b8..cabbf60e0 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -142,7 +142,7 @@ func (c *context64) SetStack(value uintptr) {
// TLS returns the current TLS pointer.
func (c *context64) TLS() uintptr {
- return uintptr(c.TPValue)
+ return uintptr(c.Regs.TPIDR_EL0)
}
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
@@ -151,7 +151,7 @@ func (c *context64) SetTLS(value uintptr) bool {
return false
}
- c.TPValue = uint64(value)
+ c.Regs.TPIDR_EL0 = uint64(value)
return true
}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index dc458b37f..b9405b320 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -31,7 +31,11 @@ import (
)
// Registers represents the CPU registers for this architecture.
-type Registers = linux.PtraceRegs
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+}
// System-related constants for x86.
const (
@@ -311,12 +315,12 @@ func (s *State) ptraceGetRegs() Registers {
return regs
}
-var registersSize = (*Registers)(nil).SizeBytes()
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
var regs Registers
- buf := make([]byte, registersSize)
+ buf := make([]byte, ptraceRegistersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
@@ -374,7 +378,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
}
regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable)
s.Regs = regs
- return registersSize, nil
+ return ptraceRegistersSize, nil
}
// isUserSegmentSelector returns true if the given segment selector specifies a
@@ -543,7 +547,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -563,7 +567,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < registersSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index beba0f771..f5537411e 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -160,6 +160,7 @@ type FileOperations interface {
// refer.
//
// Preconditions: The AddressSpace (if any) that io refers to is activated.
+ // Must only be called from a task goroutine.
Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error)
}
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 789369220..5fb419bcd 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -8,7 +8,6 @@ go_template_instance(
out = "dirty_set_impl.go",
imports = {
"memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
},
package = "fsutil",
prefix = "Dirty",
@@ -25,14 +24,14 @@ go_template_instance(
name = "frame_ref_set_impl",
out = "frame_ref_set_impl.go",
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "fsutil",
prefix = "FrameRef",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "uint64",
"Functions": "FrameRefSetFunctions",
},
@@ -43,7 +42,6 @@ go_template_instance(
out = "file_range_set_impl.go",
imports = {
"memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
},
package = "fsutil",
prefix = "FileRange",
@@ -86,7 +84,6 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/platform",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/state",
diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go
index c6cd45087..2c9446c1d 100644
--- a/pkg/sentry/fs/fsutil/dirty_set.go
+++ b/pkg/sentry/fs/fsutil/dirty_set.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -159,7 +158,7 @@ func (ds *DirtySet) AllowClean(mr memmap.MappableRange) {
// repeatedly until all bytes have been written. max is the true size of the
// cached object; offsets beyond max will not be passed to writeAt, even if
// they are marked dirty.
-func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
var changedDirty bool
defer func() {
if changedDirty {
@@ -194,7 +193,7 @@ func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet
// successful partial write, SyncDirtyAll will call it repeatedly until all
// bytes have been written. max is the true size of the cached object; offsets
// beyond max will not be passed to writeAt, even if they are marked dirty.
-func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
dseg := dirty.FirstSegment()
for dseg.Ok() {
if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil {
@@ -210,7 +209,7 @@ func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max
}
// Preconditions: mr must be page-aligned.
-func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() {
wbr := cseg.Range().Intersect(mr)
if max < wbr.Start {
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 5643cdac9..bbafebf03 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -23,13 +23,12 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/usermem"
)
// FileRangeSet maps offsets into a memmap.Mappable to offsets into a
-// platform.File. It is used to implement Mappables that store data in
+// memmap.File. It is used to implement Mappables that store data in
// sparsely-allocated memory.
//
// type FileRangeSet <generated by go_generics>
@@ -65,20 +64,20 @@ func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, spli
}
// FileRange returns the FileRange mapped by seg.
-func (seg FileRangeIterator) FileRange() platform.FileRange {
+func (seg FileRangeIterator) FileRange() memmap.FileRange {
return seg.FileRangeOf(seg.Range())
}
// FileRangeOf returns the FileRange mapped by mr.
//
// Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0.
-func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange {
+func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange {
frstart := seg.Value() + (mr.Start - seg.Start())
- return platform.FileRange{frstart, frstart + mr.Length()}
+ return memmap.FileRange{frstart, frstart + mr.Length()}
}
// Fill attempts to ensure that all memmap.Mappable offsets in required are
-// mapped to a platform.File offset, by allocating from mf with the given
+// mapped to a memmap.File offset, by allocating from mf with the given
// memory usage kind and invoking readAt to store data into memory. (If readAt
// returns a successful partial read, Fill will call it repeatedly until all
// bytes have been read.) EOF is handled consistently with the requirements of
@@ -141,7 +140,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map
}
// Drop removes segments for memmap.Mappable offsets in mr, freeing the
-// corresponding platform.FileRanges.
+// corresponding memmap.FileRanges.
//
// Preconditions: mr must be page-aligned.
func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
@@ -154,7 +153,7 @@ func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
}
// DropAll removes all segments in mr, freeing the corresponding
-// platform.FileRanges.
+// memmap.FileRanges.
func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) {
for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
mf.DecRef(seg.FileRange())
diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go
index dd6f5aba6..a808894df 100644
--- a/pkg/sentry/fs/fsutil/frame_ref_set.go
+++ b/pkg/sentry/fs/fsutil/frame_ref_set.go
@@ -17,7 +17,7 @@ package fsutil
import (
"math"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usage"
)
@@ -39,7 +39,7 @@ func (FrameRefSetFunctions) ClearValue(val *uint64) {
}
// Merge implements segment.Functions.Merge.
-func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) {
+func (FrameRefSetFunctions) Merge(_ memmap.FileRange, val1 uint64, _ memmap.FileRange, val2 uint64) (uint64, bool) {
if val1 != val2 {
return 0, false
}
@@ -47,13 +47,13 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.
}
// Split implements segment.Functions.Split.
-func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) {
+func (FrameRefSetFunctions) Split(_ memmap.FileRange, val uint64, _ uint64) (uint64, uint64) {
return val, val
}
// IncRefAndAccount adds a reference on the range fr. All newly inserted segments
// are accounted as host page cache memory mappings.
-func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) {
+func (refs *FrameRefSet) IncRefAndAccount(fr memmap.FileRange) {
seg, gap := refs.Find(fr.Start)
for {
switch {
@@ -74,7 +74,7 @@ func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) {
// DecRefAndAccount removes a reference on the range fr and untracks segments
// that are removed from memory accounting.
-func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) {
+func (refs *FrameRefSet) DecRefAndAccount(fr memmap.FileRange) {
seg := refs.FindSegment(fr.Start)
for seg.Ok() && seg.Start() < fr.End {
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index e82afd112..ef0113b52 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -126,7 +125,7 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
// offsets in fr or until the next call to UnmapAll.
//
// Preconditions: The caller must hold a reference on all offsets in fr.
-func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
+func (f *HostFileMapper) MapInternal(fr memmap.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift)
f.mapsMu.Lock()
defer f.mapsMu.Unlock()
@@ -146,7 +145,7 @@ func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool)
}
// Preconditions: f.mapsMu must be locked.
-func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error {
+func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, write bool, fn func(safemem.Block)) error {
prot := syscall.PROT_READ
if write {
prot |= syscall.PROT_WRITE
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index 78fec553e..c15d8a946 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -21,18 +21,17 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
-// HostMappable implements memmap.Mappable and platform.File over a
+// HostMappable implements memmap.Mappable and memmap.File over a
// CachedFileObject.
//
// Lock order (compare the lock order model in mm/mm.go):
// truncateMu ("fs locks")
// mu ("memmap.Mappable locks not taken by Translate")
-// ("platform.File locks")
+// ("memmap.File locks")
// backingFile ("CachedFileObject locks")
//
// +stateify savable
@@ -124,24 +123,24 @@ func (h *HostMappable) NotifyChangeFD() error {
return nil
}
-// MapInternal implements platform.File.MapInternal.
-func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write)
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (h *HostMappable) FD() int {
return h.backingFile.FD()
}
-// IncRef implements platform.File.IncRef.
-func (h *HostMappable) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (h *HostMappable) IncRef(fr memmap.FileRange) {
mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
h.hostFileMapper.IncRefOn(mr)
}
-// DecRef implements platform.File.DecRef.
-func (h *HostMappable) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (h *HostMappable) DecRef(fr memmap.FileRange) {
mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
h.hostFileMapper.DecRefOn(mr)
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 800c8b4e1..fe8b0b6ac 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -26,7 +26,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -934,7 +933,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error {
- // Whether we have a host fd (and consequently what platform.File is
+ // Whether we have a host fd (and consequently what memmap.File is
// mapped) can change across save/restore, so invalidate all translations
// unconditionally.
c.mapsMu.Lock()
@@ -999,10 +998,10 @@ func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.Evictable
}
}
-// IncRef implements platform.File.IncRef. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// IncRef implements memmap.File.IncRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
-func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
+func (c *CachingInodeOperations) IncRef(fr memmap.FileRange) {
// Hot path. Avoid defers.
c.dataMu.Lock()
seg, gap := c.refs.Find(fr.Start)
@@ -1024,10 +1023,10 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
}
}
-// DecRef implements platform.File.DecRef. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// DecRef implements memmap.File.DecRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
-func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
+func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) {
// Hot path. Avoid defers.
c.dataMu.Lock()
seg := c.refs.FindSegment(fr.Start)
@@ -1046,15 +1045,15 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
c.dataMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal. This is used when we
+// MapInternal implements memmap.File.MapInternal. This is used when we
// directly map an underlying host fd and CachingInodeOperations is used as the
-// platform.File during translation.
-func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// memmap.File during translation.
+func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write)
}
-// FD implements platform.File.FD. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// FD implements memmap.File.FD. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
func (c *CachingInodeOperations) FD() int {
return c.backingFile.FD()
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
index 2018b978a..a91cae3ef 100644
--- a/pkg/sentry/fsimpl/devpts/slave.go
+++ b/pkg/sentry/fsimpl/devpts/slave.go
@@ -132,7 +132,7 @@ func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequen
return sfd.inode.t.ld.outputQueueWrite(ctx, src)
}
-// Ioctl implements vfs.FileDescripionImpl.Ioctl.
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch cmd := args[1].Uint(); cmd {
case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 3e00c2abb..67649e811 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -1,20 +1,63 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
+go_template_instance(
+ name = "request_list",
+ out = "request_list.go",
+ package = "fuse",
+ prefix = "request",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Request",
+ "Linker": "*Request",
+ },
+)
+
go_library(
name = "fuse",
srcs = [
+ "connection.go",
"dev.go",
+ "fusefs.go",
+ "register.go",
+ "request_list.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/log",
"//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "dev_test",
+ size = "small",
+ srcs = ["dev_test.go"],
+ library = ":fuse",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go
new file mode 100644
index 000000000..f330da0bd
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/connection.go
@@ -0,0 +1,255 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "errors"
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// MaxActiveRequestsDefault is the default setting controlling the upper bound
+// on the number of active requests at any given time.
+const MaxActiveRequestsDefault = 10000
+
+var (
+ // Ordinary requests have even IDs, while interrupts IDs are odd.
+ InitReqBit uint64 = 1
+ ReqIDStep uint64 = 2
+)
+
+// Request represents a FUSE operation request that hasn't been sent to the
+// server yet.
+//
+// +stateify savable
+type Request struct {
+ requestEntry
+
+ id linux.FUSEOpID
+ hdr *linux.FUSEHeaderIn
+ data []byte
+}
+
+// Response represents an actual response from the server, including the
+// response payload.
+//
+// +stateify savable
+type Response struct {
+ opcode linux.FUSEOpcode
+ hdr linux.FUSEHeaderOut
+ data []byte
+}
+
+// Connection is the struct by which the sentry communicates with the FUSE server daemon.
+type Connection struct {
+ fd *DeviceFD
+
+ // MaxWrite is the daemon's maximum size of a write buffer.
+ // This is negotiated during FUSE_INIT.
+ MaxWrite uint32
+}
+
+// NewFUSEConnection creates a FUSE connection to fd
+func NewFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*Connection, error) {
+ // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to
+ // mount a FUSE filesystem.
+ fuseFD := fd.Impl().(*DeviceFD)
+ fuseFD.mounted = true
+
+ // Create the writeBuf for the header to be stored in.
+ hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ fuseFD.writeBuf = make([]byte, hdrLen)
+ fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse)
+ fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests)
+ fuseFD.writeCursor = 0
+
+ return &Connection{
+ fd: fuseFD,
+ }, nil
+}
+
+// NewRequest creates a new request that can be sent to the FUSE server.
+func (conn *Connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
+ conn.fd.mu.Lock()
+ defer conn.fd.mu.Unlock()
+ conn.fd.nextOpID += linux.FUSEOpID(ReqIDStep)
+
+ hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes()
+ hdr := linux.FUSEHeaderIn{
+ Len: uint32(hdrLen + payload.SizeBytes()),
+ Opcode: opcode,
+ Unique: conn.fd.nextOpID,
+ NodeID: ino,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ PID: pid,
+ }
+
+ buf := make([]byte, hdr.Len)
+ hdr.MarshalUnsafe(buf[:hdrLen])
+ payload.MarshalUnsafe(buf[hdrLen:])
+
+ return &Request{
+ id: hdr.Unique,
+ hdr: &hdr,
+ data: buf,
+ }, nil
+}
+
+// Call makes a request to the server and blocks the invoking task until a
+// server responds with a response.
+// NOTE: If no task is provided then the Call will simply enqueue the request
+// and return a nil response. No blocking will happen in this case. Instead,
+// this is used to signify that the processing of this request will happen by
+// the kernel.Task that writes the response. See FUSE_INIT for such an
+// invocation.
+func (conn *Connection) Call(t *kernel.Task, r *Request) (*Response, error) {
+ fut, err := conn.callFuture(t, r)
+ if err != nil {
+ return nil, err
+ }
+
+ return fut.resolve(t)
+}
+
+// Error returns the error of the FUSE call.
+func (r *Response) Error() error {
+ errno := r.hdr.Error
+ if errno >= 0 {
+ return nil
+ }
+
+ sysErrNo := syscall.Errno(-errno)
+ return error(sysErrNo)
+}
+
+// UnmarshalPayload unmarshals the response data into m.
+func (r *Response) UnmarshalPayload(m marshal.Marshallable) error {
+ hdrLen := r.hdr.SizeBytes()
+ haveDataLen := r.hdr.Len - uint32(hdrLen)
+ wantDataLen := uint32(m.SizeBytes())
+
+ if haveDataLen < wantDataLen {
+ return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen)
+ }
+
+ m.UnmarshalUnsafe(r.data[hdrLen:])
+ return nil
+}
+
+// callFuture makes a request to the server and returns a future response.
+// Call resolve() when the response needs to be fulfilled.
+func (conn *Connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) {
+ conn.fd.mu.Lock()
+ defer conn.fd.mu.Unlock()
+
+ // Is the queue full?
+ //
+ // We must busy wait here until the request can be queued. We don't
+ // block on the fd.fullQueueCh with a lock - so after being signalled,
+ // before we acquire the lock, it is possible that a barging task enters
+ // and queues a request. As a result, upon acquiring the lock we must
+ // again check if the room is available.
+ //
+ // This can potentially starve a request forever but this can only happen
+ // if there are always too many ongoing requests all the time. The
+ // supported maxActiveRequests setting should be really high to avoid this.
+ for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
+ if t == nil {
+ // Since there is no task that is waiting. We must error out.
+ return nil, errors.New("FUSE request queue full")
+ }
+
+ log.Infof("Blocking request %v from being queued. Too many active requests: %v",
+ r.id, conn.fd.numActiveRequests)
+ conn.fd.mu.Unlock()
+ err := t.Block(conn.fd.fullQueueCh)
+ conn.fd.mu.Lock()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return conn.callFutureLocked(t, r)
+}
+
+// callFutureLocked makes a request to the server and returns a future response.
+func (conn *Connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) {
+ conn.fd.queue.PushBack(r)
+ conn.fd.numActiveRequests += 1
+ fut := newFutureResponse(r.hdr.Opcode)
+ conn.fd.completions[r.id] = fut
+
+ // Signal the readers that there is something to read.
+ conn.fd.waitQueue.Notify(waiter.EventIn)
+
+ return fut, nil
+}
+
+// futureResponse represents an in-flight request, that may or may not have
+// completed yet. Convert it to a resolved Response by calling Resolve, but note
+// that this may block.
+//
+// +stateify savable
+type futureResponse struct {
+ opcode linux.FUSEOpcode
+ ch chan struct{}
+ hdr *linux.FUSEHeaderOut
+ data []byte
+}
+
+// newFutureResponse creates a future response to a FUSE request.
+func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse {
+ return &futureResponse{
+ opcode: opcode,
+ ch: make(chan struct{}),
+ }
+}
+
+// resolve blocks the task until the server responds to its corresponding request,
+// then returns a resolved response.
+func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) {
+ // If there is no Task associated with this request - then we don't try to resolve
+ // the response. Instead, the task writing the response (proxy to the server) will
+ // process the response on our behalf.
+ if t == nil {
+ log.Infof("fuse.Response.resolve: Not waiting on a response from server.")
+ return nil, nil
+ }
+
+ if err := t.Block(f.ch); err != nil {
+ return nil, err
+ }
+
+ return f.getResponse(), nil
+}
+
+// getResponse creates a Response from the data the futureResponse has.
+func (f *futureResponse) getResponse() *Response {
+ return &Response{
+ opcode: f.opcode,
+ hdr: *f.hdr,
+ data: f.data,
+ }
+}
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index dc33268af..f3443ac71 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -15,13 +15,17 @@
package fuse
import (
+ "syscall"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const fuseDevMinor = 229
@@ -51,9 +55,46 @@ type DeviceFD struct {
vfs.DentryMetadataFileDescriptionImpl
vfs.NoLockFD
- // TODO(gvisor.dev/issue/2987): Add all the data structures needed to enqueue
- // and deque requests, control synchronization and establish communication
- // between the FUSE kernel module and the /dev/fuse character device.
+ // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD.
+ mounted bool
+
+ // nextOpID is used to create new requests.
+ nextOpID linux.FUSEOpID
+
+ // queue is the list of requests that need to be processed by the FUSE server.
+ queue requestList
+
+ // numActiveRequests is the number of requests made by the Sentry that has
+ // yet to be responded to.
+ numActiveRequests uint64
+
+ // completions is used to map a request to its response. A Writer will use this
+ // to notify the caller of a completed response.
+ completions map[linux.FUSEOpID]*futureResponse
+
+ writeCursor uint32
+
+ // writeBuf is the memory buffer used to copy in the FUSE out header from
+ // userspace.
+ writeBuf []byte
+
+ // writeCursorFR current FR being copied from server.
+ writeCursorFR *futureResponse
+
+ // mu protects all the queues, maps, buffers and cursors and nextOpID.
+ mu sync.Mutex
+
+ // waitQueue is used to notify interested parties when the device becomes
+ // readable or writable.
+ waitQueue waiter.Queue
+
+ // fullQueueCh is a channel used to synchronize the readers with the writers.
+ // Writers (inbound requests to the filesystem) block if there are too many
+ // unprocessed in-flight requests.
+ fullQueueCh chan struct{}
+
+ // fs is the FUSE filesystem that this FD is being used for.
+ fs *filesystem
}
// Release implements vfs.FileDescriptionImpl.Release.
@@ -61,45 +102,293 @@ func (fd *DeviceFD) Release() {}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// Read implements vfs.FileDescriptionImpl.Read.
func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- return 0, syserror.ENOSYS
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ // We require that any Read done on this filesystem have a sane minimum
+ // read buffer. It must have the capacity for the fixed parts of any request
+ // header (Linux uses the request header and the FUSEWriteIn header for this
+ // calculation) + the negotiated MaxWrite room for the data.
+ minBuffSize := linux.FUSE_MIN_READ_BUFFER
+ inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
+ writeHdrLen := uint32((*linux.FUSEWriteIn)(nil).SizeBytes())
+ negotiatedMinBuffSize := inHdrLen + writeHdrLen + fd.fs.conn.MaxWrite
+ if minBuffSize < negotiatedMinBuffSize {
+ minBuffSize = negotiatedMinBuffSize
+ }
+
+ // If the read buffer is too small, error out.
+ if dst.NumBytes() < int64(minBuffSize) {
+ return 0, syserror.EINVAL
+ }
+
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ return fd.readLocked(ctx, dst, opts)
+}
+
+// readLocked implements the reading of the fuse device while locked with DeviceFD.mu.
+func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ if fd.queue.Empty() {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ var readCursor uint32
+ var bytesRead int64
+ for {
+ req := fd.queue.Front()
+ if dst.NumBytes() < int64(req.hdr.Len) {
+ // The request is too large. Cannot process it. All requests must be smaller than the
+ // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT
+ // handshake.
+ errno := -int32(syscall.EIO)
+ if req.hdr.Opcode == linux.FUSE_SETXATTR {
+ errno = -int32(syscall.E2BIG)
+ }
+
+ // Return the error to the calling task.
+ if err := fd.sendError(ctx, errno, req); err != nil {
+ return 0, err
+ }
+
+ // We're done with this request.
+ fd.queue.Remove(req)
+
+ // Restart the read as this request was invalid.
+ log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.")
+ return fd.readLocked(ctx, dst, opts)
+ }
+
+ n, err := dst.CopyOut(ctx, req.data[readCursor:])
+ if err != nil {
+ return 0, err
+ }
+ readCursor += uint32(n)
+ bytesRead += int64(n)
+
+ if readCursor >= req.hdr.Len {
+ // Fully done with this req, remove it from the queue.
+ fd.queue.Remove(req)
+ break
+ }
+ }
+
+ return bytesRead, nil
}
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
- return 0, syserror.ENOSYS
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ return fd.writeLocked(ctx, src, opts)
+}
+
+// writeLocked implements writing to the fuse device while locked with DeviceFD.mu.
+func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ var cn, n int64
+ hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+
+ for src.NumBytes() > 0 {
+ if fd.writeCursorFR != nil {
+ // Already have common header, and we're now copying the payload.
+ wantBytes := fd.writeCursorFR.hdr.Len
+
+ // Note that the FR data doesn't have the header. Copy it over if its necessary.
+ if fd.writeCursorFR.data == nil {
+ fd.writeCursorFR.data = make([]byte, wantBytes)
+ }
+
+ bytesCopied, err := src.CopyIn(ctx, fd.writeCursorFR.data[fd.writeCursor:wantBytes])
+ if err != nil {
+ return 0, err
+ }
+ src = src.DropFirst(bytesCopied)
+
+ cn = int64(bytesCopied)
+ n += cn
+ fd.writeCursor += uint32(cn)
+ if fd.writeCursor == wantBytes {
+ // Done reading this full response. Clean up and unblock the
+ // initiator.
+ break
+ }
+
+ // Check if we have more data in src.
+ continue
+ }
+
+ // Assert that the header isn't read into the writeBuf yet.
+ if fd.writeCursor >= hdrLen {
+ return 0, syserror.EINVAL
+ }
+
+ // We don't have the full common response header yet.
+ wantBytes := hdrLen - fd.writeCursor
+ bytesCopied, err := src.CopyIn(ctx, fd.writeBuf[fd.writeCursor:wantBytes])
+ if err != nil {
+ return 0, err
+ }
+ src = src.DropFirst(bytesCopied)
+
+ cn = int64(bytesCopied)
+ n += cn
+ fd.writeCursor += uint32(cn)
+ if fd.writeCursor == hdrLen {
+ // Have full header in the writeBuf. Use it to fetch the actual futureResponse
+ // from the device's completions map.
+ var hdr linux.FUSEHeaderOut
+ hdr.UnmarshalBytes(fd.writeBuf)
+
+ // We have the header now and so the writeBuf has served its purpose.
+ // We could reset it manually here but instead of doing that, at the
+ // end of the write, the writeCursor will be set to 0 thereby allowing
+ // the next request to overwrite whats in the buffer,
+
+ fut, ok := fd.completions[hdr.Unique]
+ if !ok {
+ // Server sent us a response for a request we never sent?
+ return 0, syserror.EINVAL
+ }
+
+ delete(fd.completions, hdr.Unique)
+
+ // Copy over the header into the future response. The rest of the payload
+ // will be copied over to the FR's data in the next iteration.
+ fut.hdr = &hdr
+ fd.writeCursorFR = fut
+
+ // Next iteration will now try read the complete request, if src has
+ // any data remaining. Otherwise we're done.
+ }
+ }
+
+ if fd.writeCursorFR != nil {
+ if err := fd.sendResponse(ctx, fd.writeCursorFR); err != nil {
+ return 0, err
+ }
+
+ // Ready the device for the next request.
+ fd.writeCursorFR = nil
+ fd.writeCursor = 0
+ }
+
+ return n, nil
+}
+
+// Readiness implements vfs.FileDescriptionImpl.Readiness.
+func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ ready |= waiter.EventOut // FD is always writable
+ if !fd.queue.Empty() {
+ // Have reqs available, FD is readable.
+ ready |= waiter.EventIn
+ }
+
+ return ready & mask
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *DeviceFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.waitQueue.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *DeviceFD) EventUnregister(e *waiter.Entry) {
+ fd.waitQueue.EventUnregister(e)
}
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
return 0, syserror.ENOSYS
}
-// Register registers the FUSE device with vfsObj.
-func Register(vfsObj *vfs.VirtualFilesystem) error {
- if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{
- GroupName: "misc",
- }); err != nil {
+// sendResponse sends a response to the waiting task (if any).
+func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error {
+ // See if the running task need to perform some action before returning.
+ // Since we just finished writing the future, we can be sure that
+ // getResponse generates a populated response.
+ if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil {
return err
}
+ // Signal that the queue is no longer full.
+ select {
+ case fd.fullQueueCh <- struct{}{}:
+ default:
+ }
+ fd.numActiveRequests -= 1
+
+ // Signal the task waiting on a response.
+ close(fut.ch)
return nil
}
-// CreateDevtmpfsFile creates a device special file in devtmpfs.
-func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error {
- if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil {
+// sendError sends an error response to the waiting task (if any).
+func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error {
+ // Return the error to the calling task.
+ outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ respHdr := linux.FUSEHeaderOut{
+ Len: outHdrLen,
+ Error: errno,
+ Unique: req.hdr.Unique,
+ }
+
+ fut, ok := fd.completions[respHdr.Unique]
+ if !ok {
+ // Server sent us a response for a request we never sent?
+ return syserror.EINVAL
+ }
+ delete(fd.completions, respHdr.Unique)
+
+ fut.hdr = &respHdr
+ if err := fd.sendResponse(ctx, fut); err != nil {
return err
}
return nil
}
+
+// noReceiverAction has the calling kernel.Task do some action if its known that no
+// receiver is going to be waiting on the future channel. This is to be used by:
+// FUSE_INIT.
+func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error {
+ if r.opcode == linux.FUSE_INIT {
+ // TODO: process init response here.
+ // Maybe get the creds from the context?
+ // creds := auth.CredentialsFromContext(ctx)
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
new file mode 100644
index 000000000..fcd77832a
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -0,0 +1,429 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "fmt"
+ "io"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// echoTestOpcode is the Opcode used during testing. The server used in tests
+// will simply echo the payload back with the appropriate headers.
+const echoTestOpcode linux.FUSEOpcode = 1000
+
+type testPayload struct {
+ data uint32
+}
+
+// TestFUSECommunication tests that the communication layer between the Sentry and the
+// FUSE server daemon works as expected.
+func TestFUSECommunication(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ creds := auth.CredentialsFromContext(s.Ctx)
+
+ // Create test cases with different number of concurrent clients and servers.
+ testCases := []struct {
+ Name string
+ NumClients int
+ NumServers int
+ MaxActiveRequests uint64
+ }{
+ {
+ Name: "SingleClientSingleServer",
+ NumClients: 1,
+ NumServers: 1,
+ MaxActiveRequests: MaxActiveRequestsDefault,
+ },
+ {
+ Name: "SingleClientMultipleServers",
+ NumClients: 1,
+ NumServers: 10,
+ MaxActiveRequests: MaxActiveRequestsDefault,
+ },
+ {
+ Name: "MultipleClientsSingleServer",
+ NumClients: 10,
+ NumServers: 1,
+ MaxActiveRequests: MaxActiveRequestsDefault,
+ },
+ {
+ Name: "MultipleClientsMultipleServers",
+ NumClients: 10,
+ NumServers: 10,
+ MaxActiveRequests: MaxActiveRequestsDefault,
+ },
+ {
+ Name: "RequestCapacityFull",
+ NumClients: 10,
+ NumServers: 1,
+ MaxActiveRequests: 1,
+ },
+ {
+ Name: "RequestCapacityContinuouslyFull",
+ NumClients: 100,
+ NumServers: 2,
+ MaxActiveRequests: 2,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests)
+ if err != nil {
+ t.Fatalf("newTestConnection: %v", err)
+ }
+
+ clientsDone := make([]chan struct{}, testCase.NumClients)
+ serversDone := make([]chan struct{}, testCase.NumServers)
+ serversKill := make([]chan struct{}, testCase.NumServers)
+
+ // FUSE clients.
+ for i := 0; i < testCase.NumClients; i++ {
+ clientsDone[i] = make(chan struct{})
+ go func(i int) {
+ fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i])
+ }(i)
+ }
+
+ // FUSE servers.
+ for j := 0; j < testCase.NumServers; j++ {
+ serversDone[j] = make(chan struct{})
+ serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block.
+ go func(j int) {
+ fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j])
+ }(j)
+ }
+
+ // Tear down.
+ //
+ // Make sure all the clients are done.
+ for i := 0; i < testCase.NumClients; i++ {
+ <-clientsDone[i]
+ }
+
+ // Kill any server that is potentially waiting.
+ for j := 0; j < testCase.NumServers; j++ {
+ serversKill[j] <- struct{}{}
+ }
+
+ // Make sure all the servers are done.
+ for j := 0; j < testCase.NumServers; j++ {
+ <-serversDone[j]
+ }
+ })
+ }
+}
+
+// CallTest makes a request to the server and blocks the invoking
+// goroutine until a server responds with a response. Doesn't block
+// a kernel.Task. Analogous to Connection.Call but used for testing.
+func CallTest(conn *Connection, t *kernel.Task, r *Request, i uint32) (*Response, error) {
+ conn.fd.mu.Lock()
+
+ // Wait until we're certain that a new request can be processed.
+ for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
+ conn.fd.mu.Unlock()
+ select {
+ case <-conn.fd.fullQueueCh:
+ }
+ conn.fd.mu.Lock()
+ }
+
+ fut, err := conn.callFutureLocked(t, r) // No task given.
+ conn.fd.mu.Unlock()
+
+ if err != nil {
+ return nil, err
+ }
+
+ // Resolve the response.
+ //
+ // Block without a task.
+ select {
+ case <-fut.ch:
+ }
+
+ // A response is ready. Resolve and return it.
+ return fut.getResponse(), nil
+}
+
+// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE
+// device. However, it does so by - not blocking the task that is calling - and
+// instead just waits on a channel. The behaviour is essentially the same as
+// DeviceFD.Read except it guarantees that the task is not blocked.
+func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) {
+ var err error
+ var n, total int64
+
+ dev := fd.Impl().(*DeviceFD)
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ dev.EventRegister(&w, waiter.EventIn)
+ for {
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{})
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ // Emulate the blocking for when no requests are available
+ select {
+ case <-ch:
+ case <-killServer:
+ // Server killed by the main program.
+ return 0, true, nil
+ }
+ }
+
+ dev.EventUnregister(&w)
+ return total, false, err
+}
+
+// fuseClientRun emulates all the actions of a normal FUSE request. It creates
+// a header, a payload, calls the server, waits for the response, and processes
+// the response.
+func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *Connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) {
+ defer func() { clientDone <- struct{}{} }()
+
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatal(err)
+ }
+ testObj := &testPayload{
+ data: rand.Uint32(),
+ }
+
+ req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
+ if err != nil {
+ t.Fatalf("NewRequest creation failed: %v", err)
+ }
+
+ // Queue up a request.
+ // Analogous to Call except it doesn't block on the task.
+ resp, err := CallTest(conn, clientTask, req, pid)
+ if err != nil {
+ t.Fatalf("CallTaskNonBlock failed: %v", err)
+ }
+
+ if err = resp.Error(); err != nil {
+ t.Fatalf("Server responded with an error: %v", err)
+ }
+
+ var respTestPayload testPayload
+ if err := resp.UnmarshalPayload(&respTestPayload); err != nil {
+ t.Fatalf("Unmarshalling payload error: %v", err)
+ }
+
+ if resp.hdr.Unique != req.hdr.Unique {
+ t.Fatalf("got response for another request. Expected response for req %v but got response for req %v",
+ req.hdr.Unique, resp.hdr.Unique)
+ }
+
+ if respTestPayload.data != testObj.data {
+ t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data)
+ }
+
+}
+
+// fuseServerRun creates a task and emulates all the actions of a simple FUSE server
+// that simply reads a request and echos the same struct back as a response using the
+// appropriate headers.
+func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) {
+ defer func() { serverDone <- struct{}{} }()
+
+ // Create the tasks that the server will be using.
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ var readPayload testPayload
+
+ serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Read the request.
+ for {
+ inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
+ payloadLen := uint32(readPayload.SizeBytes())
+
+ // The raed buffer must meet some certain size criteria.
+ buffSize := inHdrLen + payloadLen
+ if buffSize < linux.FUSE_MIN_READ_BUFFER {
+ buffSize = linux.FUSE_MIN_READ_BUFFER
+ }
+ inBuf := make([]byte, buffSize)
+ inIOseq := usermem.BytesIOSequence(inBuf)
+
+ n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer)
+ if err != nil {
+ t.Fatalf("Read failed :%v", err)
+ }
+
+ // Server should shut down. No new requests are going to be made.
+ if serverKilled {
+ break
+ }
+
+ if n <= 0 {
+ t.Fatalf("Read read no bytes")
+ }
+
+ var readFUSEHeaderIn linux.FUSEHeaderIn
+ readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen])
+ readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen])
+
+ if readFUSEHeaderIn.Opcode != echoTestOpcode {
+ t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload)
+ }
+
+ // Write the response.
+ outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ outBuf := make([]byte, outHdrLen+payloadLen)
+ outHeader := linux.FUSEHeaderOut{
+ Len: outHdrLen + payloadLen,
+ Error: 0,
+ Unique: readFUSEHeaderIn.Unique,
+ }
+
+ // Echo the payload back.
+ outHeader.MarshalUnsafe(outBuf[:outHdrLen])
+ readPayload.MarshalUnsafe(outBuf[outHdrLen:])
+ outIOseq := usermem.BytesIOSequence(outBuf)
+
+ n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed :%v", err)
+ }
+ }
+}
+
+func setup(t *testing.T) *testutil.System {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("Error creating kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+
+ k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ AllowUserMount: true,
+ })
+
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("NewMountNamespace(): %v", err)
+ }
+
+ return testutil.NewSystem(ctx, t, k.VFS(), mntns)
+}
+
+// newTestConnection creates a fuse connection that the sentry can communicate with
+// and the FD for the server to communicate with.
+func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*Connection, *vfs.FileDescription, error) {
+ vfsObj := &vfs.VirtualFilesystem{}
+ fuseDev := &DeviceFD{}
+
+ if err := vfsObj.Init(); err != nil {
+ return nil, nil, err
+ }
+
+ vd := vfsObj.NewAnonVirtualDentry("genCountFD")
+ defer vd.DecRef()
+ if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, nil, err
+ }
+
+ fsopts := filesystemOptions{
+ maxActiveRequests: maxActiveRequests,
+ }
+ fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return fs.conn, &fuseDev.vfsfd, nil
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (t *testPayload) SizeBytes() int {
+ return 4
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (t *testPayload) MarshalBytes(dst []byte) {
+ usermem.ByteOrder.PutUint32(dst[:4], t.data)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (t *testPayload) UnmarshalBytes(src []byte) {
+ *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])}
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (t *testPayload) Packed() bool {
+ return true
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (t *testPayload) MarshalUnsafe(dst []byte) {
+ t.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (t *testPayload) UnmarshalUnsafe(src []byte) {
+ t.UnmarshalBytes(src)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ panic("not implemented")
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (t *testPayload) WriteTo(w io.Writer) (int64, error) {
+ panic("not implemented")
+}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
new file mode 100644
index 000000000..911b6f7cb
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -0,0 +1,224 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fuse implements fusefs.
+package fuse
+
+import (
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "fuse"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+type filesystemOptions struct {
+ // userID specifies the numeric uid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ userID uint32
+
+ // groupID specifies the numeric gid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ groupID uint32
+
+ // rootMode specifies the the file mode of the filesystem's root.
+ rootMode linux.FileMode
+
+ // maxActiveRequests specifies the maximum number of active requests that can
+ // exist at any time. Any further requests will block when trying to
+ // Call the server.
+ maxActiveRequests uint64
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+ devMinor uint32
+
+ // conn is used for communication between the FUSE server
+ // daemon and the sentry fusefs.
+ conn *Connection
+
+ // opts is the options the fusefs is initialized with.
+ opts *filesystemOptions
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ var fsopts filesystemOptions
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ deviceDescriptorStr, ok := mopts["fd"]
+ if !ok {
+ log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "fd")
+
+ deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor))
+
+ // Parse and set all the other supported FUSE mount options.
+ // TODO(gVisor.dev/issue/3229): Expand the supported mount options.
+ if userIDStr, ok := mopts["user_id"]; ok {
+ delete(mopts, "user_id")
+ userID, err := strconv.ParseUint(userIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.userID = uint32(userID)
+ }
+
+ if groupIDStr, ok := mopts["group_id"]; ok {
+ delete(mopts, "group_id")
+ groupID, err := strconv.ParseUint(groupIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.groupID = uint32(groupID)
+ }
+
+ rootMode := linux.FileMode(0777)
+ modeStr, ok := mopts["rootmode"]
+ if ok {
+ delete(mopts, "rootmode")
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr)
+ return nil, nil, syserror.EINVAL
+ }
+ rootMode = linux.FileMode(mode)
+ }
+ fsopts.rootMode = rootMode
+
+ // Set the maxInFlightRequests option.
+ fsopts.maxActiveRequests = MaxActiveRequestsDefault
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Create a new FUSE filesystem.
+ fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd)
+ if err != nil {
+ log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err)
+ return nil, nil, err
+ }
+
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+
+ // TODO: dispatch a FUSE_INIT request to the FUSE daemon server before
+ // returning. Mount will not block on this dispatched request.
+
+ // root is the fusefs root directory.
+ root := fs.newInode(creds, fsopts.rootMode)
+
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// NewFUSEFilesystem creates a new FUSE filesystem.
+func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) {
+ fs := &filesystem{
+ devMinor: devMinor,
+ opts: opts,
+ }
+
+ conn, err := NewFUSEConnection(ctx, device, opts.maxActiveRequests)
+ if err != nil {
+ log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err)
+ return nil, syserror.EINVAL
+ }
+
+ fs.conn = conn
+ fuseFD := device.Impl().(*DeviceFD)
+ fuseFD.fs = fs
+
+ return fs, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release()
+}
+
+// Inode implements kernfs.Inode.
+type Inode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
+ i := &Inode{}
+ i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ i.dentry.Init(i)
+
+ return &i.dentry
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *Inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/register.go b/pkg/sentry/fsimpl/fuse/register.go
new file mode 100644
index 000000000..b5b581152
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/register.go
@@ -0,0 +1,42 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Register registers the FUSE device with vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{
+ GroupName: "misc",
+ }); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// CreateDevtmpfsFile creates a device special file in devtmpfs.
+func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error {
+ if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index cd5f5049e..00e3c99cd 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -150,11 +150,9 @@ afterSymlink:
return nil, err
}
if d != d.parent && !d.cachedMetadataAuthoritative() {
- _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask())
- if err != nil {
+ if err := d.parent.updateFromGetattr(ctx); err != nil {
return nil, err
}
- d.parent.updateFromP9Attrs(attrMask, &attr)
}
rp.Advance()
return d.parent, nil
@@ -209,17 +207,28 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil
// Preconditions: As for getChildLocked. !parent.isSynthetic().
func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) {
+ if child != nil {
+ // Need to lock child.metadataMu because we might be updating child
+ // metadata. We need to hold the lock *before* getting metadata from the
+ // server and release it after updating local metadata.
+ child.metadataMu.Lock()
+ }
qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
if err != nil && err != syserror.ENOENT {
+ if child != nil {
+ child.metadataMu.Unlock()
+ }
return nil, err
}
if child != nil {
if !file.isNil() && inoFromPath(qid.Path) == child.ino {
// The file at this path hasn't changed. Just update cached metadata.
file.close(ctx)
- child.updateFromP9Attrs(attrMask, &attr)
+ child.updateFromP9AttrsLocked(attrMask, &attr)
+ child.metadataMu.Unlock()
return child, nil
}
+ child.metadataMu.Unlock()
if file.isNil() && child.isSynthetic() {
// We have a synthetic file, and no remote file has arisen to
// replace it.
@@ -1325,7 +1334,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.renameMuRUnlockAndCheckCaching(&ds)
return err
}
- if err := d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount()); err != nil {
+ if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil {
fs.renameMuRUnlockAndCheckCaching(&ds)
return err
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index b74d489a0..e20de84b5 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -785,8 +785,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool {
// updateFromP9Attrs is called to update d's metadata after an update from the
// remote filesystem.
-func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
- d.metadataMu.Lock()
+// Precondition: d.metadataMu must be locked.
+func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
if mask.Mode {
if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want {
d.metadataMu.Unlock()
@@ -822,7 +822,6 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
if mask.Size {
d.updateFileSizeLocked(attr.Size)
}
- d.metadataMu.Unlock()
}
// Preconditions: !d.isSynthetic()
@@ -834,6 +833,10 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
file p9file
handleMuRLocked bool
)
+ // d.metadataMu must be locked *before* we getAttr so that we do not end up
+ // updating stale attributes in d.updateFromP9AttrsLocked().
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
d.handleMu.RLock()
if !d.handle.file.isNil() {
file = d.handle.file
@@ -849,7 +852,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error {
if err != nil {
return err
}
- d.updateFromP9Attrs(attrMask, &attr)
+ d.updateFromP9AttrsLocked(attrMask, &attr)
return nil
}
@@ -885,7 +888,8 @@ func (d *dentry) statTo(stat *linux.Statx) {
stat.DevMinor = d.fs.devMinor
}
-func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error {
+func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions, mnt *vfs.Mount) error {
+ stat := &opts.Stat
if stat.Mask == 0 {
return nil
}
@@ -893,7 +897,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
return syserror.EPERM
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
if err := mnt.CheckBeginWrite(); err != nil {
@@ -934,6 +938,17 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
}
if !d.isSynthetic() {
if stat.Mask != 0 {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // Check whether to allow a truncate request to be made.
+ switch d.mode & linux.S_IFMT {
+ case linux.S_IFREG:
+ // Allow.
+ case linux.S_IFDIR:
+ return syserror.EISDIR
+ default:
+ return syserror.EINVAL
+ }
+ }
if err := d.file.setAttr(ctx, p9.SetAttrMask{
Permissions: stat.Mask&linux.STATX_MODE != 0,
UID: stat.Mask&linux.STATX_UID != 0,
@@ -1495,7 +1510,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount()); err != nil {
+ if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts, fd.vfsfd.Mount()); err != nil {
return err
}
if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index f10350c97..09f142cfc 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -155,26 +154,53 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
// Check that flags are supported.
//
// TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
if opts.Flags&^linux.RWF_HIPRI != 0 {
- return 0, syserror.EOPNOTSUPP
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ // If the fd was opened with O_APPEND, make sure the file size is updated.
+ // There is a possible race here if size is modified externally after
+ // metadata cache is updated.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
}
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ // Set offset to file size if the fd was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
+ }
limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(limit)
+ n, err := fd.pwriteLocked(ctx, src, offset, opts)
+ return n, offset + n, err
+}
+// Preconditions: fd.dentry().metatdataMu must be locked.
+func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
d := fd.dentry()
- d.metadataMu.Lock()
- defer d.metadataMu.Unlock()
if d.fs.opts.interop != InteropModeShared {
// Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
// file_update_time(). This is d.touchCMtime(), but without locking
@@ -194,12 +220,12 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
return 0, syserror.EINVAL
}
mr := memmap.MappableRange{pgstart, pgend}
- var freed []platform.FileRange
+ var freed []memmap.FileRange
d.dataMu.Lock()
cseg := d.cache.LowerBoundSegment(mr.Start)
for cseg.Ok() && cseg.Start() < mr.End {
cseg = d.cache.Isolate(cseg, mr)
- freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
+ freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
cseg = d.cache.Remove(cseg).NextSegment()
}
d.dataMu.Unlock()
@@ -237,8 +263,8 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
fd.mu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.mu.Unlock()
return n, err
}
@@ -794,7 +820,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
func (d *dentry) InvalidateUnsavable(ctx context.Context) error {
- // Whether we have a host fd (and consequently what platform.File is
+ // Whether we have a host fd (and consequently what memmap.File is
// mapped) can change across save/restore, so invalidate all translations
// unconditionally.
d.mapsMu.Lock()
@@ -842,8 +868,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
}
}
-// dentryPlatformFile implements platform.File. It exists solely because dentry
-// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef.
+// dentryPlatformFile implements memmap.File. It exists solely because dentry
+// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef.
//
// dentryPlatformFile is only used when a host FD representing the remote file
// is available (i.e. dentry.handle.fd >= 0), and that FD is used for
@@ -851,7 +877,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
type dentryPlatformFile struct {
*dentry
- // fdRefs counts references on platform.File offsets. fdRefs is protected
+ // fdRefs counts references on memmap.File offsets. fdRefs is protected
// by dentry.dataMu.
fdRefs fsutil.FrameRefSet
@@ -863,29 +889,29 @@ type dentryPlatformFile struct {
hostFileMapperInitOnce sync.Once
}
-// IncRef implements platform.File.IncRef.
-func (d *dentryPlatformFile) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (d *dentryPlatformFile) IncRef(fr memmap.FileRange) {
d.dataMu.Lock()
d.fdRefs.IncRefAndAccount(fr)
d.dataMu.Unlock()
}
-// DecRef implements platform.File.DecRef.
-func (d *dentryPlatformFile) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) {
d.dataMu.Lock()
d.fdRefs.DecRefAndAccount(fr)
d.dataMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal.
-func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
d.handleMu.RLock()
bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write)
d.handleMu.RUnlock()
return bs, err
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (d *dentryPlatformFile) FD() int {
d.handleMu.RLock()
fd := d.handle.fd
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index a7b53b2d2..811528982 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -16,6 +16,7 @@ package gofer
import (
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -144,7 +145,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
// mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
// hold here since specialFileFD doesn't client-cache data. Just buffer the
// read instead.
- if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ if d := fd.dentry(); d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
buf := make([]byte, dst.NumBytes())
@@ -176,39 +177,76 @@ func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if fd.seekable && offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
// Check that flags are supported.
//
// TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
if opts.Flags&^linux.RWF_HIPRI != 0 {
- return 0, syserror.EOPNOTSUPP
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ // If the regular file fd was opened with O_APPEND, make sure the file size
+ // is updated. There is a possible race here if size is modified externally
+ // after metadata cache is updated.
+ if fd.seekable && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
}
if fd.seekable {
+ // We need to hold the metadataMu *while* writing to a regular file.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ // Set offset to file size if the regular file was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
+ }
limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(limit)
}
// Do a buffered write. See rationale in PRead.
- if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ if d.cachedMetadataAuthoritative() {
d.touchCMtime()
}
buf := make([]byte, src.NumBytes())
// Don't do partial writes if we get a partial read from src.
if _, err := src.CopyIn(ctx, buf); err != nil {
- return 0, err
+ return 0, offset, err
}
n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
if err == syserror.EAGAIN {
err = syserror.ErrWouldBlock
}
- return int64(n), err
+ finalOff = offset
+ // Update file size for regular files.
+ if fd.seekable {
+ finalOff += int64(n)
+ // d.metadataMu is already locked at this point.
+ if uint64(finalOff) > d.size {
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ atomic.StoreUint64(&d.size, uint64(finalOff))
+ }
+ }
+ return int64(n), finalOff, err
}
// Write implements vfs.FileDescriptionImpl.Write.
@@ -218,8 +256,8 @@ func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts
}
fd.mu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.mu.Unlock()
return n, err
}
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index e86fbe2d5..bd701bbc7 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -34,7 +34,6 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
- "//pkg/sentry/platform",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 1a88cb657..c894f2ca0 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -373,7 +373,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) {
// SetStat implements kernfs.Inode.
func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
- s := opts.Stat
+ s := &opts.Stat
m := s.Mask
if m == 0 {
@@ -386,7 +386,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
if err := syscall.Fstat(i.hostFD, &hostStat); err != nil {
return err
}
- if err := vfs.CheckSetStat(ctx, creds, &s, linux.FileMode(hostStat.Mode&linux.PermissionsMask), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &opts, linux.FileMode(hostStat.Mode), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
return err
}
@@ -396,6 +396,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
}
if m&linux.STATX_SIZE != 0 {
+ if hostStat.Mode&linux.S_IFMT != linux.S_IFREG {
+ return syserror.EINVAL
+ }
if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go
index 8545a82f0..65d3af38c 100644
--- a/pkg/sentry/fsimpl/host/mmap.go
+++ b/pkg/sentry/fsimpl/host/mmap.go
@@ -19,13 +19,12 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
-// inodePlatformFile implements platform.File. It exists solely because inode
-// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef.
+// inodePlatformFile implements memmap.File. It exists solely because inode
+// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef.
//
// inodePlatformFile should only be used if inode.canMap is true.
type inodePlatformFile struct {
@@ -34,7 +33,7 @@ type inodePlatformFile struct {
// fdRefsMu protects fdRefs.
fdRefsMu sync.Mutex
- // fdRefs counts references on platform.File offsets. It is used solely for
+ // fdRefs counts references on memmap.File offsets. It is used solely for
// memory accounting.
fdRefs fsutil.FrameRefSet
@@ -45,32 +44,32 @@ type inodePlatformFile struct {
fileMapperInitOnce sync.Once
}
-// IncRef implements platform.File.IncRef.
+// IncRef implements memmap.File.IncRef.
//
// Precondition: i.inode.canMap must be true.
-func (i *inodePlatformFile) IncRef(fr platform.FileRange) {
+func (i *inodePlatformFile) IncRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.IncRefAndAccount(fr)
i.fdRefsMu.Unlock()
}
-// DecRef implements platform.File.DecRef.
+// DecRef implements memmap.File.DecRef.
//
// Precondition: i.inode.canMap must be true.
-func (i *inodePlatformFile) DecRef(fr platform.FileRange) {
+func (i *inodePlatformFile) DecRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.DecRefAndAccount(fr)
i.fdRefsMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal.
+// MapInternal implements memmap.File.MapInternal.
//
// Precondition: i.inode.canMap must be true.
-func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return i.fileMapper.MapInternal(fr, i.hostFD, at.Write)
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (i *inodePlatformFile) FD() int {
return i.hostFD
}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 3f0aea73a..1d37ccb98 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -112,7 +112,7 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence
return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts)
}
-// Release implements vfs.FileDecriptionImpl.Release.
+// Release implements vfs.FileDescriptionImpl.Release.
func (fd *GenericDirectoryFD) Release() {}
func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
@@ -123,7 +123,7 @@ func (fd *GenericDirectoryFD) inode() Inode {
return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
}
-// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds
// o.mu when calling cb.
func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
fd.mu.Lock()
@@ -198,7 +198,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
return err
}
-// Seek implements vfs.FileDecriptionImpl.Seek.
+// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
fd.mu.Lock()
defer fd.mu.Unlock()
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 2ab3f1761..579e627f0 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -267,7 +267,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
return syserror.EPERM
}
- if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index ff82e1f20..6b705e955 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -1104,7 +1104,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
}
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
mnt := rp.Mount()
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go
index a3c1f7a8d..c0749e711 100644
--- a/pkg/sentry/fsimpl/overlay/non_directory.go
+++ b/pkg/sentry/fsimpl/overlay/non_directory.go
@@ -151,7 +151,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux
func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
d := fd.dentry()
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
- if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
return err
}
mnt := fd.vfsfd.Mount()
@@ -176,7 +176,7 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return nil
}
-// StatFS implements vfs.FileDesciptionImpl.StatFS.
+// StatFS implements vfs.FileDescriptionImpl.StatFS.
func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) {
return fd.filesystem().statFS(ctx)
}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index dad4db1a7..79c2725f3 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -128,7 +128,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac
return fd.GenericDirectoryFD.IterDirents(ctx, cb)
}
-// Seek implements vfs.FileDecriptionImpl.Seek.
+// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
if fd.task.ExitState() >= kernel.TaskExitZombie {
return 0, syserror.ENOENT
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index a0f20c2d4..ef210a69b 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -649,7 +649,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
fs.mu.RUnlock()
return err
}
- if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil {
+ if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil {
fs.mu.RUnlock()
return err
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index 1cdb46e6f..abbaa5d60 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -325,8 +325,15 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts
// PWrite implements vfs.FileDescriptionImpl.PWrite.
func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset and error. The
+// final offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
if offset < 0 {
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
// Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
@@ -334,40 +341,44 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
//
// TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
- return 0, syserror.EOPNOTSUPP
+ return 0, offset, syserror.EOPNOTSUPP
}
srclen := src.NumBytes()
if srclen == 0 {
- return 0, nil
+ return 0, offset, nil
}
f := fd.inode().impl.(*regularFile)
+ f.inode.mu.Lock()
+ defer f.inode.mu.Unlock()
+ // If the file is opened with O_APPEND, update offset to file size.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Locking f.inode.mu is sufficient for reading f.size.
+ offset = int64(f.size)
+ }
if end := offset + srclen; end < offset {
// Overflow.
- return 0, syserror.EINVAL
+ return 0, offset, syserror.EINVAL
}
- var err error
srclen, err = vfs.CheckLimit(ctx, offset, srclen)
if err != nil {
- return 0, err
+ return 0, offset, err
}
src = src.TakeFirst64(srclen)
- f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
- fd.inode().touchCMtimeLocked()
- f.inode.mu.Unlock()
+ f.inode.touchCMtimeLocked()
putRegularFileReadWriter(rw)
- return n, err
+ return n, n + offset, err
}
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
fd.offMu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
fd.offMu.Unlock()
return n, err
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index d7f4f0779..2545d88e9 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -452,7 +452,8 @@ func (i *inode) statTo(stat *linux.Statx) {
}
}
-func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error {
+func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions) error {
+ stat := &opts.Stat
if stat.Mask == 0 {
return nil
}
@@ -460,7 +461,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu
return syserror.EPERM
}
mode := linux.FileMode(atomic.LoadUint32(&i.mode))
- if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
return err
}
i.mu.Lock()
@@ -695,7 +696,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
creds := auth.CredentialsFromContext(ctx)
d := fd.dentry()
- if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil {
+ if err := d.inode.setStat(ctx, creds, &opts); err != nil {
return err
}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 25fe1921b..f6886a758 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -132,6 +132,7 @@ go_library(
"task_stop.go",
"task_syscall.go",
"task_usermem.go",
+ "task_work.go",
"thread_group.go",
"threads.go",
"timekeeper.go",
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index 732e66da4..bcc1b29a8 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -717,10 +717,10 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3
}
}
-// UnlockPI unlock the futex following the Priority-inheritance futex
-// rules. The address provided must contain the caller's TID. If there are
-// waiters, TID of the next waiter (FIFO) is set to the given address, and the
-// waiter woken up. If there are no waiters, 0 is set to the address.
+// UnlockPI unlocks the futex following the Priority-inheritance futex rules.
+// The address provided must contain the caller's TID. If there are waiters,
+// TID of the next waiter (FIFO) is set to the given address, and the waiter
+// woken up. If there are no waiters, 0 is set to the address.
func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error {
k, err := getKey(t, addr, private)
if err != nil {
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 240cd6fe0..15dae0f5b 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -1469,6 +1469,11 @@ func (k *Kernel) NowMonotonic() int64 {
return now
}
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ return ktime.TcpipAfterFunc(k.realtimeClock, d, f)
+}
+
// SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or
// LoadFrom.
func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) {
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index bfd779837..c211fc8d0 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -20,7 +20,6 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/platform",
"//pkg/sentry/usage",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index f66cfcc7f..55b4c2cdb 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -45,7 +45,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -370,7 +369,7 @@ type Shm struct {
// fr is the offset into mfp.MemoryFile() that backs this contents of this
// segment. Immutable.
- fr platform.FileRange
+ fr memmap.FileRange
// mu protects all fields below.
mu sync.Mutex `state:"nosave"`
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index f48247c94..c4db05bd8 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -68,6 +68,21 @@ type Task struct {
// runState is exclusive to the task goroutine.
runState taskRunState
+ // taskWorkCount represents the current size of the task work queue. It is
+ // used to avoid acquiring taskWorkMu when the queue is empty.
+ //
+ // Must accessed with atomic memory operations.
+ taskWorkCount int32
+
+ // taskWorkMu protects taskWork.
+ taskWorkMu sync.Mutex `state:"nosave"`
+
+ // taskWork is a queue of work to be executed before resuming user execution.
+ // It is similar to the task_work mechanism in Linux.
+ //
+ // taskWork is exclusive to the task goroutine.
+ taskWork []TaskWorker
+
// haveSyscallReturn is true if tc.Arch().Return() represents a value
// returned by a syscall (or set by ptrace after a syscall).
//
@@ -550,6 +565,10 @@ type Task struct {
// futexWaiter is exclusive to the task goroutine.
futexWaiter *futex.Waiter `state:"nosave"`
+ // robustList is a pointer to the head of the tasks's robust futex
+ // list.
+ robustList usermem.Addr
+
// startTime is the real time at which the task started. It is set when
// a Task is created or invokes execve(2).
//
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 9b69f3cbe..7803b98d0 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -207,6 +207,9 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
return flags.CloseOnExec
})
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// NOTE(b/30815691): We currently do not implement privileged
// executables (set-user/group-ID bits and file capabilities). This
// allows us to unconditionally enable user dumpability on the new mm.
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index c4ade6e8e..231ac548a 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -253,6 +253,9 @@ func (*runExitMain) execute(t *Task) taskRunState {
}
}
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// Deactivate the address space and update max RSS before releasing the
// task's MM.
t.Deactivate()
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
index a53e77c9f..4b535c949 100644
--- a/pkg/sentry/kernel/task_futex.go
+++ b/pkg/sentry/kernel/task_futex.go
@@ -15,6 +15,7 @@
package kernel
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -52,3 +53,127 @@ func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) {
func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) {
return t.MemoryManager().GetSharedFutexKey(t, addr)
}
+
+// GetRobustList sets the robust futex list for the task.
+func (t *Task) GetRobustList() usermem.Addr {
+ t.mu.Lock()
+ addr := t.robustList
+ t.mu.Unlock()
+ return addr
+}
+
+// SetRobustList sets the robust futex list for the task.
+func (t *Task) SetRobustList(addr usermem.Addr) {
+ t.mu.Lock()
+ t.robustList = addr
+ t.mu.Unlock()
+}
+
+// exitRobustList walks the robust futex list, marking locks dead and notifying
+// wakers. It corresponds to Linux's exit_robust_list(). Following Linux,
+// errors are silently ignored.
+func (t *Task) exitRobustList() {
+ t.mu.Lock()
+ addr := t.robustList
+ t.robustList = 0
+ t.mu.Unlock()
+
+ if addr == 0 {
+ return
+ }
+
+ var rl linux.RobustListHead
+ if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil {
+ return
+ }
+
+ next := rl.List
+ done := 0
+ var pendingLockAddr usermem.Addr
+ if rl.ListOpPending != 0 {
+ pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset)
+ }
+
+ // Wake up normal elements.
+ for usermem.Addr(next) != addr {
+ // We traverse to the next element of the list before we
+ // actually wake anything. This prevents the race where waking
+ // this futex causes a modification of the list.
+ thisLockAddr := usermem.Addr(next + rl.FutexOffset)
+
+ // Try to decode the next element in the list before waking the
+ // current futex. But don't check the error until after we've
+ // woken the current futex. Linux does it in this order too
+ _, nextErr := t.CopyIn(usermem.Addr(next), &next)
+
+ // Wakeup the current futex if it's not pending.
+ if thisLockAddr != pendingLockAddr {
+ t.wakeRobustListOne(thisLockAddr)
+ }
+
+ // If there was an error copying the next futex, we must bail.
+ if nextErr != nil {
+ break
+ }
+
+ // This is a user structure, so it could be a massive list, or
+ // even contain a loop if they are trying to mess with us. We
+ // cap traversal to prevent that.
+ done++
+ if done >= linux.ROBUST_LIST_LIMIT {
+ break
+ }
+ }
+
+ // Is there a pending entry to wake?
+ if pendingLockAddr != 0 {
+ t.wakeRobustListOne(pendingLockAddr)
+ }
+}
+
+// wakeRobustListOne wakes a single futex from the robust list.
+func (t *Task) wakeRobustListOne(addr usermem.Addr) {
+ // Bit 0 in address signals PI futex.
+ pi := addr&1 == 1
+ addr = addr &^ 1
+
+ // Load the futex.
+ f, err := t.LoadUint32(addr)
+ if err != nil {
+ // Can't read this single value? Ignore the problem.
+ // We can wake the other futexes in the list.
+ return
+ }
+
+ tid := uint32(t.ThreadID())
+ for {
+ // Is this held by someone else?
+ if f&linux.FUTEX_TID_MASK != tid {
+ return
+ }
+
+ // This thread is dying and it's holding this futex. We need to
+ // set the owner died bit and wake up any waiters.
+ newF := (f & linux.FUTEX_WAITERS) | linux.FUTEX_OWNER_DIED
+ if curF, err := t.CompareAndSwapUint32(addr, f, newF); err != nil {
+ return
+ } else if curF != f {
+ // Futex changed out from under us. Try again...
+ f = curF
+ continue
+ }
+
+ // Wake waiters if there are any.
+ if f&linux.FUTEX_WAITERS != 0 {
+ private := f&linux.FUTEX_PRIVATE_FLAG != 0
+ if pi {
+ t.Futex().UnlockPI(t, addr, tid, private)
+ return
+ }
+ t.Futex().Wake(t, addr, private, linux.FUTEX_BITSET_MATCH_ANY, 1)
+ }
+
+ // Done.
+ return
+ }
+}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index d654dd997..7d4f44caf 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -167,7 +167,22 @@ func (app *runApp) execute(t *Task) taskRunState {
return (*runInterrupt)(nil)
}
- // We're about to switch to the application again. If there's still a
+ // Execute any task work callbacks before returning to user space.
+ if atomic.LoadInt32(&t.taskWorkCount) > 0 {
+ t.taskWorkMu.Lock()
+ queue := t.taskWork
+ t.taskWork = nil
+ atomic.StoreInt32(&t.taskWorkCount, 0)
+ t.taskWorkMu.Unlock()
+
+ // Do not hold taskWorkMu while executing task work, which may register
+ // more work.
+ for _, work := range queue {
+ work.TaskWork(t)
+ }
+ }
+
+ // We're about to switch to the application again. If there's still an
// unhandled SyscallRestartErrno that wasn't translated to an EINTR,
// restart the syscall that was interrupted. If there's a saved signal
// mask, restore it. (Note that restoring the saved signal mask may unblock
diff --git a/pkg/sentry/kernel/task_work.go b/pkg/sentry/kernel/task_work.go
new file mode 100644
index 000000000..dda5a433a
--- /dev/null
+++ b/pkg/sentry/kernel/task_work.go
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "sync/atomic"
+
+// TaskWorker is a deferred task.
+//
+// This must be savable.
+type TaskWorker interface {
+ // TaskWork will be executed prior to returning to user space. Note that
+ // TaskWork may call RegisterWork again, but this will not be executed until
+ // the next return to user space, unlike in Linux. This effectively allows
+ // registration of indefinite user return hooks, but not by default.
+ TaskWork(t *Task)
+}
+
+// RegisterWork can be used to register additional task work that will be
+// performed prior to returning to user space. See TaskWorker.TaskWork for
+// semantics regarding registration.
+func (t *Task) RegisterWork(work TaskWorker) {
+ t.taskWorkMu.Lock()
+ defer t.taskWorkMu.Unlock()
+ atomic.AddInt32(&t.taskWorkCount, 1)
+ t.taskWork = append(t.taskWork, work)
+}
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 7ba7dc50c..2817aa3ba 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "time",
srcs = [
"context.go",
+ "tcpip.go",
"time.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go
new file mode 100644
index 000000000..c4474c0cf
--- /dev/null
+++ b/pkg/sentry/kernel/time/tcpip.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "sync"
+ "time"
+)
+
+// TcpipAfterFunc waits for duration to elapse according to clock then runs fn.
+// The timer is started immediately and will fire exactly once.
+func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer {
+ timer := &TcpipTimer{
+ clock: clock,
+ }
+ timer.notifier = functionNotifier{
+ fn: func() {
+ // tcpip.Timer.Stop() explicitly states that the function is called in a
+ // separate goroutine that Stop() does not synchronize with.
+ // Timer.Destroy() synchronizes with calls to TimerListener.Notify().
+ // This is semantically meaningful because, in the former case, it's
+ // legal to call tcpip.Timer.Stop() while holding locks that may also be
+ // taken by the function, but this isn't so in the latter case. Most
+ // immediately, Timer calls TimerListener.Notify() while holding
+ // Timer.mu. A deadlock occurs without spawning a goroutine:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify()
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock!
+ //
+ // Spawning a goroutine avoids the deadlock:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify() <- Launches T2
+ // T2:
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, blocks
+ // T1:
+ // => (returns) <- Timer.mu.Unlock() called
+ // T2:
+ // => (continues) <- No deadlock!
+ go func() {
+ timer.Stop()
+ fn()
+ }()
+ },
+ }
+ timer.Reset(duration)
+ return timer
+}
+
+// TcpipTimer is a resettable timer with variable duration expirations.
+// Implements tcpip.Timer, which does not define a Destroy method; instead, all
+// resources are released after timer expiration and calls to Timer.Stop.
+//
+// Must be created by AfterFunc.
+type TcpipTimer struct {
+ // clock is the time source. clock is immutable.
+ clock Clock
+
+ // notifier is called when the Timer expires. notifier is immutable.
+ notifier functionNotifier
+
+ // mu protects t.
+ mu sync.Mutex
+
+ // t stores the latest running Timer. This is replaced whenever Reset is
+ // called since Timer cannot be restarted once it has been Destroyed by Stop.
+ //
+ // This field is nil iff Stop has been called.
+ t *Timer
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (r *TcpipTimer) Stop() bool {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ return false
+ }
+ _, lastSetting := r.t.Swap(Setting{})
+ r.t.Destroy()
+ r.t = nil
+ return lastSetting.Enabled
+}
+
+// Reset implements tcpip.Timer.Reset.
+func (r *TcpipTimer) Reset(d time.Duration) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ r.t = NewTimer(r.clock, &r.notifier)
+ }
+
+ r.t.Swap(Setting{
+ Enabled: true,
+ Period: 0,
+ Next: r.clock.Now().Add(d),
+ })
+}
+
+// functionNotifier is a TimerListener that runs a function.
+//
+// functionNotifier cannot be saved or loaded.
+type functionNotifier struct {
+ fn func()
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) {
+ f.fn()
+ return Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (f *functionNotifier) Destroy() {}
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
index 5f3908d8b..7c4fefb16 100644
--- a/pkg/sentry/kernel/timekeeper.go
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -21,8 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/log"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -90,7 +90,7 @@ type Timekeeper struct {
// NewTimekeeper does not take ownership of paramPage.
//
// SetClocks must be called on the returned Timekeeper before it is usable.
-func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) {
+func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) {
return &Timekeeper{
params: NewVDSOParamPage(mfp, paramPage),
}, nil
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index f1b3c212c..290c32466 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -19,8 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -58,7 +58,7 @@ type vdsoParams struct {
type VDSOParamPage struct {
// The parameter page is fr, allocated from mfp.MemoryFile().
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
// seq is the current sequence count written to the page.
//
@@ -81,7 +81,7 @@ type VDSOParamPage struct {
// * VDSOParamPage must be the only writer to fr.
//
// * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block.
-func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage {
+func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage {
return &VDSOParamPage{mfp: mfp, fr: fr}
}
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index a98b66de1..2c95669cd 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -28,9 +28,21 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "file_range",
+ out = "file_range.go",
+ package = "memmap",
+ prefix = "File",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uint64",
+ },
+)
+
go_library(
name = "memmap",
srcs = [
+ "file_range.go",
"mappable_range.go",
"mapping_set.go",
"mapping_set_impl.go",
@@ -40,7 +52,7 @@ go_library(
deps = [
"//pkg/context",
"//pkg/log",
- "//pkg/sentry/platform",
+ "//pkg/safemem",
"//pkg/syserror",
"//pkg/usermem",
],
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index c6db9fc8f..c188f6c29 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -19,12 +19,12 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/usermem"
)
// Mappable represents a memory-mappable object, a mutable mapping from uint64
-// offsets to (platform.File, uint64 File offset) pairs.
+// offsets to (File, uint64 File offset) pairs.
//
// See mm/mm.go for Mappable's place in the lock order.
//
@@ -74,7 +74,7 @@ type Mappable interface {
// Translations are valid until invalidated by a callback to
// MappingSpace.Invalidate or until the caller removes its mapping of the
// translated range. Mappable implementations must ensure that at least one
- // reference is held on all pages in a platform.File that may be the result
+ // reference is held on all pages in a File that may be the result
// of a valid Translation.
//
// Preconditions: required.Length() > 0. optional.IsSupersetOf(required).
@@ -100,7 +100,7 @@ type Translation struct {
Source MappableRange
// File is the mapped file.
- File platform.File
+ File File
// Offset is the offset into File at which this Translation begins.
Offset uint64
@@ -110,9 +110,9 @@ type Translation struct {
Perms usermem.AccessType
}
-// FileRange returns the platform.FileRange represented by t.
-func (t Translation) FileRange() platform.FileRange {
- return platform.FileRange{t.Offset, t.Offset + t.Source.Length()}
+// FileRange returns the FileRange represented by t.
+func (t Translation) FileRange() FileRange {
+ return FileRange{t.Offset, t.Offset + t.Source.Length()}
}
// CheckTranslateResult returns an error if (ts, terr) does not satisfy all
@@ -361,3 +361,49 @@ type MMapOpts struct {
// TODO(jamieliu): Replace entirely with MappingIdentity?
Hint string
}
+
+// File represents a host file that may be mapped into an platform.AddressSpace.
+type File interface {
+ // All pages in a File are reference-counted.
+
+ // IncRef increments the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr. (The File
+ // interface does not provide a way to acquire an initial reference;
+ // implementors may define mechanisms for doing so.)
+ IncRef(fr FileRange)
+
+ // DecRef decrements the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr.
+ DecRef(fr FileRange)
+
+ // MapInternal returns a mapping of the given file offsets in the invoking
+ // process' address space for reading and writing.
+ //
+ // Note that fr.Start and fr.End need not be page-aligned.
+ //
+ // Preconditions: fr.Length() > 0. At least one reference must be held on
+ // all pages in fr.
+ //
+ // Postconditions: The returned mapping is valid as long as at least one
+ // reference is held on the mapped pages.
+ MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
+
+ // FD returns the file descriptor represented by the File.
+ //
+ // The only permitted operation on the returned file descriptor is to map
+ // pages from it consistent with the requirements of AddressSpace.MapFile.
+ FD() int
+}
+
+// FileRange represents a range of uint64 offsets into a File.
+//
+// type FileRange <generated using go_generics>
+
+// String implements fmt.Stringer.String.
+func (fr FileRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
+}
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index a036ce53c..f9d0837a1 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -7,14 +7,14 @@ go_template_instance(
name = "file_refcount_set",
out = "file_refcount_set.go",
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "mm",
prefix = "fileRefcount",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "int32",
"Functions": "fileRefcountSetFunctions",
},
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 379148903..1999ec706 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -243,7 +242,7 @@ type aioMappable struct {
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
}
var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp())
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 6db7c3d40..3e85964e4 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -25,7 +25,7 @@
// Locks taken by memmap.Mappable.Translate
// mm.privateRefs.mu
// platform.AddressSpace locks
-// platform.File locks
+// memmap.File locks
// mm.aioManager.mu
// mm.AIOContext.mu
//
@@ -396,7 +396,7 @@ type pma struct {
// file is the file mapped by this pma. Only pmas for which file ==
// MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to
// the corresponding file range while they exist.
- file platform.File `state:"nosave"`
+ file memmap.File `state:"nosave"`
// off is the offset into file at which this pma begins.
//
@@ -436,7 +436,7 @@ type pma struct {
private bool
// If internalMappings is not empty, it is the cached return value of
- // file.MapInternal for the platform.FileRange mapped by this pma.
+ // file.MapInternal for the memmap.FileRange mapped by this pma.
internalMappings safemem.BlockSeq `state:"nosave"`
}
@@ -469,10 +469,10 @@ func (fileRefcountSetFunctions) MaxKey() uint64 {
func (fileRefcountSetFunctions) ClearValue(_ *int32) {
}
-func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) {
+func (fileRefcountSetFunctions) Merge(_ memmap.FileRange, rc1 int32, _ memmap.FileRange, rc2 int32) (int32, bool) {
return rc1, rc1 == rc2
}
-func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) {
+func (fileRefcountSetFunctions) Split(_ memmap.FileRange, rc int32, _ uint64) (int32, int32) {
return rc, rc
}
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index 62e4c20af..930ec895f 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/safecopy"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -604,7 +603,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat
}
}
-// Pin returns the platform.File ranges currently mapped by addresses in ar in
+// Pin returns the memmap.File ranges currently mapped by addresses in ar in
// mm, acquiring a reference on the returned ranges which the caller must
// release by calling Unpin. If not all addresses are mapped, Pin returns a
// non-nil error. Note that Pin may return both a non-empty slice of
@@ -674,15 +673,15 @@ type PinnedRange struct {
Source usermem.AddrRange
// File is the mapped file.
- File platform.File
+ File memmap.File
// Offset is the offset into File at which this PinnedRange begins.
Offset uint64
}
-// FileRange returns the platform.File offsets mapped by pr.
-func (pr PinnedRange) FileRange() platform.FileRange {
- return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
+// FileRange returns the memmap.File offsets mapped by pr.
+func (pr PinnedRange) FileRange() memmap.FileRange {
+ return memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
}
// Unpin releases the reference held by prs.
@@ -857,7 +856,7 @@ func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) saf
}
// incPrivateRef acquires a reference on private pages in fr.
-func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
+func (mm *MemoryManager) incPrivateRef(fr memmap.FileRange) {
mm.privateRefs.mu.Lock()
defer mm.privateRefs.mu.Unlock()
refSet := &mm.privateRefs.refs
@@ -878,8 +877,8 @@ func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
}
// decPrivateRef releases a reference on private pages in fr.
-func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) {
- var freed []platform.FileRange
+func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) {
+ var freed []memmap.FileRange
mm.privateRefs.mu.Lock()
refSet := &mm.privateRefs.refs
@@ -951,7 +950,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa
// Discard internal mappings instead of trying to merge them, since merging
// them requires an allocation and getting them again from the
- // platform.File might not.
+ // memmap.File might not.
pma1.internalMappings = safemem.BlockSeq{}
return pma1, true
}
@@ -1012,12 +1011,12 @@ func (pseg pmaIterator) getInternalMappingsLocked() error {
return nil
}
-func (pseg pmaIterator) fileRange() platform.FileRange {
+func (pseg pmaIterator) fileRange() memmap.FileRange {
return pseg.fileRangeOf(pseg.Range())
}
// Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0.
-func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
+func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange {
if checkInvariants {
if !pseg.Ok() {
panic("terminal pma iterator")
@@ -1032,5 +1031,5 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
pma := pseg.ValuePtr()
pstart := pseg.Start()
- return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
+ return memmap.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
}
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
index 9ad52082d..0e142fb11 100644
--- a/pkg/sentry/mm/special_mappable.go
+++ b/pkg/sentry/mm/special_mappable.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -35,7 +34,7 @@ type SpecialMappable struct {
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
name string
}
@@ -44,7 +43,7 @@ type SpecialMappable struct {
// SpecialMappable will use the given name in /proc/[pid]/maps.
//
// Preconditions: fr.Length() != 0.
-func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable {
+func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable {
m := SpecialMappable{mfp: mfp, fr: fr, name: name}
m.EnableLeakCheck("mm.SpecialMappable")
return &m
@@ -126,7 +125,7 @@ func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider {
// FileRange returns the offsets into MemoryFileProvider().MemoryFile() that
// store the SpecialMappable's contents.
-func (m *SpecialMappable) FileRange() platform.FileRange {
+func (m *SpecialMappable) FileRange() memmap.FileRange {
return m.fr
}
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index e1fcb175f..7a3311a70 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -36,14 +36,14 @@ go_template_instance(
"trackGaps": "1",
},
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "pgalloc",
prefix = "usage",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "usageInfo",
"Functions": "usageSetFunctions",
},
@@ -56,14 +56,14 @@ go_template_instance(
"minDegree": "10",
},
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "pgalloc",
prefix = "reclaim",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "reclaimSetValue",
"Functions": "reclaimSetFunctions",
},
@@ -89,7 +89,7 @@ go_library(
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/hostmm",
- "//pkg/sentry/platform",
+ "//pkg/sentry/memmap",
"//pkg/sentry/usage",
"//pkg/state",
"//pkg/state/wire",
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index afab97c0a..3243d7214 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -33,14 +33,14 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostmm"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
-// MemoryFile is a platform.File whose pages may be allocated to arbitrary
+// MemoryFile is a memmap.File whose pages may be allocated to arbitrary
// users.
type MemoryFile struct {
// opts holds options passed to NewMemoryFile. opts is immutable.
@@ -372,7 +372,7 @@ func (f *MemoryFile) Destroy() {
// to Allocate.
//
// Preconditions: length must be page-aligned and non-zero.
-func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) {
+func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) {
if length == 0 || length%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid allocation length: %#x", length))
}
@@ -390,7 +390,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
// Find a range in the underlying file.
fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment)
if !ok {
- return platform.FileRange{}, syserror.ENOMEM
+ return memmap.FileRange{}, syserror.ENOMEM
}
// Expand the file if needed.
@@ -398,7 +398,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
// Round the new file size up to be chunk-aligned.
newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask
if err := f.file.Truncate(newFileSize); err != nil {
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
f.fileSize = newFileSize
f.mappingsMu.Lock()
@@ -416,7 +416,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
bs[i] = 0
}
}); err != nil {
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
}
if !f.usage.Add(fr, usageInfo{
@@ -439,7 +439,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
// space for mappings to be allocated downwards.
//
// Precondition: alignment must be a power of 2.
-func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) {
+func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (memmap.FileRange, bool) {
alignmentMask := alignment - 1
// Search for space in existing gaps, starting at the current end of the
@@ -461,7 +461,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
break
}
if start := unalignedStart &^ alignmentMask; start >= gap.Start() {
- return platform.FileRange{start, start + length}, true
+ return memmap.FileRange{start, start + length}, true
}
gap = gap.PrevLargeEnoughGap(length)
@@ -475,7 +475,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
min = (min + alignmentMask) &^ alignmentMask
if min+length < min {
// Overflow: allocation would exceed the range of uint64.
- return platform.FileRange{}, false
+ return memmap.FileRange{}, false
}
// Determine the minimum file size required to fit this allocation at its end.
@@ -484,7 +484,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
if newFileSize <= fileSize {
if fileSize != 0 {
// Overflow: allocation would exceed the range of int64.
- return platform.FileRange{}, false
+ return memmap.FileRange{}, false
}
newFileSize = chunkSize
}
@@ -496,7 +496,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
continue
}
if start := unalignedStart &^ alignmentMask; start >= min {
- return platform.FileRange{start, start + length}, true
+ return memmap.FileRange{start, start + length}, true
}
}
}
@@ -508,22 +508,22 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
// by r.ReadToBlocks(), it returns that error.
//
// Preconditions: length > 0. length must be page-aligned.
-func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) {
+func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) {
fr, err := f.Allocate(length, kind)
if err != nil {
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
dsts, err := f.MapInternal(fr, usermem.Write)
if err != nil {
f.DecRef(fr)
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
n, err := safemem.ReadFullToBlocks(r, dsts)
un := uint64(usermem.Addr(n).RoundDown())
if un < length {
// Free unused memory and update fr to contain only the memory that is
// still allocated.
- f.DecRef(platform.FileRange{fr.Start + un, fr.End})
+ f.DecRef(memmap.FileRange{fr.Start + un, fr.End})
fr.End = fr.Start + un
}
return fr, err
@@ -540,7 +540,7 @@ const (
// will read zeroes.
//
// Preconditions: fr.Length() > 0.
-func (f *MemoryFile) Decommit(fr platform.FileRange) error {
+func (f *MemoryFile) Decommit(fr memmap.FileRange) error {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -560,7 +560,7 @@ func (f *MemoryFile) Decommit(fr platform.FileRange) error {
return nil
}
-func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
+func (f *MemoryFile) markDecommitted(fr memmap.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
// Since we're changing the knownCommitted attribute, we need to merge
@@ -581,8 +581,8 @@ func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
f.usage.MergeRange(fr)
}
-// IncRef implements platform.File.IncRef.
-func (f *MemoryFile) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (f *MemoryFile) IncRef(fr memmap.FileRange) {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -600,8 +600,8 @@ func (f *MemoryFile) IncRef(fr platform.FileRange) {
f.usage.MergeAdjacent(fr)
}
-// DecRef implements platform.File.DecRef.
-func (f *MemoryFile) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (f *MemoryFile) DecRef(fr memmap.FileRange) {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -637,8 +637,8 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
}
}
-// MapInternal implements platform.File.MapInternal.
-func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
if !fr.WellFormed() || fr.Length() == 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -664,7 +664,7 @@ func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (
// forEachMappingSlice invokes fn on a sequence of byte slices that
// collectively map all bytes in fr.
-func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error {
+func (f *MemoryFile) forEachMappingSlice(fr memmap.FileRange, fn func([]byte)) error {
mappings := f.mappings.Load().([]uintptr)
for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
chunk := int(chunkStart >> chunkShift)
@@ -944,7 +944,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
continue
case !populated && populatedRun:
// Finish the run by changing this segment.
- runRange := platform.FileRange{
+ runRange := memmap.FileRange{
Start: r.Start + uint64(populatedRunStart*usermem.PageSize),
End: r.Start + uint64(i*usermem.PageSize),
}
@@ -1009,7 +1009,7 @@ func (f *MemoryFile) File() *os.File {
return f.file
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (f *MemoryFile) FD() int {
return int(f.file.Fd())
}
@@ -1090,13 +1090,13 @@ func (f *MemoryFile) runReclaim() {
//
// Note that there returned range will be removed from tracking. It
// must be reclaimed (removed from f.usage) at this point.
-func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
+func (f *MemoryFile) findReclaimable() (memmap.FileRange, bool) {
f.mu.Lock()
defer f.mu.Unlock()
for {
for {
if f.destroyed {
- return platform.FileRange{}, false
+ return memmap.FileRange{}, false
}
if f.reclaimable {
break
@@ -1120,7 +1120,7 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
}
}
-func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
+func (f *MemoryFile) markReclaimed(fr memmap.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
seg := f.usage.FindSegment(fr.Start)
@@ -1222,11 +1222,11 @@ func (usageSetFunctions) MaxKey() uint64 {
func (usageSetFunctions) ClearValue(val *usageInfo) {
}
-func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) {
+func (usageSetFunctions) Merge(_ memmap.FileRange, val1 usageInfo, _ memmap.FileRange, val2 usageInfo) (usageInfo, bool) {
return val1, val1 == val2
}
-func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
+func (usageSetFunctions) Split(_ memmap.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
return val, val
}
@@ -1270,10 +1270,10 @@ func (reclaimSetFunctions) MaxKey() uint64 {
func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) {
}
-func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) {
+func (reclaimSetFunctions) Merge(_ memmap.FileRange, _ reclaimSetValue, _ memmap.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) {
return reclaimSetValue{}, true
}
-func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) {
+func (reclaimSetFunctions) Split(_ memmap.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) {
return reclaimSetValue{}, reclaimSetValue{}
}
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index 453241eca..209b28053 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -1,39 +1,21 @@
load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-go_template_instance(
- name = "file_range",
- out = "file_range.go",
- package = "platform",
- prefix = "File",
- template = "//pkg/segment:generic_range",
- types = {
- "T": "uint64",
- },
-)
-
go_library(
name = "platform",
srcs = [
"context.go",
- "file_range.go",
"mmap_min_addr.go",
"platform.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/atomicbitops",
"//pkg/context",
- "//pkg/log",
- "//pkg/safecopy",
- "//pkg/safemem",
"//pkg/seccomp",
"//pkg/sentry/arch",
- "//pkg/sentry/usage",
- "//pkg/syserror",
+ "//pkg/sentry/memmap",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 10a10bfe2..b5d27a72a 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -47,6 +47,7 @@ go_library(
"//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/platform/ring0",
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index faf1d5e1c..98a3e539d 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sync"
@@ -150,7 +151,7 @@ func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.
}
// MapFile implements platform.AddressSpace.MapFile.
-func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
as.mu.Lock()
defer as.mu.Unlock()
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index 171513f3f..4b13eec30 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -22,9 +22,9 @@ import (
"os"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -207,7 +207,7 @@ type AddressSpace interface {
// Preconditions: addr and fr must be page-aligned. fr.Length() > 0.
// at.Any() == true. At least one reference must be held on all pages in
// fr, and must continue to be held as long as pages are mapped.
- MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error
+ MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error
// Unmap unmaps the given range.
//
@@ -310,52 +310,6 @@ func (f SegmentationFault) Error() string {
return fmt.Sprintf("segmentation fault at %#x", f.Addr)
}
-// File represents a host file that may be mapped into an AddressSpace.
-type File interface {
- // All pages in a File are reference-counted.
-
- // IncRef increments the reference count on all pages in fr.
- //
- // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
- // 0. At least one reference must be held on all pages in fr. (The File
- // interface does not provide a way to acquire an initial reference;
- // implementors may define mechanisms for doing so.)
- IncRef(fr FileRange)
-
- // DecRef decrements the reference count on all pages in fr.
- //
- // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
- // 0. At least one reference must be held on all pages in fr.
- DecRef(fr FileRange)
-
- // MapInternal returns a mapping of the given file offsets in the invoking
- // process' address space for reading and writing.
- //
- // Note that fr.Start and fr.End need not be page-aligned.
- //
- // Preconditions: fr.Length() > 0. At least one reference must be held on
- // all pages in fr.
- //
- // Postconditions: The returned mapping is valid as long as at least one
- // reference is held on the mapped pages.
- MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
-
- // FD returns the file descriptor represented by the File.
- //
- // The only permitted operation on the returned file descriptor is to map
- // pages from it consistent with the requirements of AddressSpace.MapFile.
- FD() int
-}
-
-// FileRange represents a range of uint64 offsets into a File.
-//
-// type FileRange <generated using go_generics>
-
-// String implements fmt.Stringer.String.
-func (fr FileRange) String() string {
- return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
-}
-
// Requirements is used to specify platform specific requirements.
type Requirements struct {
// RequiresCurrentPIDNS indicates that the sandbox has to be started in the
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index 30402c2df..29fd23cc3 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/hostcpu",
+ "//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sync",
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 2389423b0..c990f3454 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -616,7 +617,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp
}
// MapFile implements platform.AddressSpace.MapFile.
-func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
var flags int
if precommit {
flags |= syscall.MAP_POPULATE
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index ccacaea6b..fca3a5478 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -58,7 +58,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Pstate &= ^uint64(UserFlagsClear)
regs.Pstate |= UserFlagsSet
+
+ SetTLS(regs.TPIDR_EL0)
+
kernelExitToEl0()
+
+ regs.TPIDR_EL0 = GetTLS()
+
vector = c.vecCode
// Perform the switch.
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index c40c6d673..c0fd3425b 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,5 +20,6 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index ff81ea6e6..e76e498de 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -40,6 +40,8 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index a92aed2c9..532a1ea5d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -36,6 +36,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const (
@@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
@@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
if err != nil {
return nil, syserr.FromError(err)
}
- return opt, nil
+ optP := primitive.ByteSlice(opt)
+ return &optP, nil
}
// SetSockOpt implements socket.Socket.SetSockOpt.
@@ -708,6 +711,6 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int
func init() {
for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
socket.RegisterProvider(family, &socketProvider{family})
- socket.RegisterProviderVFS2(family, &socketProviderVFS2{})
+ socket.RegisterProviderVFS2(family, &socketProviderVFS2{family})
}
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index f7abe77d3..a9f0604ae 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -66,7 +66,7 @@ func nflog(format string, args ...interface{}) {
func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
var info linux.IPTGetinfo
- if _, err := t.CopyIn(outPtr, &info); err != nil {
+ if _, err := info.CopyIn(t, outPtr); err != nil {
return linux.IPTGetinfo{}, syserr.FromError(err)
}
@@ -84,7 +84,7 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT
func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
// Read in the struct and table name.
var userEntries linux.IPTGetEntries
- if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ if _, err := userEntries.CopyIn(t, outPtr); err != nil {
nflog("couldn't copy in entries %q", userEntries.Name)
return linux.KernelIPTGetEntries{}, syserr.FromError(err)
}
@@ -145,7 +145,7 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
// Each rule corresponds to an entry.
entry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
+ Entry: linux.IPTEntry{
IP: linux.IPTIP{
Protocol: uint16(rule.Filter.Protocol),
},
@@ -153,20 +153,20 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
TargetOffset: linux.SizeOfIPTEntry,
},
}
- copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst)
- copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask)
- copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src)
- copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask)
- copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface)
- copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ copy(entry.Entry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.Entry.IP.Src[:], rule.Filter.Src)
+ copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
if rule.Filter.DstInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP
}
if rule.Filter.SrcInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP
}
if rule.Filter.OutputInterfaceInvert {
- entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
}
for _, matcher := range rule.Matchers {
@@ -178,8 +178,8 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
- entry.TargetOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
+ entry.Entry.TargetOffset += uint16(len(serialized))
}
// Serialize and append the target.
@@ -188,11 +188,11 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin
panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
}
entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
+ entry.Entry.NextOffset += uint16(len(serialized))
nflog("convert to binary: adding entry: %+v", entry)
- entries.Size += uint32(entry.NextOffset)
+ entries.Size += uint32(entry.Entry.NextOffset)
entries.Entrytable = append(entries.Entrytable, entry)
info.NumEntries++
}
@@ -342,10 +342,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
- case stack.TablenameFilter:
+ case stack.FilterTable:
table = stack.EmptyFilterTable()
- case stack.TablenameNat:
- table = stack.EmptyNatTable()
+ case stack.NATTable:
+ table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
@@ -431,6 +431,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
for hook, _ := range replace.HookEntry {
if table.ValidHooks()&(1<<hook) != 0 {
hk := hookFromLinux(hook)
+ table.BuiltinChains[hk] = stack.HookUnset
+ table.Underflows[hk] = stack.HookUnset
for offset, ruleIdx := range offsets {
if offset == replace.HookEntry[hook] {
table.BuiltinChains[hk] = ruleIdx
@@ -456,8 +458,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Add the user chains.
for ruleIdx, rule := range table.Rules {
- target, ok := rule.Target.(stack.UserChainTarget)
- if !ok {
+ if _, ok := rule.Target.(stack.UserChainTarget); !ok {
continue
}
@@ -473,7 +474,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
nflog("user chain's first node must have no matchers")
return syserr.ErrInvalidArgument
}
- table.UserChains[target.Name] = ruleIdx + 1
}
// Set each jump to point to the appropriate rule. Right now they hold byte
@@ -499,7 +499,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
// make sure all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook == stack.Forward || hook == stack.Postrouting {
+ if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
+ if ruleIdx == stack.HookUnset {
+ continue
+ }
if !isUnconditionalAccept(table.Rules[ruleIdx]) {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
@@ -512,9 +515,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- stk.IPTables().ReplaceTable(replace.Name.String(), table)
-
- return nil
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table))
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index d5ca3ac56..0546801bf 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -36,6 +36,8 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 81f34c5a2..98ca7add0 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -38,6 +38,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const sizeOfInt32 int = 4
@@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
}
s.mu.Lock()
defer s.mu.Unlock()
- return int32(s.sendBufferSize), nil
+ sendBufferSizeP := primitive.Int32(s.sendBufferSize)
+ return &sendBufferSizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
// We don't have limit on receiving size.
- return int32(math.MaxInt32), nil
+ recvBufferSizeP := primitive.Int32(math.MaxInt32)
+ return &recvBufferSizeP, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var passcred int32
+ var passcred primitive.Int32
if s.Passcred() {
passcred = 1
}
- return passcred, nil
+ return &passcred, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index ea6ebd0e2..1fb777a6c 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -51,6 +51,8 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 49a04e613..44b3fff46 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -26,6 +26,7 @@ package netstack
import (
"bytes"
+ "fmt"
"io"
"math"
"reflect"
@@ -61,6 +62,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
@@ -909,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -919,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -955,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -970,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
@@ -980,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -1013,7 +1016,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -1024,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// Get the last error and convert it.
err := ep.GetSockOpt(tcpip.ErrorOption{})
if err == nil {
- return int32(0), nil
+ optP := primitive.Int32(0)
+ return &optP, nil
}
- return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number())
+ return &optP, nil
case linux.SO_PEERCRED:
if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
@@ -1034,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
tcred := t.Credentials()
- return syscall.Ucred{
- Pid: int32(t.ThreadGroup().ID()),
- Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
- Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
- }, nil
+ creds := linux.ControlMessageCredentials{
+ PID: int32(t.ThreadGroup().ID()),
+ UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }
+ return &creds, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
@@ -1049,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1065,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
@@ -1081,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_REUSEADDR:
if outLen < sizeOfInt32 {
@@ -1092,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
@@ -1103,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1111,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
if v == 0 {
- return []byte{}, nil
+ var b primitive.ByteSlice
+ return &b, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
@@ -1126,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// interface was removed.
return nil, syserr.ErrUnknownDevice
}
- return append([]byte(nic.Name), 0), nil
+
+ name := primitive.ByteSlice(append([]byte(nic.Name), 0))
+ return &name, nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1137,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
@@ -1148,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- return linux.Linger{}, nil
+
+ linger := linux.Linger{}
+ return &linger, nil
case linux.SO_SNDTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1162,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.SendTimeout()), nil
+ sendTimeout := linux.NsecToTimeval(s.SendTimeout())
+ return &sendTimeout, nil
case linux.SO_RCVTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1170,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.RecvTimeout()), nil
+ recvTimeout := linux.NsecToTimeval(s.RecvTimeout())
+ return &recvTimeout, nil
case linux.SO_OOBINLINE:
if outLen < sizeOfInt32 {
@@ -1182,7 +1207,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
@@ -1193,7 +1219,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1202,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
@@ -1213,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(!v), nil
+
+ vP := primitive.Int32(boolToInt32(!v))
+ return &vP, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
@@ -1224,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
@@ -1235,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1246,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_KEEPIDLE:
if outLen < sizeOfInt32 {
@@ -1258,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveIdle, nil
case linux.TCP_KEEPINTVL:
if outLen < sizeOfInt32 {
@@ -1270,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveInterval, nil
case linux.TCP_KEEPCNT:
if outLen < sizeOfInt32 {
@@ -1282,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_USER_TIMEOUT:
if outLen < sizeOfInt32 {
@@ -1294,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Millisecond), nil
+ tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond)
+ return &tcpUserTimeout, nil
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
@@ -1308,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
info := linux.TCPInfo{}
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &info)
- if len(ib) > outLen {
- ib = ib[:outLen]
+ buf := t.CopyScratchBuffer(info.SizeBytes())
+ info.MarshalUnsafe(buf)
+ if len(buf) > outLen {
+ buf = buf[:outLen]
}
-
- return ib, nil
+ bufP := primitive.ByteSlice(buf)
+ return &bufP, nil
case linux.TCP_CC_INFO,
linux.TCP_NOTSENT_LOWAT,
@@ -1343,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
b := make([]byte, toCopy)
copy(b, v)
- return b, nil
+
+ bP := primitive.ByteSlice(b)
+ return &bP, nil
case linux.TCP_LINGER2:
if outLen < sizeOfInt32 {
@@ -1355,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ return &lingerTimeout, nil
case linux.TCP_DEFER_ACCEPT:
if outLen < sizeOfInt32 {
@@ -1367,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second)
+ return &tcpDeferAccept, nil
case linux.TCP_SYNCNT:
if outLen < sizeOfInt32 {
@@ -1378,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_WINDOW_CLAMP:
if outLen < sizeOfInt32 {
@@ -1390,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1399,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
@@ -1410,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1418,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
case linux.IPV6_TCLASS:
// Length handling for parity with Linux.
if outLen == 0 {
- return make([]byte, 0), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- uintv := uint32(v)
+ uintv := primitive.Uint32(v)
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ ib := t.CopyScratchBuffer(uintv.SizeBytes())
+ uintv.MarshalUnsafe(ib)
// Handle cases where outLen is lesser than sizeOfInt32.
if len(ib) > outLen {
ib = ib[:outLen]
}
- return ib, nil
+ ibP := primitive.ByteSlice(ib)
+ return &ibP, nil
case linux.IPV6_RECVTCLASS:
if outLen < sizeOfInt32 {
@@ -1443,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIPv6(t, name)
@@ -1452,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1465,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
// Fill in the default value, if needed.
- if v == 0 {
- v = DefaultTTL
+ vP := primitive.Int32(v)
+ if vP == 0 {
+ vP = DefaultTTL
}
- return int32(v), nil
+ return &vP, nil
case linux.IP_MULTICAST_TTL:
if outLen < sizeOfInt32 {
@@ -1481,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_MULTICAST_IF:
if outLen < len(linux.InetAddr{}) {
@@ -1495,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
- return a.(*linux.SockAddrInet).Addr, nil
+ return &a.(*linux.SockAddrInet).Addr, nil
case linux.IP_MULTICAST_LOOP:
if outLen < sizeOfInt32 {
@@ -1506,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
- return []byte(nil), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
- return uint8(v), nil
+ vP := primitive.Uint8(v)
+ return &vP, nil
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_RECVTOS:
if outLen < sizeOfInt32 {
@@ -1531,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
@@ -1542,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIP(t, name)
@@ -2468,6 +2524,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
}
+func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
+ switch pktType {
+ case tcpip.PacketHost:
+ return linux.PACKET_HOST
+ case tcpip.PacketOtherHost:
+ return linux.PACKET_OTHERHOST
+ case tcpip.PacketOutgoing:
+ return linux.PACKET_OUTGOING
+ case tcpip.PacketBroadcast:
+ return linux.PACKET_BROADCAST
+ case tcpip.PacketMulticast:
+ return linux.PACKET_MULTICAST
+ default:
+ panic(fmt.Sprintf("unknown packet type: %d", pktType))
+ }
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
@@ -2526,6 +2599,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
switch v := addr.(type) {
case *linux.SockAddrLink:
v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
}
}
@@ -2761,6 +2835,11 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("ioctl(2) may only be called from a task goroutine")
+ }
+
// SIOCGSTAMP is implemented by netstack rather than all commonEndpoint
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
@@ -2773,9 +2852,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
}
tv := linux.NsecToTimeval(s.timestampNS)
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := tv.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCINQ:
@@ -2794,9 +2871,8 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
}
@@ -2805,6 +2881,11 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
// Ioctl performs a socket ioctl.
func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("ioctl(2) may only be called from a task goroutine")
+ }
+
switch arg := int(args[1].Int()); arg {
case linux.SIOCGIFFLAGS,
linux.SIOCGIFADDR,
@@ -2821,37 +2902,28 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
linux.SIOCETHTOOL:
var ifr linux.IFReq
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil {
return 0, err
}
if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil {
return 0, err.ToError()
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := ifr.CopyOut(t, args[2].Pointer())
return 0, err
case linux.SIOCGIFCONF:
// Return a list of interface addresses or the buffer size
// necessary to hold the list.
var ifc linux.IFConf
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil {
return 0, err
}
- if err := ifconfIoctl(ctx, io, &ifc); err != nil {
+ if err := ifconfIoctl(ctx, t, io, &ifc); err != nil {
return 0, err
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
+ _, err := ifc.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCINQ:
@@ -2864,9 +2936,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
v = math.MaxInt32
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCOUTQ:
@@ -2880,9 +2951,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
}
// Copy result to userspace.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
@@ -3031,7 +3101,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
+func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error {
// If Ptr is NULL, return the necessary buffer size via Len.
// Otherwise, write up to Len bytes starting at Ptr containing ifreq
// structs.
@@ -3068,9 +3138,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
// Copy the ifr to userspace.
dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
ifc.Len += int32(linux.SizeOfIFReq)
- if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil {
return err
}
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index d65a89316..a9025b0ec 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -31,6 +31,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// SocketVFS2 encapsulates all the state needed to represent a network stack
@@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketVFS2 rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fcd7f9d7f..d112757fb 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// ControlMessages represents the union of unix control messages and tcpip
@@ -86,7 +87,7 @@ type SocketOps interface {
Shutdown(t *kernel.Task, how int) *syserr.Error
// GetSockOpt implements the getsockopt(2) linux syscall.
- GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error)
// SetSockOpt implements the setsockopt(2) linux syscall.
SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cca5e70f1..061a689a9 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -35,5 +35,6 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4bb2b6ff4..0482d33cf 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -40,6 +40,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketOperations is a Unix socket. It is similar to a netstack socket,
@@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index ff2149250..05c16fcfe 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketVFS2 implements socket.SocketVFS2 (and by extension,
@@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 217fcfef2..4a9b04fd0 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -99,5 +99,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index ea4f9b1a7..80c65164a 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -325,8 +325,8 @@ var AMD64 = &kernel.SyscallTable{
270: syscalls.Supported("pselect", Pselect),
271: syscalls.Supported("ppoll", Ppoll),
272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
- 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 273: syscalls.Supported("set_robust_list", SetRobustList),
+ 274: syscalls.Supported("get_robust_list", GetRobustList),
275: syscalls.Supported("splice", Splice),
276: syscalls.Supported("tee", Tee),
277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index b68261f72..f04d78856 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -198,7 +198,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
switch cmd {
case linux.FUTEX_WAIT:
// WAIT uses a relative timeout.
- mask = ^uint32(0)
+ mask = linux.FUTEX_BITSET_MATCH_ANY
var timeoutDur time.Duration
if !forever {
timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond
@@ -286,3 +286,49 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, syserror.ENOSYS
}
}
+
+// SetRobustList implements linux syscall set_robust_list(2).
+func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ head := args[0].Pointer()
+ length := args[1].SizeT()
+
+ if length != uint(linux.SizeOfRobustListHead) {
+ return 0, nil, syserror.EINVAL
+ }
+ t.SetRobustList(head)
+ return 0, nil, nil
+}
+
+// GetRobustList implements linux syscall get_robust_list(2).
+func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ tid := args[0].Int()
+ head := args[1].Pointer()
+ size := args[2].Pointer()
+
+ if tid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ot := t
+ if tid != 0 {
+ if ot = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid)); ot == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ // Copy out head pointer.
+ if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil {
+ return 0, nil, err
+ }
+
+ // Copy out size, which is a constant.
+ if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 0760af77b..414fce8e3 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -29,6 +29,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
@@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 0c740335b..64696b438 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -72,5 +72,7 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
index adeaa39cc..ea337de7c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/mount.go
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -77,8 +77,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Silently allow MS_NOSUID, since we don't implement set-id bits
// anyway.
- const unsupportedFlags = linux.MS_NODEV |
- linux.MS_NODIRATIME | linux.MS_STRICTATIME
+ const unsupportedFlags = linux.MS_NODIRATIME | linux.MS_STRICTATIME
// Linux just allows passing any flags to mount(2) - it won't fail when
// unknown or unsupported flags are passed. Since we don't implement
@@ -94,6 +93,12 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
opts.Flags.NoExec = true
}
+ if flags&linux.MS_NODEV == linux.MS_NODEV {
+ opts.Flags.NoDev = true
+ }
+ if flags&linux.MS_NOSUID == linux.MS_NOSUID {
+ opts.Flags.NoSUID = true
+ }
if flags&linux.MS_RDONLY == linux.MS_RDONLY {
opts.ReadOnly = true
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 09ecfed26..6daedd173 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -178,6 +178,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
Mask: linux.STATX_SIZE,
Size: uint64(length),
},
+ NeedWritePerm: true,
})
return 0, nil, handleSetSizeError(t, err)
}
@@ -197,6 +198,10 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
defer file.DecRef()
+ if !file.IsWritable() {
+ return 0, nil, syserror.EINVAL
+ }
+
err := file.SetStat(t, vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_SIZE,
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 10b668477..8096a8f9c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -30,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// minListenBacklog is the minimum reasonable backlog for listening sockets.
@@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 945a364a7..63ab11f8c 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -15,12 +15,15 @@
package vfs2
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -110,16 +113,20 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Move data.
var (
- n int64
- err error
- inCh chan struct{}
- outCh chan struct{}
+ n int64
+ err error
)
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
for {
// If both input and output are pipes, delegate to the pipe
- // implementation. Otherwise, exactly one end is a pipe, which we
- // ensure is consistently ordered after the non-pipe FD's locks by
- // passing the pipe FD as usermem.IO to the non-pipe end.
+ // implementation. Otherwise, exactly one end is a pipe, which
+ // we ensure is consistently ordered after the non-pipe FD's
+ // locks by passing the pipe FD as usermem.IO to the non-pipe
+ // end.
switch {
case inIsPipe && outIsPipe:
n, err = pipe.Splice(t, outPipeFD, inPipeFD, count)
@@ -137,38 +144,15 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
} else {
n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
}
+ default:
+ panic("not possible")
}
+
if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
break
}
-
- // Note that the blocking behavior here is a bit different than the
- // normal pattern. Because we need to have both data to read and data
- // to write simultaneously, we actually explicitly block on both of
- // these cases in turn before returning to the splice operation.
- if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
- if inCh == nil {
- inCh = make(chan struct{}, 1)
- inW, _ := waiter.NewChannelEntry(inCh)
- inFile.EventRegister(&inW, eventMaskRead)
- defer inFile.EventUnregister(&inW)
- continue // Need to refresh readiness.
- }
- if err = t.Block(inCh); err != nil {
- break
- }
- }
- if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
- if outCh == nil {
- outCh = make(chan struct{}, 1)
- outW, _ := waiter.NewChannelEntry(outCh)
- outFile.EventRegister(&outW, eventMaskWrite)
- defer outFile.EventUnregister(&outW)
- continue // Need to refresh readiness.
- }
- if err = t.Block(outCh); err != nil {
- break
- }
+ if err = dw.waitForBoth(t); err != nil {
+ break
}
}
@@ -247,45 +231,256 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
// Copy data.
var (
- inCh chan struct{}
- outCh chan struct{}
+ n int64
+ err error
)
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
for {
- n, err := pipe.Tee(t, outPipeFD, inPipeFD, count)
- if n != 0 {
- return uintptr(n), nil, nil
+ n, err = pipe.Tee(t, outPipeFD, inPipeFD, count)
+ if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ break
+ }
+ if err = dw.waitForBoth(t); err != nil {
+ break
+ }
+ }
+ if n == 0 {
+ return 0, nil, err
+ }
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// Sendfile implements linux system call sendfile(2).
+func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ outFD := args[0].Int()
+ inFD := args[1].Int()
+ offsetAddr := args[2].Pointer()
+ count := int64(args[3].SizeT())
+
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+ if !inFile.IsReadable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef()
+ if !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Verify that the outFile Append flag is not set.
+ if outFile.StatusFlags()&linux.O_APPEND != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Verify that inFile is a regular file or block device. This is a
+ // requirement; the same check appears in Linux
+ // (fs/splice.c:splice_direct_to_actor).
+ if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil {
+ return 0, nil, err
+ } else if stat.Mask&linux.STATX_TYPE == 0 ||
+ (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy offset if it exists.
+ offset := int64(-1)
+ if offsetAddr != 0 {
+ if inFile.Options().DenyPRead {
+ return 0, nil, syserror.ESPIPE
}
- if err != syserror.ErrWouldBlock || nonBlock {
+ if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
return 0, nil, err
}
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if offset+count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ // Validate count. This must come after offset checks.
+ if count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
- // Note that the blocking behavior here is a bit different than the
- // normal pattern. Because we need to have both data to read and data
- // to write simultaneously, we actually explicitly block on both of
- // these cases in turn before returning to the tee operation.
- if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
- if inCh == nil {
- inCh = make(chan struct{}, 1)
- inW, _ := waiter.NewChannelEntry(inCh)
- inFile.EventRegister(&inW, eventMaskRead)
- defer inFile.EventUnregister(&inW)
- continue // Need to refresh readiness.
+ // Copy data.
+ var (
+ n int64
+ err error
+ )
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ // Reading from input file should never block, since it is regular or
+ // block device. We only need to check if writing to the output file
+ // can block.
+ nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0
+ if outIsPipe {
+ for n < count {
+ var spliceN int64
+ if offset != -1 {
+ spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{})
+ offset += spliceN
+ } else {
+ spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
}
- if err := t.Block(inCh); err != nil {
- return 0, nil, err
+ n += spliceN
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
+ }
+ if err != nil {
+ break
}
}
- if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
- if outCh == nil {
- outCh = make(chan struct{}, 1)
- outW, _ := waiter.NewChannelEntry(outCh)
- outFile.EventRegister(&outW, eventMaskWrite)
- defer outFile.EventUnregister(&outW)
- continue // Need to refresh readiness.
+ } else {
+ // Read inFile to buffer, then write the contents to outFile.
+ buf := make([]byte, count)
+ for n < count {
+ var readN int64
+ if offset != -1 {
+ readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{})
+ offset += readN
+ } else {
+ readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ }
+ if readN == 0 && err == io.EOF {
+ // We reached the end of the file. Eat the
+ // error and exit the loop.
+ err = nil
+ break
}
- if err := t.Block(outCh); err != nil {
- return 0, nil, err
+ n += readN
+ if err != nil {
+ break
+ }
+
+ // Write all of the bytes that we read. This may need
+ // multiple write calls to complete.
+ wbuf := buf[:n]
+ for len(wbuf) > 0 {
+ var writeN int64
+ writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{})
+ wbuf = wbuf[writeN:]
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForOut(t)
+ }
+ if err != nil {
+ // We didn't complete the write. Only
+ // report the bytes that were actually
+ // written, and rewind the offset.
+ notWritten := int64(len(wbuf))
+ n -= notWritten
+ if offset != -1 {
+ offset -= notWritten
+ }
+ break
+ }
+ }
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
}
+ if err != nil {
+ break
+ }
+ }
+ }
+
+ if offsetAddr != 0 {
+ // Copy out the new offset.
+ if _, err := t.CopyOut(offsetAddr, offset); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if n == 0 {
+ return 0, nil, err
+ }
+
+ inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent)
+ outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not
+// thread-safe, and does not take a reference on the vfs.FileDescriptions.
+//
+// Users must call destroy() when finished.
+type dualWaiter struct {
+ inFile *vfs.FileDescription
+ outFile *vfs.FileDescription
+
+ inW waiter.Entry
+ inCh chan struct{}
+ outW waiter.Entry
+ outCh chan struct{}
+}
+
+// waitForBoth waits for both dw.inFile and dw.outFile to be ready.
+func (dw *dualWaiter) waitForBoth(t *kernel.Task) error {
+ if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
+ if dw.inCh == nil {
+ dw.inW, dw.inCh = waiter.NewChannelEntry(nil)
+ dw.inFile.EventRegister(&dw.inW, eventMaskRead)
+ // We might be ready now. Try again before blocking.
+ return nil
+ }
+ if err := t.Block(dw.inCh); err != nil {
+ return err
+ }
+ }
+ return dw.waitForOut(t)
+}
+
+// waitForOut waits for dw.outfile to be read.
+func (dw *dualWaiter) waitForOut(t *kernel.Task) error {
+ if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
+ if dw.outCh == nil {
+ dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
+ dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
+ // We might be ready now. Try again before blocking.
+ return nil
}
+ if err := t.Block(dw.outCh); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// destroy cleans up resources help by dw. No more calls to wait* can occur
+// after destroy is called.
+func (dw *dualWaiter) destroy() {
+ if dw.inCh != nil {
+ dw.inFile.EventUnregister(&dw.inW)
+ dw.inCh = nil
+ }
+ if dw.outCh != nil {
+ dw.outFile.EventUnregister(&dw.outW)
+ dw.outCh = nil
}
+ dw.inFile = nil
+ dw.outFile = nil
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index 8f497ecc7..c576d9475 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -44,7 +44,7 @@ func Override() {
s.Table[23] = syscalls.Supported("select", Select)
s.Table[32] = syscalls.Supported("dup", Dup)
s.Table[33] = syscalls.Supported("dup2", Dup2)
- delete(s.Table, 40) // sendfile
+ s.Table[40] = syscalls.Supported("sendfile", Sendfile)
s.Table[41] = syscalls.Supported("socket", Socket)
s.Table[42] = syscalls.Supported("connect", Connect)
s.Table[43] = syscalls.Supported("accept", Accept)
@@ -62,7 +62,7 @@ func Override() {
s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt)
s.Table[59] = syscalls.Supported("execve", Execve)
s.Table[72] = syscalls.Supported("fcntl", Fcntl)
- s.Table[73] = syscalls.Supported("fcntl", Flock)
+ s.Table[73] = syscalls.Supported("flock", Flock)
s.Table[74] = syscalls.Supported("fsync", Fsync)
s.Table[75] = syscalls.Supported("fdatasync", Fdatasync)
s.Table[76] = syscalls.Supported("truncate", Truncate)
@@ -163,6 +163,106 @@ func Override() {
// Override ARM64.
s = linux.ARM64
+ s.Table[5] = syscalls.Supported("setxattr", Setxattr)
+ s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr)
+ s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr)
+ s.Table[8] = syscalls.Supported("getxattr", Getxattr)
+ s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr)
+ s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr)
+ s.Table[11] = syscalls.Supported("listxattr", Listxattr)
+ s.Table[12] = syscalls.Supported("llistxattr", Llistxattr)
+ s.Table[13] = syscalls.Supported("flistxattr", Flistxattr)
+ s.Table[14] = syscalls.Supported("removexattr", Removexattr)
+ s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr)
+ s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr)
+ s.Table[17] = syscalls.Supported("getcwd", Getcwd)
+ s.Table[19] = syscalls.Supported("eventfd2", Eventfd2)
+ s.Table[20] = syscalls.Supported("epoll_create1", EpollCreate1)
+ s.Table[21] = syscalls.Supported("epoll_ctl", EpollCtl)
+ s.Table[22] = syscalls.Supported("epoll_pwait", EpollPwait)
+ s.Table[23] = syscalls.Supported("dup", Dup)
+ s.Table[24] = syscalls.Supported("dup3", Dup3)
+ s.Table[25] = syscalls.Supported("fcntl", Fcntl)
+ s.Table[26] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil)
+ s.Table[27] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[28] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[29] = syscalls.Supported("ioctl", Ioctl)
+ s.Table[32] = syscalls.Supported("flock", Flock)
+ s.Table[33] = syscalls.Supported("mknodat", Mknodat)
+ s.Table[34] = syscalls.Supported("mkdirat", Mkdirat)
+ s.Table[35] = syscalls.Supported("unlinkat", Unlinkat)
+ s.Table[36] = syscalls.Supported("symlinkat", Symlinkat)
+ s.Table[37] = syscalls.Supported("linkat", Linkat)
+ s.Table[38] = syscalls.Supported("renameat", Renameat)
+ s.Table[39] = syscalls.Supported("umount2", Umount2)
+ s.Table[40] = syscalls.Supported("mount", Mount)
+ s.Table[43] = syscalls.Supported("statfs", Statfs)
+ s.Table[44] = syscalls.Supported("fstatfs", Fstatfs)
+ s.Table[45] = syscalls.Supported("truncate", Truncate)
+ s.Table[46] = syscalls.Supported("ftruncate", Ftruncate)
+ s.Table[48] = syscalls.Supported("faccessat", Faccessat)
+ s.Table[49] = syscalls.Supported("chdir", Chdir)
+ s.Table[50] = syscalls.Supported("fchdir", Fchdir)
+ s.Table[51] = syscalls.Supported("chroot", Chroot)
+ s.Table[52] = syscalls.Supported("fchmod", Fchmod)
+ s.Table[53] = syscalls.Supported("fchmodat", Fchmodat)
+ s.Table[54] = syscalls.Supported("fchownat", Fchownat)
+ s.Table[55] = syscalls.Supported("fchown", Fchown)
+ s.Table[56] = syscalls.Supported("openat", Openat)
+ s.Table[57] = syscalls.Supported("close", Close)
+ s.Table[59] = syscalls.Supported("pipe2", Pipe2)
+ s.Table[61] = syscalls.Supported("getdents64", Getdents64)
+ s.Table[62] = syscalls.Supported("lseek", Lseek)
s.Table[63] = syscalls.Supported("read", Read)
+ s.Table[64] = syscalls.Supported("write", Write)
+ s.Table[65] = syscalls.Supported("readv", Readv)
+ s.Table[66] = syscalls.Supported("writev", Writev)
+ s.Table[67] = syscalls.Supported("pread64", Pread64)
+ s.Table[68] = syscalls.Supported("pwrite64", Pwrite64)
+ s.Table[69] = syscalls.Supported("preadv", Preadv)
+ s.Table[70] = syscalls.Supported("pwritev", Pwritev)
+ s.Table[72] = syscalls.Supported("pselect", Pselect)
+ s.Table[73] = syscalls.Supported("ppoll", Ppoll)
+ s.Table[74] = syscalls.Supported("signalfd4", Signalfd4)
+ s.Table[76] = syscalls.Supported("splice", Splice)
+ s.Table[77] = syscalls.Supported("tee", Tee)
+ s.Table[78] = syscalls.Supported("readlinkat", Readlinkat)
+ s.Table[80] = syscalls.Supported("fstat", Fstat)
+ s.Table[81] = syscalls.Supported("sync", Sync)
+ s.Table[82] = syscalls.Supported("fsync", Fsync)
+ s.Table[83] = syscalls.Supported("fdatasync", Fdatasync)
+ s.Table[84] = syscalls.Supported("sync_file_range", SyncFileRange)
+ s.Table[85] = syscalls.Supported("timerfd_create", TimerfdCreate)
+ s.Table[86] = syscalls.Supported("timerfd_settime", TimerfdSettime)
+ s.Table[87] = syscalls.Supported("timerfd_gettime", TimerfdGettime)
+ s.Table[88] = syscalls.Supported("utimensat", Utimensat)
+ s.Table[198] = syscalls.Supported("socket", Socket)
+ s.Table[199] = syscalls.Supported("socketpair", SocketPair)
+ s.Table[200] = syscalls.Supported("bind", Bind)
+ s.Table[201] = syscalls.Supported("listen", Listen)
+ s.Table[202] = syscalls.Supported("accept", Accept)
+ s.Table[203] = syscalls.Supported("connect", Connect)
+ s.Table[204] = syscalls.Supported("getsockname", GetSockName)
+ s.Table[205] = syscalls.Supported("getpeername", GetPeerName)
+ s.Table[206] = syscalls.Supported("sendto", SendTo)
+ s.Table[207] = syscalls.Supported("recvfrom", RecvFrom)
+ s.Table[208] = syscalls.Supported("setsockopt", SetSockOpt)
+ s.Table[209] = syscalls.Supported("getsockopt", GetSockOpt)
+ s.Table[210] = syscalls.Supported("shutdown", Shutdown)
+ s.Table[211] = syscalls.Supported("sendmsg", SendMsg)
+ s.Table[212] = syscalls.Supported("recvmsg", RecvMsg)
+ s.Table[221] = syscalls.Supported("execve", Execve)
+ s.Table[222] = syscalls.Supported("mmap", Mmap)
+ s.Table[242] = syscalls.Supported("accept4", Accept4)
+ s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg)
+ s.Table[267] = syscalls.Supported("syncfs", Syncfs)
+ s.Table[269] = syscalls.Supported("sendmmsg", SendMMsg)
+ s.Table[276] = syscalls.Supported("renameat2", Renameat2)
+ s.Table[279] = syscalls.Supported("memfd_create", MemfdCreate)
+ s.Table[281] = syscalls.Supported("execveat", Execveat)
+ s.Table[286] = syscalls.Supported("preadv2", Preadv2)
+ s.Table[287] = syscalls.Supported("pwritev2", Pwritev2)
+ s.Table[291] = syscalls.Supported("statx", Statx)
+
s.Init()
}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index f223aeda8..dfc8573fd 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -79,6 +79,17 @@ type MountFlags struct {
// NoATime is equivalent to MS_NOATIME and indicates that the
// filesystem should not update access time in-place.
NoATime bool
+
+ // NoDev is equivalent to MS_NODEV and indicates that the
+ // filesystem should not allow access to devices (special files).
+ // TODO(gVisor.dev/issue/3186): respect this flag in non FUSE
+ // filesystems.
+ NoDev bool
+
+ // NoSUID is equivalent to MS_NOSUID and indicates that the
+ // filesystem should not honor set-user-ID and set-group-ID bits or
+ // file capabilities when executing programs.
+ NoSUID bool
}
// MountOptions contains options to VirtualFilesystem.MountAt().
@@ -153,6 +164,12 @@ type SetStatOptions struct {
// == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask
// instead).
Stat linux.Statx
+
+ // NeedWritePerm indicates that write permission on the file is needed for
+ // this operation. This is needed for truncate(2) (note that ftruncate(2)
+ // does not require the same check--instead, it checks that the fd is
+ // writable).
+ NeedWritePerm bool
}
// BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt()
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index 9cb050597..33389c1df 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -183,7 +183,8 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
// CheckSetStat checks that creds has permission to change the metadata of a
// file with the given permissions, UID, and GID as specified by stat, subject
// to the rules of Linux's fs/attr.c:setattr_prepare().
-func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOptions, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ stat := &opts.Stat
if stat.Mask&linux.STATX_SIZE != 0 {
limit, err := CheckLimit(ctx, 0, int64(stat.Size))
if err != nil {
@@ -215,6 +216,11 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat
return syserror.EPERM
}
}
+ if opts.NeedWritePerm && !creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
+ return err
+ }
+ }
if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 {
if !CanActAsOwner(creds, kuid) {
if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) ||
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 7908c5744..1a631b31a 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -72,6 +72,7 @@ const (
// Values for ICMP code as defined in RFC 792.
const (
ICMPv4TTLExceeded = 0
+ ICMPv4HostUnreachable = 1
ICMPv4PortUnreachable = 3
ICMPv4FragmentationNeeded = 4
)
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index c7ee2de57..a13b4b809 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -110,9 +110,16 @@ const (
ICMPv6RedirectMsg ICMPv6Type = 137
)
-// Values for ICMP code as defined in RFC 4443.
+// Values for ICMP destination unreachable code as defined in RFC 4443 section
+// 3.1.
const (
- ICMPv6PortUnreachable = 4
+ ICMPv6NetworkUnreachable = 0
+ ICMPv6Prohibited = 1
+ ICMPv6BeyondScope = 2
+ ICMPv6AddressUnreachable = 3
+ ICMPv6PortUnreachable = 4
+ ICMPv6Policy = 5
+ ICMPv6RejectRoute = 6
)
// Type is the ICMP type field.
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index a2bb773d4..e12a5929b 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -302,3 +302,7 @@ func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 6aa1badc7..c18bb91fb 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -386,26 +386,33 @@ const (
_VIRTIO_NET_HDR_GSO_TCPV6 = 4
)
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
// Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
}
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if e.hdrSize > 0 {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
var builder iovec.Builder
@@ -448,22 +455,8 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
// Send a batch of packets through batchFD.
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
- var eth header.Ethernet
if e.hdrSize > 0 {
- // Add ethernet header if needed.
- eth = make(header.Ethernet, header.EthernetMinimumSize)
- ethHdr := &header.EthernetFields{
- DstAddr: pkt.EgressRoute.RemoteLinkAddress,
- Type: pkt.NetworkProtocolNumber,
- }
-
- // Preserve the src address if it's set in the route.
- if pkt.EgressRoute.LocalLinkAddress != "" {
- ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
var vnetHdrBuf []byte
@@ -493,7 +486,6 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
var builder iovec.Builder
builder.Add(vnetHdrBuf)
- builder.Add(eth)
builder.Add(pkt.Header.View())
for _, v := range pkt.Data.Views() {
builder.Add(v)
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 4bad930c7..7b995b85a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin
c.ch <- packetInfo{remote, protocol, pkt}
}
+func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNoEthernetProperties(t *testing.T) {
c := newContext(t, &Options{MTU: mtu})
defer c.cleanup()
@@ -510,6 +514,10 @@ func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAdd
d.pkts = append(d.pkts, pkt)
}
+func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestDispatchPacketFormat(t *testing.T) {
for _, test := range []struct {
name string
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 3b17d8c28..781cdd317 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -118,3 +118,6 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareLoopback
}
+
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index c305d9e86..56a611825 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -135,6 +135,10 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("unsupported operation")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 328bd048e..d40de54df 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -61,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
}
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverOutboundPacket(remote, local, protocol, pkt)
+ }
+}
+
// Attach implements stack.LinkEndpoint.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
@@ -135,3 +145,8 @@ func (e *Endpoint) GSOMaxSize() uint32 {
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.child.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.child.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
index c1a219f02..7d9249c1c 100644
--- a/pkg/tcpip/link/nested/nested_test.go
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd
d.count++
}
+func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestNestedLinkEndpoint(t *testing.T) {
const emptyAddress = tcpip.LinkAddress("")
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
new file mode 100644
index 000000000..6fff160ce
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "packetsocket",
+ srcs = ["endpoint.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
new file mode 100644
index 000000000..3922c2a04
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packetsocket provides a link layer endpoint that provides the ability
+// to loop outbound packets to any AF_PACKET sockets that may be interested in
+// the outgoing packet.
+package packetsocket
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ nested.Endpoint
+}
+
+// New creates a new packetsocket LinkEndpoint.
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ e := &endpoint{}
+ e.Endpoint.Init(lower, e)
+ return e
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index c84fe1bb9..467083239 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -107,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
@@ -194,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3267/): Queue these packets as well once
+ // WriteRawPacket takes PacketBuffer instead of VectorisedView.
return e.lower.WriteRawPacket(vv)
}
@@ -213,3 +220,8 @@ func (e *endpoint) Wait() {
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return e.lower.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index a36862c67..507c76b76 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- // Add the ethernet header here.
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
pkt.LinkHeader = buffer.View(eth)
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
v := pkt.Data.ToView()
// Transmit the packet.
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 28a2e88ba..8f3cd9449 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L
c.packetCh <- struct{}{}
}
+func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (c *testContext) cleanup() {
c.ep.Close()
closeFDs(&c.txCfg)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index d9cd4e83a..509076643 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -123,6 +123,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 47446efec..04ae58e59 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -272,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
if d.hasFlags(linux.IFF_TAP) {
// Add ethernet header if not provided.
if info.Pkt.LinkHeader == nil {
- hdr := &header.EthernetFields{
- SrcAddr: info.Route.LocalLinkAddress,
- DstAddr: info.Route.RemoteLinkAddress,
- Type: info.Proto,
- }
- if hdr.SrcAddr == "" {
- hdr.SrcAddr = d.endpoint.LinkAddress()
- }
-
- eth := make(header.Ethernet, header.EthernetMinimumSize)
- eth.Encode(hdr)
- vv.AppendView(buffer.View(eth))
- } else {
- vv.AppendView(info.Pkt.LinkHeader)
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
}
+ vv.AppendView(info.Pkt.LinkHeader)
}
// Append upper headers.
@@ -366,3 +354,30 @@ func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
}
return header.ARPHardwareNone
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.isTap {
+ return
+ }
+ eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
+ pkt.LinkHeader = buffer.View(eth)
+ hdr := &header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: protocol,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = e.LinkAddress()
+ }
+
+ eth.Encode(hdr)
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header.
+func (e *tunEndpoint) MaxHeaderLength() uint16 {
+ if e.isTap {
+ return header.EthernetMinimumSize
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 24a8dc2eb..b152a0f26 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -60,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco
e.dispatchGate.Leave()
}
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
// registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
@@ -153,3 +162,8 @@ func (e *Endpoint) Wait() {}
func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
return e.lower.ARPHardwareType()
}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index ffb2354be..c448a888f 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -40,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress,
e.dispatchCount++
}
+func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.attachCount++
e.dispatcher = dispatcher
@@ -90,6 +94,11 @@ func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index b0f57040c..31a242482 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -160,9 +160,12 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
- RemoteLinkAddress: header.EthernetBroadcastAddress,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 66e67429c..a35a64a0f 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -32,10 +32,14 @@ import (
)
const (
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+ stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+ stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d")
+ stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
+ stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
+ stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
type testContext struct {
@@ -50,8 +54,7 @@ func newTestContext(t *testing.T) *testContext {
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
})
- const defaultMTU = 65536
- ep := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
@@ -119,7 +122,7 @@ func TestDirectRequest(t *testing.T) {
if !rep.IsValid() {
t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
}
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
}
if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
@@ -144,3 +147,44 @@ func TestDirectRequest(t *testing.T) {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: stackLinkAddr2,
+ expectLinkAddr: stackLinkAddr2,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: header.EthernetBroadcastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ p := arp.NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
+ if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index a5b780ca2..615bae648 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -185,6 +185,11 @@ func (*testObject) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 1b67aa066..83e71cb8c 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -129,6 +129,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
+ case header.ICMPv4HostUnreachable:
+ e.handleControl(stack.ControlNoRoute, 0, pkt)
+
case header.ICMPv4PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 3b79749b5..24600d877 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -128,6 +128,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6NetworkUnreachable:
+ e.handleControl(stack.ControlNetworkUnreachable, 0, pkt)
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
@@ -502,7 +504,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
snaddr := header.SolicitedNodeAddr(addr)
// TODO(b/148672031): Use stack.FindRoute instead of manually creating the
@@ -511,8 +513,12 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
- RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
+ RemoteLinkAddress: remoteLinkAddr,
}
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
+ }
+
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 52a01b44e..f86aaed1d 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -34,6 +34,9 @@ const (
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
var (
@@ -257,8 +260,7 @@ func newTestContext(t *testing.T) *testContext {
}),
}
- const defaultMTU = 65536
- c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+ c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
@@ -271,7 +273,7 @@ func newTestContext(t *testing.T) *testContext {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -951,3 +953,47 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
})
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ snaddr := header.SolicitedNodeAddr(lladdr0)
+ mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
+
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: mcaddr,
+ },
+ }
+
+ for _, test := range tests {
+ p := NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
+ if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index d39baf620..559a1c4dd 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -49,7 +49,8 @@ const (
type manipType int
const (
- manipDstPrerouting manipType = iota
+ manipNone manipType = iota
+ manipDstPrerouting
manipDstOutput
)
@@ -113,13 +114,11 @@ type conn struct {
// update the state of tcb. It is immutable.
tcbHook Hook
- // mu protects tcb.
+ // mu protects all mutable state.
mu sync.Mutex `state:"nosave"`
-
// tcb is TCB control block. It is used to keep track of states
// of tcp connection and is protected by mu.
tcb tcpconntrack.TCB
-
// lastUsed is the last time the connection saw a relevant packet, and
// is updated by each packet on the connection. It is protected by mu.
lastUsed time.Time `state:".(unixTime)"`
@@ -141,8 +140,26 @@ func (cn *conn) timedOut(now time.Time) bool {
return now.Sub(cn.lastUsed) > defaultTimeout
}
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
+}
+
// ConnTrack tracks all connections created for NAT rules. Most users are
-// expected to only call handlePacket and createConnFor.
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
//
// ConnTrack keeps all connections in a slice of buckets, each of which holds a
// linked list of tuples. This gives us some desirable properties:
@@ -248,8 +265,7 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
return nil, dirOriginal
}
-// createConnFor creates a new conn for pkt.
-func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -272,10 +288,15 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
manip = manipDstOutput
}
conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
+}
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
// Lock the buckets in the correct order.
- tupleBucket := ct.bucket(tid)
- replyBucket := ct.bucket(replyTID)
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
ct.mu.RLock()
defer ct.mu.RUnlock()
if tupleBucket < replyBucket {
@@ -289,22 +310,37 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg
ct.buckets[tupleBucket].mu.Lock()
}
- // Add the tuple to the map.
- ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
- ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
+
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ }
// Unlocking can happen in any order.
ct.buckets[tupleBucket].mu.Unlock()
if tupleBucket != replyBucket {
ct.buckets[replyBucket].mu.Unlock()
}
-
- return conn
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -322,12 +358,22 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
netHeader.SetSourceAddress(conn.original.dstAddr)
}
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
// handlePacketOutput manipulates ports for packets in Output hook.
func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
netHeader := header.IPv4(pkt.NetworkHeader)
tcpHeader := header.TCP(pkt.TransportHeader)
@@ -362,20 +408,31 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
}
// handlePacket will manipulate the port and address of the packet if the
-// connection exists.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
if pkt.NatDone {
- return
+ return false
}
if hook != Prerouting && hook != Output {
- return
+ return false
+ }
+
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return false
}
conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
if conn == nil {
- // Connection not found for the packet or the packet is invalid.
- return
+ return true
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if tcpHeader == nil {
+ return false
}
switch hook {
@@ -395,14 +452,39 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
// Mark the connection as having been used recently so it isn't reaped.
conn.lastUsed = time.Now()
// Update connection state.
- if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() {
- conn.tcb.Init(tcpHeader)
- conn.tcbHook = hook
- } else if hook == conn.tcbHook {
- conn.tcb.UpdateStateOutbound(tcpHeader)
- } else {
- conn.tcb.UpdateStateInbound(tcpHeader)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
+ }
+
+ // We only track TCP connections.
+ if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ return
+ }
+
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return
}
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+ ct.insertConn(conn)
}
// bucket gets the conntrack bucket for a tupleID.
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index eefb4b07f..c962693f5 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -121,10 +121,12 @@ func (*fwdTestNetworkEndpoint) Close() {}
type fwdTestNetworkProtocol struct {
addrCache *linkAddrCache
addrResolveDelay time.Duration
- onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
+ onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
}
+var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
+
func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
@@ -174,10 +176,10 @@ func (f *fwdTestNetworkProtocol) Close() {}
func (f *fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
if f.addrCache != nil && f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
- f.onLinkAddressResolved(f.addrCache, addr)
+ f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr)
})
}
return nil
@@ -307,6 +309,11 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
@@ -400,7 +407,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any address will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -458,7 +465,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Only packets to address 3 will be resolved to the
// link address "c".
if addr == "\x03" {
@@ -510,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -554,7 +561,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -611,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index bbf3b60e8..cbbae4224 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -22,22 +22,30 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// Table names.
+// tableID is an index into IPTables.tables.
+type tableID int
+
const (
- TablenameNat = "nat"
- TablenameMangle = "mangle"
- TablenameFilter = "filter"
+ natID tableID = iota
+ mangleID
+ filterID
+ numTables
)
-// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+// Table names.
const (
- ChainNamePrerouting = "PREROUTING"
- ChainNameInput = "INPUT"
- ChainNameForward = "FORWARD"
- ChainNameOutput = "OUTPUT"
- ChainNamePostrouting = "POSTROUTING"
+ NATTable = "nat"
+ MangleTable = "mangle"
+ FilterTable = "filter"
)
+// nameToID is immutable.
+var nameToID = map[string]tableID{
+ NATTable: natID,
+ MangleTable: mangleID,
+ FilterTable: filterID,
+}
+
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
@@ -48,11 +56,9 @@ const reaperDelay = 5 * time.Second
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables() *IPTables {
- // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
- // iotas.
return &IPTables{
- tables: map[string]Table{
- TablenameNat: Table{
+ tables: [numTables]Table{
+ natID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
@@ -60,60 +66,66 @@ func DefaultTables() *IPTables {
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- Underflows: map[Hook]int{
+ Underflows: [NumHooks]int{
Prerouting: 0,
Input: 1,
+ Forward: HookUnset,
Output: 2,
Postrouting: 3,
},
- UserChains: map[string]int{},
},
- TablenameMangle: Table{
+ mangleID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
+ BuiltinChains: [NumHooks]int{
Prerouting: 0,
Output: 1,
},
- Underflows: map[Hook]int{
- Prerouting: 0,
- Output: 1,
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
- TablenameFilter: Table{
+ filterID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
- BuiltinChains: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: 0,
- Forward: 1,
- Output: 2,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
},
},
- priorities: map[Hook][]string{
- Input: []string{TablenameNat, TablenameFilter},
- Prerouting: []string{TablenameMangle, TablenameNat},
- Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ priorities: [NumHooks][]tableID{
+ Prerouting: []tableID{mangleID, natID},
+ Input: []tableID{natID, filterID},
+ Output: []tableID{mangleID, natID, filterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -127,51 +139,48 @@ func DefaultTables() *IPTables {
func EmptyFilterTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- Underflows: map[Hook]int{
- Input: HookUnset,
- Forward: HookUnset,
- Output: HookUnset,
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// EmptyNatTable returns a Table with no rules and the filter table chains
+// EmptyNATTable returns a Table with no rules and the filter table chains
// mapped to HookUnset.
-func EmptyNatTable() Table {
+func EmptyNATTable() Table {
return Table{
Rules: []Rule{},
- BuiltinChains: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ BuiltinChains: [NumHooks]int{
+ Forward: HookUnset,
},
- Underflows: map[Hook]int{
- Prerouting: HookUnset,
- Input: HookUnset,
- Output: HookUnset,
- Postrouting: HookUnset,
+ Underflows: [NumHooks]int{
+ Forward: HookUnset,
},
- UserChains: map[string]int{},
}
}
-// GetTable returns table by name.
+// GetTable returns a table by name.
func (it *IPTables) GetTable(name string) (Table, bool) {
+ id, ok := nameToID[name]
+ if !ok {
+ return Table{}, false
+ }
it.mu.RLock()
defer it.mu.RUnlock()
- t, ok := it.tables[name]
- return t, ok
+ return it.tables[id], true
}
// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table) {
+func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+ id, ok := nameToID[name]
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -181,14 +190,8 @@ func (it *IPTables) ReplaceTable(name string, table Table) {
it.startReaper(reaperDelay)
}
it.modified = true
- it.tables[name] = table
-}
-
-// GetPriorities returns slice of priorities associated with hook.
-func (it *IPTables) GetPriorities(hook Hook) []string {
- it.mu.RLock()
- defer it.mu.RUnlock()
- return it.priorities[hook]
+ it.tables[id] = table
+ return nil
}
// A chainVerdict is what a table decides should be done with a packet.
@@ -223,11 +226,19 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
// Packets are manipulated only if connection and matching
// NAT rule exists.
- it.connections.handlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
// Go through each table containing the hook.
- for _, tablename := range it.GetPriorities(hook) {
- table, _ := it.GetTable(tablename)
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ priorities := it.priorities[hook]
+ for _, tableID := range priorities {
+ // If handlePacket already NATed the packet, we don't need to
+ // check the NAT table.
+ if tableID == natID && pkt.NatDone {
+ continue
+ }
+ table := it.tables[tableID]
ruleIdx := table.BuiltinChains[hook]
switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
// If the table returns Accept, move on to the next table.
@@ -256,6 +267,20 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
}
}
+ // If this connection should be tracked, try to add an entry for it. If
+ // traversing the nat table didn't end in adding an entry,
+ // maybeInsertNoop will add a no-op entry for the connection. This is
+ // needeed when establishing connections so that the SYN/ACK reply to an
+ // outgoing SYN is delivered to the correct endpoint rather than being
+ // redirected by a prerouting rule.
+ //
+ // From the iptables documentation: "If there is no rule, a `null'
+ // binding is created: this usually does not map the packet, but exists
+ // to ensure we don't map another stream over an existing one."
+ if shouldTrack {
+ it.connections.maybeInsertNoop(pkt, hook)
+ }
+
// Every table returned Accept.
return true
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index d43f60c67..dc88033c7 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -153,7 +153,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.createConnFor(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index eb70e3104..73274ada9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -84,14 +84,14 @@ type IPTables struct {
// mu protects tables, priorities, and modified.
mu sync.RWMutex
- // tables maps table names to tables. User tables have arbitrary names.
- // mu needs to be locked for accessing.
- tables map[string]Table
+ // tables maps tableIDs to tables. Holds builtin tables only, not user
+ // tables. mu must be locked for accessing.
+ tables [numTables]Table
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. mu needs to be locked for accessing.
- priorities map[Hook][]string
+ priorities [NumHooks][]tableID
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
@@ -113,22 +113,20 @@ type Table struct {
Rules []Rule
// BuiltinChains maps builtin chains to their entrypoint rule in Rules.
- BuiltinChains map[Hook]int
+ BuiltinChains [NumHooks]int
// Underflows maps builtin chains to their underflow rule in Rules
// (i.e. the rule to execute if the chain returns without a verdict).
- Underflows map[Hook]int
-
- // UserChains holds user-defined chains for the keyed by name. Users
- // can give their chains arbitrary names.
- UserChains map[string]int
+ Underflows [NumHooks]int
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook := range table.BuiltinChains {
- hooks |= 1 << hook
+ for hook, ruleIdx := range table.BuiltinChains {
+ if ruleIdx != HookUnset {
+ hooks |= 1 << hook
+ }
}
return hooks
}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 403557fd7..6f73a0ce4 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 1baa498d0..b15b8d1cb 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -48,7 +48,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
if f := r.onLinkAddressRequest; f != nil {
f()
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index e28c23d66..9dce11a97 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -469,7 +469,7 @@ type ndpState struct {
rtrSolicit struct {
// The timer used to send the next router solicitation message.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the Router Solicitation timer know that it has been stopped.
//
@@ -503,7 +503,7 @@ type ndpState struct {
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the DAD timer know that it has been stopped.
//
@@ -515,38 +515,38 @@ type dadState struct {
// defaultRouterState holds data associated with a default router discovered by
// a Router Advertisement (RA).
type defaultRouterState struct {
- // Timer to invalidate the default router.
+ // Job to invalidate the default router.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// onLinkPrefixState holds data associated with an on-link prefix discovered by
// a Router Advertisement's Prefix Information option (PI) when the NDP
// configurations was configured to do so.
type onLinkPrefixState struct {
- // Timer to invalidate the on-link prefix.
+ // Job to invalidate the on-link prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
}
// tempSLAACAddrState holds state associated with a temporary SLAAC address.
type tempSLAACAddrState struct {
- // Timer to deprecate the temporary SLAAC address.
+ // Job to deprecate the temporary SLAAC address.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the temporary SLAAC address.
+ // Job to invalidate the temporary SLAAC address.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
- // Timer to regenerate the temporary SLAAC address.
+ // Job to regenerate the temporary SLAAC address.
//
// Must not be nil.
- regenTimer *tcpip.CancellableTimer
+ regenJob *tcpip.Job
createdAt time.Time
@@ -561,15 +561,15 @@ type tempSLAACAddrState struct {
// slaacPrefixState holds state associated with a SLAAC prefix.
type slaacPrefixState struct {
- // Timer to deprecate the prefix.
+ // Job to deprecate the prefix.
//
// Must not be nil.
- deprecationTimer *tcpip.CancellableTimer
+ deprecationJob *tcpip.Job
- // Timer to invalidate the prefix.
+ // Job to invalidate the prefix.
//
// Must not be nil.
- invalidationTimer *tcpip.CancellableTimer
+ invalidationJob *tcpip.Job
// Nonzero only when the address is not valid forever.
validUntil time.Time
@@ -651,12 +651,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
}
var done bool
- var timer *time.Timer
+ var timer tcpip.Timer
// We initially start a timer to fire immediately because some of the DAD work
// cannot be done while holding the NIC's lock. This is effectively the same
// as starting a goroutine but we use a timer that fires immediately so we can
// reset it for the next DAD iteration.
- timer = time.AfterFunc(0, func() {
+ timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
ndp.nic.mu.Lock()
defer ndp.nic.mu.Unlock()
@@ -871,9 +871,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
case ok && rl != 0:
// This is an already discovered default router. Update
- // the invalidation timer.
- rtr.invalidationTimer.StopLocked()
- rtr.invalidationTimer.Reset(rl)
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = rtr
case ok && rl == 0:
@@ -950,7 +950,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
return
}
- rtr.invalidationTimer.StopLocked()
+ rtr.invalidationJob.Cancel()
delete(ndp.defaultRouters, ip)
// Let the integrator know a discovered default router is invalidated.
@@ -979,12 +979,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
state := defaultRouterState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
- state.invalidationTimer.Reset(rl)
+ state.invalidationJob.Schedule(rl)
ndp.defaultRouters[ip] = state
}
@@ -1009,13 +1009,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
state := onLinkPrefixState{
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
if l < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(l)
+ state.invalidationJob.Schedule(l)
}
ndp.onLinkPrefixes[prefix] = state
@@ -1033,7 +1033,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
return
}
- s.invalidationTimer.StopLocked()
+ s.invalidationJob.Cancel()
delete(ndp.onLinkPrefixes, prefix)
// Let the integrator know a discovered on-link prefix is invalidated.
@@ -1082,14 +1082,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio
// This is an already discovered on-link prefix with a
// new non-zero valid lifetime.
//
- // Update the invalidation timer.
+ // Update the invalidation job.
- prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationJob.Cancel()
if vl < header.NDPInfiniteLifetime {
- // Prefix is valid for a finite lifetime, reset the timer to expire after
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
// the new valid lifetime.
- prefixState.invalidationTimer.Reset(vl)
+ prefixState.invalidationJob.Schedule(vl)
}
ndp.onLinkPrefixes[prefix] = prefixState
@@ -1154,7 +1154,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1162,7 +1162,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1184,19 +1184,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if !ndp.generateSLAACAddr(prefix, &state) {
// We were unable to generate an address for the prefix, we do not nothing
- // further as there is no reason to maintain state or timers for a prefix we
+ // further as there is no reason to maintain state or jobs for a prefix we
// do not have an address for.
return
}
- // Setup the initial timers to deprecate and invalidate prefix.
+ // Setup the initial jobs to deprecate and invalidate prefix.
if pl < header.NDPInfiniteLifetime && pl != 0 {
- state.deprecationTimer.Reset(pl)
+ state.deprecationJob.Schedule(pl)
}
if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
+ state.invalidationJob.Schedule(vl)
state.validUntil = now.Add(vl)
}
@@ -1428,7 +1428,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1441,7 +1441,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.ref)
}),
- invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1454,7 +1454,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() {
+ regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1481,9 +1481,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ref: ref,
}
- state.deprecationTimer.Reset(pl)
- state.invalidationTimer.Reset(vl)
- state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration)
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
prefixState.generationAttempts++
prefixState.tempAddrs[generatedAddr.Address] = state
@@ -1518,16 +1518,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
prefixState.stableAddr.ref.deprecated = false
}
- // If prefix was preferred for some finite lifetime before, stop the
- // deprecation timer so it can be reset.
- prefixState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
now := time.Now()
- // Reset the deprecation timer if prefix has a finite preferred lifetime.
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
if pl < header.NDPInfiniteLifetime {
if !deprecated {
- prefixState.deprecationTimer.Reset(pl)
+ prefixState.deprecationJob.Schedule(pl)
}
prefixState.preferredUntil = now.Add(pl)
} else {
@@ -1546,9 +1546,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
if vl >= header.NDPInfiniteLifetime {
- // Handle the infinite valid lifetime separately as we do not keep a timer
- // in this case.
- prefixState.invalidationTimer.StopLocked()
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
prefixState.validUntil = time.Time{}
} else {
var effectiveVl time.Duration
@@ -1569,8 +1569,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
if effectiveVl != 0 {
- prefixState.invalidationTimer.StopLocked()
- prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
prefixState.validUntil = now.Add(effectiveVl)
}
}
@@ -1582,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// Note, we do not need to update the entries in the temporary address map
- // after updating the timers because the timers are held as pointers.
+ // after updating the jobs because the jobs are held as pointers.
var regenForAddr tcpip.Address
allAddressesRegenerated := true
for tempAddr, tempAddrState := range prefixState.tempAddrs {
@@ -1596,14 +1596,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
- // reset the invalidation timer.
+ // reset the invalidation job.
newValidLifetime := validUntil.Sub(now)
if newValidLifetime <= 0 {
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
continue
}
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.invalidationTimer.Reset(newValidLifetime)
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
// As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
// address is the lower of the preferred lifetime of the stable address or
@@ -1616,17 +1616,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
}
// If the address is no longer preferred, deprecate it immediately.
- // Otherwise, reset the deprecation timer.
+ // Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
- tempAddrState.deprecationTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.ref)
} else {
tempAddrState.ref.deprecated = false
- tempAddrState.deprecationTimer.Reset(newPreferredLifetime)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
}
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.regenJob.Cancel()
if tempAddrState.regenerated {
} else {
allAddressesRegenerated = false
@@ -1637,7 +1637,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// immediately after we finish iterating over the temporary addresses.
regenForAddr = tempAddr
} else {
- tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
}
}
}
@@ -1717,7 +1717,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
ndp.cleanupSLAACPrefixResources(prefix, state)
}
-// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry.
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
//
// Panics if the SLAAC prefix is not known.
//
@@ -1729,8 +1729,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
}
state.stableAddr.ref = nil
- state.deprecationTimer.StopLocked()
- state.invalidationTimer.StopLocked()
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
delete(ndp.slaacPrefixes, prefix)
}
@@ -1775,13 +1775,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi
}
// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
-// timers and entry.
+// jobs and entry.
//
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
- tempAddrState.deprecationTimer.StopLocked()
- tempAddrState.invalidationTimer.StopLocked()
- tempAddrState.regenTimer.StopLocked()
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
delete(tempAddrs, tempAddr)
}
@@ -1860,7 +1860,7 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = time.AfterFunc(delay, func() {
+ ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
ndp.nic.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 6f86abc98..644ba7c33 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1254,7 +1254,7 @@ func TestRouterDiscovery(t *testing.T) {
default:
}
- // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1271,7 +1271,7 @@ func TestRouterDiscovery(t *testing.T) {
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
expectRouterEvent(llAddr2, false)
- // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
// of the router should have been updated to the most recent (smaller)
// lifetime.
//
@@ -1502,7 +1502,7 @@ func TestPrefixDiscovery(t *testing.T) {
default:
}
- // Wait for prefix2's most recent invalidation timer plus some buffer to
+ // Wait for prefix2's most recent invalidation job plus some buffer to
// expire.
select {
case e := <-ndpDisp.prefixC:
@@ -2395,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
for _, addr := range tempAddrs {
// Wait for a deprecation then invalidation event, or just an invalidation
// event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation timers could fire in any
+ // cases because the deprecation and invalidation jobs could execute in any
// order.
select {
case e := <-ndpDisp.autoGenAddrC:
@@ -2432,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
}
-// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's
-// regeneration timer gets updated when refreshing the address's lifetimes.
-func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
+// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
+// regeneration job gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
nicID = 1
regenAfter = 2 * time.Second
@@ -2533,7 +2533,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
//
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
- // - this regeneration does not depend on a timer.
+ // - this regeneration does not depend on a job.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(tempAddr2, newAddr)
@@ -2559,11 +2559,11 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
}
// Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration timer gets reset.
+ // RA, the regeneration job gets scheduled again.
//
// The maximum lifetime is the sum of the minimum lifetimes for temporary
// addresses + the time that has already passed since the last address was
- // generated so that the regeneration timer is needed to generate the next
+ // generated so that the regeneration job is needed to generate the next
// address.
newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
@@ -2993,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr(addr2)
}
-// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated
+// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
// when its preferred lifetime expires.
-func TestAutoGenAddrTimerDeprecation(t *testing.T) {
+func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
@@ -3513,8 +3513,8 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
expectAutoGenAddrEvent(addr, invalidatedAddr)
- // Wait for the original valid lifetime to make sure the original timer
- // got stopped/cleaned up.
+ // Wait for the original valid lifetime to make sure the original job got
+ // cancelled/cleaned up.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 7b80534e6..fea0ce7e8 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// Are any packet sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
- }
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
@@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
+ }
+}
+
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
// TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 3bc9fd831..a70792b50 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -89,6 +89,11 @@ func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
panic("not implemented")
}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
// An IPv6 NetworkEndpoint that throws away outgoing packets.
@@ -238,7 +243,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index e3556d5d2..5d6865e35 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -79,6 +79,10 @@ type PacketBuffer struct {
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index f260eeb7f..8604c4259 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -52,8 +52,11 @@ type TransportEndpointID struct {
type ControlType int
// The following are the allowed values for ControlType values.
+// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlPacketTooBig ControlType = iota
+ ControlNetworkUnreachable ControlType = iota
+ ControlNoRoute
+ ControlPacketTooBig
ControlPortUnreachable
ControlUnknown
)
@@ -330,8 +333,7 @@ type NetworkProtocol interface {
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing.
@@ -342,6 +344,16 @@ type NetworkDispatcher interface {
//
// DeliverNetworkPacket takes ownership of pkt.
DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -443,6 +455,9 @@ type LinkEndpoint interface {
// See:
// https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -463,12 +478,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr.
- // The request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+ // the request on the local network if remoteLinkAddr is the zero value. The
+ // request is sent on linkEP with localAddr as the source.
//
// A valid response will cause the discovery protocol's network
// endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+ LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 2b7ece851..a6faa22c2 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -728,6 +728,11 @@ func New(opts Options) *Stack {
return s
}
+// newJob returns a tcpip.Job using the Stack clock.
+func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
+
// UniqueID returns a unique identifier.
func (s *Stack) UniqueID() uint64 {
return s.uniqueIDGenerator.UniqueID()
@@ -801,9 +806,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f
}
}
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (s *Stack) NowNanoseconds() int64 {
- return s.clock.NowNanoseconds()
+// Clock returns the Stack's clock for retrieving the current time and
+// scheduling work.
+func (s *Stack) Clock() tcpip.Clock {
+ return s.clock
}
// Stats returns a mutable copy of the current stats.
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 48ad56d4d..21aafb0a2 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -192,7 +192,7 @@ func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
-// A Clock provides the current time.
+// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
@@ -203,6 +203,31 @@ type Clock interface {
// NowMonotonic returns a monotonic time value.
NowMonotonic() int64
+
+ // AfterFunc waits for the duration to elapse and then calls f in its own
+ // goroutine. It returns a Timer that can be used to cancel the call using
+ // its Stop method.
+ AfterFunc(d time.Duration, f func()) Timer
+}
+
+// Timer represents a single event. A Timer must be created with
+// Clock.AfterFunc.
+type Timer interface {
+ // Stop prevents the Timer from firing. It returns true if the call stops the
+ // timer, false if the timer has already expired or been stopped.
+ //
+ // If Stop returns false, then the timer has already expired and the function
+ // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop
+ // does not wait for f to complete before returning. If the caller needs to
+ // know whether f is completed, it must coordinate with f explicitly.
+ Stop() bool
+
+ // Reset changes the timer to expire after duration d.
+ //
+ // Reset should be invoked only on stopped or expired timers. If the timer is
+ // known to have expired, Reset can be used directly. Otherwise, the caller
+ // must coordinate with the function f of Clock.AfterFunc(d, f).
+ Reset(d time.Duration)
}
// Address is a byte slice cast as a string that represents the address of a
@@ -316,6 +341,28 @@ const (
ShutdownWrite
)
+// PacketType is used to indicate the destination of the packet.
+type PacketType uint8
+
+const (
+ // PacketHost indicates a packet addressed to the local host.
+ PacketHost PacketType = iota
+
+ // PacketOtherHost indicates an outgoing packet addressed to
+ // another host caught by a NIC in promiscuous mode.
+ PacketOtherHost
+
+ // PacketOutgoing for a packet originating from the local host
+ // that is looped back to a packet socket.
+ PacketOutgoing
+
+ // PacketBroadcast indicates a link layer broadcast packet.
+ PacketBroadcast
+
+ // PacketMulticast indicates a link layer multicast packet.
+ PacketMulticast
+)
+
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
//
@@ -555,6 +602,9 @@ type Endpoint interface {
type LinkPacketInfo struct {
// Protocol is the NetworkProtocolNumber for the packet.
Protocol NetworkProtocolNumber
+
+ // PktType is used to indicate the destination of the packet.
+ PktType PacketType
}
// PacketEndpoint are additional methods that are only implemented by Packet
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 7f172f978..f32d58091 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -20,7 +20,7 @@
package tcpip
import (
- _ "time" // Used with go:linkname.
+ "time" // Used with go:linkname.
_ "unsafe" // Required for go:linkname.
)
@@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 {
_, _, mono := now()
return mono
}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
index 5554c573f..f1dd7c310 100644
--- a/pkg/tcpip/timer.go
+++ b/pkg/tcpip/timer.go
@@ -20,50 +20,49 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
-// cancellableTimerInstance is a specific instance of CancellableTimer.
+// jobInstance is a specific instance of Job.
//
-// Different instances are created each time CancellableTimer is Reset so each
-// timer has its own earlyReturn signal. This is to address a bug when a
-// CancellableTimer is stopped and reset in quick succession resulting in a
-// timer instance's earlyReturn signal being affected or seen by another timer
-// instance.
+// Different instances are created each time Job is scheduled so each timer has
+// its own earlyReturn signal. This is to address a bug when a Job is stopped
+// and reset in quick succession resulting in a timer instance's earlyReturn
+// signal being affected or seen by another timer instance.
//
// Consider the following sceneario where timer instances share a common
// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
// (B), third (C), and fourth (D) instance of the timer firing, respectively):
// T1: Obtain L
-// T1: Create a new CancellableTimer w/ lock L (create instance A)
+// T1: Create a new Job w/ lock L (create instance A)
// T2: instance A fires, blocked trying to obtain L.
// T1: Attempt to stop instance A (set earlyReturn = true)
-// T1: Reset timer (create instance B)
+// T1: Schedule timer (create instance B)
// T3: instance B fires, blocked trying to obtain L.
// T1: Attempt to stop instance B (set earlyReturn = true)
-// T1: Reset timer (create instance C)
+// T1: Schedule timer (create instance C)
// T4: instance C fires, blocked trying to obtain L.
// T1: Attempt to stop instance C (set earlyReturn = true)
-// T1: Reset timer (create instance D)
+// T1: Schedule timer (create instance D)
// T5: instance D fires, blocked trying to obtain L.
// T1: Release L
//
-// Now that T1 has released L, any of the 4 timer instances can take L and check
-// earlyReturn. If the timers simply check earlyReturn and then do nothing
-// further, then instance D will never early return even though it was not
-// requested to stop. If the timers reset earlyReturn before early returning,
-// then all but one of the timers will do work when only one was expected to.
-// If CancellableTimer resets earlyReturn when resetting, then all the timers
+// Now that T1 has released L, any of the 4 timer instances can take L and
+// check earlyReturn. If the timers simply check earlyReturn and then do
+// nothing further, then instance D will never early return even though it was
+// not requested to stop. If the timers reset earlyReturn before early
+// returning, then all but one of the timers will do work when only one was
+// expected to. If Job resets earlyReturn when resetting, then all the timers
// will fire (again, when only one was expected to).
//
// To address the above concerns the simplest solution was to give each timer
// its own earlyReturn signal.
-type cancellableTimerInstance struct {
- timer *time.Timer
+type jobInstance struct {
+ timer Timer
// Used to inform the timer to early return when it gets stopped while the
// lock the timer tries to obtain when fired is held (T1 is a goroutine that
// tries to cancel the timer and T2 is the goroutine that handles the timer
// firing):
- // T1: Obtain the lock, then call StopLocked()
+ // T1: Obtain the lock, then call Cancel()
// T2: timer fires, and gets blocked on obtaining the lock
// T1: Releases lock
// T2: Obtains lock does unintended work
@@ -74,29 +73,33 @@ type cancellableTimerInstance struct {
earlyReturn *bool
}
-// stop stops the timer instance t from firing if it hasn't fired already. If it
+// stop stops the job instance j from firing if it hasn't fired already. If it
// has fired and is blocked at obtaining the lock, earlyReturn will be set to
// true so that it will early return when it obtains the lock.
-func (t *cancellableTimerInstance) stop() {
- if t.timer != nil {
- t.timer.Stop()
- *t.earlyReturn = true
+func (j *jobInstance) stop() {
+ if j.timer != nil {
+ j.timer.Stop()
+ *j.earlyReturn = true
}
}
-// CancellableTimer is a timer that does some work and can be safely cancelled
-// when it fires at the same time some "related work" is being done.
+// Job represents some work that can be scheduled for execution. The work can
+// be safely cancelled when it fires at the same time some "related work" is
+// being done.
//
// The term "related work" is defined as some work that needs to be done while
// holding some lock that the timer must also hold while doing some work.
//
-// Note, it is not safe to copy a CancellableTimer as its timer instance creates
-// a closure over the address of the CancellableTimer.
-type CancellableTimer struct {
+// Note, it is not safe to copy a Job as its timer instance creates
+// a closure over the address of the Job.
+type Job struct {
_ sync.NoCopy
+ // The clock used to schedule the backing timer
+ clock Clock
+
// The active instance of a cancellable timer.
- instance cancellableTimerInstance
+ instance jobInstance
// locker is the lock taken by the timer immediately after it fires and must
// be held when attempting to stop the timer.
@@ -113,59 +116,91 @@ type CancellableTimer struct {
fn func()
}
-// StopLocked prevents the Timer from firing if it has not fired already.
+// Cancel prevents the Job from executing if it has not executed already.
//
-// If the timer is blocked on obtaining the t.locker lock when StopLocked is
-// called, it will early return instead of calling t.fn.
+// Cancel requires appropriate locking to be in place for any resources managed
+// by the Job. If the Job is blocked on obtaining the lock when Cancel is
+// called, it will early return.
//
// Note, t will be modified.
//
-// t.locker MUST be locked.
-func (t *CancellableTimer) StopLocked() {
- t.instance.stop()
+// j.locker MUST be locked.
+func (j *Job) Cancel() {
+ j.instance.stop()
// Nothing to do with the stopped instance anymore.
- t.instance = cancellableTimerInstance{}
+ j.instance = jobInstance{}
}
-// Reset changes the timer to expire after duration d.
+// Schedule schedules the Job for execution after duration d. This can be
+// called on cancelled or completed Jobs to schedule them again.
//
-// Note, t will be modified.
+// Schedule should be invoked only on unscheduled, cancelled, or completed
+// Jobs. To be safe, callers should always call Cancel before calling Schedule.
//
-// Reset should only be called on stopped or expired timers. To be safe, callers
-// should always call StopLocked before calling Reset.
-func (t *CancellableTimer) Reset(d time.Duration) {
+// Note, j will be modified.
+func (j *Job) Schedule(d time.Duration) {
// Create a new instance.
earlyReturn := false
// Capture the locker so that updating the timer does not cause a data race
// when a timer fires and tries to obtain the lock (read the timer's locker).
- locker := t.locker
- t.instance = cancellableTimerInstance{
- timer: time.AfterFunc(d, func() {
+ locker := j.locker
+ j.instance = jobInstance{
+ timer: j.clock.AfterFunc(d, func() {
locker.Lock()
defer locker.Unlock()
if earlyReturn {
// If we reach this point, it means that the timer fired while another
- // goroutine called StopLocked while it had the lock. Simply return
- // here and do nothing further.
+ // goroutine called Cancel while it had the lock. Simply return here
+ // and do nothing further.
earlyReturn = false
return
}
- t.fn()
+ j.fn()
}),
earlyReturn: &earlyReturn,
}
}
-// NewCancellableTimer returns an unscheduled CancellableTimer with the given
-// locker and fn.
-//
-// fn MUST NOT attempt to lock locker.
-//
-// Callers must call Reset to schedule the timer to fire.
-func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer {
- return &CancellableTimer{locker: locker, fn: fn}
+// NewJob returns a new Job that can be used to schedule f to run in its own
+// gorountine. l will be locked before calling f then unlocked after f returns.
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// message = "bar"
+// mu.Unlock()
+//
+// // Output: bar
+//
+// f MUST NOT attempt to lock l.
+//
+// l MUST be locked prior to calling the returned job's Cancel().
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// job.Cancel()
+// mu.Unlock()
+func NewJob(c Clock, l sync.Locker, f func()) *Job {
+ return &Job{
+ clock: c,
+ locker: l,
+ fn: f,
+ }
}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index b4940e397..a82384c49 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -28,8 +28,8 @@ const (
longDuration = 1 * time.Second
)
-func TestCancellableTimerReassignment(t *testing.T) {
- var timer tcpip.CancellableTimer
+func TestJobReschedule(t *testing.T) {
+ var clock tcpip.StdClock
var wg sync.WaitGroup
var lock sync.Mutex
@@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) {
// that has an active timer (even if it has been stopped as a stopped
// timer may be blocked on a lock before it can check if it has been
// stopped while another goroutine holds the same lock).
- timer = *tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
wg.Done()
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
lock.Unlock()
}()
}
wg.Wait()
}
-func TestCancellableTimerFire(t *testing.T) {
+func TestJobExecution(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() {
+ job := tcpip.NewJob(&clock, &lock, func() {
ch <- struct{}{}
})
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) {
func TestCancellableTimerResetFromLongDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(middleDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(middleDuration)
lock.Lock()
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) {
}
}
-func TestCancellableTimerResetFromShortDuration(t *testing.T) {
+func TestJobRescheduleFromShortDuration(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
// Wait for timer to fire if it wasn't correctly stopped.
@@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
case <-time.After(middleDuration):
}
- timer.Reset(shortDuration)
+ job.Schedule(shortDuration)
// Wait for timer to fire.
select {
@@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) {
}
}
-func TestCancellableTimerImmediatelyStop(t *testing.T) {
+func TestJobImmediatelyCancel(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
for i := 0; i < 1000; i++ {
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
}
@@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) {
}
}
-func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
+func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
- timer.StopLocked()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
lock.Unlock()
for i := 0; i < 10; i++ {
- timer.Reset(middleDuration)
+ job.Schedule(middleDuration)
lock.Lock()
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration * 2)
- timer.StopLocked()
+ job.Cancel()
lock.Unlock()
}
@@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) {
func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
// Sleep until the timer fires and gets blocked trying to take the lock.
time.Sleep(middleDuration)
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
@@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
}
}
-func TestManyCancellableTimerResetUnderLock(t *testing.T) {
+func TestManyJobReschedulesUnderLock(t *testing.T) {
t.Parallel()
- ch := make(chan struct{})
+ var clock tcpip.StdClock
var lock sync.Mutex
+ ch := make(chan struct{})
lock.Lock()
- timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} })
- timer.Reset(shortDuration)
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
for i := 0; i < 10; i++ {
- timer.StopLocked()
- timer.Reset(shortDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
}
lock.Unlock()
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 678f4e016..4612be4e7 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -797,7 +797,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 7b2083a09..0e46e6355 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -441,6 +441,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
Addr: tcpip.Address(hdr.SourceAddress()),
}
packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
@@ -448,34 +449,57 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
Addr: tcpip.Address(localAddr),
}
packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = pkt.Data
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header from the Header.
+ pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):])
+ combinedVV := pkt.Header.View().ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
var linkHeader buffer.View
- if len(pkt.LinkHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ var combinedVV buffer.VectorisedView
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if len(pkt.LinkHeader) == 0 {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- linkHeader = buffer.View(fakeHeader)
- } else {
- linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
+ combinedVV = linkHeader.ToVectorisedView()
+ }
+ if pkt.PktType == tcpip.PacketOutgoing {
+ // For outgoing packets the Link, Network and Transport
+ // headers are in the pkt.Header fields normally unless
+ // a Raw socket is in use in which case pkt.Header could
+ // be nil.
+ combinedVV.AppendView(pkt.Header.View())
}
- combinedVV := linkHeader.ToVectorisedView()
combinedVV.Append(pkt.Data)
packet.data = combinedVV
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index c2e9fd29f..f85a68554 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -456,7 +456,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
defer e.mu.Unlock()
// If a local address was specified, verify that it's valid.
- if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -700,7 +700,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
}
combinedVV.Append(pkt.Data)
packet.data = combinedVV
- packet.timestampNS = e.stack.NowNanoseconds()
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 18ff89ffc..e860ee484 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -49,6 +49,7 @@ go_library(
"segment_heap.go",
"segment_queue.go",
"segment_state.go",
+ "segment_unsafe.go",
"snd.go",
"snd_state.go",
"tcp_endpoint_list.go",
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 81b740115..1798510bc 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
}
// Wait for notification.
@@ -616,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error {
<-h.ep.undrain
h.ep.mu.Lock()
}
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 83dc10ed0..0f7487963 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1209,6 +1209,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@@ -1956,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
+ return e.takeLastError()
case *tcpip.BindToDeviceOption:
e.LockUser()
@@ -2546,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index b34e47bbd..5d6174a59 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -49,7 +49,7 @@ const (
// DefaultReceiveBufferSize is the default size of the receive buffer
// for an endpoint.
- DefaultReceiveBufferSize = 1 << 20 // 1MB
+ DefaultReceiveBufferSize = 32 << 10 // 32KB
// MaxBufferSize is the largest size a receive/send buffer can grow to.
MaxBufferSize = 4 << 20 // 4MB
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index dd89a292a..5e0bfe585 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -372,7 +372,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// We only store the segment if it's within our buffer
// size limit.
if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += s.logicalLen()
+ r.pendingBufUsed += seqnum.Size(s.segMemSize())
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -406,7 +406,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= s.logicalLen()
+ r.pendingBufUsed -= seqnum.Size(s.segMemSize())
s.decRef()
}
return false, nil
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 0280892a8..bb60dc29d 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -138,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// segMemSize is the amount of memory used to hold the segment data and
+// the associated metadata.
+func (s *segment) segMemSize() int {
+ return segSize + s.data.Size()
+}
+
// parse populates the sequence & ack numbers, flags, and window fields of the
// segment from the TCP header stored in the data. It then updates the view to
// skip the header.
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
new file mode 100644
index 000000000..0ab7b8f56
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "unsafe"
+)
+
+const (
+ segSize = int(unsafe.Sizeof(segment{}))
+)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 06fde2a79..37e7767d6 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -143,12 +143,14 @@ func New(t *testing.T, mtu uint32) *Context {
TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
})
+ const sendBufferSize = 1 << 20 // 1 MiB
+ const recvBufferSize = 1 << 20 // 1 MiB
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
@@ -202,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context {
t: t,
s: s,
linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
+ WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index a14643ae8..6e692da07 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1451,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
- packet.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go
index 66f10c016..70945f234 100644
--- a/pkg/test/criutil/criutil.go
+++ b/pkg/test/criutil/criutil.go
@@ -40,9 +40,9 @@ type Crictl struct {
cleanup []func()
}
-// resolvePath attempts to find binary paths. It may set the path to invalid,
+// ResolvePath attempts to find binary paths. It may set the path to invalid,
// which will cause the execution to fail with a sensible error.
-func resolvePath(executable string) string {
+func ResolvePath(executable string) string {
runtime, err := dockerutil.RuntimePath()
if err == nil {
// Check first the directory of the runtime itself.
@@ -230,7 +230,7 @@ func (cc *Crictl) Import(image string) error {
// be pushing a lot of bytes in order to import the image. The connect
// timeout stays the same and is inherited from the Crictl instance.
cmd := testutil.Command(cc.logger,
- resolvePath("ctr"),
+ ResolvePath("ctr"),
fmt.Sprintf("--connect-timeout=%s", 30*time.Second),
fmt.Sprintf("--address=%s", cc.endpoint),
"-n", "k8s.io", "images", "import", "-")
@@ -358,7 +358,7 @@ func (cc *Crictl) StopPodAndContainer(podID, contID string) error {
// run runs crictl with the given args.
func (cc *Crictl) run(args ...string) (string, error) {
defaultArgs := []string{
- resolvePath("crictl"),
+ ResolvePath("crictl"),
"--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
"--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
}
diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD
index 83b80c8bc..a5e84658a 100644
--- a/pkg/test/dockerutil/BUILD
+++ b/pkg/test/dockerutil/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,6 +10,7 @@ go_library(
"dockerutil.go",
"exec.go",
"network.go",
+ "profile.go",
],
visibility = ["//:sandbox"],
deps = [
@@ -23,3 +24,19 @@ go_library(
"@com_github_docker_go_connections//nat:go_default_library",
],
)
+
+go_test(
+ name = "profile_test",
+ size = "large",
+ srcs = [
+ "profile_test.go",
+ ],
+ library = ":dockerutil",
+ tags = [
+ # Requires docker and runsc to be configured before test runs.
+ # Also requires the test to be run as root.
+ "manual",
+ "local",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md
new file mode 100644
index 000000000..870292096
--- /dev/null
+++ b/pkg/test/dockerutil/README.md
@@ -0,0 +1,86 @@
+# dockerutil
+
+This package is for creating and controlling docker containers for testing
+runsc, gVisor's docker/kubernetes binary. A simple test may look like:
+
+```
+ func TestSuperCool(t *testing.T) {
+ ctx := context.Background()
+ c := dockerutil.MakeContainer(ctx, t)
+ got, err := c.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine"
+ }, "echo", "super cool")
+ if err != nil {
+ t.Fatalf("err was not nil: %v", err)
+ }
+ want := "super cool"
+ if !strings.Contains(got, want){
+ t.Fatalf("want: %s, got: %s", want, got)
+ }
+ }
+```
+
+For further examples, see many of our end to end tests elsewhere in the repo,
+such as those in //test/e2e or benchmarks at //test/benchmarks.
+
+dockerutil uses the "official" docker golang api, which is
+[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil
+is a thin wrapper around this API, allowing desired new use cases to be easily
+implemented.
+
+## Profiling
+
+dockerutil is capable of generating profiles. Currently, the only option is to
+use pprof profiles generated by `runsc debug`. The profiler will generate Block,
+CPU, Heap, Goroutine, and Mutex profiles. To generate profiles:
+
+* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc
+ ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or
+ `--vfs2`.
+* Restart docker: `sudo service docker restart`
+
+To run and generate CPU profiles run:
+
+```
+make sudo TARGETS=//path/to:target \
+ ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt"
+```
+
+Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof`
+
+Container name in most tests and benchmarks in gVisor is usually the test name
+and some random characters like so:
+`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2`
+
+Profiling requires root as runsc debug inspects running containers in /var/run
+among other things.
+
+### Writing for Profiling
+
+The below shows an example of using profiles with dockerutil.
+
+```
+func TestSuperCool(t *testing.T){
+ ctx := context.Background()
+ // profiled and using runtime from dockerutil.runtime flag
+ profiled := MakeContainer()
+
+ // not profiled and using runtime runc
+ native := MakeNativeContainer()
+
+ err := profiled.Spawn(ctx, RunOpts{
+ Image: "some/image",
+ }, "sleep", "100000")
+ // profiling has begun here
+ ...
+ expensive setup that I don't want to profile.
+ ...
+ profiled.RestartProfiles()
+ // profiled activity
+}
+```
+
+In the above example, `profiled` would be profiled and `native` would not. The
+call to `RestartProfiles()` restarts the clock on profiling. This is useful if
+the main activity being tested is done with `docker exec` or `container.Spawn()`
+followed by one or more `container.Exec()` calls.
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
index 17acdaf6f..b59503188 100644
--- a/pkg/test/dockerutil/container.go
+++ b/pkg/test/dockerutil/container.go
@@ -43,15 +43,21 @@ import (
// See: https://pkg.go.dev/github.com/docker/docker.
type Container struct {
Name string
- Runtime string
+ runtime string
logger testutil.Logger
client *client.Client
id string
mounts []mount.Mount
links []string
- cleanups []func()
copyErr error
+ cleanups []func()
+
+ // Profiles are profiles added to this container. They contain methods
+ // that are run after Creation, Start, and Cleanup of this Container, along
+ // a handle to restart the profile. Generally, tests/benchmarks using
+ // profiles need to run as root.
+ profiles []Profile
// Stores streams attached to the container. Used by WaitForOutputSubmatch.
streams types.HijackedResponse
@@ -106,7 +112,19 @@ type RunOpts struct {
// MakeContainer sets up the struct for a Docker container.
//
// Names of containers will be unique.
+// Containers will check flags for profiling requests.
func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ c := MakeNativeContainer(ctx, logger)
+ c.runtime = *runtime
+ if p := MakePprofFromFlags(c); p != nil {
+ c.AddProfile(p)
+ }
+ return c
+}
+
+// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native
+// containers aren't profiled.
+func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container {
// Slashes are not allowed in container names.
name := testutil.RandomID(logger.Name())
name = strings.ReplaceAll(name, "/", "-")
@@ -114,20 +132,33 @@ func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
if err != nil {
return nil
}
-
client.NegotiateAPIVersion(ctx)
-
return &Container{
logger: logger,
Name: name,
- Runtime: *runtime,
+ runtime: "",
client: client,
}
}
+// AddProfile adds a profile to this container.
+func (c *Container) AddProfile(p Profile) {
+ c.profiles = append(c.profiles, p)
+}
+
+// RestartProfiles calls Restart on all profiles for this container.
+func (c *Container) RestartProfiles() error {
+ for _, profile := range c.profiles {
+ if err := profile.Restart(c); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
// Spawn is analogous to 'docker run -d'.
func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error {
- if err := c.create(ctx, r, args); err != nil {
+ if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
return err
}
return c.Start(ctx)
@@ -153,7 +184,7 @@ func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string)
// Run is analogous to 'docker run'.
func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) {
- if err := c.create(ctx, r, args); err != nil {
+ if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
return "", err
}
@@ -181,27 +212,25 @@ func (c *Container) MakeLink(target string) string {
// CreateFrom creates a container from the given configs.
func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
- cont, err := c.client.ContainerCreate(ctx, conf, hostconf, netconf, c.Name)
- if err != nil {
- return err
- }
- c.id = cont.ID
- return nil
+ return c.create(ctx, conf, hostconf, netconf)
}
// Create is analogous to 'docker create'.
func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error {
- return c.create(ctx, r, args)
+ return c.create(ctx, c.config(r, args), c.hostConfig(r), nil)
}
-func (c *Container) create(ctx context.Context, r RunOpts, args []string) error {
- conf := c.config(r, args)
- hostconf := c.hostConfig(r)
+func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name)
if err != nil {
return err
}
c.id = cont.ID
+ for _, profile := range c.profiles {
+ if err := profile.OnCreate(c); err != nil {
+ return fmt.Errorf("OnCreate method failed with: %v", err)
+ }
+ }
return nil
}
@@ -227,7 +256,7 @@ func (c *Container) hostConfig(r RunOpts) *container.HostConfig {
c.mounts = append(c.mounts, r.Mounts...)
return &container.HostConfig{
- Runtime: c.Runtime,
+ Runtime: c.runtime,
Mounts: c.mounts,
PublishAllPorts: true,
Links: r.Links,
@@ -261,8 +290,15 @@ func (c *Container) Start(ctx context.Context) error {
c.cleanups = append(c.cleanups, func() {
c.streams.Close()
})
-
- return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{})
+ if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil {
+ return fmt.Errorf("ContainerStart failed: %v", err)
+ }
+ for _, profile := range c.profiles {
+ if err := profile.OnStart(c); err != nil {
+ return fmt.Errorf("OnStart method failed: %v", err)
+ }
+ }
+ return nil
}
// Stop is analogous to 'docker stop'.
@@ -482,6 +518,12 @@ func (c *Container) Remove(ctx context.Context) error {
// CleanUp kills and deletes the container (best effort).
func (c *Container) CleanUp(ctx context.Context) {
+ // Execute profile cleanups before the container goes down.
+ for _, profile := range c.profiles {
+ profile.OnCleanUp(c)
+ }
+ // Forget profiles.
+ c.profiles = nil
// Kill the container.
if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") {
// Just log; can't do anything here.
diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go
index df09babf3..5a9dd8bd8 100644
--- a/pkg/test/dockerutil/dockerutil.go
+++ b/pkg/test/dockerutil/dockerutil.go
@@ -25,6 +25,7 @@ import (
"os/exec"
"regexp"
"strconv"
+ "time"
"gvisor.dev/gvisor/pkg/test/testutil"
)
@@ -42,6 +43,26 @@ var (
// config is the default Docker daemon configuration path.
config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths")
+
+ // The following flags are for the "pprof" profiler tool.
+
+ // pprofBaseDir allows the user to change the directory to which profiles are
+ // written. By default, profiles will appear under:
+ // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof.
+ pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)")
+
+ // duration is the max duration `runsc debug` will run and capture profiles.
+ // If the container's clean up method is called prior to duration, the
+ // profiling process will be killed.
+ duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds")
+
+ // The below flags enable each type of profile. Multiple profiles can be
+ // enabled for each run.
+ pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug")
+ pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug")
+ pprofGo = flag.Bool("pprof-go", false, "enables goroutine profiling with runsc debug")
+ pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug")
+ pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug")
)
// EnsureSupportedDockerVersion checks if correct docker is installed.
diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go
index 921d1da9e..4c739c9e9 100644
--- a/pkg/test/dockerutil/exec.go
+++ b/pkg/test/dockerutil/exec.go
@@ -87,7 +87,6 @@ func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Proc
execid: resp.ID,
conn: hijack,
}, nil
-
}
func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig {
diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go
new file mode 100644
index 000000000..1fab33083
--- /dev/null
+++ b/pkg/test/dockerutil/profile.go
@@ -0,0 +1,152 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "time"
+)
+
+// Profile represents profile-like operations on a container,
+// such as running perf or pprof. It is meant to be added to containers
+// such that the container type calls the Profile during its lifecycle.
+type Profile interface {
+ // OnCreate is called just after the container is created when the container
+ // has a valid ID (e.g. c.ID()).
+ OnCreate(c *Container) error
+
+ // OnStart is called just after the container is started when the container
+ // has a valid Pid (e.g. c.SandboxPid()).
+ OnStart(c *Container) error
+
+ // Restart restarts the Profile on request.
+ Restart(c *Container) error
+
+ // OnCleanUp is called during the container's cleanup method.
+ // Cleanups should just log errors if they have them.
+ OnCleanUp(c *Container) error
+}
+
+// Pprof is for running profiles with 'runsc debug'. Pprof workloads
+// should be run as root and ONLY against runsc sandboxes. The runtime
+// should have --profile set as an option in /etc/docker/daemon.json in
+// order for profiling to work with Pprof.
+type Pprof struct {
+ BasePath string // path to put profiles
+ BlockProfile bool
+ CPUProfile bool
+ GoRoutineProfile bool
+ HeapProfile bool
+ MutexProfile bool
+ Duration time.Duration // duration to run profiler e.g. '10s' or '1m'.
+ shouldRun bool
+ cmd *exec.Cmd
+ stdout io.ReadCloser
+ stderr io.ReadCloser
+}
+
+// MakePprofFromFlags makes a Pprof profile from flags.
+func MakePprofFromFlags(c *Container) *Pprof {
+ if !(*pprofBlock || *pprofCPU || *pprofGo || *pprofHeap || *pprofMutex) {
+ return nil
+ }
+ return &Pprof{
+ BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name),
+ BlockProfile: *pprofBlock,
+ CPUProfile: *pprofCPU,
+ GoRoutineProfile: *pprofGo,
+ HeapProfile: *pprofHeap,
+ MutexProfile: *pprofMutex,
+ Duration: *duration,
+ }
+}
+
+// OnCreate implements Profile.OnCreate.
+func (p *Pprof) OnCreate(c *Container) error {
+ return os.MkdirAll(p.BasePath, 0755)
+}
+
+// OnStart implements Profile.OnStart.
+func (p *Pprof) OnStart(c *Container) error {
+ path, err := RuntimePath()
+ if err != nil {
+ return fmt.Errorf("failed to get runtime path: %v", err)
+ }
+
+ // The root directory of this container's runtime.
+ root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime)
+ // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`.
+ args := []string{root, "debug"}
+ args = append(args, p.makeProfileArgs(c)...)
+ args = append(args, c.ID())
+
+ // Best effort wait until container is running.
+ for now := time.Now(); time.Since(now) < 5*time.Second; {
+ if status, err := c.Status(context.Background()); err != nil {
+ return fmt.Errorf("failed to get status with: %v", err)
+
+ } else if status.Running {
+ break
+ }
+ time.Sleep(500 * time.Millisecond)
+ }
+ p.cmd = exec.Command(path, args...)
+ if err := p.cmd.Start(); err != nil {
+ return fmt.Errorf("process failed: %v", err)
+ }
+ return nil
+}
+
+// Restart implements Profile.Restart.
+func (p *Pprof) Restart(c *Container) error {
+ p.OnCleanUp(c)
+ return p.OnStart(c)
+}
+
+// OnCleanUp implements Profile.OnCleanup
+func (p *Pprof) OnCleanUp(c *Container) error {
+ defer func() { p.cmd = nil }()
+ if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() {
+ return p.cmd.Process.Kill()
+ }
+ return nil
+}
+
+// makeProfileArgs turns Pprof fields into runsc debug flags.
+func (p *Pprof) makeProfileArgs(c *Container) []string {
+ var ret []string
+ if p.BlockProfile {
+ ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof")))
+ }
+ if p.CPUProfile {
+ ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof")))
+ }
+ if p.GoRoutineProfile {
+ ret = append(ret, fmt.Sprintf("--profile-goroutine=%s", filepath.Join(p.BasePath, "go.pprof")))
+ }
+ if p.HeapProfile {
+ ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof")))
+ }
+ if p.MutexProfile {
+ ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof")))
+ }
+ ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration))
+ return ret
+}
diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go
new file mode 100644
index 000000000..b7b4d7618
--- /dev/null
+++ b/pkg/test/dockerutil/profile_test.go
@@ -0,0 +1,117 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+)
+
+type testCase struct {
+ name string
+ pprof Pprof
+ expectedFiles []string
+}
+
+func TestPprof(t *testing.T) {
+ // Basepath and expected file names for each type of profile.
+ basePath := "/tmp/test/profile"
+ block := "block.pprof"
+ cpu := "cpu.pprof"
+ goprofle := "go.pprof"
+ heap := "heap.pprof"
+ mutex := "mutex.pprof"
+
+ testCases := []testCase{
+ {
+ name: "Cpu",
+ pprof: Pprof{
+ BasePath: basePath,
+ CPUProfile: true,
+ Duration: 2 * time.Second,
+ },
+ expectedFiles: []string{cpu},
+ },
+ {
+ name: "All",
+ pprof: Pprof{
+ BasePath: basePath,
+ BlockProfile: true,
+ CPUProfile: true,
+ GoRoutineProfile: true,
+ HeapProfile: true,
+ MutexProfile: true,
+ Duration: 2 * time.Second,
+ },
+ expectedFiles: []string{block, cpu, goprofle, heap, mutex},
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ ctx := context.Background()
+ c := MakeContainer(ctx, t)
+ // Set basepath to include the container name so there are no conflicts.
+ tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name)
+ c.AddProfile(&tc.pprof)
+
+ func() {
+ defer c.CleanUp(ctx)
+ // Start a container.
+ if err := c.Spawn(ctx, RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
+ t.Fatalf("run failed with: %v", err)
+ }
+
+ if status, err := c.Status(context.Background()); !status.Running {
+ t.Fatalf("container is not yet running: %+v err: %v", status, err)
+ }
+
+ // End early if the expected files exist and have data.
+ for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) {
+ if err := checkFiles(tc); err == nil {
+ break
+ }
+ }
+ }()
+
+ // Check all expected files exist and have data.
+ if err := checkFiles(tc); err != nil {
+ t.Fatalf(err.Error())
+ }
+ })
+ }
+}
+
+func checkFiles(tc testCase) error {
+ for _, file := range tc.expectedFiles {
+ stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file))
+ if err != nil {
+ return fmt.Errorf("stat failed with: %v", err)
+ } else if stat.Size() < 1 {
+ return fmt.Errorf("file not written to: %+v", stat)
+ }
+ }
+ return nil
+}
+
+func TestMain(m *testing.M) {
+ EnsureSupportedDockerVersion()
+ os.Exit(m.Run())
+}