summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/epoll.go7
-rw-r--r--pkg/abi/linux/epoll_amd64.go2
-rw-r--r--pkg/abi/linux/epoll_arm64.go2
-rw-r--r--pkg/abi/linux/file.go2
-rw-r--r--pkg/abi/linux/fs.go2
-rw-r--r--pkg/abi/linux/ioctl.go26
-rw-r--r--pkg/abi/linux/ioctl_tun.go (renamed from pkg/usermem/usermem_unsafe.go)24
-rw-r--r--pkg/abi/linux/netfilter.go9
-rw-r--r--pkg/abi/linux/signal.go2
-rw-r--r--pkg/abi/linux/time.go6
-rw-r--r--pkg/abi/linux/xattr.go1
-rw-r--r--pkg/atomicbitops/BUILD10
-rw-r--r--pkg/atomicbitops/atomicbitops.go (renamed from pkg/atomicbitops/atomic_bitops.go)31
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s (renamed from pkg/atomicbitops/atomic_bitops_amd64.s)38
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s (renamed from pkg/atomicbitops/atomic_bitops_arm64.s)34
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go (renamed from pkg/atomicbitops/atomic_bitops_common.go)42
-rw-r--r--pkg/atomicbitops/atomicbitops_test.go (renamed from pkg/atomicbitops/atomic_bitops_test.go)64
-rw-r--r--pkg/cpuid/cpuid_x86.go19
-rw-r--r--pkg/fspath/BUILD4
-rw-r--r--pkg/fspath/builder.go8
-rw-r--r--pkg/fspath/fspath.go3
-rw-r--r--pkg/gohacks/BUILD11
-rw-r--r--pkg/gohacks/gohacks_unsafe.go57
-rw-r--r--pkg/log/glog.go6
-rw-r--r--pkg/log/json.go2
-rw-r--r--pkg/log/json_k8s.go2
-rw-r--r--pkg/log/log.go60
-rw-r--r--pkg/sentry/arch/arch_aarch64.go3
-rw-r--r--pkg/sentry/fs/dev/BUILD5
-rw-r--r--pkg/sentry/fs/dev/dev.go10
-rw-r--r--pkg/sentry/fs/dev/net_tun.go170
-rw-r--r--pkg/sentry/fs/mount_test.go11
-rw-r--r--pkg/sentry/fs/mounts.go10
-rw-r--r--pkg/sentry/fs/proc/mounts.go16
-rw-r--r--pkg/sentry/fs/proc/net.go5
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fsbridge/vfs.go10
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go4
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_net.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go4
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go233
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go16
-rw-r--r--pkg/sentry/inet/BUILD1
-rw-r--r--pkg/sentry/inet/namespace.go99
-rw-r--r--pkg/sentry/kernel/fd_table.go49
-rw-r--r--pkg/sentry/kernel/fs_context.go22
-rw-r--r--pkg/sentry/kernel/kernel.go30
-rw-r--r--pkg/sentry/kernel/rseq.go2
-rw-r--r--pkg/sentry/kernel/task.go27
-rw-r--r--pkg/sentry/kernel/task_clone.go16
-rw-r--r--pkg/sentry/kernel/task_context.go2
-rw-r--r--pkg/sentry/kernel/task_exec.go2
-rw-r--r--pkg/sentry/kernel/task_log.go6
-rw-r--r--pkg/sentry/kernel/task_net.go19
-rw-r--r--pkg/sentry/kernel/task_start.go8
-rw-r--r--pkg/sentry/kernel/task_usermem.go2
-rw-r--r--pkg/sentry/mm/address_space.go46
-rw-r--r--pkg/sentry/mm/lifecycle.go37
-rw-r--r--pkg/sentry/mm/mm.go5
-rw-r--r--pkg/sentry/mm/mm_test.go2
-rw-r--r--pkg/sentry/platform/kvm/machine.go18
-rw-r--r--pkg/sentry/socket/control/control.go2
-rw-r--r--pkg/sentry/socket/netfilter/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go63
-rw-r--r--pkg/sentry/socket/netfilter/targets.go35
-rw-r--r--pkg/sentry/socket/netstack/netstack.go27
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/epoll.go89
-rw-r--r--pkg/sentry/strace/linux64_amd64.go6
-rw-r--r--pkg/sentry/strace/linux64_arm64.go4
-rw-r--r--pkg/sentry/strace/strace.go8
-rw-r--r--pkg/sentry/strace/syscalls.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go40
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD28
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go225
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go44
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go147
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go326
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fscontext.go131
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go149
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go35
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go216
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go92
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/path.go94
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go584
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go511
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go380
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go346
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go87
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sys_read.go95
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go353
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/epoll.go3
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go12
-rw-r--r--pkg/sentry/vfs/resolving_path.go2
-rw-r--r--pkg/sentry/vfs/vfs.go10
-rw-r--r--pkg/syncevent/BUILD39
-rw-r--r--pkg/syncevent/broadcaster.go218
-rw-r--r--pkg/syncevent/broadcaster_test.go376
-rw-r--r--pkg/syncevent/receiver.go103
-rw-r--r--pkg/syncevent/source.go59
-rw-r--r--pkg/syncevent/syncevent.go32
-rw-r--r--pkg/syncevent/syncevent_example_test.go108
-rw-r--r--pkg/syncevent/waiter_amd64.s32
-rw-r--r--pkg/syncevent/waiter_arm64.s34
-rw-r--r--pkg/syncevent/waiter_asm_unsafe.go (renamed from pkg/fspath/builder_unsafe.go)15
-rw-r--r--pkg/syncevent/waiter_noasm_unsafe.go39
-rw-r--r--pkg/syncevent/waiter_test.go414
-rw-r--r--pkg/syncevent/waiter_unsafe.go206
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go63
-rw-r--r--pkg/tcpip/buffer/view.go6
-rw-r--r--pkg/tcpip/checker/checker.go14
-rw-r--r--pkg/tcpip/iptables/iptables.go104
-rw-r--r--pkg/tcpip/iptables/targets.go28
-rw-r--r--pkg/tcpip/iptables/types.go19
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go180
-rw-r--r--pkg/tcpip/link/tun/BUILD18
-rw-r--r--pkg/tcpip/link/tun/device.go352
-rw-r--r--pkg/tcpip/link/tun/protocol.go56
-rw-r--r--pkg/tcpip/network/arp/arp.go20
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go6
-rw-r--r--pkg/tcpip/stack/ndp.go25
-rw-r--r--pkg/tcpip/stack/ndp_test.go900
-rw-r--r--pkg/tcpip/stack/nic.go149
-rw-r--r--pkg/tcpip/stack/registration.go25
-rw-r--r--pkg/tcpip/stack/stack.go103
-rw-r--r--pkg/tcpip/stack/stack_test.go439
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go20
-rw-r--r--pkg/tcpip/stack/transport_test.go15
-rw-r--r--pkg/tcpip/tcpip.go23
-rw-r--r--pkg/tcpip/time_unsafe.go2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go16
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go5
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go9
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go31
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go33
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go43
-rw-r--r--pkg/tcpip/transport/udp/protocol.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go120
-rw-r--r--pkg/usermem/BUILD2
-rw-r--r--pkg/usermem/usermem.go9
164 files changed, 9034 insertions, 1254 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index a89f34d4b..322d1ccc4 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -30,6 +30,7 @@ go_library(
"futex.go",
"inotify.go",
"ioctl.go",
+ "ioctl_tun.go",
"ip.go",
"ipc.go",
"limits.go",
diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go
index 6e4de69da..1121a1a92 100644
--- a/pkg/abi/linux/epoll.go
+++ b/pkg/abi/linux/epoll.go
@@ -14,6 +14,10 @@
package linux
+import (
+ "gvisor.dev/gvisor/pkg/binary"
+)
+
// Event masks.
const (
EPOLLIN = 0x1
@@ -53,3 +57,6 @@ const (
EPOLL_CTL_DEL = 0x2
EPOLL_CTL_MOD = 0x3
)
+
+// SizeOfEpollEvent is the size of EpollEvent struct.
+var SizeOfEpollEvent = int(binary.Size(EpollEvent{}))
diff --git a/pkg/abi/linux/epoll_amd64.go b/pkg/abi/linux/epoll_amd64.go
index 57041491c..34ff18009 100644
--- a/pkg/abi/linux/epoll_amd64.go
+++ b/pkg/abi/linux/epoll_amd64.go
@@ -15,6 +15,8 @@
package linux
// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
type EpollEvent struct {
Events uint32
// Linux makes struct epoll_event::data a __u64. We represent it as
diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go
index 62ef5821e..f86c35329 100644
--- a/pkg/abi/linux/epoll_arm64.go
+++ b/pkg/abi/linux/epoll_arm64.go
@@ -15,6 +15,8 @@
package linux
// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
type EpollEvent struct {
Events uint32
// Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index c3ab15a4f..e229ac21c 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -241,6 +241,8 @@ const (
)
// Statx represents struct statx.
+//
+// +marshal
type Statx struct {
Mask uint32
Blksize uint32
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index 2c652baa2..158d2db5b 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -38,6 +38,8 @@ const (
)
// Statfs is struct statfs, from uapi/asm-generic/statfs.h.
+//
+// +marshal
type Statfs struct {
// Type is one of the filesystem magic values, defined above.
Type uint64
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 0e18db9ef..2062e6a4b 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -72,3 +72,29 @@ const (
SIOCGMIIPHY = 0x8947
SIOCGMIIREG = 0x8948
)
+
+// ioctl(2) directions. Used to calculate requests number.
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NONE = 0
+ _IOC_WRITE = 1
+ _IOC_READ = 2
+)
+
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NRBITS = 8
+ _IOC_TYPEBITS = 8
+ _IOC_SIZEBITS = 14
+ _IOC_DIRBITS = 2
+
+ _IOC_NRSHIFT = 0
+ _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS
+ _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS
+ _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS
+)
+
+// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
+func IOC(dir, typ, nr, size uint32) uint32 {
+ return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
+}
diff --git a/pkg/usermem/usermem_unsafe.go b/pkg/abi/linux/ioctl_tun.go
index 876783e78..c59c9c136 100644
--- a/pkg/usermem/usermem_unsafe.go
+++ b/pkg/abi/linux/ioctl_tun.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,16 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+package linux
-import (
- "unsafe"
+// ioctl(2) request numbers from linux/if_tun.h
+var (
+ TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4)
+ TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4)
)
-// stringFromImmutableBytes is equivalent to string(bs), except that it never
-// copies even if escape analysis can't prove that bs does not escape. This is
-// only valid if bs is never mutated after stringFromImmutableBytes returns.
-func stringFromImmutableBytes(bs []byte) string {
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
-}
+// Flags from net/if_tun.h
+const (
+ IFF_TUN = 0x0001
+ IFF_TAP = 0x0002
+ IFF_NO_PI = 0x1000
+ IFF_NOFILTER = 0x1000
+)
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 2179ac995..314c318b6 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -225,11 +225,14 @@ type XTEntryTarget struct {
// SizeOfXTEntryTarget is the size of an XTEntryTarget.
const SizeOfXTEntryTarget = 32
-// XTStandardTarget is a builtin target, one of ACCEPT, DROP, JUMP, QUEUE, or
-// RETURN. It corresponds to struct xt_standard_target in
+// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
+// RETURN, or jump. It corresponds to struct xt_standard_target in
// include/uapi/linux/netfilter/x_tables.h.
type XTStandardTarget struct {
- Target XTEntryTarget
+ Target XTEntryTarget
+ // A positive verdict indicates a jump, and is the offset from the
+ // start of the table to jump to. A negative value means one of the
+ // other built-in targets.
Verdict int32
_ [4]byte
}
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index c69b04ea9..1c330e763 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -115,6 +115,8 @@ const (
)
// SignalSet is a signal mask with a bit corresponding to each signal.
+//
+// +marshal
type SignalSet uint64
// SignalSetSize is the size in bytes of a SignalSet.
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index e562b46d9..e6860ed49 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -157,6 +157,8 @@ func DurationToTimespec(dur time.Duration) Timespec {
const SizeOfTimeval = 16
// Timeval represents struct timeval in <time.h>.
+//
+// +marshal
type Timeval struct {
Sec int64
Usec int64
@@ -230,6 +232,8 @@ type Tms struct {
type TimerID int32
// StatxTimestamp represents struct statx_timestamp.
+//
+// +marshal
type StatxTimestamp struct {
Sec int64
Nsec uint32
@@ -258,6 +262,8 @@ func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
}
// Utime represents struct utimbuf used by utimes(2).
+//
+// +marshal
type Utime struct {
Actime int64
Modtime int64
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
index a3b6406fa..99180b208 100644
--- a/pkg/abi/linux/xattr.go
+++ b/pkg/abi/linux/xattr.go
@@ -18,6 +18,7 @@ package linux
const (
XATTR_NAME_MAX = 255
XATTR_SIZE_MAX = 65536
+ XATTR_LIST_MAX = 65536
XATTR_CREATE = 1
XATTR_REPLACE = 2
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 3948074ba..1a30f6967 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -5,10 +5,10 @@ package(licenses = ["notice"])
go_library(
name = "atomicbitops",
srcs = [
- "atomic_bitops.go",
- "atomic_bitops_amd64.s",
- "atomic_bitops_arm64.s",
- "atomic_bitops_common.go",
+ "atomicbitops.go",
+ "atomicbitops_amd64.s",
+ "atomicbitops_arm64.s",
+ "atomicbitops_noasm.go",
],
visibility = ["//:sandbox"],
)
@@ -16,7 +16,7 @@ go_library(
go_test(
name = "atomicbitops_test",
size = "small",
- srcs = ["atomic_bitops_test.go"],
+ srcs = ["atomicbitops_test.go"],
library = ":atomicbitops",
deps = ["//pkg/sync"],
)
diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomicbitops.go
index fcc41a9ea..1be081719 100644
--- a/pkg/atomicbitops/atomic_bitops.go
+++ b/pkg/atomicbitops/atomicbitops.go
@@ -14,47 +14,34 @@
// +build amd64 arm64
-// Package atomicbitops provides basic bitwise operations in an atomic way.
-// The implementation on amd64 leverages the LOCK prefix directly instead of
-// relying on the generic cas primitives, and the arm64 leverages the LDAXR
-// and STLXR pair primitives.
+// Package atomicbitops provides extensions to the sync/atomic package.
//
-// WARNING: the bitwise ops provided in this package doesn't imply any memory
-// ordering. Using them to construct locks must employ proper memory barriers.
+// All read-modify-write operations implemented by this package have
+// acquire-release memory ordering (like sync/atomic).
package atomicbitops
-// AndUint32 atomically applies bitwise and operation to *addr with val.
+// AndUint32 atomically applies bitwise AND operation to *addr with val.
func AndUint32(addr *uint32, val uint32)
-// OrUint32 atomically applies bitwise or operation to *addr with val.
+// OrUint32 atomically applies bitwise OR operation to *addr with val.
func OrUint32(addr *uint32, val uint32)
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
+// XorUint32 atomically applies bitwise XOR operation to *addr with val.
func XorUint32(addr *uint32, val uint32)
// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) uint32
-// AndUint64 atomically applies bitwise and operation to *addr with val.
+// AndUint64 atomically applies bitwise AND operation to *addr with val.
func AndUint64(addr *uint64, val uint64)
-// OrUint64 atomically applies bitwise or operation to *addr with val.
+// OrUint64 atomically applies bitwise OR operation to *addr with val.
func OrUint64(addr *uint64, val uint64)
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
+// XorUint64 atomically applies bitwise XOR operation to *addr with val.
func XorUint64(addr *uint64, val uint64)
// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) uint64
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool
diff --git a/pkg/atomicbitops/atomic_bitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
index db0972001..54c887ee5 100644
--- a/pkg/atomicbitops/atomic_bitops_amd64.s
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -75,41 +75,3 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
CMPXCHGQ DX, 0(DI)
MOVQ AX, ret+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- TESTL AX, AX
- JZ fail
- LEAL 1(AX), DX
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB AX, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- LEAL -1(AX), DX
- TESTL DX, DX
- JZ fail
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB DX, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
index 97f8808c1..5c780851b 100644
--- a/pkg/atomicbitops/atomic_bitops_arm64.s
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -50,7 +50,6 @@ TEXT ·CompareAndSwapUint32(SB),$0-20
MOVD addr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
-
again:
LDAXRW (R0), R3
CMPW R1, R3
@@ -95,7 +94,6 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
MOVD addr+0(FP), R0
MOVD old+8(FP), R1
MOVD new+16(FP), R2
-
again:
LDAXR (R0), R3
CMP R1, R3
@@ -105,35 +103,3 @@ again:
done:
MOVD R3, prev+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- CBZ R1, fail
- ADDW $1, R1
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- SUBSW $1, R1, R1
- BEQ fail
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomicbitops_noasm.go
index 85163ad62..3b2898256 100644
--- a/pkg/atomicbitops/atomic_bitops_common.go
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -20,7 +20,6 @@ import (
"sync/atomic"
)
-// AndUint32 atomically applies bitwise and operation to *addr with val.
func AndUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -31,7 +30,6 @@ func AndUint32(addr *uint32, val uint32) {
}
}
-// OrUint32 atomically applies bitwise or operation to *addr with val.
func OrUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -42,7 +40,6 @@ func OrUint32(addr *uint32, val uint32) {
}
}
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
func XorUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -53,8 +50,6 @@ func XorUint32(addr *uint32, val uint32) {
}
}
-// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
-// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
for {
prev = atomic.LoadUint32(addr)
@@ -67,7 +62,6 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
}
}
-// AndUint64 atomically applies bitwise and operation to *addr with val.
func AndUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -78,7 +72,6 @@ func AndUint64(addr *uint64, val uint64) {
}
}
-// OrUint64 atomically applies bitwise or operation to *addr with val.
func OrUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -89,7 +82,6 @@ func OrUint64(addr *uint64, val uint64) {
}
}
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
func XorUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -100,8 +92,6 @@ func XorUint64(addr *uint64, val uint64) {
}
}
-// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
-// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
for {
prev = atomic.LoadUint64(addr)
@@ -113,35 +103,3 @@ func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
}
}
}
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 0 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v+1) {
- return true
- }
- }
-}
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 1 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v-1) {
- return true
- }
- }
-}
diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomicbitops_test.go
index 9466d3e23..73af71bb4 100644
--- a/pkg/atomicbitops/atomic_bitops_test.go
+++ b/pkg/atomicbitops/atomicbitops_test.go
@@ -196,67 +196,3 @@ func TestCompareAndSwapUint64(t *testing.T) {
}
}
}
-
-func TestIncUnlessZeroInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: 0,
- ret: false,
- },
- {
- initial: 1,
- final: 2,
- ret: true,
- },
- {
- initial: 2,
- final: 3,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := IncUnlessZeroInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
-
-func TestDecUnlessOneInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: -1,
- ret: true,
- },
- {
- initial: 1,
- final: 1,
- ret: false,
- },
- {
- initial: 2,
- final: 1,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := DecUnlessOneInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index 333ca0a04..a0bc55ea1 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -725,6 +725,18 @@ func vendorIDFromRegs(bx, cx, dx uint32) string {
return string(bytes)
}
+var maxXsaveSize = func() uint32 {
+ // Leaf 0 of xsaveinfo function returns the size for currently
+ // enabled xsave features in ebx, the maximum size if all valid
+ // features are saved with xsave in ecx, and valid XCR0 bits in
+ // edx:eax.
+ //
+ // If xSaveInfo isn't supported, cpuid will not fault but will
+ // return bogus values.
+ _, _, maxXsaveSize, _ := HostID(uint32(xSaveInfo), 0)
+ return maxXsaveSize
+}()
+
// ExtendedStateSize returns the number of bytes needed to save the "extended
// state" for this processor and the boundary it must be aligned to. Extended
// state includes floating point registers, and other cpu state that's not
@@ -736,12 +748,7 @@ func vendorIDFromRegs(bx, cx, dx uint32) string {
// about 2.5K worst case, with avx512).
func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
if fs.UseXsave() {
- // Leaf 0 of xsaveinfo function returns the size for currently
- // enabled xsave features in ebx, the maximum size if all valid
- // features are saved with xsave in ecx, and valid XCR0 bits in
- // edx:eax.
- _, _, maxSize, _ := HostID(uint32(xSaveInfo), 0)
- return uint(maxSize), 64
+ return uint(maxXsaveSize), 64
}
// If we don't support xsave, we fall back to fxsave, which requires
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index ee84471b2..67dd1e225 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -8,9 +8,11 @@ go_library(
name = "fspath",
srcs = [
"builder.go",
- "builder_unsafe.go",
"fspath.go",
],
+ deps = [
+ "//pkg/gohacks",
+ ],
)
go_test(
diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go
index 7ddb36826..6318d3874 100644
--- a/pkg/fspath/builder.go
+++ b/pkg/fspath/builder.go
@@ -16,6 +16,8 @@ package fspath
import (
"fmt"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
)
// Builder is similar to strings.Builder, but is used to produce pathnames
@@ -102,3 +104,9 @@ func (b *Builder) AppendString(str string) {
copy(b.buf[b.start:], b.buf[oldStart:])
copy(b.buf[len(b.buf)-len(str):], str)
}
+
+// String returns the accumulated string. No other methods should be called
+// after String.
+func (b *Builder) String() string {
+ return gohacks.StringFromImmutableBytes(b.buf[b.start:])
+}
diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go
index 9fb3fee24..4c983d5fd 100644
--- a/pkg/fspath/fspath.go
+++ b/pkg/fspath/fspath.go
@@ -67,7 +67,8 @@ func Parse(pathname string) Path {
// Path contains the information contained in a pathname string.
//
-// Path is copyable by value.
+// Path is copyable by value. The zero value for Path is equivalent to
+// fspath.Parse(""), i.e. the empty path.
type Path struct {
// Begin is an iterator to the first path component in the relative part of
// the path.
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
new file mode 100644
index 000000000..798a65eca
--- /dev/null
+++ b/pkg/gohacks/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gohacks",
+ srcs = [
+ "gohacks_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
new file mode 100644
index 000000000..aad675172
--- /dev/null
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -0,0 +1,57 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gohacks contains utilities for subverting the Go compiler.
+package gohacks
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Noescape hides a pointer from escape analysis. Noescape is the identity
+// function but escape analysis doesn't think the output depends on the input.
+// Noescape is inlined and currently compiles down to zero instructions.
+// USE CAREFULLY!
+//
+// (Noescape is copy/pasted from Go's runtime/stubs.go:noescape().)
+//
+//go:nosplit
+func Noescape(p unsafe.Pointer) unsafe.Pointer {
+ x := uintptr(p)
+ return unsafe.Pointer(x ^ 0)
+}
+
+// ImmutableBytesFromString is equivalent to []byte(s), except that it uses the
+// same memory backing s instead of making a heap-allocated copy. This is only
+// valid if the returned slice is never mutated.
+func ImmutableBytesFromString(s string) []byte {
+ shdr := (*reflect.StringHeader)(unsafe.Pointer(&s))
+ var bs []byte
+ bshdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ bshdr.Data = shdr.Data
+ bshdr.Len = shdr.Len
+ bshdr.Cap = shdr.Len
+ return bs
+}
+
+// StringFromImmutableBytes is equivalent to string(bs), except that it uses
+// the same memory backing bs instead of making a heap-allocated copy. This is
+// only valid if bs is never mutated after StringFromImmutableBytes returns.
+func StringFromImmutableBytes(bs []byte) string {
+ // This is cheaper than messing with reflect.StringHeader and
+ // reflect.SliceHeader, which as of this writing produces many dead stores
+ // of zeroes. Compare strings.Builder.String().
+ return *(*string)(unsafe.Pointer(&bs))
+}
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
index cab5fae55..b4f7bb5a4 100644
--- a/pkg/log/glog.go
+++ b/pkg/log/glog.go
@@ -46,7 +46,7 @@ var pid = os.Getpid()
// line The line number
// msg The user-supplied message
//
-func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
// Log level.
prefix := byte('?')
switch level {
@@ -64,9 +64,7 @@ func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, ar
microsecond := int(timestamp.Nanosecond() / 1000)
// 0 = this frame.
- // 1 = Debugf, etc.
- // 2 = Caller.
- _, file, line, ok := runtime.Caller(2)
+ _, file, line, ok := runtime.Caller(depth + 1)
if ok {
// Trim any directory path from the file.
slash := strings.LastIndexByte(file, byte('/'))
diff --git a/pkg/log/json.go b/pkg/log/json.go
index a278c8fc8..0943db1cc 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -62,7 +62,7 @@ type JSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e JSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e JSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := jsonLog{
Msg: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
index cee6eb514..6c6fc8b6f 100644
--- a/pkg/log/json_k8s.go
+++ b/pkg/log/json_k8s.go
@@ -33,7 +33,7 @@ type K8sJSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e *K8sJSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := k8sJSONLog{
Log: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 5056f17e6..a794da1aa 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -79,7 +79,7 @@ func (l Level) String() string {
type Emitter interface {
// Emit emits the given log statement. This allows for control over the
// timestamp used for logging.
- Emit(level Level, timestamp time.Time, format string, v ...interface{})
+ Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{})
}
// Writer writes the output to the given writer.
@@ -142,7 +142,7 @@ func (l *Writer) Write(data []byte) (int, error) {
}
// Emit emits the message.
-func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (l *Writer) Emit(_ int, _ Level, _ time.Time, format string, args ...interface{}) {
fmt.Fprintf(l, format, args...)
}
@@ -150,9 +150,9 @@ func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...i
type MultiEmitter []Emitter
// Emit emits to all emitters.
-func (m *MultiEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (m *MultiEmitter) Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) {
for _, e := range *m {
- e.Emit(level, timestamp, format, v...)
+ e.Emit(1+depth, level, timestamp, format, v...)
}
}
@@ -167,7 +167,7 @@ type TestEmitter struct {
}
// Emit emits to the TestLogger.
-func (t *TestEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (t *TestEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
t.Logf(format, v...)
}
@@ -198,22 +198,37 @@ type BasicLogger struct {
// Debugf implements logger.Debugf.
func (l *BasicLogger) Debugf(format string, v ...interface{}) {
- if l.IsLogging(Debug) {
- l.Emit(Debug, time.Now(), format, v...)
- }
+ l.DebugfAtDepth(1, format, v...)
}
// Infof implements logger.Infof.
func (l *BasicLogger) Infof(format string, v ...interface{}) {
- if l.IsLogging(Info) {
- l.Emit(Info, time.Now(), format, v...)
- }
+ l.InfofAtDepth(1, format, v...)
}
// Warningf implements logger.Warningf.
func (l *BasicLogger) Warningf(format string, v ...interface{}) {
+ l.WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs at a specific depth.
+func (l *BasicLogger) DebugfAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Debug) {
+ l.Emit(1+depth, Debug, time.Now(), format, v...)
+ }
+}
+
+// InfofAtDepth logs at a specific depth.
+func (l *BasicLogger) InfofAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Info) {
+ l.Emit(1+depth, Info, time.Now(), format, v...)
+ }
+}
+
+// WarningfAtDepth logs at a specific depth.
+func (l *BasicLogger) WarningfAtDepth(depth int, format string, v ...interface{}) {
if l.IsLogging(Warning) {
- l.Emit(Warning, time.Now(), format, v...)
+ l.Emit(1+depth, Warning, time.Now(), format, v...)
}
}
@@ -257,17 +272,32 @@ func SetLevel(newLevel Level) {
// Debugf logs to the global logger.
func Debugf(format string, v ...interface{}) {
- Log().Debugf(format, v...)
+ Log().DebugfAtDepth(1, format, v...)
}
// Infof logs to the global logger.
func Infof(format string, v ...interface{}) {
- Log().Infof(format, v...)
+ Log().InfofAtDepth(1, format, v...)
}
// Warningf logs to the global logger.
func Warningf(format string, v ...interface{}) {
- Log().Warningf(format, v...)
+ Log().WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs to the global logger.
+func DebugfAtDepth(depth int, format string, v ...interface{}) {
+ Log().DebugfAtDepth(1+depth, format, v...)
+}
+
+// InfofAtDepth logs to the global logger.
+func InfofAtDepth(depth int, format string, v ...interface{}) {
+ Log().InfofAtDepth(1+depth, format, v...)
+}
+
+// WarningfAtDepth logs to the global logger.
+func WarningfAtDepth(depth int, format string, v ...interface{}) {
+ Log().WarningfAtDepth(1+depth, format, v...)
}
// defaultStackSize is the default buffer size to allocate for stack traces.
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 3b6987665..d7794db63 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -34,6 +34,9 @@ const (
SyscallWidth = 4
)
+// ARMTrapFlag is the mask for the trap flag.
+const ARMTrapFlag = uint64(1) << 21
+
// aarch64FPState is aarch64 floating point state.
type aarch64FPState []byte
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 4c4b7d5cc..9b6bb26d0 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -9,6 +9,7 @@ go_library(
"device.go",
"fs.go",
"full.go",
+ "net_tun.go",
"null.go",
"random.go",
"tty.go",
@@ -19,15 +20,19 @@ go_library(
"//pkg/context",
"//pkg/rand",
"//pkg/safemem",
+ "//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
+ "//pkg/sentry/socket/netstack",
"//pkg/syserror",
+ "//pkg/tcpip/link/tun",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index 35bd23991..7e66c29b0 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
})
}
-func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555))
+func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
@@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
// A devpts is typically mounted at /dev/pts to provide
// pseudoterminal support. Place an empty directory there for
// the devpts to be mounted over.
- "pts": newDirectory(ctx, msrc),
+ "pts": newDirectory(ctx, nil, msrc),
// Similarly, applications expect a ptmx device at /dev/ptmx
// connected to the terminals provided by /dev/pts/. Rather
// than creating a device directly (which requires a hairy
@@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"ptmx": newSymlink(ctx, "pts/ptmx", msrc),
"tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
+
+ "net": newDirectory(ctx, map[string]*fs.Inode{
+ "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
+ }, msrc),
}
iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
new file mode 100644
index 000000000..755644488
--- /dev/null
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// +stateify savable
+type netTunInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*netTunInodeOperations)(nil)
+
+func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations {
+ return &netTunInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
+}
+
+// +stateify savable
+type netTunFileOperations struct {
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ device tun.Device
+}
+
+var _ fs.FileOperations = (*netTunFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (fops *netTunFileOperations) Release() {
+ fops.device.Release()
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fops.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fops.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Write implements fs.FileOperations.Write.
+func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fops.device.Write(data)
+}
+
+// Read implements fs.FileOperations.Read.
+func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := fops.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fops.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fops.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ fops.device.EventUnregister(e)
+}
diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go
index e672a438c..a3d10770b 100644
--- a/pkg/sentry/fs/mount_test.go
+++ b/pkg/sentry/fs/mount_test.go
@@ -36,11 +36,12 @@ func mountPathsAre(root *Dirent, got []*Mount, want ...string) error {
gotPaths := make(map[string]struct{}, len(got))
gotStr := make([]string, len(got))
for i, g := range got {
- groot := g.Root()
- name, _ := groot.FullName(root)
- groot.DecRef()
- gotStr[i] = name
- gotPaths[name] = struct{}{}
+ if groot := g.Root(); groot != nil {
+ name, _ := groot.FullName(root)
+ groot.DecRef()
+ gotStr[i] = name
+ gotPaths[name] = struct{}{}
+ }
}
if len(got) != len(want) {
return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want)
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 574a2cc91..c7981f66e 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -100,10 +100,14 @@ func newUndoMount(d *Dirent) *Mount {
}
}
-// Root returns the root dirent of this mount. Callers must call DecRef on the
-// returned dirent.
+// Root returns the root dirent of this mount.
+//
+// This may return nil if the mount has already been free. Callers must handle this
+// case appropriately. If non-nil, callers must call DecRef on the returned *Dirent.
func (m *Mount) Root() *Dirent {
- m.root.IncRef()
+ if !m.root.TryIncRef() {
+ return nil
+ }
return m.root
}
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index c10888100..94deb553b 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -60,13 +60,15 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
})
for _, m := range ms {
mroot := m.Root()
+ if mroot == nil {
+ continue // No longer valid.
+ }
mountPath, desc := mroot.FullName(rootDir)
mroot.DecRef()
if !desc {
// MountSources that are not descendants of the chroot jail are ignored.
continue
}
-
fn(mountPath, m)
}
}
@@ -91,6 +93,12 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
var buf bytes.Buffer
forEachMount(mif.t, func(mountPath string, m *fs.Mount) {
+ mroot := m.Root()
+ if mroot == nil {
+ return // No longer valid.
+ }
+ defer mroot.DecRef()
+
// Format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
@@ -107,9 +115,6 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
// (3) Major:Minor device ID. We don't have a superblock, so we
// just use the root inode device number.
- mroot := m.Root()
- defer mroot.DecRef()
-
sa := mroot.Inode.StableAttr
fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor)
@@ -207,6 +212,9 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
//
// The "needs dump"and fsck flags are always 0, which is allowed.
root := m.Root()
+ if root == nil {
+ return // No longer valid.
+ }
defer root.DecRef()
flags := root.Inode.MountSource.Flags
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 6f2775344..95d5817ff 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -43,7 +43,10 @@ import (
// newNet creates a new proc net entry.
func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
"snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc),
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 0772d4ae4..d4c4b533d 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -357,7 +357,9 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine
func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"ipv4": p.newSysNetIPv4Dir(ctx, msrc, s),
"core": p.newSysNetCore(ctx, msrc, s),
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index e657c39bc..6aa17bfc1 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -117,15 +117,19 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry)
// default anyways.
//
// TODO(gvisor.dev/issue/1623): Check mount has read and exec permission.
-func (l *vfsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
+func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
creds := auth.CredentialsFromContext(ctx)
+ path := fspath.Parse(pathname)
pop := &vfs.PathOperation{
Root: l.root,
- Start: l.root,
- Path: fspath.Parse(path),
+ Start: l.workingDir,
+ Path: path,
FollowFinalSymlink: resolveFinal,
}
+ if path.Absolute {
+ pop.Start = l.root
+ }
fd, err := vfsObj.OpenAt(ctx, creds, pop, &opts)
if err != nil {
return nil, err
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index ce08a7d53..10c08fa90 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -73,9 +73,9 @@ func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNames
"meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
"mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
"net": newNetDir(root, inoGen, k),
- "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}),
+ "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}),
"uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
- "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}),
+ "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}),
}
inode := &tasksInode{
diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/tasks_net.go
index 608fec017..d4e1812d8 100644
--- a/pkg/sentry/fsimpl/proc/tasks_net.go
+++ b/pkg/sentry/fsimpl/proc/tasks_net.go
@@ -39,7 +39,10 @@ import (
func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
const (
arp = "IP address HW type Flags HW address Mask Device\n"
netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n"
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index c7ce74883..3d5dc463c 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -50,7 +50,9 @@ func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]*kernfs.Dentry{
"ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
"tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}),
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index d0be32e72..488478e29 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -128,6 +128,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns
ThreadGroup: tc,
TaskContext: &kernel.TaskContext{Name: name},
Credentials: auth.CredentialsFromContext(ctx),
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
UTSNamespace: kernel.UTSNamespaceFromContext(ctx),
IPCNamespace: kernel.IPCNamespaceFromContext(ctx),
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 7f7b791c4..e1b551422 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -16,7 +16,6 @@ package tmpfs
import (
"fmt"
- "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -347,10 +346,9 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, err
}
if opts.Flags&linux.O_TRUNC != 0 {
- impl.mu.Lock()
- impl.data.Truncate(0, impl.memFile)
- atomic.StoreUint64(&impl.size, 0)
- impl.mu.Unlock()
+ if _, err := impl.truncate(0); err != nil {
+ return nil, err
+ }
}
return &fd.vfsfd, nil
case *directory:
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index dab346a41..711442424 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -15,6 +15,7 @@
package tmpfs
import (
+ "fmt"
"io"
"math"
"sync/atomic"
@@ -22,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -34,25 +36,53 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// regularFile is a regular (=S_IFREG) tmpfs file.
type regularFile struct {
inode inode
// memFile is a platform.File used to allocate pages to this regularFile.
memFile *pgalloc.MemoryFile
- // mu protects the fields below.
- mu sync.RWMutex
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the file into memmap.MappingSpaces.
+ //
+ // Protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // writableMappingPages tracks how many pages of virtual memory are mapped
+ // as potentially writable from this file. If a page has multiple mappings,
+ // each mapping is counted separately.
+ //
+ // This counter is susceptible to overflow as we can potentially count
+ // mappings from many VMAs. We count pages rather than bytes to slightly
+ // mitigate this.
+ //
+ // Protected by mapsMu.
+ writableMappingPages uint64
+
+ // dataMu protects the fields below.
+ dataMu sync.RWMutex
// data maps offsets into the file to offsets into memFile that store
// the file's data.
+ //
+ // Protected by dataMu.
data fsutil.FileRangeSet
- // size is the size of data, but accessed using atomic memory
- // operations to avoid locking in inode.stat().
- size uint64
-
// seals represents file seals on this inode.
+ //
+ // Protected by dataMu.
seals uint32
+
+ // size is the size of data.
+ //
+ // Protected by both dataMu and inode.mu; reading it requires holding
+ // either mutex, while writing requires holding both AND using atomics.
+ // Readers that do not require consistency (like Stat) may read the
+ // value atomically without holding either lock.
+ size uint64
}
func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
@@ -66,39 +96,170 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod
// truncate grows or shrinks the file to the given size. It returns true if the
// file size was updated.
-func (rf *regularFile) truncate(size uint64) (bool, error) {
- rf.mu.Lock()
- defer rf.mu.Unlock()
+func (rf *regularFile) truncate(newSize uint64) (bool, error) {
+ rf.inode.mu.Lock()
+ defer rf.inode.mu.Unlock()
+ return rf.truncateLocked(newSize)
+}
- if size == rf.size {
+// Preconditions: rf.inode.mu must be held.
+func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) {
+ oldSize := rf.size
+ if newSize == oldSize {
// Nothing to do.
return false, nil
}
- if size > rf.size {
- // Growing the file.
+ // Need to hold inode.mu and dataMu while modifying size.
+ rf.dataMu.Lock()
+ if newSize > oldSize {
+ // Can we grow the file?
if rf.seals&linux.F_SEAL_GROW != 0 {
- // Seal does not allow growth.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- rf.size = size
+ // We only need to update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
return true, nil
}
- // Shrinking the file
+ // We are shrinking the file. First check if this is allowed.
if rf.seals&linux.F_SEAL_SHRINK != 0 {
- // Seal does not allow shrink.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- // TODO(gvisor.dev/issues/1197): Invalidate mappings once we have
- // mappings.
+ // Update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
+
+ // Invalidate past translations of truncated pages.
+ oldpgend := fs.OffsetPageEnd(int64(oldSize))
+ newpgend := fs.OffsetPageEnd(int64(newSize))
+ if newpgend < oldpgend {
+ rf.mapsMu.Lock()
+ rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/shmem.c:shmem_setattr() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ rf.mapsMu.Unlock()
+ }
- rf.data.Truncate(size, rf.memFile)
- rf.size = size
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them.
+ rf.dataMu.Lock()
+ rf.data.Truncate(newSize, rf.memFile)
+ rf.dataMu.Unlock()
return true, nil
}
+// AddMapping implements memmap.Mappable.AddMapping.
+func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+
+ // Reject writable mapping if F_SEAL_WRITE is set.
+ if rf.seals&linux.F_SEAL_WRITE != 0 && writable {
+ return syserror.EPERM
+ }
+
+ rf.mappings.AddMapping(ms, ar, offset, writable)
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages < pagesBefore {
+ panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+
+ rf.mappings.RemoveMapping(ms, ar, offset, writable)
+
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages > pagesBefore {
+ panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return rf.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ rf.dataMu.Lock()
+ defer rf.dataMu.Unlock()
+
+ // Constrain translations to f.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(int64(rf.size))
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ // Newly-allocated pages are zeroed, so we don't need to do anything.
+ return dsts.NumBytes(), nil
+ })
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: rf.memFile,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: usermem.AnyAccess,
+ })
+ translatedEnd = segMR.End
+ }
+
+ // Don't return the error returned by f.data.Fill if it occurred outside of
+ // required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (*regularFile) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
type regularFileFD struct {
fileDescription
@@ -152,8 +313,10 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Overflow.
return 0, syserror.EFBIG
}
+ f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
+ f.inode.mu.Unlock()
putRegularFileReadWriter(rw)
return n, err
}
@@ -215,6 +378,12 @@ func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng
return nil
}
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ file := fd.inode().impl.(*regularFile)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
+}
+
// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
type regularFileReadWriter struct {
file *regularFile
@@ -244,14 +413,15 @@ func putRegularFileReadWriter(rw *regularFileReadWriter) {
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- rw.file.mu.RLock()
+ rw.file.dataMu.RLock()
+ defer rw.file.dataMu.RUnlock()
+ size := rw.file.size
// Compute the range to read (limited by file size and overflow-checked).
- if rw.off >= rw.file.size {
- rw.file.mu.RUnlock()
+ if rw.off >= size {
return 0, io.EOF
}
- end := rw.file.size
+ end := size
if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
end = rend
}
@@ -265,7 +435,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
// Get internal mappings.
ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -275,7 +444,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -291,7 +459,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -299,13 +466,16 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
}
}
- rw.file.mu.RUnlock()
return done, nil
}
// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: inode.mu must be held.
func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- rw.file.mu.Lock()
+ // Hold dataMu so we can modify size.
+ rw.file.dataMu.Lock()
+ defer rw.file.dataMu.Unlock()
// Compute the range to write (overflow-checked).
end := rw.off + srcs.NumBytes()
@@ -316,7 +486,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
// Check if seals prevent either file growth or all writes.
switch {
case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
- rw.file.mu.Unlock()
return 0, syserror.EPERM
case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
// When growth is sealed, Linux effectively allows writes which would
@@ -338,7 +507,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
}
if end <= rw.off {
// Truncation would result in no data being written.
- rw.file.mu.Unlock()
return 0, syserror.EPERM
}
}
@@ -395,9 +563,8 @@ exitLoop:
// If the write ends beyond the file's previous size, it causes the
// file to grow.
if rw.off > rw.file.size {
- atomic.StoreUint64(&rw.file.size, rw.off)
+ rw.file.size = rw.off
}
- rw.file.mu.Unlock()
return done, retErr
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index c5bb17562..521206305 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -18,9 +18,10 @@
// Lock order:
//
// filesystem.mu
-// regularFileFD.offMu
-// regularFile.mu
// inode.mu
+// regularFileFD.offMu
+// regularFile.mapsMu
+// regularFile.dataMu
package tmpfs
import (
@@ -226,12 +227,15 @@ func (i *inode) tryIncRef() bool {
func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
- // This is unnecessary; it's mostly to simulate what tmpfs would do.
if regFile, ok := i.impl.(*regularFile); ok {
- regFile.mu.Lock()
+ // Hold inode.mu and regFile.dataMu while mutating
+ // size.
+ i.mu.Lock()
+ regFile.dataMu.Lock()
regFile.data.DropAll(regFile.memFile)
atomic.StoreUint64(&regFile.size, 0)
- regFile.mu.Unlock()
+ regFile.dataMu.Unlock()
+ i.mu.Unlock()
}
} else if refs < 0 {
panic("tmpfs.inode.decRef() called without holding a reference")
@@ -320,7 +324,7 @@ func (i *inode) setStat(stat linux.Statx) error {
if mask&linux.STATX_SIZE != 0 {
switch impl := i.impl.(type) {
case *regularFile:
- updated, err := impl.truncate(stat.Size)
+ updated, err := impl.truncateLocked(stat.Size)
if err != nil {
return err
}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 334432abf..07bf39fed 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -10,6 +10,7 @@ go_library(
srcs = [
"context.go",
"inet.go",
+ "namespace.go",
"test_stack.go",
],
deps = [
diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go
new file mode 100644
index 000000000..c16667e7f
--- /dev/null
+++ b/pkg/sentry/inet/namespace.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+// Namespace represents a network namespace. See network_namespaces(7).
+//
+// +stateify savable
+type Namespace struct {
+ // stack is the network stack implementation of this network namespace.
+ stack Stack `state:"nosave"`
+
+ // creator allows kernel to create new network stack for network namespaces.
+ // If nil, no networking will function if network is namespaced.
+ creator NetworkStackCreator
+
+ // isRoot indicates whether this is the root network namespace.
+ isRoot bool
+}
+
+// NewRootNamespace creates the root network namespace, with creator
+// allowing new network namespaces to be created. If creator is nil, no
+// networking will function if the network is namespaced.
+func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace {
+ return &Namespace{
+ stack: stack,
+ creator: creator,
+ isRoot: true,
+ }
+}
+
+// NewNamespace creates a new network namespace from the root.
+func NewNamespace(root *Namespace) *Namespace {
+ n := &Namespace{
+ creator: root.creator,
+ }
+ n.init()
+ return n
+}
+
+// Stack returns the network stack of n. Stack may return nil if no network
+// stack is configured.
+func (n *Namespace) Stack() Stack {
+ return n.stack
+}
+
+// IsRoot returns whether n is the root network namespace.
+func (n *Namespace) IsRoot() bool {
+ return n.isRoot
+}
+
+// RestoreRootStack restores the root network namespace with stack. This should
+// only be called when restoring kernel.
+func (n *Namespace) RestoreRootStack(stack Stack) {
+ if !n.isRoot {
+ panic("RestoreRootStack can only be called on root network namespace")
+ }
+ if n.stack != nil {
+ panic("RestoreRootStack called after a stack has already been set")
+ }
+ n.stack = stack
+}
+
+func (n *Namespace) init() {
+ // Root network namespace will have stack assigned later.
+ if n.isRoot {
+ return
+ }
+ if n.creator != nil {
+ var err error
+ n.stack, err = n.creator.CreateStack()
+ if err != nil {
+ panic(err)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (n *Namespace) afterLoad() {
+ n.init()
+}
+
+// NetworkStackCreator allows new instances of a network stack to be created. It
+// is used by the kernel to create new network namespaces when requested.
+type NetworkStackCreator interface {
+ // CreateStack creates a new network stack for a network namespace.
+ CreateStack() (Stack, error)
+}
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 23b88f7a6..58001d56c 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -296,6 +296,50 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return fds, nil
}
+// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
+// the given file description. If it succeeds, it takes a reference on file.
+func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ if minfd < 0 {
+ // Don't accept negative FDs.
+ return -1, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if minfd >= end {
+ return -1, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ fd := minfd
+ if fd < f.next {
+ fd = f.next
+ }
+ for fd < end {
+ if d, _, _ := f.get(fd); d == nil {
+ f.setVFS2(fd, file, flags)
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fd + 1
+ }
+ return fd, nil
+ }
+ fd++
+ }
+ return -1, syscall.EMFILE
+}
+
// NewFDAt sets the file reference for the given FD. If there is an active
// reference for that FD, the ref count for that existing reference is
// decremented.
@@ -316,9 +360,6 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
return syscall.EBADF
}
- f.mu.Lock()
- defer f.mu.Unlock()
-
// Check the limit for the provided file.
if limitSet := limits.FromContext(ctx); limitSet != nil {
if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur {
@@ -327,6 +368,8 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
}
// Install the entry.
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.setAll(fd, file, fileVFS2, flags)
return nil
}
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index 7218aa24e..47f78df9a 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -244,6 +244,28 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
old.DecRef()
}
+// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd.
+//
+// This is not a valid call after free.
+func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) {
+ if !vd.Ok() {
+ panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry")
+ }
+
+ f.mu.Lock()
+
+ if !f.rootVFS2.Ok() {
+ f.mu.Unlock()
+ panic(fmt.Sprintf("FSContext.SetRootDirectoryVFS2(%v)) called after destroy", vd))
+ }
+
+ old := f.rootVFS2
+ vd.IncRef()
+ f.rootVFS2 = vd
+ f.mu.Unlock()
+ old.DecRef()
+}
+
// Umask returns the current umask.
func (f *FSContext) Umask() uint {
f.mu.Lock()
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 7da0368f1..8b76750e9 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -111,7 +111,7 @@ type Kernel struct {
timekeeper *Timekeeper
tasks *TaskSet
rootUserNamespace *auth.UserNamespace
- networkStack inet.Stack `state:"nosave"`
+ rootNetworkNamespace *inet.Namespace
applicationCores uint
useHostCores bool
extraAuxv []arch.AuxEntry
@@ -247,6 +247,10 @@ type Kernel struct {
// VFS keeps the filesystem state used across the kernel.
vfs vfs.VirtualFilesystem
+
+ // If set to true, report address space activation waits as if the task is in
+ // external wait so that the watchdog doesn't report the task stuck.
+ SleepForAddressSpaceActivation bool
}
// InitKernelArgs holds arguments to Init.
@@ -260,8 +264,9 @@ type InitKernelArgs struct {
// RootUserNamespace is the root user namespace.
RootUserNamespace *auth.UserNamespace
- // NetworkStack is the TCP/IP network stack. NetworkStack may be nil.
- NetworkStack inet.Stack
+ // RootNetworkNamespace is the root network namespace. If nil, no networking
+ // will be available.
+ RootNetworkNamespace *inet.Namespace
// ApplicationCores is the number of logical CPUs visible to sandboxed
// applications. The set of logical CPU IDs is [0, ApplicationCores); thus
@@ -320,7 +325,10 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.rootUTSNamespace = args.RootUTSNamespace
k.rootIPCNamespace = args.RootIPCNamespace
k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace
- k.networkStack = args.NetworkStack
+ k.rootNetworkNamespace = args.RootNetworkNamespace
+ if k.rootNetworkNamespace == nil {
+ k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil)
+ }
k.applicationCores = args.ApplicationCores
if args.UseHostCores {
k.useHostCores = true
@@ -543,8 +551,6 @@ func (ts *TaskSet) unregisterEpollWaiters() {
func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
- k.networkStack = net
-
initAppCores := k.applicationCores
// Load the pre-saved CPUID FeatureSet.
@@ -575,6 +581,10 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
log.Infof("Kernel load stats: %s", &stats)
log.Infof("Kernel load took [%s].", time.Since(kernelStart))
+ // rootNetworkNamespace should be populated after loading the state file.
+ // Restore the root network stack.
+ k.rootNetworkNamespace.RestoreRootStack(net)
+
// Load the memory file's state.
memoryStart := time.Now()
if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
@@ -905,6 +915,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
FSContext: fsContext,
FDTable: args.FDTable,
Credentials: args.Credentials,
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores),
UTSNamespace: args.UTSNamespace,
IPCNamespace: args.IPCNamespace,
@@ -1255,10 +1266,9 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace {
return k.rootAbstractSocketNamespace
}
-// NetworkStack returns the network stack. NetworkStack may return nil if no
-// network stack is available.
-func (k *Kernel) NetworkStack() inet.Stack {
- return k.networkStack
+// RootNetworkNamespace returns the root network namespace, always non-nil.
+func (k *Kernel) RootNetworkNamespace() *inet.Namespace {
+ return k.rootNetworkNamespace
}
// GlobalInit returns the thread group with ID 1 in the root PID namespace, or
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index 18416643b..ded95f532 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -304,7 +304,7 @@ func (t *Task) rseqAddrInterrupt() {
}
var cs linux.RSeqCriticalSection
- if _, err := cs.CopyIn(t, critAddr); err != nil {
+ if err := cs.CopyIn(t, critAddr); err != nil {
t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index a3443ff21..2cee2e6ed 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -486,13 +486,10 @@ type Task struct {
numaPolicy int32
numaNodeMask uint64
- // If netns is true, the task is in a non-root network namespace. Network
- // namespaces aren't currently implemented in full; being in a network
- // namespace simply prevents the task from observing any network devices
- // (including loopback) or using abstract socket addresses (see unix(7)).
+ // netns is the task's network namespace. netns is never nil.
//
- // netns is protected by mu. netns is owned by the task goroutine.
- netns bool
+ // netns is protected by mu.
+ netns *inet.Namespace
// If rseqPreempted is true, before the next call to p.Switch(),
// interrupt rseq critical regions as defined by rseqAddr and
@@ -792,6 +789,15 @@ func (t *Task) NewFDFrom(fd int32, file *fs.File, flags FDFlags) (int32, error)
return fds[0], nil
}
+// NewFDFromVFS2 is a convenience wrapper for t.FDTable().NewFDVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) NewFDFromVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ return t.fdTable.NewFDVFS2(t, fd, file, flags)
+}
+
// NewFDAt is a convenience wrapper for t.FDTable().NewFDAt.
//
// This automatically passes the task as the context.
@@ -801,6 +807,15 @@ func (t *Task) NewFDAt(fd int32, file *fs.File, flags FDFlags) error {
return t.fdTable.NewFDAt(t, fd, file, flags)
}
+// NewFDAtVFS2 is a convenience wrapper for t.FDTable().NewFDAtVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDAtVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) error {
+ return t.fdTable.NewFDAtVFS2(t, fd, file, flags)
+}
+
// WithMuLocked executes f with t.mu locked.
func (t *Task) WithMuLocked(f func(*Task)) {
t.mu.Lock()
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index ba74b4c1c..78866f280 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -54,8 +55,7 @@ type SharingOptions struct {
NewUserNamespace bool
// If NewNetworkNamespace is true, the task should have an independent
- // network namespace. (Note that network namespaces are not really
- // implemented; see comment on Task.netns for details.)
+ // network namespace.
NewNetworkNamespace bool
// If NewFiles is true, the task should use an independent file descriptor
@@ -199,6 +199,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
ipcns = NewIPCNamespace(userns)
}
+ netns := t.NetworkNamespace()
+ if opts.NewNetworkNamespace {
+ netns = inet.NewNamespace(netns)
+ }
+
// TODO(b/63601033): Implement CLONE_NEWNS.
mntnsVFS2 := t.mountNamespaceVFS2
if mntnsVFS2 != nil {
@@ -268,7 +273,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
FDTable: fdTable,
Credentials: creds,
Niceness: t.Niceness(),
- NetworkNamespaced: t.netns,
+ NetworkNamespace: netns,
AllowedCPUMask: t.CPUMask(),
UTSNamespace: utsns,
IPCNamespace: ipcns,
@@ -283,9 +288,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else {
cfg.InheritParent = t
}
- if opts.NewNetworkNamespace {
- cfg.NetworkNamespaced = true
- }
nt, err := t.tg.pidns.owner.NewTask(cfg)
if err != nil {
if opts.NewThreadGroup {
@@ -482,7 +484,7 @@ func (t *Task) Unshare(opts *SharingOptions) error {
t.mu.Unlock()
return syserror.EPERM
}
- t.netns = true
+ t.netns = inet.NewNamespace(t.netns)
}
if opts.NewUTSNamespace {
if !haveCapSysAdmin {
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 2be982684..0158b1788 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -140,7 +140,7 @@ func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*Task
}
// Prepare a new user address space to load into.
- m := mm.NewMemoryManager(k, k)
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
defer m.DecUsers(ctx)
args.MemoryManager = m
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 8f57a34a6..00c425cca 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -220,7 +220,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.mu.Unlock()
t.unstopVforkParent()
// NOTE(b/30316266): All locks must be dropped prior to calling Activate.
- t.MemoryManager().Activate()
+ t.MemoryManager().Activate(t)
t.ptraceExec(oldTID)
return (*runSyscallExit)(nil)
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index 6d737d3e5..eeccaa197 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -32,21 +32,21 @@ const (
// Infof logs an formatted info message by calling log.Infof.
func (t *Task) Infof(fmt string, v ...interface{}) {
if log.IsLogging(log.Info) {
- log.Infof(t.logPrefix.Load().(string)+fmt, v...)
+ log.InfofAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Warningf logs a warning string by calling log.Warningf.
func (t *Task) Warningf(fmt string, v ...interface{}) {
if log.IsLogging(log.Warning) {
- log.Warningf(t.logPrefix.Load().(string)+fmt, v...)
+ log.WarningfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Debugf creates a debug string that includes the task ID.
func (t *Task) Debugf(fmt string, v ...interface{}) {
if log.IsLogging(log.Debug) {
- log.Debugf(t.logPrefix.Load().(string)+fmt, v...)
+ log.DebugfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
index 172a31e1d..f7711232c 100644
--- a/pkg/sentry/kernel/task_net.go
+++ b/pkg/sentry/kernel/task_net.go
@@ -22,14 +22,23 @@ import (
func (t *Task) IsNetworkNamespaced() bool {
t.mu.Lock()
defer t.mu.Unlock()
- return t.netns
+ return !t.netns.IsRoot()
}
// NetworkContext returns the network stack used by the task. NetworkContext
// may return nil if no network stack is available.
+//
+// TODO(gvisor.dev/issue/1833): Migrate callers of this method to
+// NetworkNamespace().
func (t *Task) NetworkContext() inet.Stack {
- if t.IsNetworkNamespaced() {
- return nil
- }
- return t.k.networkStack
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns.Stack()
+}
+
+// NetworkNamespace returns the network namespace observed by the task.
+func (t *Task) NetworkNamespace() *inet.Namespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns
}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index f9236a842..a5035bb7f 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
@@ -65,9 +66,8 @@ type TaskConfig struct {
// Niceness is the niceness of the new task.
Niceness int
- // If NetworkNamespaced is true, the new task should observe a non-root
- // network namespace.
- NetworkNamespaced bool
+ // NetworkNamespace is the network namespace to be used for the new task.
+ NetworkNamespace *inet.Namespace
// AllowedCPUMask contains the cpus that this task can run on.
AllowedCPUMask sched.CPUSet
@@ -133,7 +133,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
allowedCPUMask: cfg.AllowedCPUMask.Copy(),
ioUsage: &usage.IO{},
niceness: cfg.Niceness,
- netns: cfg.NetworkNamespaced,
+ netns: cfg.NetworkNamespace,
utsns: cfg.UTSNamespace,
ipcns: cfg.IPCNamespace,
abstractSockets: cfg.AbstractSocketNamespace,
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index 2bf3ce8a8..b02044ad2 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -30,7 +30,7 @@ var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown())
// Activate ensures that the task has an active address space.
func (t *Task) Activate() {
if mm := t.MemoryManager(); mm != nil {
- if err := mm.Activate(); err != nil {
+ if err := mm.Activate(t); err != nil {
panic("unable to activate mm: " + err.Error())
}
}
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
index e58a63deb..0332fc71c 100644
--- a/pkg/sentry/mm/address_space.go
+++ b/pkg/sentry/mm/address_space.go
@@ -18,7 +18,7 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -39,11 +39,18 @@ func (mm *MemoryManager) AddressSpace() platform.AddressSpace {
//
// When this MemoryManager is no longer needed by a task, it should call
// Deactivate to release the reference.
-func (mm *MemoryManager) Activate() error {
+func (mm *MemoryManager) Activate(ctx context.Context) error {
// Fast path: the MemoryManager already has an active
// platform.AddressSpace, and we just need to indicate that we need it too.
- if atomicbitops.IncUnlessZeroInt32(&mm.active) {
- return nil
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 0 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active+1) {
+ return nil
+ }
}
for {
@@ -85,16 +92,20 @@ func (mm *MemoryManager) Activate() error {
if as == nil {
// AddressSpace is unavailable, we must wait.
//
- // activeMu must not be held while waiting, as the user
- // of the address space we are waiting on may attempt
- // to take activeMu.
- //
- // Don't call UninterruptibleSleepStart to register the
- // wait to allow the watchdog stuck task to trigger in
- // case a process is starved waiting for the address
- // space.
+ // activeMu must not be held while waiting, as the user of the address
+ // space we are waiting on may attempt to take activeMu.
mm.activeMu.Unlock()
+
+ sleep := mm.p.CooperativelySchedulesAddressSpace() && mm.sleepForActivation
+ if sleep {
+ // Mark this task sleeping while waiting for the address space to
+ // prevent the watchdog from reporting it as a stuck task.
+ ctx.UninterruptibleSleepStart(false)
+ }
<-c
+ if sleep {
+ ctx.UninterruptibleSleepFinish(false)
+ }
continue
}
@@ -118,8 +129,15 @@ func (mm *MemoryManager) Activate() error {
func (mm *MemoryManager) Deactivate() {
// Fast path: this is not the last goroutine to deactivate the
// MemoryManager.
- if atomicbitops.DecUnlessOneInt32(&mm.active) {
- return
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 1 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active-1) {
+ return
+ }
}
mm.activeMu.Lock()
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 47b8fbf43..d8a5b9d29 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -18,7 +18,6 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/limits"
@@ -29,16 +28,17 @@ import (
)
// NewMemoryManager returns a new MemoryManager with no mappings and 1 user.
-func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *MemoryManager {
+func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider, sleepForActivation bool) *MemoryManager {
return &MemoryManager{
- p: p,
- mfp: mfp,
- haveASIO: p.SupportsAddressSpaceIO(),
- privateRefs: &privateRefs{},
- users: 1,
- auxv: arch.Auxv{},
- dumpability: UserDumpable,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ p: p,
+ mfp: mfp,
+ haveASIO: p.SupportsAddressSpaceIO(),
+ privateRefs: &privateRefs{},
+ users: 1,
+ auxv: arch.Auxv{},
+ dumpability: UserDumpable,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: sleepForActivation,
}
}
@@ -80,9 +80,10 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
envv: mm.envv,
auxv: append(arch.Auxv(nil), mm.auxv...),
// IncRef'd below, once we know that there isn't an error.
- executable: mm.executable,
- dumpability: mm.dumpability,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ executable: mm.executable,
+ dumpability: mm.dumpability,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: mm.sleepForActivation,
}
// Copy vmas.
@@ -229,7 +230,15 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
// IncUsers increments mm's user count and returns true. If the user count is
// already 0, IncUsers does nothing and returns false.
func (mm *MemoryManager) IncUsers() bool {
- return atomicbitops.IncUnlessZeroInt32(&mm.users)
+ for {
+ users := atomic.LoadInt32(&mm.users)
+ if users == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt32(&mm.users, users, users+1) {
+ return true
+ }
+ }
}
// DecUsers decrements mm's user count. If the user count reaches 0, all
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 637383c7a..c2195ae11 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -226,6 +226,11 @@ type MemoryManager struct {
// aioManager keeps track of AIOContexts used for async IOs. AIOManager
// must be cloned when CLONE_VM is used.
aioManager aioManager
+
+ // sleepForActivation indicates whether the task should report to be sleeping
+ // before trying to activate the address space. When set to true, delays in
+ // activation are not reported as stuck tasks by the watchdog.
+ sleepForActivation bool
}
// vma represents a virtual memory area.
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index edacca741..fdc308542 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -31,7 +31,7 @@ import (
func testMemoryManager(ctx context.Context) *MemoryManager {
p := platform.FromContext(ctx)
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
- mm := NewMemoryManager(p, mfp)
+ mm := NewMemoryManager(p, mfp, false)
mm.layout = arch.MmapLayout{
MinAddr: p.MinUserAddress(),
MaxAddr: p.MaxUserAddress(),
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 8076c7529..f1afc74dc 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -329,10 +329,12 @@ func (m *machine) Destroy() {
}
// Get gets an available vCPU.
+//
+// This will return with the OS thread locked.
func (m *machine) Get() *vCPU {
+ m.mu.RLock()
runtime.LockOSThread()
tid := procid.Current()
- m.mu.RLock()
// Check for an exact match.
if c := m.vCPUs[tid]; c != nil {
@@ -343,8 +345,22 @@ func (m *machine) Get() *vCPU {
// The happy path failed. We now proceed to acquire an exclusive lock
// (because the vCPU map may change), and scan all available vCPUs.
+ // In this case, we first unlock the OS thread. Otherwise, if mu is
+ // not available, the current system thread will be parked and a new
+ // system thread spawned. We avoid this situation by simply refreshing
+ // tid after relocking the system thread.
m.mu.RUnlock()
+ runtime.UnlockOSThread()
m.mu.Lock()
+ runtime.LockOSThread()
+ tid = procid.Current()
+
+ // Recheck for an exact match.
+ if c := m.vCPUs[tid]; c != nil {
+ c.lock()
+ m.mu.Unlock()
+ return c
+ }
for {
// Scan for an available vCPU.
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 4667373d2..8834a1e1a 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -329,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
}
// PackTClass packs an IPV6_TCLASS socket control message.
-func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
+func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IPV6,
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index c91ec7494..7cd2ce55b 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"extensions.go",
"netfilter.go",
+ "targets.go",
"tcp_matcher.go",
"udp_matcher.go",
],
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 257cb485b..f68a2260d 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -248,13 +248,15 @@ func marshalTarget(target iptables.Target) []byte {
return marshalStandardTarget(iptables.RuleReturn)
case iptables.RedirectTarget:
return marshalRedirectTarget()
+ case JumpTarget:
+ return marshalJumpTarget(tg)
default:
panic(fmt.Errorf("unknown target of type %T", target))
}
}
func marshalStandardTarget(verdict iptables.RuleVerdict) []byte {
- nflog("convert to binary: marshalling standard target with size %d", linux.SizeOfXTStandardTarget)
+ nflog("convert to binary: marshalling standard target")
// The target's name will be the empty string.
target := linux.XTStandardTarget{
@@ -290,8 +292,25 @@ func marshalRedirectTarget() []byte {
},
}
copy(target.Target.Name[:], redirectTargetName)
-
+
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalJumpTarget(jt JumpTarget) []byte {
+ nflog("convert to binary: marshalling jump target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ // Verdict is overloaded by the ABI. When positive, it holds
+ // the jump offset from the start of the table.
+ Verdict: int32(jt.Offset),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
return binary.Marshal(ret, usermem.ByteOrder, target)
}
@@ -358,7 +377,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// Convert input into a list of rules and their offsets.
var offset uint32
- var offsets []uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
nflog("set entries: processing entry at offset %d", offset)
@@ -419,11 +439,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
Target: target,
Matchers: matchers,
})
- offsets = append(offsets, offset)
+ offsets[offset] = int(entryIdx)
offset += uint32(entry.NextOffset)
if initialOptValLen-len(optVal) != int(entry.NextOffset) {
nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return syserr.ErrInvalidArgument
}
}
@@ -432,13 +453,13 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
for hook, _ := range replace.HookEntry {
if table.ValidHooks()&(1<<hook) != 0 {
hk := hookFromLinux(hook)
- for ruleIdx, offset := range offsets {
+ for offset, ruleIdx := range offsets {
if offset == replace.HookEntry[hook] {
table.BuiltinChains[hk] = ruleIdx
}
if offset == replace.Underflow[hook] {
if !validUnderflow(table.Rules[ruleIdx]) {
- nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP.")
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP")
return syserr.ErrInvalidArgument
}
table.Underflows[hk] = ruleIdx
@@ -467,16 +488,35 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// - There's some other rule after it.
// - There are no matchers.
if ruleIdx == len(table.Rules)-1 {
- nflog("user chain must have a rule or default policy.")
+ nflog("user chain must have a rule or default policy")
return syserr.ErrInvalidArgument
}
if len(table.Rules[ruleIdx].Matchers) != 0 {
- nflog("user chain's first node must have no matcheres.")
+ nflog("user chain's first node must have no matchers")
return syserr.ErrInvalidArgument
}
table.UserChains[target.Name] = ruleIdx + 1
}
+ // Set each jump to point to the appropriate rule. Right now they hold byte
+ // offsets.
+ for ruleIdx, rule := range table.Rules {
+ jump, ok := rule.Target.(JumpTarget)
+ if !ok {
+ continue
+ }
+
+ // Find the rule corresponding to the jump rule offset.
+ jumpTo, ok := offsets[jump.Offset]
+ if !ok {
+ nflog("failed to find a rule to jump to")
+ return syserr.ErrInvalidArgument
+ }
+ jump.RuleNum = jumpTo
+ rule.Target = jump
+ table.Rules[ruleIdx] = rule
+ }
+
// TODO(gvisor.dev/issue/170): Support other chains.
// Since we only support modifying the INPUT chain and redirect for
// PREROUTING chain right now, make sure all other chains point to
@@ -572,7 +612,12 @@ func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target
buf = optVal[:linux.SizeOfXTStandardTarget]
binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
- return translateToStandardTarget(standardTarget.Verdict)
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict)
+ }
+ // A verdict >= 0 indicates a jump.
+ return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
case errorTargetName:
// Error target.
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
new file mode 100644
index 000000000..c421b87cf
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+)
+
+// JumpTarget implements iptables.Target.
+type JumpTarget struct {
+ // Offset is the byte offset of the rule to jump to. It is used for
+ // marshaling and unmarshaling.
+ Offset uint32
+
+ // RuleNum is the rule to jump to.
+ RuleNum int
+}
+
+// Action implements iptables.Target.Action.
+func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) {
+ return iptables.RuleJump, jt.RuleNum
+}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9757fbfba..e187276c5 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1318,6 +1318,22 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
return ib, nil
+ case linux.IPV6_RECVTCLASS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1803,6 +1819,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+ case linux.IPV6_RECVTCLASS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -2086,7 +2110,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
- linux.IPV6_RECVTCLASS,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
linux.IPV6_TCLASS,
@@ -2424,6 +2447,8 @@ func (s *SocketOperations) controlMessages() socket.ControlMessages {
Timestamp: s.readCM.Timestamp,
HasTOS: s.readCM.HasTOS,
TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
HasIPPacketInfo: s.readCM.HasIPPacketInfo,
PacketInfo: s.readCM.PacketInfo,
},
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 5afff2564..5f181f017 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -75,6 +75,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
switch protocol {
case syscall.IPPROTO_ICMP:
return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
case syscall.IPPROTO_UDP:
return header.UDPProtocolNumber, true, nil
case syscall.IPPROTO_TCP:
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 2f39a6f2b..88d5db9fc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"capability.go",
"clone.go",
+ "epoll.go",
"futex.go",
"linux64_amd64.go",
"linux64_arm64.go",
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
new file mode 100644
index 000000000..a6e48b836
--- /dev/null
+++ b/pkg/sentry/strace/epoll.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
+ }
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x ", eventAddr)
+ writeEpollEvent(&sb, e)
+ return sb.String()
+}
+
+func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string {
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x {", eventsAddr)
+ addr := eventsAddr
+ for i := uint64(0); i < numEvents; i++ {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(addr, &e); err != nil {
+ fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err)
+ continue
+ }
+ writeEpollEvent(&sb, e)
+ if uint64(sb.Len()) >= maxBytes {
+ sb.WriteString("...")
+ break
+ }
+ if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok {
+ fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr)
+ continue
+ }
+ }
+ sb.WriteString("}")
+ return sb.String()
+}
+
+func writeEpollEvent(sb *strings.Builder, e linux.EpollEvent) {
+ events := epollEventEvents.Parse(uint64(e.Events))
+ fmt.Fprintf(sb, "{events=%s data=[%#x, %#x]}", events, e.Data[0], e.Data[1])
+}
+
+var epollCtlOps = abi.ValueSet{
+ linux.EPOLL_CTL_ADD: "EPOLL_CTL_ADD",
+ linux.EPOLL_CTL_DEL: "EPOLL_CTL_DEL",
+ linux.EPOLL_CTL_MOD: "EPOLL_CTL_MOD",
+}
+
+var epollEventEvents = abi.FlagSet{
+ {Flag: linux.EPOLLIN, Name: "EPOLLIN"},
+ {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"},
+ {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"},
+ {Flag: linux.EPOLLERR, Name: "EPOLLERR"},
+ {Flag: linux.EPOLLHUP, Name: "EPULLHUP"},
+ {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"},
+ {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"},
+ {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"},
+ {Flag: linux.EPOLLWRBAND, Name: "EPOLLWRBAND"},
+ {Flag: linux.EPOLLMSG, Name: "EPOLLMSG"},
+ {Flag: linux.EPOLLRDHUP, Name: "EPOLLRDHUP"},
+ {Flag: linux.EPOLLEXCLUSIVE, Name: "EPOLLEXCLUSIVE"},
+ {Flag: linux.EPOLLWAKEUP, Name: "EPOLLWAKEUP"},
+ {Flag: linux.EPOLLONESHOT, Name: "EPOLLONESHOT"},
+ {Flag: linux.EPOLLET, Name: "EPOLLET"},
+}
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index a4de545e9..71b92eaee 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -256,8 +256,8 @@ var linuxAMD64 = SyscallMap{
229: makeSyscallInfo("clock_getres", Hex, PostTimespec),
230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
231: makeSyscallInfo("exit_group", Hex),
- 232: makeSyscallInfo("epoll_wait", Hex, Hex, Hex, Hex),
- 233: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 232: makeSyscallInfo("epoll_wait", FD, EpollEvents, Hex, Hex),
+ 233: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
234: makeSyscallInfo("tgkill", Hex, Hex, Signal),
235: makeSyscallInfo("utimes", Path, Timeval),
// 236: vserver (not implemented in the Linux kernel)
@@ -305,7 +305,7 @@ var linuxAMD64 = SyscallMap{
278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
- 281: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 281: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
282: makeSyscallInfo("signalfd", Hex, Hex, Hex),
283: makeSyscallInfo("timerfd_create", Hex, Hex),
284: makeSyscallInfo("eventfd", Hex),
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
index 8bc38545f..bd7361a52 100644
--- a/pkg/sentry/strace/linux64_arm64.go
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -45,8 +45,8 @@ var linuxARM64 = SyscallMap{
18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
19: makeSyscallInfo("eventfd2", Hex, Hex),
20: makeSyscallInfo("epoll_create1", Hex),
- 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
- 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 21: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
+ 22: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
23: makeSyscallInfo("dup", FD),
24: makeSyscallInfo("dup3", FD, FD, Hex),
25: makeSyscallInfo("fcntl", FD, Hex, Hex),
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 46cb2a1cc..77655558e 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -481,6 +481,12 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
case PollFDs:
output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case EpollCtlOp:
+ output = append(output, epollCtlOps.Parse(uint64(args[arg].Int())))
+ case EpollEvent:
+ output = append(output, epollEvent(t, args[arg].Pointer()))
+ case EpollEvents:
+ output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize)))
case SelectFDSet:
output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
case Oct:
@@ -549,6 +555,8 @@ func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uint
output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer())
case PollFDs:
output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true)
+ case EpollEvents:
+ output[arg] = epollEvents(t, args[arg].Pointer(), uint64(rval), uint64(maximumBlobSize))
case GetSockOptVal:
output[arg] = getSockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Pointer() /* optLen */, maximumBlobSize, rval)
case SetSockOptVal:
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 446d1e0f6..7e69b9279 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -228,6 +228,16 @@ const (
// SockOptLevel is the optname argument in getsockopt(2) and
// setsockopt(2).
SockOptName
+
+ // EpollCtlOp is the op argument to epoll_ctl(2).
+ EpollCtlOp
+
+ // EpollEvent is the event argument in epoll_ctl(2).
+ EpollEvent
+
+ // EpollEvents is an array of struct epoll_event. It is the events
+ // argument in epoll_wait(2)/epoll_pwait(2).
+ EpollEvents
)
// defaultFormat is the syscall argument format to use if the actual format is
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index fbef5b376..3ab93fbde 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// EpollCreate1 implements the epoll_create1(2) linux syscall.
func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
@@ -164,3 +166,5 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+
+// LINT.ThenChange(vfs2/epoll.go)
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 421845ebb..c21f14dc0 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -130,6 +130,8 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string
return path, dirPath, nil
}
+// LINT.IfChange
+
func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -575,6 +577,10 @@ func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Ioctl implements linux syscall ioctl(2).
func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -650,6 +656,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/ioctl.go)
+
+// LINT.IfChange
+
// Getcwd implements the linux syscall getcwd(2).
func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -760,6 +770,10 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/fscontext.go)
+
+// LINT.IfChange
+
// Close implements linux syscall close(2).
func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -1094,6 +1108,8 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/fd.go)
+
const (
_FADV_NORMAL = 0
_FADV_RANDOM = 1
@@ -1141,6 +1157,8 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.IfChange
+
func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1421,6 +1439,10 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1480,6 +1502,10 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return n, nil, err
}
+// LINT.ThenChange(vfs2/stat.go)
+
+// LINT.IfChange
+
func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1516,6 +1542,10 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, unlinkAt(t, dirFD, addr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Truncate implements linux syscall truncate(2).
func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -1614,6 +1644,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/setstat.go)
+
// Umask implements linux syscall umask(2).
func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
mask := args[0].ModeT()
@@ -1621,6 +1653,8 @@ func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(mask), nil, nil
}
+// LINT.IfChange
+
// Change ownership of a file.
//
// uid and gid may be -1, in which case they will not be changed.
@@ -1987,6 +2021,10 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true)
}
+// LINT.ThenChange(vfs2/setstat.go)
+
+// LINT.IfChange
+
func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error {
newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */)
if err != nil {
@@ -2042,6 +2080,8 @@ func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
// Fallocate implements linux system call fallocate(2).
func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index f66f4ffde..b126fecc0 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -27,6 +27,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Getdents implements linux syscall getdents(2) for 64bit systems.
func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -244,3 +246,5 @@ func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error {
func (ds *direntSerializer) Written() int {
return ds.written
}
+
+// LINT.ThenChange(vfs2/getdents.go)
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
index 297e920c4..3f7691eae 100644
--- a/pkg/sentry/syscalls/linux/sys_lseek.go
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Lseek implements linux syscall lseek(2).
func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -52,3 +54,5 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
return uintptr(offset), nil, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 9959f6e61..91694d374 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -35,6 +35,8 @@ func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return uintptr(addr), nil, nil
}
+// LINT.IfChange
+
// Mmap implements linux syscall mmap(2).
func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
prot := args[2].Int()
@@ -104,6 +106,8 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(rv), nil, err
}
+// LINT.ThenChange(vfs2/mmap.go)
+
// Munmap implements linux syscall munmap(2).
func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64())
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 227692f06..78a2cb750 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskRead contains events that can be triggered on reads.
EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
@@ -388,3 +390,5 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index 8b66a9006..701b27b4a 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -23,6 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
return linux.Stat{
Dev: sattr.DeviceID,
@@ -131,8 +133,7 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err
return err
}
s := statFromAttrs(t, d.Inode.StableAttr, uattr)
- _, err = s.CopyOut(t, statAddr)
- return err
+ return s.CopyOut(t, statAddr)
}
// fstat implements fstat for the given *fs.File.
@@ -142,8 +143,7 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
return err
}
s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr)
- _, err = s.CopyOut(t, statAddr)
- return err
+ return s.CopyOut(t, statAddr)
}
// Statx implements linux syscall statx(2).
@@ -299,3 +299,5 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
_, err = t.CopyOut(addr, &statfs)
return err
}
+
+// LINT.ThenChange(vfs2/stat.go)
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
index 3e55235bd..5ad465ae3 100644
--- a/pkg/sentry/syscalls/linux/sys_sync.go
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Sync implements linux system call sync(2).
func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
t.MountNamespace().SyncAll(t)
@@ -135,3 +137,5 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
+
+// LINT.ThenChange(vfs2/sync.go)
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index aba892939..506ee54ce 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskWrite contains events that can be triggered on writes.
//
@@ -358,3 +360,5 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index 9d8140b8a..2de5e3422 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// GetXattr implements linux syscall getxattr(2).
func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return getXattrFromPath(t, args, true)
@@ -418,3 +420,5 @@ func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error {
return d.Inode.RemoveXattr(t, d, name)
}
+
+// LINT.ThenChange(vfs2/xattr.go)
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 6b8a00b6e..f51761e81 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -5,18 +5,44 @@ package(licenses = ["notice"])
go_library(
name = "vfs2",
srcs = [
+ "epoll.go",
+ "epoll_unsafe.go",
+ "execve.go",
+ "fd.go",
+ "filesystem.go",
+ "fscontext.go",
+ "getdents.go",
+ "ioctl.go",
"linux64.go",
"linux64_override_amd64.go",
"linux64_override_arm64.go",
- "sys_read.go",
+ "mmap.go",
+ "path.go",
+ "poll.go",
+ "read_write.go",
+ "setstat.go",
+ "stat.go",
+ "sync.go",
+ "xattr.go",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
+ "//pkg/abi/linux",
+ "//pkg/fspath",
+ "//pkg/gohacks",
"//pkg/sentry/arch",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/memmap",
"//pkg/sentry/syscalls",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
new file mode 100644
index 000000000..d6cb0e79a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -0,0 +1,225 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EpollCreate1 implements Linux syscall epoll_create1(2).
+func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^linux.EPOLL_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCreate implements Linux syscall epoll_create(2).
+func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+
+ // "Since Linux 2.6.8, the size argument is ignored, but must be greater
+ // than zero" - epoll_create(2)
+ if size <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCtl implements Linux syscall epoll_ctl(2).
+func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ op := args[1].Int()
+ fd := args[2].Int()
+ eventAddr := args[3].Pointer()
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+ if epfile == file {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var event linux.EpollEvent
+ switch op {
+ case linux.EPOLL_CTL_ADD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.AddInterest(file, fd, event)
+ case linux.EPOLL_CTL_DEL:
+ return 0, nil, ep.DeleteInterest(file, fd)
+ case linux.EPOLL_CTL_MOD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.ModifyInterest(file, fd, event)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// EpollWait implements Linux syscall epoll_wait(2).
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeout := int(args[3].Int())
+
+ const _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
+ if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
+ return 0, nil, syserror.EINVAL
+ }
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent,
+ // maxEvents), so that the buffer can be allocated on the stack.
+ var (
+ events [16]linux.EpollEvent
+ total int
+ ch chan struct{}
+ haveDeadline bool
+ deadline ktime.Time
+ )
+ for {
+ batchEvents := len(events)
+ if batchEvents > maxEvents {
+ batchEvents = maxEvents
+ }
+ n := ep.ReadEvents(events[:batchEvents])
+ maxEvents -= n
+ if n != 0 {
+ // Copy what we read out.
+ copiedEvents, err := copyOutEvents(t, eventsAddr, events[:n])
+ eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
+ total += copiedEvents
+ if err != nil {
+ if total != 0 {
+ return uintptr(total), nil, nil
+ }
+ return 0, nil, err
+ }
+ // If we've filled the application's event buffer, we're done.
+ if maxEvents == 0 {
+ return uintptr(total), nil, nil
+ }
+ // Loop if we read a full batch, under the expectation that there
+ // may be more events to read.
+ if n == batchEvents {
+ continue
+ }
+ }
+ // We get here if n != batchEvents. If we read any number of events
+ // (just now, or in a previous iteration of this loop), or if timeout
+ // is 0 (such that epoll_wait should be non-blocking), return the
+ // events we've read so far to the application.
+ if total != 0 || timeout == 0 {
+ return uintptr(total), nil, nil
+ }
+ // In the first iteration of this loop, register with the epoll
+ // instance for readability events, but then immediately continue the
+ // loop since we need to retry ReadEvents() before blocking. In all
+ // subsequent iterations, block until events are available, the timeout
+ // expires, or an interrupt arrives.
+ if ch == nil {
+ var w waiter.Entry
+ w, ch = waiter.NewChannelEntry(nil)
+ epfile.EventRegister(&w, waiter.EventIn)
+ defer epfile.EventUnregister(&w)
+ } else {
+ // Set up the timer if a timeout was specified.
+ if timeout > 0 && !haveDeadline {
+ timeoutDur := time.Duration(timeout) * time.Millisecond
+ deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
+ haveDeadline = true
+ }
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ // total must be 0 since otherwise we would have returned
+ // above.
+ return 0, nil, err
+ }
+ }
+ }
+}
+
+// EpollPwait implements Linux syscall epoll_pwait(2).
+func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ return EpollWait(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
new file mode 100644
index 000000000..825f325bf
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "reflect"
+ "runtime"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const sizeofEpollEvent = int(unsafe.Sizeof(linux.EpollEvent{}))
+
+func copyOutEvents(t *kernel.Task, addr usermem.Addr, events []linux.EpollEvent) (int, error) {
+ if len(events) == 0 {
+ return 0, nil
+ }
+ // Cast events to a byte slice for copying.
+ var eventBytes []byte
+ eventBytesHdr := (*reflect.SliceHeader)(unsafe.Pointer(&eventBytes))
+ eventBytesHdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(&events[0])))
+ eventBytesHdr.Len = len(events) * sizeofEpollEvent
+ eventBytesHdr.Cap = len(events) * sizeofEpollEvent
+ copiedBytes, err := t.CopyOutBytes(addr, eventBytes)
+ runtime.KeepAlive(events)
+ copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
+ return copiedEvents, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
new file mode 100644
index 000000000..aef0078a8
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Execve implements linux syscall execve(2).
+func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathnameAddr := args[0].Pointer()
+ argvAddr := args[1].Pointer()
+ envvAddr := args[2].Pointer()
+ return execveat(t, linux.AT_FDCWD, pathnameAddr, argvAddr, envvAddr, 0 /* flags */)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
+ if err != nil {
+ return 0, nil, err
+ }
+ var argv, envv []string
+ if argvAddr != 0 {
+ var err error
+ argv, err = t.CopyInVector(argvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if envvAddr != 0 {
+ var err error
+ envv, err = t.CopyInVector(envvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ var executable fsbridge.File
+ closeOnExec := false
+ if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute {
+ // We must open the executable ourselves since dirfd is used as the
+ // starting point while resolving path, but the task working directory
+ // is used as the starting point while resolving interpreters (Linux:
+ // fs/binfmt_script.c:load_script() => fs/exec.c:open_exec() =>
+ // do_open_execat(fd=AT_FDCWD)), and the loader package is currently
+ // incapable of handling this correctly.
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ start := dirfile.VirtualDentry()
+ start.IncRef()
+ dirfile.DecRef()
+ closeOnExec = dirfileFlags.CloseOnExec
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ })
+ start.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ executable = fsbridge.NewVFSFile(file)
+ }
+
+ // Load the new TaskContext.
+ mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
+ defer mntns.DecRef()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef()
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Opener: fsbridge.NewVFSLookup(mntns, root, wd),
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ Filename: pathname,
+ File: executable,
+ CloseOnExec: closeOnExec,
+ Argv: argv,
+ Envv: envv,
+ Features: t.Arch().FeatureSet(),
+ }
+
+ tc, se := t.Kernel().LoadTaskImage(t, loadArgs)
+ if se != nil {
+ return 0, nil, se.ToError()
+ }
+
+ ctrl, err := t.Execve(tc)
+ return 0, ctrl, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
new file mode 100644
index 000000000..3afcea665
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -0,0 +1,147 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Close implements Linux syscall close(2).
+func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ // Note that Remove provides a reference on the file that we may use to
+ // flush. It is still active until we drop the final reference below
+ // (and other reference-holding operations complete).
+ _, file := t.FDTable().Remove(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.OnClose(t)
+ return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file)
+}
+
+// Dup implements Linux syscall dup(2).
+func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, syserror.EMFILE
+ }
+ return uintptr(newFD), nil, nil
+}
+
+// Dup2 implements Linux syscall dup2(2).
+func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+
+ if oldfd == newfd {
+ // As long as oldfd is valid, dup2() does nothing and returns newfd.
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ file.DecRef()
+ return uintptr(newfd), nil, nil
+ }
+
+ return dup3(t, oldfd, newfd, 0)
+}
+
+// Dup3 implements Linux syscall dup3(2).
+func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+ flags := args[2].Uint()
+
+ if oldfd == newfd {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return dup3(t, oldfd, newfd, flags)
+}
+
+func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^linux.O_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(newfd), nil, nil
+}
+
+// Fcntl implements linux syscall fcntl(2).
+func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ cmd := args[1].Int()
+
+ file, flags := t.FDTable().GetVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
+ minfd := args[2].Int()
+ fd, err := t.NewFDFromVFS2(minfd, file, kernel.FDFlags{
+ CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+ case linux.F_GETFD:
+ return uintptr(flags.ToLinuxFDFlags()), nil, nil
+ case linux.F_SETFD:
+ flags := args[2].Uint()
+ t.FDTable().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: flags&linux.FD_CLOEXEC != 0,
+ })
+ return 0, nil, nil
+ case linux.F_GETFL:
+ return uintptr(file.StatusFlags()), nil, nil
+ case linux.F_SETFL:
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
+ default:
+ // TODO(gvisor.dev/issue/1623): Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
new file mode 100644
index 000000000..fc5ceea4c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -0,0 +1,326 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Link implements Linux syscall link(2).
+func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, linkat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Linkat implements Linux syscall linkat(2).
+func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+ if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) {
+ return syserror.ENOENT
+ }
+
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_FOLLOW != 0))
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop)
+}
+
+// Mkdir implements Linux syscall mkdir(2).
+func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, mkdirat(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Mkdirat implements Linux syscall mkdirat(2).
+func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, mkdirat(t, dirfd, addr, mode)
+}
+
+func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+}
+
+// Mknod implements Linux syscall mknod(2).
+func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ dev := args[2].Uint()
+ return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev)
+}
+
+// Mknodat implements Linux syscall mknodat(2).
+func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ dev := args[3].Uint()
+ return 0, nil, mknodat(t, dirfd, addr, mode, dev)
+}
+
+func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ major, minor := linux.DecodeDeviceID(dev)
+ return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{
+ Mode: linux.FileMode(mode &^ t.FSContext().Umask()),
+ DevMajor: uint32(major),
+ DevMinor: minor,
+ })
+}
+
+// Open implements Linux syscall open(2).
+func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+ mode := args[2].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, flags, mode)
+}
+
+// Openat implements Linux syscall openat(2).
+func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ flags := args[2].Uint()
+ mode := args[3].ModeT()
+ return openat(t, dirfd, addr, flags, mode)
+}
+
+// Creat implements Linux syscall creat(2).
+func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode)
+}
+
+func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.O_NOFOLLOW == 0))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
+ Flags: flags,
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ return uintptr(fd), nil, err
+}
+
+// Rename implements Linux syscall rename(2).
+func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, renameat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Renameat implements Linux syscall renameat(2).
+func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, 0 /* flags */)
+}
+
+// Renameat2 implements Linux syscall renameat2(2).
+func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Uint()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error {
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ // "If oldpath refers to a symbolic link, the link is renamed" - rename(2)
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{
+ Flags: flags,
+ })
+}
+
+// Rmdir implements Linux syscall rmdir(2).
+func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlink implements Linux syscall unlink(2).
+func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlinkat implements Linux syscall unlinkat(2).
+func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if flags&^linux.AT_REMOVEDIR != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if flags&linux.AT_REMOVEDIR != 0 {
+ return 0, nil, rmdirat(t, dirfd, pathAddr)
+ }
+ return 0, nil, unlinkat(t, dirfd, pathAddr)
+}
+
+// Symlink implements Linux syscall symlink(2).
+func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ linkpathAddr := args[1].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, linux.AT_FDCWD, linkpathAddr)
+}
+
+// Symlinkat implements Linux syscall symlinkat(2).
+func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ newdirfd := args[1].Int()
+ linkpathAddr := args[2].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr)
+}
+
+func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error {
+ target, err := t.CopyInString(targetAddr, linux.PATH_MAX)
+ if err != nil {
+ return err
+ }
+ linkpath, err := copyInPath(t, linkpathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, newdirfd, linkpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
new file mode 100644
index 000000000..317409a18
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Getcwd implements Linux syscall getcwd(2).
+func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ size := args[1].SizeT()
+
+ root := t.FSContext().RootDirectoryVFS2()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd)
+ root.DecRef()
+ wd.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Note this is >= because we need a terminator.
+ if uint(len(s)) >= size {
+ return 0, nil, syserror.ERANGE
+ }
+
+ // Construct a byte slice containing a NUL terminator.
+ buf := t.CopyScratchBuffer(len(s) + 1)
+ copy(buf, s)
+ buf[len(buf)-1] = 0
+
+ // Write the pathname slice.
+ n, err := t.CopyOutBytes(addr, buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Chdir implements Linux syscall chdir(2).
+func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Fchdir implements Linux syscall fchdir(2).
+func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Chroot implements Linux syscall chroot(2).
+func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ if !t.HasCapability(linux.CAP_SYS_CHROOT) {
+ return 0, nil, syserror.EPERM
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetRootDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
new file mode 100644
index 000000000..ddc140b65
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Getdents implements Linux syscall getdents(2).
+func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, false /* isGetdents64 */)
+}
+
+// Getdents64 implements Linux syscall getdents64(2).
+func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, true /* isGetdents64 */)
+}
+
+func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ cb := getGetdentsCallback(t, addr, size, isGetdents64)
+ err := file.IterDirents(t, cb)
+ n := size - cb.remaining
+ putGetdentsCallback(cb)
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+type getdentsCallback struct {
+ t *kernel.Task
+ addr usermem.Addr
+ remaining int
+ isGetdents64 bool
+}
+
+var getdentsCallbackPool = sync.Pool{
+ New: func() interface{} {
+ return &getdentsCallback{}
+ },
+}
+
+func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback {
+ cb := getdentsCallbackPool.Get().(*getdentsCallback)
+ *cb = getdentsCallback{
+ t: t,
+ addr: addr,
+ remaining: size,
+ isGetdents64: isGetdents64,
+ }
+ return cb
+}
+
+func putGetdentsCallback(cb *getdentsCallback) {
+ cb.t = nil
+ getdentsCallbackPool.Put(cb)
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
+ var buf []byte
+ if cb.isGetdents64 {
+ // struct linux_dirent64 {
+ // ino64_t d_ino; /* 64-bit inode number */
+ // off64_t d_off; /* 64-bit offset to next structure */
+ // unsigned short d_reclen; /* Size of this dirent */
+ // unsigned char d_type; /* File type */
+ // char d_name[]; /* Filename (null-terminated) */
+ // };
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ buf[18] = dirent.Type
+ copy(buf[19:], dirent.Name)
+ buf[size-1] = 0 // NUL terminator
+ } else {
+ // struct linux_dirent {
+ // unsigned long d_ino; /* Inode number */
+ // unsigned long d_off; /* Offset to next linux_dirent */
+ // unsigned short d_reclen; /* Length of this linux_dirent */
+ // char d_name[]; /* Filename (null-terminated) */
+ // /* length is actually (d_reclen - 2 -
+ // offsetof(struct linux_dirent, d_name)) */
+ // /*
+ // char pad; // Zero padding byte
+ // char d_type; // File type (only since Linux
+ // // 2.6.4); offset is (d_reclen - 1)
+ // */
+ // };
+ if cb.t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
+ }
+ size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ copy(buf[18:], dirent.Name)
+ buf[size-3] = 0 // NUL terminator
+ buf[size-2] = 0 // zero padding byte
+ buf[size-1] = dirent.Type
+ }
+ n, err := cb.t.CopyOutBytes(cb.addr, buf)
+ if err != nil {
+ // Don't report partially-written dirents by advancing cb.addr or
+ // cb.remaining.
+ return err
+ }
+ cb.addr += usermem.Addr(n)
+ cb.remaining -= n
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
new file mode 100644
index 000000000..5a2418da9
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Ioctl implements Linux syscall ioctl(2).
+func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ ret, err := file.Ioctl(t, t.MemoryManager(), args)
+ return ret, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
index e0ac32b33..7d220bc20 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package vfs2
import (
@@ -22,110 +24,142 @@ import (
// Override syscall table to add syscalls implementations from this package.
func Override(table map[uintptr]kernel.Syscall) {
table[0] = syscalls.Supported("read", Read)
-
- // Remove syscalls that haven't been converted yet. It's better to get ENOSYS
- // rather than a SIGSEGV deep in the stack.
- delete(table, 1) // write
- delete(table, 2) // open
- delete(table, 3) // close
- delete(table, 4) // stat
- delete(table, 5) // fstat
- delete(table, 6) // lstat
- delete(table, 7) // poll
- delete(table, 8) // lseek
- delete(table, 9) // mmap
- delete(table, 16) // ioctl
- delete(table, 17) // pread64
- delete(table, 18) // pwrite64
- delete(table, 19) // readv
- delete(table, 20) // writev
- delete(table, 21) // access
- delete(table, 22) // pipe
- delete(table, 32) // dup
- delete(table, 33) // dup2
- delete(table, 40) // sendfile
- delete(table, 59) // execve
- delete(table, 72) // fcntl
- delete(table, 73) // flock
- delete(table, 74) // fsync
- delete(table, 75) // fdatasync
- delete(table, 76) // truncate
- delete(table, 77) // ftruncate
- delete(table, 78) // getdents
- delete(table, 79) // getcwd
- delete(table, 80) // chdir
- delete(table, 81) // fchdir
- delete(table, 82) // rename
- delete(table, 83) // mkdir
- delete(table, 84) // rmdir
- delete(table, 85) // creat
- delete(table, 86) // link
- delete(table, 87) // unlink
- delete(table, 88) // symlink
- delete(table, 89) // readlink
- delete(table, 90) // chmod
- delete(table, 91) // fchmod
- delete(table, 92) // chown
- delete(table, 93) // fchown
- delete(table, 94) // lchown
- delete(table, 133) // mknod
- delete(table, 137) // statfs
- delete(table, 138) // fstatfs
- delete(table, 161) // chroot
- delete(table, 162) // sync
+ table[1] = syscalls.Supported("write", Write)
+ table[2] = syscalls.Supported("open", Open)
+ table[3] = syscalls.Supported("close", Close)
+ table[4] = syscalls.Supported("stat", Stat)
+ table[5] = syscalls.Supported("fstat", Fstat)
+ table[6] = syscalls.Supported("lstat", Lstat)
+ table[7] = syscalls.Supported("poll", Poll)
+ table[8] = syscalls.Supported("lseek", Lseek)
+ table[9] = syscalls.Supported("mmap", Mmap)
+ table[16] = syscalls.Supported("ioctl", Ioctl)
+ table[17] = syscalls.Supported("pread64", Pread64)
+ table[18] = syscalls.Supported("pwrite64", Pwrite64)
+ table[19] = syscalls.Supported("readv", Readv)
+ table[20] = syscalls.Supported("writev", Writev)
+ table[21] = syscalls.Supported("access", Access)
+ delete(table, 22) // pipe
+ table[23] = syscalls.Supported("select", Select)
+ table[32] = syscalls.Supported("dup", Dup)
+ table[33] = syscalls.Supported("dup2", Dup2)
+ delete(table, 40) // sendfile
+ delete(table, 41) // socket
+ delete(table, 42) // connect
+ delete(table, 43) // accept
+ delete(table, 44) // sendto
+ delete(table, 45) // recvfrom
+ delete(table, 46) // sendmsg
+ delete(table, 47) // recvmsg
+ delete(table, 48) // shutdown
+ delete(table, 49) // bind
+ delete(table, 50) // listen
+ delete(table, 51) // getsockname
+ delete(table, 52) // getpeername
+ delete(table, 53) // socketpair
+ delete(table, 54) // setsockopt
+ delete(table, 55) // getsockopt
+ table[59] = syscalls.Supported("execve", Execve)
+ table[72] = syscalls.Supported("fcntl", Fcntl)
+ delete(table, 73) // flock
+ table[74] = syscalls.Supported("fsync", Fsync)
+ table[75] = syscalls.Supported("fdatasync", Fdatasync)
+ table[76] = syscalls.Supported("truncate", Truncate)
+ table[77] = syscalls.Supported("ftruncate", Ftruncate)
+ table[78] = syscalls.Supported("getdents", Getdents)
+ table[79] = syscalls.Supported("getcwd", Getcwd)
+ table[80] = syscalls.Supported("chdir", Chdir)
+ table[81] = syscalls.Supported("fchdir", Fchdir)
+ table[82] = syscalls.Supported("rename", Rename)
+ table[83] = syscalls.Supported("mkdir", Mkdir)
+ table[84] = syscalls.Supported("rmdir", Rmdir)
+ table[85] = syscalls.Supported("creat", Creat)
+ table[86] = syscalls.Supported("link", Link)
+ table[87] = syscalls.Supported("unlink", Unlink)
+ table[88] = syscalls.Supported("symlink", Symlink)
+ table[89] = syscalls.Supported("readlink", Readlink)
+ table[90] = syscalls.Supported("chmod", Chmod)
+ table[91] = syscalls.Supported("fchmod", Fchmod)
+ table[92] = syscalls.Supported("chown", Chown)
+ table[93] = syscalls.Supported("fchown", Fchown)
+ table[94] = syscalls.Supported("lchown", Lchown)
+ table[132] = syscalls.Supported("utime", Utime)
+ table[133] = syscalls.Supported("mknod", Mknod)
+ table[137] = syscalls.Supported("statfs", Statfs)
+ table[138] = syscalls.Supported("fstatfs", Fstatfs)
+ table[161] = syscalls.Supported("chroot", Chroot)
+ table[162] = syscalls.Supported("sync", Sync)
delete(table, 165) // mount
delete(table, 166) // umount2
- delete(table, 172) // iopl
- delete(table, 173) // ioperm
delete(table, 187) // readahead
- delete(table, 188) // setxattr
- delete(table, 189) // lsetxattr
- delete(table, 190) // fsetxattr
- delete(table, 191) // getxattr
- delete(table, 192) // lgetxattr
- delete(table, 193) // fgetxattr
+ table[188] = syscalls.Supported("setxattr", Setxattr)
+ table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
+ table[190] = syscalls.Supported("fsetxattr", Fsetxattr)
+ table[191] = syscalls.Supported("getxattr", Getxattr)
+ table[192] = syscalls.Supported("lgetxattr", Lgetxattr)
+ table[193] = syscalls.Supported("fgetxattr", Fgetxattr)
+ table[194] = syscalls.Supported("listxattr", Listxattr)
+ table[195] = syscalls.Supported("llistxattr", Llistxattr)
+ table[196] = syscalls.Supported("flistxattr", Flistxattr)
+ table[197] = syscalls.Supported("removexattr", Removexattr)
+ table[198] = syscalls.Supported("lremovexattr", Lremovexattr)
+ table[199] = syscalls.Supported("fremovexattr", Fremovexattr)
delete(table, 206) // io_setup
delete(table, 207) // io_destroy
delete(table, 208) // io_getevents
delete(table, 209) // io_submit
delete(table, 210) // io_cancel
- delete(table, 213) // epoll_create
- delete(table, 214) // epoll_ctl_old
- delete(table, 215) // epoll_wait_old
- delete(table, 216) // remap_file_pages
- delete(table, 217) // getdents64
- delete(table, 232) // epoll_wait
- delete(table, 233) // epoll_ctl
+ table[213] = syscalls.Supported("epoll_create", EpollCreate)
+ table[217] = syscalls.Supported("getdents64", Getdents64)
+ delete(table, 221) // fdavise64
+ table[232] = syscalls.Supported("epoll_wait", EpollWait)
+ table[233] = syscalls.Supported("epoll_ctl", EpollCtl)
+ table[235] = syscalls.Supported("utimes", Utimes)
delete(table, 253) // inotify_init
delete(table, 254) // inotify_add_watch
delete(table, 255) // inotify_rm_watch
- delete(table, 257) // openat
- delete(table, 258) // mkdirat
- delete(table, 259) // mknodat
- delete(table, 260) // fchownat
- delete(table, 261) // futimesat
- delete(table, 262) // fstatat
- delete(table, 263) // unlinkat
- delete(table, 264) // renameat
- delete(table, 265) // linkat
- delete(table, 266) // symlinkat
- delete(table, 267) // readlinkat
- delete(table, 268) // fchmodat
- delete(table, 269) // faccessat
- delete(table, 270) // pselect
- delete(table, 271) // ppoll
+ table[257] = syscalls.Supported("openat", Openat)
+ table[258] = syscalls.Supported("mkdirat", Mkdirat)
+ table[259] = syscalls.Supported("mknodat", Mknodat)
+ table[260] = syscalls.Supported("fchownat", Fchownat)
+ table[261] = syscalls.Supported("futimens", Futimens)
+ table[262] = syscalls.Supported("newfstatat", Newfstatat)
+ table[263] = syscalls.Supported("unlinkat", Unlinkat)
+ table[264] = syscalls.Supported("renameat", Renameat)
+ table[265] = syscalls.Supported("linkat", Linkat)
+ table[266] = syscalls.Supported("symlinkat", Symlinkat)
+ table[267] = syscalls.Supported("readlinkat", Readlinkat)
+ table[268] = syscalls.Supported("fchmodat", Fchmodat)
+ table[269] = syscalls.Supported("faccessat", Faccessat)
+ table[270] = syscalls.Supported("pselect", Pselect)
+ table[271] = syscalls.Supported("ppoll", Ppoll)
+ delete(table, 275) // splice
+ delete(table, 276) // tee
+ table[277] = syscalls.Supported("sync_file_range", SyncFileRange)
+ table[280] = syscalls.Supported("utimensat", Utimensat)
+ table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
+ delete(table, 282) // signalfd
+ delete(table, 283) // timerfd_create
+ delete(table, 284) // eventfd
delete(table, 285) // fallocate
- delete(table, 291) // epoll_create1
- delete(table, 292) // dup3
+ delete(table, 286) // timerfd_settime
+ delete(table, 287) // timerfd_gettime
+ delete(table, 288) // accept4
+ delete(table, 289) // signalfd4
+ delete(table, 290) // eventfd2
+ table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
+ table[292] = syscalls.Supported("dup3", Dup3)
delete(table, 293) // pipe2
delete(table, 294) // inotify_init1
- delete(table, 295) // preadv
- delete(table, 296) // pwritev
- delete(table, 306) // syncfs
- delete(table, 316) // renameat2
+ table[295] = syscalls.Supported("preadv", Preadv)
+ table[296] = syscalls.Supported("pwritev", Pwritev)
+ delete(table, 299) // recvmmsg
+ table[306] = syscalls.Supported("syncfs", Syncfs)
+ delete(table, 307) // sendmmsg
+ table[316] = syscalls.Supported("renameat2", Renameat2)
delete(table, 319) // memfd_create
- delete(table, 322) // execveat
- delete(table, 327) // preadv2
- delete(table, 328) // pwritev2
- delete(table, 332) // statx
+ table[322] = syscalls.Supported("execveat", Execveat)
+ table[327] = syscalls.Supported("preadv2", Preadv2)
+ table[328] = syscalls.Supported("pwritev2", Pwritev2)
+ table[332] = syscalls.Supported("statx", Statx)
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
index 6af5c400f..a6b367468 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package vfs2
import (
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
new file mode 100644
index 000000000..60a43f0a0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mmap implements Linux syscall mmap(2).
+func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ prot := args[2].Int()
+ flags := args[3].Int()
+ fd := args[4].Int()
+ fixed := flags&linux.MAP_FIXED != 0
+ private := flags&linux.MAP_PRIVATE != 0
+ shared := flags&linux.MAP_SHARED != 0
+ anon := flags&linux.MAP_ANONYMOUS != 0
+ map32bit := flags&linux.MAP_32BIT != 0
+
+ // Require exactly one of MAP_PRIVATE and MAP_SHARED.
+ if private == shared {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := memmap.MMapOpts{
+ Length: args[1].Uint64(),
+ Offset: args[5].Uint64(),
+ Addr: args[0].Pointer(),
+ Fixed: fixed,
+ Unmap: fixed,
+ Map32Bit: map32bit,
+ Private: private,
+ Perms: usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ },
+ MaxPerms: usermem.AnyAccess,
+ GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
+ Precommit: linux.MAP_POPULATE&flags != 0,
+ }
+ if linux.MAP_LOCKED&flags != 0 {
+ opts.MLockMode = memmap.MLockEager
+ }
+ defer func() {
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef()
+ }
+ }()
+
+ if !anon {
+ // Convert the passed FD to a file reference.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // mmap unconditionally requires that the FD is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EACCES
+ }
+ // MAP_SHARED requires that the FD be writable for PROT_WRITE.
+ if shared && !file.IsWritable() {
+ opts.MaxPerms.Write = false
+ }
+
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ rv, err := t.MemoryManager().MMap(t, opts)
+ return uintptr(rv), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
new file mode 100644
index 000000000..97da6c647
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) {
+ pathname, err := t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return fspath.Path{}, err
+ }
+ return fspath.Parse(pathname), nil
+}
+
+type taskPathOperation struct {
+ pop vfs.PathOperation
+ haveStartRef bool
+}
+
+func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink) (taskPathOperation, error) {
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ haveStartRef := false
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ root.DecRef()
+ return taskPathOperation{}, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ haveStartRef = true
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ root.DecRef()
+ return taskPathOperation{}, syserror.EBADF
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ haveStartRef = true
+ dirfile.DecRef()
+ }
+ }
+ return taskPathOperation{
+ pop: vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ },
+ haveStartRef: haveStartRef,
+ }, nil
+}
+
+func (tpop *taskPathOperation) Release() {
+ tpop.pop.Root.DecRef()
+ if tpop.haveStartRef {
+ tpop.pop.Start.DecRef()
+ tpop.haveStartRef = false
+ }
+}
+
+type shouldAllowEmptyPath bool
+
+const (
+ disallowEmptyPath shouldAllowEmptyPath = false
+ allowEmptyPath shouldAllowEmptyPath = true
+)
+
+type shouldFollowFinalSymlink bool
+
+const (
+ nofollowFinalSymlink shouldFollowFinalSymlink = false
+ followFinalSymlink shouldFollowFinalSymlink = true
+)
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
new file mode 100644
index 000000000..dbf4882da
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -0,0 +1,584 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fileCap is the maximum allowable files for poll & select. This has no
+// equivalent in Linux; it exists in gVisor since allocation failure in Go is
+// unrecoverable.
+const fileCap = 1024 * 1024
+
+// Masks for "readable", "writable", and "exceptional" events as defined by
+// select(2).
+const (
+ // selectReadEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLIN_SET.
+ selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
+
+ // selectWriteEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLOUT_SET.
+ selectWriteEvents = linux.POLLOUT | linux.POLLERR
+
+ // selectExceptEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLEX_SET.
+ selectExceptEvents = linux.POLLPRI
+)
+
+// pollState tracks the associated file description and waiter of a PollFD.
+type pollState struct {
+ file *vfs.FileDescription
+ waiter waiter.Entry
+}
+
+// initReadiness gets the current ready mask for the file represented by the FD
+// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
+// used to register with the file for event notifications, and a reference to
+// the file is stored in "state".
+func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
+ if pfd.FD < 0 {
+ pfd.REvents = 0
+ return
+ }
+
+ file := t.GetFileVFS2(pfd.FD)
+ if file == nil {
+ pfd.REvents = linux.POLLNVAL
+ return
+ }
+
+ if ch == nil {
+ defer file.DecRef()
+ } else {
+ state.file = file
+ state.waiter, _ = waiter.NewChannelEntry(ch)
+ file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ }
+
+ r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ pfd.REvents = int16(r.ToLinux()) & pfd.Events
+}
+
+// releaseState releases all the pollState in "state".
+func releaseState(state []pollState) {
+ for i := range state {
+ if state[i].file != nil {
+ state[i].file.EventUnregister(&state[i].waiter)
+ state[i].file.DecRef()
+ }
+ }
+}
+
+// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
+// when "timeout" is greater than zero.
+//
+// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
+// positive if interrupted by a signal.
+func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
+ var ch chan struct{}
+ if timeout != 0 {
+ ch = make(chan struct{}, 1)
+ }
+
+ // Register for event notification in the files involved if we may
+ // block (timeout not zero). Once we find a file that has a non-zero
+ // result, we stop registering for events but still go through all files
+ // to get their ready masks.
+ state := make([]pollState, len(pfd))
+ defer releaseState(state)
+ n := uintptr(0)
+ for i := range pfd {
+ initReadiness(t, &pfd[i], &state[i], ch)
+ if pfd[i].REvents != 0 {
+ n++
+ ch = nil
+ }
+ }
+
+ if timeout == 0 {
+ return timeout, n, nil
+ }
+
+ haveTimeout := timeout >= 0
+
+ for n == 0 {
+ var err error
+ // Wait for a notification.
+ timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout)
+ if err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ return timeout, 0, err
+ }
+
+ // We got notified, count how many files are ready. If none,
+ // then this was a spurious notification, and we just go back
+ // to sleep with the remaining timeout.
+ for i := range state {
+ if state[i].file == nil {
+ continue
+ }
+
+ r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
+ rl := int16(r.ToLinux()) & pfd[i].Events
+ if rl != 0 {
+ pfd[i].REvents = rl
+ n++
+ }
+ }
+ }
+
+ return timeout, n, nil
+}
+
+// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
+func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+ if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
+ return nil, syserror.EINVAL
+ }
+
+ pfd := make([]linux.PollFD, nfds)
+ if nfds > 0 {
+ if _, err := t.CopyIn(addr, &pfd); err != nil {
+ return nil, err
+ }
+ }
+
+ return pfd, nil
+}
+
+func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+ pfd, err := copyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return timeout, 0, err
+ }
+
+ // Compatibility warning: Linux adds POLLHUP and POLLERR just before
+ // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
+ // polling, changing event masks here is an application-visible difference.
+ // (Linux also doesn't copy out event masks at all, only revents.)
+ for i := range pfd {
+ pfd[i].Events |= linux.POLLHUP | linux.POLLERR
+ }
+ remainingTimeout, n, err := pollBlock(t, pfd, timeout)
+ err = syserror.ConvertIntr(err, syserror.EINTR)
+
+ // The poll entries are copied out regardless of whether
+ // any are set or not. This aligns with the Linux behavior.
+ if nfds > 0 && err == nil {
+ if _, err := t.CopyOut(addr, pfd); err != nil {
+ return remainingTimeout, 0, err
+ }
+ }
+
+ return remainingTimeout, n, err
+}
+
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+ if nfds < 0 || nfds > fileCap {
+ return 0, syserror.EINVAL
+ }
+
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+
+ // Count how many FDs are actually being requested so that we can build
+ // a PollFD array.
+ fdCount := 0
+ for i := 0; i < nBytes; i++ {
+ v := r[i] | w[i] | e[i]
+ for v != 0 {
+ v &= (v - 1)
+ fdCount++
+ }
+ }
+
+ // Build the PollFD array.
+ pfd := make([]linux.PollFD, 0, fdCount)
+ var fd int32
+ for i := 0; i < nBytes; i++ {
+ rV, wV, eV := r[i], w[i], e[i]
+ v := rV | wV | eV
+ m := byte(1)
+ for j := 0; j < 8; j++ {
+ if (v & m) != 0 {
+ // Make sure the fd is valid and decrement the reference
+ // immediately to ensure we don't leak. Note, another thread
+ // might be about to close fd. This is racy, but that's
+ // OK. Linux is racy in the same way.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ file.DecRef()
+
+ var mask int16
+ if (rV & m) != 0 {
+ mask |= selectReadEvents
+ }
+
+ if (wV & m) != 0 {
+ mask |= selectWriteEvents
+ }
+
+ if (eV & m) != 0 {
+ mask |= selectExceptEvents
+ }
+
+ pfd = append(pfd, linux.PollFD{
+ FD: fd,
+ Events: mask,
+ })
+ }
+
+ fd++
+ m <<= 1
+ }
+ }
+
+ // Do the syscall, then count the number of bits set.
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
+ return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ // r, w, and e are currently event mask bitsets; unset bits corresponding
+ // to events that *didn't* occur.
+ bitSetCount := uintptr(0)
+ for idx := range pfd {
+ events := pfd[idx].REvents
+ i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
+ m := byte(1) << j
+ if r[i]&m != 0 {
+ if (events & selectReadEvents) != 0 {
+ bitSetCount++
+ } else {
+ r[i] &^= m
+ }
+ }
+ if w[i]&m != 0 {
+ if (events & selectWriteEvents) != 0 {
+ bitSetCount++
+ } else {
+ w[i] &^= m
+ }
+ }
+ if e[i]&m != 0 {
+ if (events & selectExceptEvents) != 0 {
+ bitSetCount++
+ } else {
+ e[i] &^= m
+ }
+ }
+ }
+
+ // Copy updated vectors back.
+ if readFDs != 0 {
+ if _, err := t.CopyOut(readFDs, r); err != nil {
+ return 0, err
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyOut(writeFDs, w); err != nil {
+ return 0, err
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ return 0, err
+ }
+ }
+
+ return bitSetCount, nil
+}
+
+// timeoutRemaining returns the amount of time remaining for the specified
+// timeout or 0 if it has elapsed.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
+ now := t.Kernel().MonotonicClock().Now()
+ remaining := timeout - now.Sub(startNs)
+ if remaining < 0 {
+ remaining = 0
+ }
+ return remaining
+}
+
+// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
+ return tsRemaining.CopyOut(t, timespecAddr)
+}
+
+// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
+ return tvRemaining.CopyOut(t, timevalAddr)
+}
+
+// pollRestartBlock encapsulates the state required to restart poll(2) via
+// restart_syscall(2).
+//
+// +stateify savable
+type pollRestartBlock struct {
+ pfdAddr usermem.Addr
+ nfds uint
+ timeout time.Duration
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return poll(t, p.pfdAddr, p.nfds, p.timeout)
+}
+
+func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+ remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ // On an interrupt poll(2) is restarted with the remaining timeout.
+ if err == syserror.EINTR {
+ t.SetSyscallRestartBlock(&pollRestartBlock{
+ pfdAddr: pfdAddr,
+ nfds: nfds,
+ timeout: remainingTimeout,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+ }
+ return n, err
+}
+
+// Poll implements linux syscall poll(2).
+func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timeout := time.Duration(args[2].Int()) * time.Millisecond
+ n, err := poll(t, pfdAddr, nfds, timeout)
+ return n, nil, err
+}
+
+// Ppoll implements linux syscall ppoll(2).
+func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timespecAddr := args[2].Pointer()
+ maskAddr := args[3].Pointer()
+ maskSize := uint(args[4].Uint())
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ _, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // doPoll returns EINTR if interrupted, but ppoll is normally restartable
+ // if interrupted by something other than a signal handled by the
+ // application (i.e. returns ERESTARTNOHAND). However, if
+ // copyOutTimespecRemaining failed, then the restarted ppoll would use the
+ // wrong timeout, so the error should be left as EINTR.
+ //
+ // Note that this means that if err is nil but copyErr is not, copyErr is
+ // ignored. This is consistent with Linux.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Select implements linux syscall select(2).
+func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timevalAddr := args[4].Pointer()
+
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timevalAddr != 0 {
+ var timeval linux.Timeval
+ if err := timeval.CopyIn(t, timevalAddr); err != nil {
+ return 0, nil, err
+ }
+ if timeval.Sec < 0 || timeval.Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(timeval.ToNsecCapped())
+ }
+ startNs := t.Kernel().MonotonicClock().Now()
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Pselect implements linux syscall pselect(2).
+func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+ maskWithSizeAddr := args[5].Pointer()
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskWithSizeAddr != 0 {
+ if t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
+ }
+ var maskStruct sigSetWithSize
+ if err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
+ return 0, nil, err
+ }
+ if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// +marshal
+type sigSetWithSize struct {
+ sigsetAddr uint64
+ sizeofSigset uint64
+}
+
+// copyTimespecInToDuration copies a Timespec from the untrusted app range,
+// validates it and converts it to a Duration.
+//
+// If the Timespec is larger than what can be represented in a Duration, the
+// returned value is the maximum that Duration will allow.
+//
+// If timespecAddr is NULL, the returned value is negative.
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timespecAddr != 0 {
+ var timespec linux.Timespec
+ if err := timespec.CopyIn(t, timespecAddr); err != nil {
+ return 0, err
+ }
+ if !timespec.Valid() {
+ return 0, syserror.EINVAL
+ }
+ timeout = time.Duration(timespec.ToNsecCapped())
+ }
+ return timeout, nil
+}
+
+func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error {
+ if maskAddr == 0 {
+ return nil
+ }
+ if maskSize != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ var mask linux.SignalSet
+ if err := mask.CopyIn(t, maskAddr); err != nil {
+ return err
+ }
+ mask &^= kernel.UnblockableSignals
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
new file mode 100644
index 000000000..35f6308d6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -0,0 +1,511 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ eventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
+ eventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr
+)
+
+// Read implements Linux syscall read(2).
+func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
+}
+
+// Readv implements Linux syscall readv(2).
+func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file)
+}
+
+func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.Read(t, dst, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Read(t, dst, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pread64 implements Linux syscall pread64(2).
+func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file)
+}
+
+// Preadv implements Linux syscall preadv(2).
+func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file)
+}
+
+// Preadv2 implements Linux syscall preadv2(2).
+func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1142)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.ReadOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = read(t, file, dst, opts)
+ } else {
+ n, err = pread(t, file, dst, offset, opts)
+ }
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+}
+
+func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.PRead(t, dst, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PRead(t, dst, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Write implements Linux syscall write(2).
+func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file)
+}
+
+// Writev implements Linux syscall writev(2).
+func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file)
+}
+
+func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.Write(t, src, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Write(t, src, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pwrite64 implements Linux syscall pwrite64(2).
+func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file)
+}
+
+// Pwritev implements Linux syscall pwritev(2).
+func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file)
+}
+
+// Pwritev2 implements Linux syscall pwritev2(2).
+func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1162)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.WriteOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = write(t, file, src, opts)
+ } else {
+ n, err = pwrite(t, file, src, offset, opts)
+ }
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+}
+
+func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.PWrite(t, src, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PWrite(t, src, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Lseek implements Linux syscall lseek(2).
+func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ whence := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newoff, err := file.Seek(t, offset, whence)
+ return uintptr(newoff), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
new file mode 100644
index 000000000..9250659ff
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -0,0 +1,380 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+// Chmod implements Linux syscall chmod(2).
+func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, fchmodat(t, linux.AT_FDCWD, pathAddr, mode)
+}
+
+// Fchmodat implements Linux syscall fchmodat(2).
+func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, fchmodat(t, dirfd, pathAddr, mode)
+}
+
+func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Fchmod implements Linux syscall fchmod(2).
+func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].ModeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Chown implements Linux syscall chown(2).
+func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, 0 /* flags */)
+}
+
+// Lchown implements Linux syscall lchown(2).
+func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Fchownat implements Linux syscall fchownat(2).
+func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ owner := args[2].Int()
+ group := args[3].Int()
+ flags := args[4].Int()
+ return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags)
+}
+
+func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
+}
+
+func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vfs.SetStatOptions) error {
+ userns := t.UserNamespace()
+ if owner != -1 {
+ kuid := userns.MapToKUID(auth.UID(owner))
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_UID
+ opts.Stat.UID = uint32(kuid)
+ }
+ if group != -1 {
+ kgid := userns.MapToKGID(auth.GID(group))
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_GID
+ opts.Stat.GID = uint32(kgid)
+ }
+ return nil
+}
+
+// Fchown implements Linux syscall fchown(2).
+func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ owner := args[1].Int()
+ group := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, file.SetStat(t, opts)
+}
+
+// Truncate implements Linux syscall truncate(2).
+func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Ftruncate implements Linux syscall ftruncate(2).
+func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Utime implements Linux syscall utime(2).
+func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times linux.Utime
+ if err := times.CopyIn(t, timesAddr); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime.Sec = times.Actime
+ opts.Stat.Mtime.Sec = times.Modtime
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimes implements Linux syscall utimes(2).
+func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Usec * 1000),
+ }
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Usec * 1000),
+ }
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimensat implements Linux syscall utimensat(2).
+func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Futimens implements Linux syscall futimens(2).
+func Futimens(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ timesAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.SetStat(t, opts)
+}
+
+func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+ if timesAddr == 0 {
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ return nil
+ }
+ var times [2]linux.Timespec
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return err
+ }
+ if times[0].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_ATIME
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Nsec),
+ }
+ }
+ if times[1].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_MTIME
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Nsec),
+ }
+ }
+ return nil
+}
+
+func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.SetStat() instead of
+ // VirtualFilesystem.SetStatAt(), since the former may be able
+ // to use opened file state to expedite the SetStat.
+ err := dirfile.SetStat(t, *opts)
+ dirfile.DecRef()
+ return err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+ return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ }, opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
new file mode 100644
index 000000000..dca8d7011
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -0,0 +1,346 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Stat implements Linux syscall stat(2).
+func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, 0 /* flags */)
+}
+
+// Lstat implements Linux syscall lstat(2).
+func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Newfstatat implements Linux syscall newfstatat, which backs fstatat(2).
+func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ statAddr := args[2].Pointer()
+ flags := args[3].Int()
+ return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags)
+}
+
+func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for fstatat(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+}
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint64(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int64(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
+
+func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
+ return linux.Timespec{
+ Sec: sxts.Sec,
+ Nsec: int64(sxts.Nsec),
+ }
+}
+
+// Fstat implements Linux syscall fstat(2).
+func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ statAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ statx, err := file.Stat(t, vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return 0, nil, stat.CopyOut(t, statAddr)
+}
+
+// Statx implements Linux syscall statx(2).
+func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+ mask := args[3].Uint()
+ statxAddr := args[4].Pointer()
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: mask,
+ Sync: uint32(flags & linux.AT_STATX_SYNC_TYPE),
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for statx(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+}
+
+func userifyStatx(t *kernel.Task, statx *linux.Statx) {
+ userns := t.UserNamespace()
+ statx.UID = uint32(auth.KUID(statx.UID).In(userns).OrOverflow())
+ statx.GID = uint32(auth.KGID(statx.GID).In(userns).OrOverflow())
+}
+
+// Readlink implements Linux syscall readlink(2).
+func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+ size := args[2].SizeT()
+ return readlinkat(t, linux.AT_FDCWD, pathAddr, bufAddr, size)
+}
+
+// Access implements Linux syscall access(2).
+func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Faccessat implements Linux syscall access(2).
+func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Readlinkat implements Linux syscall mknodat(2).
+func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ bufAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ return readlinkat(t, dirfd, pathAddr, bufAddr, size)
+}
+
+func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
+ if int(size) <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ // "Since Linux 2.6.39, pathname can be an empty string, in which case the
+ // call operates on the symbolic link referred to by dirfd ..." -
+ // readlinkat(2)
+ tpop, err := getTaskPathOperation(t, dirfd, path, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if len(target) > int(size) {
+ target = target[:size]
+ }
+ n, err := t.CopyOutBytes(bufAddr, gohacks.ImmutableBytesFromString(target))
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Statfs implements Linux syscall statfs(2).
+func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
+
+// Fstatfs implements Linux syscall fstatfs(2).
+func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufAddr := args[1].Pointer()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
new file mode 100644
index 000000000..365250b0b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -0,0 +1,87 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements Linux syscall sync(2).
+func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.Kernel().VFS().SyncAllFilesystems(t)
+}
+
+// Syncfs implements Linux syscall syncfs(2).
+func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SyncFS(t)
+}
+
+// Fsync implements Linux syscall fsync(2).
+func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.Sync(t)
+}
+
+// Fdatasync implements Linux syscall fdatasync(2).
+func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of unnecessary metadata.
+ return Fsync(t, args)
+}
+
+// SyncFileRange implements Linux syscall sync_file_range(2).
+func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ nbytes := args[2].Int64()
+ flags := args[3].Uint()
+
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if nbytes < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of
+ // [offset, offset+nbytes).
+ return 0, nil, file.Sync(t)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_read.go b/pkg/sentry/syscalls/linux/vfs2/sys_read.go
deleted file mode 100644
index 7667524c7..000000000
--- a/pkg/sentry/syscalls/linux/vfs2/sys_read.go
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs2
-
-import (
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- // EventMaskRead contains events that can be triggered on reads.
- EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
-)
-
-// Read implements linux syscall read(2). Note that we try to get a buffer that
-// is exactly the size requested because some applications like qemu expect
-// they can do large reads all at once. Bug for bug. Same for other read
-// calls below.
-func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- fd := args[0].Int()
- addr := args[1].Pointer()
- size := args[2].SizeT()
-
- file := t.GetFileVFS2(fd)
- if file == nil {
- return 0, nil, syserror.EBADF
- }
- defer file.DecRef()
-
- // Check that the size is legitimate.
- si := int(size)
- if si < 0 {
- return 0, nil, syserror.EINVAL
- }
-
- // Get the destination of the read.
- dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
- AddressSpaceActive: true,
- })
- if err != nil {
- return 0, nil, err
- }
-
- n, err := read(t, file, dst, vfs.ReadOptions{})
- t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, linux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
-}
-
-func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- n, err := file.Read(t, dst, opts)
- if err != syserror.ErrWouldBlock {
- return n, err
- }
-
- // Register for notifications.
- w, ch := waiter.NewChannelEntry(nil)
- file.EventRegister(&w, EventMaskRead)
-
- total := n
- for {
- // Shorten dst to reflect bytes previously read.
- dst = dst.DropFirst(int(n))
-
- // Issue the request and break out if it completes with anything other than
- // "would block".
- n, err := file.Read(t, dst, opts)
- total += n
- if err != syserror.ErrWouldBlock {
- break
- }
- if err := t.Block(ch); err != nil {
- break
- }
- }
- file.EventUnregister(&w)
-
- return total, err
-}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
new file mode 100644
index 000000000..89e9ff4d7
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -0,0 +1,353 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Listxattr implements Linux syscall listxattr(2).
+func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, followFinalSymlink)
+}
+
+// Llistxattr implements Linux syscall llistxattr(2).
+func Llistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, nofollowFinalSymlink)
+}
+
+func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Flistxattr implements Linux syscall flistxattr(2).
+func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ names, err := file.Listxattr(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Getxattr implements Linux syscall getxattr(2).
+func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, followFinalSymlink)
+}
+
+// Lgetxattr implements Linux syscall lgetxattr(2).
+func Lgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, nofollowFinalSymlink)
+}
+
+func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Fgetxattr implements Linux syscall fgetxattr(2).
+func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := file.Getxattr(t, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Setxattr implements Linux syscall setxattr(2).
+func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, followFinalSymlink)
+}
+
+// Lsetxattr implements Linux syscall lsetxattr(2).
+func Lsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, nofollowFinalSymlink)
+}
+
+func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Fsetxattr implements Linux syscall fsetxattr(2).
+func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Removexattr implements Linux syscall removexattr(2).
+func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, followFinalSymlink)
+}
+
+// Lremovexattr implements Linux syscall lremovexattr(2).
+func Lremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, nofollowFinalSymlink)
+}
+
+func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name)
+}
+
+// Fremovexattr implements Linux syscall fremovexattr(2).
+func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Removexattr(t, name)
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) {
+ if size > linux.XATTR_LIST_MAX {
+ size = linux.XATTR_LIST_MAX
+ }
+ var buf bytes.Buffer
+ for _, name := range names {
+ buf.WriteString(name)
+ buf.WriteByte(0)
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the list.
+ return buf.Len(), nil
+ }
+ if buf.Len() > int(size) {
+ if size >= linux.XATTR_LIST_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(listAddr, buf.Bytes())
+}
+
+func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ return "", syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
+ return "", err
+ }
+ return gohacks.StringFromImmutableBytes(buf), nil
+}
+
+func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ size = linux.XATTR_SIZE_MAX
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the value.
+ return len(value), nil
+ }
+ if len(value) > int(size) {
+ if size >= linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(valueAddr, gohacks.ImmutableBytesFromString(value))
+}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 0b4f18ab5..07c8383e6 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -43,6 +43,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index eed41139b..3da45d744 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -202,6 +202,9 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin
// Add epi to file.epolls so that it is removed when the last
// FileDescription reference is dropped.
file.epollMu.Lock()
+ if file.epolls == nil {
+ file.epolls = make(map[*epollInterest]struct{})
+ }
file.epolls[epi] = struct{}{}
file.epollMu.Unlock()
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index 1fe766a44..bc7581698 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -26,6 +26,7 @@ import (
"sync/atomic"
"unsafe"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -160,7 +161,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
loop:
for {
@@ -361,12 +362,3 @@ func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
//go:linkname rand32 runtime.fastrand
func rand32() uint32
-
-// This is copy/pasted from runtime.noescape(), and is needed because arguments
-// apparently escape from all functions defined by linkname.
-//
-//go:nosplit
-func noescape(p unsafe.Pointer) unsafe.Pointer {
- x := uintptr(p)
- return unsafe.Pointer(x ^ 0)
-}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 8a0b382f6..eb4ebb511 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -228,7 +228,7 @@ func (rp *ResolvingPath) Advance() {
rp.pit = next
} else { // at end of path segment, continue with next one
rp.curPart--
- rp.pit = rp.parts[rp.curPart-1]
+ rp.pit = rp.parts[rp.curPart]
}
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 8f29031b2..73f8043be 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -385,15 +385,11 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
// Only a regular file can be executed.
stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
if err != nil {
+ fd.DecRef()
return nil, err
}
- if stat.Mask&linux.STATX_TYPE != 0 {
- // This shouldn't happen, but if type can't be retrieved, file can't
- // be executed.
- return nil, syserror.EACCES
- }
- if t := linux.FileMode(stat.Mode).FileType(); t != linux.ModeRegular {
- ctx.Infof("%q is not a regular file: %v", pop.Path, t)
+ if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG {
+ fd.DecRef()
return nil, syserror.EACCES
}
}
diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD
new file mode 100644
index 000000000..0500a22cf
--- /dev/null
+++ b/pkg/syncevent/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "syncevent",
+ srcs = [
+ "broadcaster.go",
+ "receiver.go",
+ "source.go",
+ "syncevent.go",
+ "waiter_amd64.s",
+ "waiter_arm64.s",
+ "waiter_asm_unsafe.go",
+ "waiter_noasm_unsafe.go",
+ "waiter_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/atomicbitops",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "syncevent_test",
+ size = "small",
+ srcs = [
+ "broadcaster_test.go",
+ "syncevent_example_test.go",
+ "waiter_test.go",
+ ],
+ library = ":syncevent",
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go
new file mode 100644
index 000000000..4bff59e7d
--- /dev/null
+++ b/pkg/syncevent/broadcaster.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Broadcaster is an implementation of Source that supports any number of
+// subscribed Receivers.
+//
+// The zero value of Broadcaster is valid and has no subscribed Receivers.
+// Broadcaster is not copyable by value.
+//
+// All Broadcaster methods may be called concurrently from multiple goroutines.
+type Broadcaster struct {
+ // Broadcaster is implemented as a hash table where keys are assigned by
+ // the Broadcaster and returned as SubscriptionIDs, making it safe to use
+ // the identity function for hashing. The hash table resolves collisions
+ // using linear probing and features Robin Hood insertion and backward
+ // shift deletion in order to support a relatively high load factor
+ // efficiently, which matters since the cost of Broadcast is linear in the
+ // size of the table.
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // Invariants: len(table) is 0 or a power of 2.
+ table []broadcasterSlot
+
+ // load is the number of entries in table with receiver != nil.
+ load int
+
+ lastID SubscriptionID
+}
+
+type broadcasterSlot struct {
+ // Invariants: If receiver == nil, then filter == NoEvents and id == 0.
+ // Otherwise, id != 0.
+ receiver *Receiver
+ filter Set
+ id SubscriptionID
+}
+
+const (
+ broadcasterMinNonZeroTableSize = 2 // must be a power of 2 > 1
+
+ broadcasterMaxLoadNum = 13
+ broadcasterMaxLoadDen = 16
+)
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID {
+ b.mu.Lock()
+
+ // Assign an ID for this subscription.
+ b.lastID++
+ id := b.lastID
+
+ // Expand the table if over the maximum load factor:
+ //
+ // load / len(b.table) > broadcasterMaxLoadNum / broadcasterMaxLoadDen
+ // load * broadcasterMaxLoadDen > broadcasterMaxLoadNum * len(b.table)
+ b.load++
+ if (b.load * broadcasterMaxLoadDen) > (broadcasterMaxLoadNum * len(b.table)) {
+ // Double the number of slots in the new table.
+ newlen := broadcasterMinNonZeroTableSize
+ if len(b.table) != 0 {
+ newlen = 2 * len(b.table)
+ }
+ if newlen <= cap(b.table) {
+ // Reuse excess capacity in the current table, moving entries not
+ // already in their first-probed positions to better ones.
+ newtable := b.table[:newlen]
+ newmask := uint64(newlen - 1)
+ for i := range b.table {
+ if b.table[i].receiver != nil && uint64(b.table[i].id)&newmask != uint64(i) {
+ entry := b.table[i]
+ b.table[i] = broadcasterSlot{}
+ broadcasterTableInsert(newtable, entry.id, entry.receiver, entry.filter)
+ }
+ }
+ b.table = newtable
+ } else {
+ newtable := make([]broadcasterSlot, newlen)
+ // Copy existing entries to the new table.
+ for i := range b.table {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ }
+ }
+ // Switch to the new table.
+ b.table = newtable
+ }
+ }
+
+ broadcasterTableInsert(b.table, id, r, filter)
+ b.mu.Unlock()
+ return id
+}
+
+// Preconditions: table must not be full. len(table) is a power of 2.
+func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) {
+ entry := broadcasterSlot{
+ receiver: r,
+ filter: filter,
+ id: id,
+ }
+ mask := uint64(len(table) - 1)
+ i := uint64(id) & mask
+ disp := uint64(0)
+ for {
+ if table[i].receiver == nil {
+ table[i] = entry
+ return
+ }
+ // If we've been displaced farther from our first-probed slot than the
+ // element stored in this one, swap elements and switch to inserting
+ // the replaced one. (This is Robin Hood insertion.)
+ slotDisp := (i - uint64(table[i].id)) & mask
+ if disp > slotDisp {
+ table[i], entry = entry, table[i]
+ disp = slotDisp
+ }
+ i = (i + 1) & mask
+ disp++
+ }
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (b *Broadcaster) UnsubscribeEvents(id SubscriptionID) {
+ b.mu.Lock()
+
+ mask := uint64(len(b.table) - 1)
+ i := uint64(id) & mask
+ for {
+ if b.table[i].id == id {
+ // Found the element to remove. Move all subsequent elements
+ // backward until we either find an empty slot, or an element that
+ // is already in its first-probed slot. (This is backward shift
+ // deletion.)
+ for {
+ next := (i + 1) & mask
+ if b.table[next].receiver == nil {
+ break
+ }
+ if uint64(b.table[next].id)&mask == next {
+ break
+ }
+ b.table[i] = b.table[next]
+ i = next
+ }
+ b.table[i] = broadcasterSlot{}
+ break
+ }
+ i = (i + 1) & mask
+ }
+
+ // If a table 1/4 of the current size would still be at or under the
+ // maximum load factor (i.e. the current table size is at least two
+ // expansions bigger than necessary), halve the size of the table to reduce
+ // the cost of Broadcast. Since we are concerned with iteration time and
+ // not memory usage, reuse the existing slice to reduce future allocations
+ // from table re-expansion.
+ b.load--
+ if len(b.table) > broadcasterMinNonZeroTableSize && (b.load*(4*broadcasterMaxLoadDen)) <= (broadcasterMaxLoadNum*len(b.table)) {
+ newlen := len(b.table) / 2
+ newtable := b.table[:newlen]
+ for i := newlen; i < len(b.table); i++ {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ b.table[i] = broadcasterSlot{}
+ }
+ }
+ b.table = newtable
+ }
+
+ b.mu.Unlock()
+}
+
+// Broadcast notifies all Receivers subscribed to the Broadcaster of the subset
+// of events to which they subscribed. The order in which Receivers are
+// notified is unspecified.
+func (b *Broadcaster) Broadcast(events Set) {
+ b.mu.Lock()
+ for i := range b.table {
+ if intersection := events & b.table[i].filter; intersection != 0 {
+ // We don't need to check if broadcasterSlot.receiver is nil, since
+ // if it is then broadcasterSlot.filter is 0.
+ b.table[i].receiver.Notify(intersection)
+ }
+ }
+ b.mu.Unlock()
+}
+
+// FilteredEvents returns the set of events for which Broadcast will notify at
+// least one Receiver, i.e. the union of filters for all subscribed Receivers.
+func (b *Broadcaster) FilteredEvents() Set {
+ var es Set
+ b.mu.Lock()
+ for i := range b.table {
+ es |= b.table[i].filter
+ }
+ b.mu.Unlock()
+ return es
+}
diff --git a/pkg/syncevent/broadcaster_test.go b/pkg/syncevent/broadcaster_test.go
new file mode 100644
index 000000000..e88779e23
--- /dev/null
+++ b/pkg/syncevent/broadcaster_test.go
@@ -0,0 +1,376 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestBroadcasterFilter(t *testing.T) {
+ const numReceivers = 2 * MaxEvents
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1<<(i%MaxEvents))
+ }
+ for ev := 0; ev < MaxEvents; ev++ {
+ br.Broadcast(1 << ev)
+ for i := range ws {
+ want := NoEvents
+ if i%MaxEvents == ev {
+ want = 1 << ev
+ }
+ if got := ws[i].Receiver().PendingAndAckAll(); got != want {
+ t.Errorf("after Broadcast of event %d: waiter %d has pending event set %#x, wanted %#x", ev, i, got, want)
+ }
+ }
+ }
+}
+
+// TestBroadcasterManySubscriptions tests that subscriptions are not lost by
+// table expansion/compaction.
+func TestBroadcasterManySubscriptions(t *testing.T) {
+ const numReceivers = 5000 // arbitrary
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ ids := make([]SubscriptionID, numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Subscribe receiver i.
+ ids[i] = br.SubscribeEvents(ws[i].Receiver(), 1)
+ // Check that receivers [0, i] are subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[j].Pending() != 1 {
+ t.Errorf("receiver %d did not receive an event after subscription of receiver %d", j, i)
+ }
+ ws[j].Ack(1)
+ }
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Unsubscribe receiver unsub[i].
+ br.UnsubscribeEvents(ids[unsub[i]])
+ // Check that receivers [unsub[0], unsub[i]] are not subscribed, and that
+ // receivers (unsub[i], unsub[numReceivers]) are still subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[unsub[j]].Pending() != 0 {
+ t.Errorf("unsub iteration %d: receiver %d received an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ }
+ for j := i + 1; j < numReceivers; j++ {
+ if ws[unsub[j]].Pending() != 1 {
+ t.Errorf("unsub iteration %d: receiver %d did not receive an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ ws[unsub[j]].Ack(1)
+ }
+ }
+}
+
+var (
+ receiverCountsNonZero = []int{1, 4, 16, 64}
+ receiverCountsIncludingZero = append([]int{0}, receiverCountsNonZero...)
+)
+
+// BenchmarkBroadcasterX, BenchmarkMapX, and BenchmarkQueueX benchmark usage
+// pattern X (described in terms of Broadcaster) with Broadcaster, a
+// Mutex-protected map[*Receiver]Set, and waiter.Queue respectively.
+
+// BenchmarkXxxSubscribeUnsubscribe measures the cost of a Subscribe/Unsubscribe
+// cycle.
+
+func BenchmarkBroadcasterSubscribeUnsubscribe(b *testing.B) {
+ var br Broadcaster
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ id := br.SubscribeEvents(w.Receiver(), 1)
+ br.UnsubscribeEvents(id)
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribe(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ m[w.Receiver()] = Set(1)
+ mu.Unlock()
+ mu.Lock()
+ delete(m, w.Receiver())
+ mu.Unlock()
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribe(b *testing.B) {
+ var q waiter.Queue
+ e, _ := waiter.NewChannelEntry(nil)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.EventRegister(&e, 1)
+ q.EventUnregister(&e)
+ }
+}
+
+// BenchmarkXxxSubscribeUnsubscribeBatch is similar to
+// BenchmarkXxxSubscribeUnsubscribe, but subscribes and unsubscribes a large
+// number of Receivers at a time in order to measure the amortized overhead of
+// table expansion/compaction. (Since waiter.Queue is implemented using a
+// linked list, BenchmarkQueueSubscribeUnsubscribe and
+// BenchmarkQueueSubscribeUnsubscribeBatch should produce nearly the same
+// result.)
+
+const numBatchReceivers = 1000
+
+func BenchmarkBroadcasterSubscribeUnsubscribeBatch(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+ ids := make([]SubscriptionID, numBatchReceivers)
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ ids[j] = br.SubscribeEvents(ws[j].Receiver(), 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ br.UnsubscribeEvents(ids[unsub[j]])
+ }
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribeBatch(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ m[ws[j].Receiver()] = Set(1)
+ mu.Unlock()
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ delete(m, ws[unsub[j]].Receiver())
+ mu.Unlock()
+ }
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribeBatch(b *testing.B) {
+ var q waiter.Queue
+ es := make([]waiter.Entry, numBatchReceivers)
+ for i := range es {
+ es[i], _ = waiter.NewChannelEntry(nil)
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventRegister(&es[j], 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventUnregister(&es[unsub[j]])
+ }
+ }
+}
+
+// BenchmarkXxxBroadcastRedundant measures how long it takes to Broadcast
+// already-pending events to multiple Receivers.
+
+func BenchmarkBroadcasterBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+ br.Broadcast(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ for i := 0; i < n; i++ {
+ e, _ := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ }
+ q.Notify(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ }
+ })
+ }
+}
+
+// BenchmarkXxxBroadcastAck measures how long it takes to Broadcast events to
+// multiple Receivers, check that all Receivers have received the event, and
+// clear the event from all Receivers.
+
+func BenchmarkBroadcasterBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ chs := make([]chan struct{}, n)
+ for i := range chs {
+ e, ch := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ chs[i] = ch
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ for _, ch := range chs {
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("channel did not receive event")
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/syncevent/receiver.go b/pkg/syncevent/receiver.go
new file mode 100644
index 000000000..5c86e5400
--- /dev/null
+++ b/pkg/syncevent/receiver.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/atomicbitops"
+)
+
+// Receiver is an event sink that holds pending events and invokes a callback
+// whenever new events become pending. Receiver's methods may be called
+// concurrently from multiple goroutines.
+//
+// Receiver.Init() must be called before first use.
+type Receiver struct {
+ // pending is the set of pending events. pending is accessed using atomic
+ // memory operations.
+ pending uint64
+
+ // cb is notified when new events become pending. cb is immutable after
+ // Init().
+ cb ReceiverCallback
+}
+
+// ReceiverCallback receives callbacks from a Receiver.
+type ReceiverCallback interface {
+ // NotifyPending is called when the corresponding Receiver has new pending
+ // events.
+ //
+ // NotifyPending is called synchronously from Receiver.Notify(), so
+ // implementations must not take locks that may be held by callers of
+ // Receiver.Notify(). NotifyPending may be called concurrently from
+ // multiple goroutines.
+ NotifyPending()
+}
+
+// Init must be called before first use of r.
+func (r *Receiver) Init(cb ReceiverCallback) {
+ r.cb = cb
+}
+
+// Pending returns the set of pending events.
+func (r *Receiver) Pending() Set {
+ return Set(atomic.LoadUint64(&r.pending))
+}
+
+// Notify sets the given events as pending.
+func (r *Receiver) Notify(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already pending.
+ if p&es == es {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-OR because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p|es)) {
+ // If the CAS fails, fall back to atomic-OR.
+ atomicbitops.OrUint64(&r.pending, uint64(es))
+ }
+ r.cb.NotifyPending()
+}
+
+// Ack unsets the given events as pending.
+func (r *Receiver) Ack(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already not pending.
+ if p&es == 0 {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-AND because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p&^es)) {
+ // If the CAS fails, fall back to atomic-AND.
+ atomicbitops.AndUint64(&r.pending, ^uint64(es))
+ }
+}
+
+// PendingAndAckAll unsets all events as pending and returns the set of
+// previously-pending events.
+//
+// PendingAndAckAll should only be used in preference to a call to Pending
+// followed by a conditional call to Ack when the caller expects events to be
+// pending (e.g. after a call to ReceiverCallback.NotifyPending()).
+func (r *Receiver) PendingAndAckAll() Set {
+ return Set(atomic.SwapUint64(&r.pending, 0))
+}
diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go
new file mode 100644
index 000000000..ddffb171a
--- /dev/null
+++ b/pkg/syncevent/source.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+// Source represents an event source.
+type Source interface {
+ // SubscribeEvents causes the Source to notify the given Receiver of the
+ // given subset of events.
+ //
+ // Preconditions: r != nil. The ReceiverCallback for r must not take locks
+ // that are ordered prior to the Source; for example, it cannot call any
+ // Source methods.
+ SubscribeEvents(r *Receiver, filter Set) SubscriptionID
+
+ // UnsubscribeEvents causes the Source to stop notifying the Receiver
+ // subscribed by a previous call to SubscribeEvents that returned the given
+ // SubscriptionID.
+ //
+ // Preconditions: UnsubscribeEvents may be called at most once for any
+ // given SubscriptionID.
+ UnsubscribeEvents(id SubscriptionID)
+}
+
+// SubscriptionID identifies a call to Source.SubscribeEvents.
+type SubscriptionID uint64
+
+// UnsubscribeAndAck is a convenience function that unsubscribes r from the
+// given events from src and also clears them from r.
+func UnsubscribeAndAck(src Source, r *Receiver, filter Set, id SubscriptionID) {
+ src.UnsubscribeEvents(id)
+ r.Ack(filter)
+}
+
+// NoopSource implements Source by never sending events to subscribed
+// Receivers.
+type NoopSource struct{}
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (NoopSource) SubscribeEvents(*Receiver, Set) SubscriptionID {
+ return 0
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (NoopSource) UnsubscribeEvents(SubscriptionID) {
+}
+
+// See Broadcaster for a non-noop implementations of Source.
diff --git a/pkg/syncevent/syncevent.go b/pkg/syncevent/syncevent.go
new file mode 100644
index 000000000..9fb6a06de
--- /dev/null
+++ b/pkg/syncevent/syncevent.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syncevent provides efficient primitives for goroutine
+// synchronization based on event bitmasks.
+package syncevent
+
+// Set is a bitmask where each bit represents a distinct user-defined event.
+// The event package does not treat any bits in Set specially.
+type Set uint64
+
+const (
+ // NoEvents is a Set containing no events.
+ NoEvents = Set(0)
+
+ // AllEvents is a Set containing all possible events.
+ AllEvents = ^Set(0)
+
+ // MaxEvents is the number of distinct events that can be represented by a Set.
+ MaxEvents = 64
+)
diff --git a/pkg/syncevent/syncevent_example_test.go b/pkg/syncevent/syncevent_example_test.go
new file mode 100644
index 000000000..bfb18e2ea
--- /dev/null
+++ b/pkg/syncevent/syncevent_example_test.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+)
+
+func Example_ioReadinessInterrputible() {
+ const (
+ evReady = Set(1 << iota)
+ evInterrupt
+ )
+ errNotReady := fmt.Errorf("not ready for I/O")
+
+ // State of some I/O object.
+ var (
+ br Broadcaster
+ ready uint32
+ )
+ doIO := func() error {
+ if atomic.LoadUint32(&ready) == 0 {
+ return errNotReady
+ }
+ return nil
+ }
+ go func() {
+ // The I/O object eventually becomes ready for I/O.
+ time.Sleep(100 * time.Millisecond)
+ // When it does, it first ensures that future calls to isReady() return
+ // true, then broadcasts the readiness event to Receivers.
+ atomic.StoreUint32(&ready, 1)
+ br.Broadcast(evReady)
+ }()
+
+ // Each user of the I/O object owns a Waiter.
+ var w Waiter
+ w.Init()
+ // The Waiter may be asynchronously interruptible, e.g. for signal
+ // handling in the sentry.
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ w.Receiver().Notify(evInterrupt)
+ }()
+
+ // To use the I/O object:
+ //
+ // Optionally, if the I/O object is likely to be ready, attempt I/O first.
+ err := doIO()
+ if err == nil {
+ // Success, we're done.
+ return /* nil */
+ }
+ if err != errNotReady {
+ // Failure, I/O failed for some reason other than readiness.
+ return /* err */
+ }
+ // Subscribe for readiness events from the I/O object.
+ id := br.SubscribeEvents(w.Receiver(), evReady)
+ // When we are finished blocking, unsubscribe from readiness events and
+ // remove readiness events from the pending event set.
+ defer UnsubscribeAndAck(&br, w.Receiver(), evReady, id)
+ for {
+ // Attempt I/O again. This must be done after the call to SubscribeEvents,
+ // since the I/O object might have become ready between the previous call
+ // to doIO and the call to SubscribeEvents.
+ err = doIO()
+ if err == nil {
+ return /* nil */
+ }
+ if err != errNotReady {
+ return /* err */
+ }
+ // Block until either the I/O object indicates it is ready, or we are
+ // interrupted.
+ events := w.Wait()
+ if events&evInterrupt != 0 {
+ // In the specific case of sentry signal handling, signal delivery
+ // is handled by another system, so we aren't responsible for
+ // acknowledging evInterrupt.
+ return /* errInterrupted */
+ }
+ // Note that, in a concurrent context, the I/O object might become
+ // ready and then not ready again. To handle this:
+ //
+ // - evReady must be acknowledged before calling doIO() again (rather
+ // than after), so that if the I/O object becomes ready *again* after
+ // the call to doIO(), the readiness event is not lost.
+ //
+ // - We must loop instead of just calling doIO() once after receiving
+ // evReady.
+ w.Ack(evReady)
+ }
+}
diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/syncevent/waiter_amd64.s
new file mode 100644
index 000000000..985b56ae5
--- /dev/null
+++ b/pkg/syncevent/waiter_amd64.s
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVQ g+0(FP), DI
+ MOVQ wg+8(FP), SI
+
+ MOVQ $·preparingG(SB), AX
+ LOCK
+ CMPXCHGQ DI, 0(SI)
+
+ SETEQ AX
+ MOVB AX, ret+16(FP)
+
+ RET
+
diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/syncevent/waiter_arm64.s
new file mode 100644
index 000000000..20d7ac23b
--- /dev/null
+++ b/pkg/syncevent/waiter_arm64.s
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVD wg+8(FP), R0
+ MOVD $·preparingG(SB), R1
+ MOVD g+0(FP), R2
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE ok
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ok:
+ CSET EQ, R0
+ MOVB R0, ret+16(FP)
+ RET
+
diff --git a/pkg/fspath/builder_unsafe.go b/pkg/syncevent/waiter_asm_unsafe.go
index 75606808d..0995e9053 100644
--- a/pkg/fspath/builder_unsafe.go
+++ b/pkg/syncevent/waiter_asm_unsafe.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,16 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package fspath
+// +build amd64 arm64
+
+package syncevent
import (
"unsafe"
)
-// String returns the accumulated string. No other methods should be called
-// after String.
-func (b *Builder) String() string {
- bs := b.buf[b.start:]
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
-}
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go
new file mode 100644
index 000000000..1c4b0e39a
--- /dev/null
+++ b/pkg/syncevent/waiter_noasm_unsafe.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// waiterUnlock is called from g0, so when the race detector is enabled,
+// waiterUnlock must be implemented in assembly since no race context is
+// available.
+//
+// +build !race
+// +build !amd64,!arm64
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+// waiterUnlock is the "unlock function" passed to runtime.gopark by
+// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g.
+// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping
+// should be aborted.
+//
+//go:nosplit
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool {
+ // The only way this CAS can fail is if a call to Waiter.NotifyPending()
+ // has replaced *wg with nil, in which case we should not sleep.
+ return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), g)
+}
diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go
new file mode 100644
index 000000000..3c8cbcdd8
--- /dev/null
+++ b/pkg/syncevent/waiter_test.go
@@ -0,0 +1,414 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestWaiterAlreadyPending(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ w.Notify(want)
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterAsyncNotify(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ w.Notify(want)
+ }()
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterWaitFor(t *testing.T) {
+ var w Waiter
+ w.Init()
+ evWaited := Set(1)
+ evOther := Set(2)
+ w.Notify(evOther)
+ notifiedEvent := uint32(0)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ atomic.StoreUint32(&notifiedEvent, 1)
+ w.Notify(evWaited)
+ }()
+ if got, want := w.WaitFor(evWaited), evWaited|evOther; got != want {
+ t.Errorf("Waiter.WaitFor: got %#x, wanted %#x", got, want)
+ }
+ if atomic.LoadUint32(&notifiedEvent) == 0 {
+ t.Errorf("Waiter.WaitFor returned before goroutine notified waited-for event")
+ }
+}
+
+func TestWaiterWaitAndAckAll(t *testing.T) {
+ var w Waiter
+ w.Init()
+ w.Notify(AllEvents)
+ if got := w.WaitAndAckAll(); got != AllEvents {
+ t.Errorf("Waiter.WaitAndAckAll: got %#x, wanted %#x", got, AllEvents)
+ }
+ if got := w.Pending(); got != NoEvents {
+ t.Errorf("Waiter.WaitAndAckAll did not ack all events: got %#x, wanted 0", got)
+ }
+}
+
+// BenchmarkWaiterX, BenchmarkSleeperX, and BenchmarkChannelX benchmark usage
+// pattern X (described in terms of Waiter) with Waiter, sleep.Sleeper, and
+// buffered chan struct{} respectively. When the maximum number of event
+// sources is relevant, we use 3 event sources because this is representative
+// of the kernel.Task.block() use case: an interrupt source, a timeout source,
+// and the actual event source being waited on.
+
+// Event set used by most benchmarks.
+const evBench Set = 1
+
+// BenchmarkXxxNotifyRedundant measures how long it takes to notify a Waiter of
+// an event that is already pending.
+
+func BenchmarkWaiterNotifyRedundant(b *testing.B) {
+ var w Waiter
+ w.Init()
+ w.Notify(evBench)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyRedundant(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+ w.Assert()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ }
+}
+
+func BenchmarkChannelNotifyRedundant(b *testing.B) {
+ ch := make(chan struct{}, 1)
+ ch <- struct{}{}
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }
+}
+
+// BenchmarkXxxNotifyWaitAck measures how long it takes to notify a Waiter an
+// event, return that event using a blocking check, and then unset the event as
+// pending.
+
+func BenchmarkWaiterNotifyWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+
+ // wait + ack
+ <-ch
+ }
+}
+
+// BenchmarkSleeperMultiNotifyWaitAck is equivalent to
+// BenchmarkSleeperNotifyWaitAck, but also includes allocation of a
+// temporary sleep.Waker. This is necessary when multiple goroutines may wait
+// for the same event, since each sleep.Waker can wake only a single
+// sleep.Sleeper.
+//
+// The syncevent package does not require a distinct object for each
+// waiter-waker relationship, so BenchmarkWaiterNotifyWaitAck and
+// BenchmarkWaiterMultiNotifyWaitAck would be identical. The analogous state
+// for channels, runtime.sudog, is inescapably runtime-allocated, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelMultiNotifyWaitAck would
+// also be identical.
+
+func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ // The sleep package doesn't provide sync.Pool allocation of Wakers;
+ // we do for a fairer comparison.
+ wakerPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Waker{}
+ },
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := wakerPool.Get().(*sleep.Waker)
+ s.AddWaker(w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ wakerPool.Put(w)
+ }
+}
+
+// BenchmarkXxxTempNotifyWaitAck is equivalent to NotifyWaitAck, but also
+// includes allocation of a temporary Waiter. This models the case where a
+// goroutine not already associated with a Waiter needs one in order to block.
+//
+// The analogous state for channels is built into runtime.g, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelTempNotifyWaitAck would be
+// identical.
+
+func BenchmarkWaiterTempNotifyWaitAck(b *testing.B) {
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := GetWaiter()
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ PutWaiter(w)
+ }
+}
+
+func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) {
+ // The sleep package doesn't provide sync.Pool allocation of Sleepers;
+ // we do for a fairer comparison.
+ sleeperPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Sleeper{}
+ },
+ }
+ var w sleep.Waker
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s := sleeperPool.Get().(*sleep.Sleeper)
+ s.AddWaker(&w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ sleeperPool.Put(s)
+ }
+}
+
+// BenchmarkXxxNotifyWaitMultiAck is equivalent to NotifyWaitAck, but allows
+// for multiple event sources.
+
+func BenchmarkWaiterNotifyWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ws[0].Assert()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, wanted 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+
+ // wait + clear
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitAck measures how long it takes to wait for an
+// event while another goroutine signals the event. This assumes that a new
+// goroutine doesn't run immediately (i.e. the creator of a new goroutine is
+// allowed to go to sleep before the new goroutine has a chance to run).
+
+func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(1)
+ }()
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Assert()
+ }()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }()
+ <-ch
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitMultiAck is equivalent to NotifyAsyncWaitAck, but
+// allows for multiple event sources.
+
+func BenchmarkWaiterNotifyAsyncWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(evBench)
+ }()
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ ws[0].Assert()
+ }()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, expected 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+ }()
+
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
new file mode 100644
index 000000000..112e0e604
--- /dev/null
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -0,0 +1,206 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.11
+// +build !go1.15
+
+// Check go:linkname function signatures when updating Go version.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
+
+//go:linkname goready runtime.goready
+func goready(g unsafe.Pointer, traceskip int)
+
+const (
+ waitReasonSelect = 9 // Go: src/runtime/runtime2.go
+ traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go
+)
+
+// Waiter allows a goroutine to block on pending events received by a Receiver.
+//
+// Waiter.Init() must be called before first use.
+type Waiter struct {
+ r Receiver
+
+ // g is one of:
+ //
+ // - nil: No goroutine is blocking in Wait.
+ //
+ // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
+ // completed waiterUnlock(). Thus the wait can only be interrupted by
+ // replacing the value of g with nil (the G may not be in state Gwaiting
+ // yet, so we can't call goready.)
+ //
+ // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the
+ // goroutine blocked in Wait, which can only be woken by calling goready.
+ g unsafe.Pointer `state:"zerovalue"`
+}
+
+// Sentinel object for Waiter.g.
+var preparingG struct{}
+
+// Init must be called before first use of w.
+func (w *Waiter) Init() {
+ w.r.Init(w)
+}
+
+// Receiver returns the Receiver that receives events that unblock calls to
+// w.Wait().
+func (w *Waiter) Receiver() *Receiver {
+ return &w.r
+}
+
+// Pending returns the set of pending events.
+func (w *Waiter) Pending() Set {
+ return w.r.Pending()
+}
+
+// Wait blocks until at least one event is pending, then returns the set of
+// pending events. It does not affect the set of pending events; callers must
+// call w.Ack() to do so, or use w.WaitAndAck() instead.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) Wait() Set {
+ return w.WaitFor(AllEvents)
+}
+
+// WaitFor blocks until at least one event in es is pending, then returns the
+// set of pending events (including those not in es). It does not affect the
+// set of pending events; callers must call w.Ack() to do so.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitFor(es Set) Set {
+ for {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending.
+ if p := w.r.Pending(); p&es != NoEvents {
+ return p
+ }
+
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if p := w.r.Pending(); p&es != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ }
+}
+
+// Ack marks the given events as not pending.
+func (w *Waiter) Ack(es Set) {
+ w.r.Ack(es)
+}
+
+// WaitAndAckAll blocks until at least one event is pending, then marks all
+// events as not pending and returns the set of previously-pending events.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitAndAckAll() Set {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending. Call Pending() first since, in the common case that events are
+ // not yet pending, this skips an atomic swap on w.r.pending.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+
+ for {
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+
+ // Check for pending events. We call PendingAndAckAll() directly now since
+ // we only expect to be woken after events become pending.
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+}
+
+// Notify marks the given events as pending, possibly unblocking concurrent
+// calls to w.Wait() or w.WaitFor().
+func (w *Waiter) Notify(es Set) {
+ w.r.Notify(es)
+}
+
+// NotifyPending implements ReceiverCallback.NotifyPending. Users of Waiter
+// should not call NotifyPending.
+func (w *Waiter) NotifyPending() {
+ // Optimization: Skip the atomic swap on w.g if there is no sleeping
+ // goroutine. NotifyPending is called after w.r.Pending() is updated, so
+ // concurrent and future calls to w.Wait() will observe pending events and
+ // abort sleeping.
+ if atomic.LoadPointer(&w.g) == nil {
+ return
+ }
+ // Wake a sleeping G, or prevent a G that is preparing to sleep from doing
+ // so. Swap is needed here to ensure that only one call to NotifyPending
+ // calls goready.
+ if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) {
+ goready(g, 0)
+ }
+}
+
+var waiterPool = sync.Pool{
+ New: func() interface{} {
+ w := &Waiter{}
+ w.Init()
+ return w
+ },
+}
+
+// GetWaiter returns an unused Waiter. PutWaiter should be called to release
+// the Waiter once it is no longer needed.
+//
+// Where possible, users should prefer to associate each goroutine that calls
+// Waiter.Wait() with a distinct pre-allocated Waiter to avoid allocation of
+// Waiters in hot paths.
+func GetWaiter() *Waiter {
+ return waiterPool.Get().(*Waiter)
+}
+
+// PutWaiter releases an unused Waiter previously returned by GetWaiter.
+func PutWaiter(w *Waiter) {
+ waiterPool.Put(w)
+}
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 2269f6237..4b5a0fca6 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -29,6 +29,7 @@ var (
EACCES = error(syscall.EACCES)
EAGAIN = error(syscall.EAGAIN)
EBADF = error(syscall.EBADF)
+ EBADFD = error(syscall.EBADFD)
EBUSY = error(syscall.EBUSY)
ECHILD = error(syscall.ECHILD)
ECONNREFUSED = error(syscall.ECONNREFUSED)
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index ea0a0409a..3c552988a 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
+ _, err := r.CreateEndpoint(&wq)
if err != nil {
t.Fatalf("r.CreateEndpoint() = %v", err)
}
- defer ep.Close()
- r.Complete(false)
-
- c := NewTCPConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
+ // Endpoint will be closed in deferred s.Close (above).
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
stop = func() {
c1.Close()
c2.Close()
+ s.Close()
+ s.Wait()
}
if err := l.Close(); err != nil {
@@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 150310c11..17e94c562 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) {
vv.views = append(vv.views, vv2.views...)
vv.size += vv2.size
}
+
+// AppendView appends the given view into this vectorised view.
+func (vv *VectorisedView) AppendView(v View) {
+ vv.views = append(vv.views, v)
+ vv.size += len(v)
+}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 4d6ae0871..c6c160dfc 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -161,6 +161,20 @@ func FragmentFlags(flags uint8) NetworkChecker {
}
}
+// ReceiveTClass creates a checker that checks the TCLASS field in
+// ControlMessages.
+func ReceiveTClass(want uint32) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTClass {
+ t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
+ }
+ if got := cm.TClass; got != want {
+ t.Fatalf("got cm.TClass = %d, want %d", got, want)
+ }
+ }
+}
+
// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
func ReceiveTOS(want uint8) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index f7dc4f720..80ddbd442 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -156,25 +156,54 @@ func EmptyNatTable() Table {
}
}
+// A chainVerdict is what a table decides should be done with a packet.
+type chainVerdict int
+
+const (
+ // chainAccept indicates the packet should continue through netstack.
+ chainAccept chainVerdict = iota
+
+ // chainAccept indicates the packet should be dropped.
+ chainDrop
+
+ // chainReturn indicates the packet should return to the calling chain
+ // or the underflow rule of a builtin chain.
+ chainReturn
+)
+
+
// Check runs pkt through the rules for hook. It returns true when the packet
// should continue traversing the network stack and false when it should be
// dropped.
//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
- // TODO(gvisor.dev/issue/170): A lot of this is uncomplicated because
- // we're missing features. Jumps, the call stack, etc. aren't checked
- // for yet because we're yet to support them.
-
// Go through each table containing the hook.
for _, tablename := range it.Priorities[hook] {
- switch verdict := it.checkTable(hook, pkt, tablename); verdict {
+ table := it.Tables[tablename]
+ ruleIdx := table.BuiltinChains[hook]
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict {
// If the table returns Accept, move on to the next table.
- case TableAccept:
+ case chainAccept:
continue
// The Drop verdict is final.
- case TableDrop:
+ case chainDrop:
return false
+ case chainReturn:
+ // Any Return from a built-in chain means we have to
+ // call the underflow.
+ underflow := table.Rules[table.Underflows[hook]]
+ switch v, _ := underflow.Target.Action(pkt); v {
+ case RuleAccept:
+ continue
+ case RuleDrop:
+ return false
+ case RuleJump, RuleReturn:
+ panic("Underflows should only return RuleAccept or RuleDrop.")
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", v))
+ }
+
default:
panic(fmt.Sprintf("Unknown verdict %v.", verdict))
}
@@ -185,37 +214,37 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
}
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) TableVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
- table := it.Tables[tablename]
- for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ {
- switch verdict := it.checkRule(hook, pkt, table, ruleIdx); verdict {
+ for ruleIdx < len(table.Rules) {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict {
case RuleAccept:
- return TableAccept
+ return chainAccept
case RuleDrop:
- return TableDrop
-
- case RuleContinue:
- continue
+ return chainDrop
case RuleReturn:
- // TODO(gvisor.dev/issue/170): We don't implement jump
- // yet, so any Return is from a built-in chain. That
- // means we have to to call the underflow.
- underflow := table.Rules[table.Underflows[hook]]
- // Underflow is guaranteed to be an unconditional
- // ACCEPT or DROP.
- switch v, _ := underflow.Target.Action(pkt, underflow.Filter); v {
- case RuleAccept:
- return TableAccept
- case RuleDrop:
- return TableDrop
- case RuleContinue, RuleReturn:
- panic("Underflows should only return RuleAccept or RuleDrop.")
+ return chainReturn
+
+ case RuleJump:
+ // "Jumping" to the next rule just means we're
+ // continuing on down the list.
+ if jumpTo == ruleIdx+1 {
+ ruleIdx++
+ continue
+ }
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict {
+ case chainAccept:
+ return chainAccept
+ case chainDrop:
+ return chainDrop
+ case chainReturn:
+ ruleIdx++
+ continue
default:
- panic(fmt.Sprintf("Unknown verdict: %d", v))
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
}
default:
@@ -226,11 +255,11 @@ func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename stri
// We got through the entire table without a decision. Default to DROP
// for safety.
- return TableDrop
+ return chainDrop
}
// Precondition: pk.NetworkHeader is set.
-func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) RuleVerdict {
+func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// If pkt.NetworkHeader hasn't been set yet, it will be contained in
@@ -242,7 +271,8 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru
// First check whether the packet matches the IP header filter.
// TODO(gvisor.dev/issue/170): Support other fields of the filter.
if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() {
- return RuleContinue
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
}
// Go through each rule matcher. If they all match, run
@@ -250,14 +280,14 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru
for _, matcher := range rule.Matchers {
matches, hotdrop := matcher.Match(hook, pkt, "")
if hotdrop {
- return RuleDrop
+ return RuleDrop, 0
}
if !matches {
- return RuleContinue
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
}
}
// All the matchers matched, so run the target.
- verdict, _ := rule.Target.Action(pkt, rule.Filter)
- return verdict
+ return rule.Target.Action(pkt, rule.Filter)
}
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
index a75938da3..5dbb28145 100644
--- a/pkg/tcpip/iptables/targets.go
+++ b/pkg/tcpip/iptables/targets.go
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// This file contains various Targets.
-
package iptables
import (
@@ -26,16 +24,16 @@ import (
type AcceptTarget struct{}
// Action implements Target.Action.
-func (AcceptTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) {
- return RuleAccept, ""
+func (AcceptTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) {
+ return RuleAccept, 0
}
// DropTarget drops packets.
type DropTarget struct{}
// Action implements Target.Action.
-func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) {
- return RuleDrop, ""
+func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) {
+ return RuleDrop, 0
}
// ErrorTarget logs an error and drops the packet. It represents a target that
@@ -43,9 +41,9 @@ func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (Rule
type ErrorTarget struct{}
// Action implements Target.Action.
-func (ErrorTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) {
+func (ErrorTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
- return RuleDrop, ""
+ return RuleDrop, 0
}
// UserChainTarget marks a rule as the beginning of a user chain.
@@ -54,7 +52,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, string) {
+func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -63,8 +61,8 @@ func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict,
type ReturnTarget struct{}
// Action implements Target.Action.
-func (ReturnTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, string) {
- return RuleReturn, ""
+func (ReturnTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, int) {
+ return RuleReturn, 0
}
// RedirectTarget redirects the packet by modifying the destination port/IP.
@@ -88,13 +86,13 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) {
+func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) {
headerView := pkt.Data.First()
// Network header should be set.
netHeader := header.IPv4(headerView)
if netHeader == nil {
- return RuleDrop, ""
+ return RuleDrop, 0
}
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
@@ -111,7 +109,7 @@ func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) (
tcp := header.TCP(headerView[hlen:])
tcp.SetDestinationPort(rt.MinPort)
default:
- return RuleDrop, ""
+ return RuleDrop, 0
}
- return RuleAccept, ""
+ return RuleAccept, 0
}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 0102831d0..8bd3a2c94 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -56,17 +56,6 @@ const (
NumHooks
)
-// A TableVerdict is what a table decides should be done with a packet.
-type TableVerdict int
-
-const (
- // TableAccept indicates the packet should continue through netstack.
- TableAccept TableVerdict = iota
-
- // TableDrop indicates the packet should be dropped.
- TableDrop
-)
-
// A RuleVerdict is what a rule decides should be done with a packet.
type RuleVerdict int
@@ -74,12 +63,12 @@ const (
// RuleAccept indicates the packet should continue through netstack.
RuleAccept RuleVerdict = iota
- // RuleContinue indicates the packet should continue to the next rule.
- RuleContinue
-
// RuleDrop indicates the packet should be dropped.
RuleDrop
+ // RuleJump indicates the packet should jump to another chain.
+ RuleJump
+
// RuleReturn indicates the packet should return to the previous chain.
RuleReturn
)
@@ -175,5 +164,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the name of the chain to jump to.
- Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string)
+ Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 3974c464e..b8b93e78e 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = ["channel.go"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 78d447acd..5944ba190 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -20,6 +20,7 @@ package channel
import (
"context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -33,6 +34,118 @@ type PacketInfo struct {
Route stack.Route
}
+// Notification is the interface for receiving notification from the packet
+// queue.
+type Notification interface {
+ // WriteNotify will be called when a write happens to the queue.
+ WriteNotify()
+}
+
+// NotificationHandle is an opaque handle to the registered notification target.
+// It can be used to unregister the notification when no longer interested.
+//
+// +stateify savable
+type NotificationHandle struct {
+ n Notification
+}
+
+type queue struct {
+ // mu protects fields below.
+ mu sync.RWMutex
+ // c is the outbound packet channel. Sending to c should hold mu.
+ c chan PacketInfo
+ numWrite int
+ numRead int
+ notify []*NotificationHandle
+}
+
+func (q *queue) Close() {
+ close(q.c)
+}
+
+func (q *queue) Read() (PacketInfo, bool) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ select {
+ case p := <-q.c:
+ q.numRead++
+ return p, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ // We have to receive from channel without holding the lock, since it can
+ // block indefinitely. This will cause a window that numWrite - numRead
+ // produces a larger number, but won't go to negative. numWrite >= numRead
+ // still holds.
+ select {
+ case pkt := <-q.c:
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.numRead++
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) Write(p PacketInfo) bool {
+ wrote := false
+
+ // It's important to make sure nobody can see numWrite until we increment it,
+ // so numWrite >= numRead holds.
+ q.mu.Lock()
+ select {
+ case q.c <- p:
+ wrote = true
+ q.numWrite++
+ default:
+ }
+ notify := q.notify
+ q.mu.Unlock()
+
+ if wrote {
+ // Send notification outside of lock.
+ for _, h := range notify {
+ h.n.WriteNotify()
+ }
+ }
+ return wrote
+}
+
+func (q *queue) Num() int {
+ q.mu.RLock()
+ defer q.mu.RUnlock()
+ n := q.numWrite - q.numRead
+ if n < 0 {
+ panic("numWrite < numRead")
+ }
+ return n
+}
+
+func (q *queue) AddNotify(notify Notification) *NotificationHandle {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ h := &NotificationHandle{n: notify}
+ q.notify = append(q.notify, h)
+ return h
+}
+
+func (q *queue) RemoveNotify(handle *NotificationHandle) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ // Make a copy, since we reads the array outside of lock when notifying.
+ notify := make([]*NotificationHandle, 0, len(q.notify))
+ for _, h := range q.notify {
+ if h != handle {
+ notify = append(notify, h)
+ }
+ }
+ q.notify = notify
+}
+
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
@@ -41,14 +154,16 @@ type Endpoint struct {
linkAddr tcpip.LinkAddress
LinkEPCapabilities stack.LinkEndpointCapabilities
- // c is where outbound packets are queued.
- c chan PacketInfo
+ // Outbound packet queue.
+ q *queue
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- c: make(chan PacketInfo, size),
+ q: &queue{
+ c: make(chan PacketInfo, size),
+ },
mtu: mtu,
linkAddr: linkAddr,
}
@@ -57,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
// Close closes e. Further packet injections will panic. Reads continue to
// succeed until all packets are read.
func (e *Endpoint) Close() {
- close(e.c)
+ e.q.Close()
}
-// Read does non-blocking read for one packet from the outbound packet queue.
+// Read does non-blocking read one packet from the outbound packet queue.
func (e *Endpoint) Read() (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- default:
- return PacketInfo{}, false
- }
+ return e.q.Read()
}
// ReadContext does blocking read for one packet from the outbound packet queue.
// It can be cancelled by ctx, and in this case, it returns false.
func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- case <-ctx.Done():
- return PacketInfo{}, false
- }
+ return e.q.ReadContext(ctx)
}
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
- select {
- case <-e.c:
- c++
- default:
+ if _, ok := e.Read(); !ok {
return c
}
+ c++
}
}
+// NumQueued returns the number of packet queued for outbound.
+func (e *Endpoint) NumQueued() int {
+ return e.q.Num()
+}
+
// InjectInbound injects an inbound packet.
func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
@@ -155,10 +263,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
Route: route,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
@@ -171,7 +276,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
route.Release()
payloadView := pkts[0].Data.ToView()
n := 0
-packetLoop:
for _, pkt := range pkts {
off := pkt.DataOffset
size := pkt.DataSize
@@ -185,12 +289,10 @@ packetLoop:
Route: route,
}
- select {
- case e.c <- p:
- n++
- default:
- break packetLoop
+ if !e.q.Write(p) {
+ break
}
+ n++
}
return n, nil
@@ -204,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
GSO: nil,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
+
+// AddNotify adds a notification target for receiving event about outgoing
+// packets.
+func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
+ return e.q.AddNotify(notify)
+}
+
+// RemoveNotify removes handle from the list of notification targets.
+func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
+ e.q.RemoveNotify(handle)
+}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index e5096ea38..e0db6cf54 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -4,6 +4,22 @@ package(licenses = ["notice"])
go_library(
name = "tun",
- srcs = ["tun_unsafe.go"],
+ srcs = [
+ "device.go",
+ "protocol.go",
+ "tun_unsafe.go",
+ ],
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/refs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
)
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
new file mode 100644
index 000000000..6ff47a742
--- /dev/null
+++ b/pkg/tcpip/link/tun/device.go
@@ -0,0 +1,352 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // drivers/net/tun.c:tun_net_init()
+ defaultDevMtu = 1500
+
+ // Queue length for outbound packet, arriving at fd side for read. Overflow
+ // causes packet drops. gVisor implementation-specific.
+ defaultDevOutQueueLen = 1024
+)
+
+var zeroMAC [6]byte
+
+// Device is an opened /dev/net/tun device.
+//
+// +stateify savable
+type Device struct {
+ waiter.Queue
+
+ mu sync.RWMutex `state:"nosave"`
+ endpoint *tunEndpoint
+ notifyHandle *channel.NotificationHandle
+ flags uint16
+}
+
+// beforeSave is invoked by stateify.
+func (d *Device) beforeSave() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ // TODO(b/110961832): Restore the device to stack. At this moment, the stack
+ // is not savable.
+ if d.endpoint != nil {
+ panic("/dev/net/tun does not support save/restore when a device is associated with it.")
+ }
+}
+
+// Release implements fs.FileOperations.Release.
+func (d *Device) Release() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Decrease refcount if there is an endpoint associated with this file.
+ if d.endpoint != nil {
+ d.endpoint.RemoveNotify(d.notifyHandle)
+ d.endpoint.DecRef()
+ d.endpoint = nil
+ }
+}
+
+// SetIff services TUNSETIFF ioctl(2) request.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.endpoint != nil {
+ return syserror.EINVAL
+ }
+
+ // Input validations.
+ isTun := flags&linux.IFF_TUN != 0
+ isTap := flags&linux.IFF_TAP != 0
+ supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
+ if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
+ return syserror.EINVAL
+ }
+
+ prefix := "tun"
+ if isTap {
+ prefix = "tap"
+ }
+
+ endpoint, err := attachOrCreateNIC(s, name, prefix)
+ if err != nil {
+ return syserror.EINVAL
+ }
+
+ d.endpoint = endpoint
+ d.notifyHandle = d.endpoint.AddNotify(d)
+ d.flags = flags
+ return nil
+}
+
+func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) {
+ for {
+ // 1. Try to attach to an existing NIC.
+ if name != "" {
+ if nic, found := s.GetNICByName(name); found {
+ endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if !ok {
+ // Not a NIC created by tun device.
+ return nil, syserror.EOPNOTSUPP
+ }
+ if !endpoint.TryIncRef() {
+ // Race detected: NIC got deleted in between.
+ continue
+ }
+ return endpoint, nil
+ }
+ }
+
+ // 2. Creating a new NIC.
+ id := tcpip.NICID(s.UniqueID())
+ endpoint := &tunEndpoint{
+ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
+ stack: s,
+ nicID: id,
+ name: name,
+ }
+ if endpoint.name == "" {
+ endpoint.name = fmt.Sprintf("%s%d", prefix, id)
+ }
+ err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
+ Name: endpoint.name,
+ })
+ switch err {
+ case nil:
+ return endpoint, nil
+ case tcpip.ErrDuplicateNICID:
+ // Race detected: A NIC has been created in between.
+ continue
+ default:
+ return nil, syserror.EINVAL
+ }
+ }
+}
+
+// Write inject one inbound packet to the network interface.
+func (d *Device) Write(data []byte) (int64, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return 0, syserror.EBADFD
+ }
+ if !endpoint.IsAttached() {
+ return 0, syserror.EIO
+ }
+
+ dataLen := int64(len(data))
+
+ // Packet information.
+ var pktInfoHdr PacketInfoHeader
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ if len(data) < PacketInfoHeaderSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize])
+ data = data[PacketInfoHeaderSize:]
+ }
+
+ // Ethernet header (TAP only).
+ var ethHdr header.Ethernet
+ if d.hasFlags(linux.IFF_TAP) {
+ if len(data) < header.EthernetMinimumSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ ethHdr = header.Ethernet(data[:header.EthernetMinimumSize])
+ data = data[header.EthernetMinimumSize:]
+ }
+
+ // Try to determine network protocol number, default zero.
+ var protocol tcpip.NetworkProtocolNumber
+ switch {
+ case pktInfoHdr != nil:
+ protocol = pktInfoHdr.Protocol()
+ case ethHdr != nil:
+ protocol = ethHdr.Type()
+ }
+
+ // Try to determine remote link address, default zero.
+ var remote tcpip.LinkAddress
+ switch {
+ case ethHdr != nil:
+ remote = ethHdr.SourceAddress()
+ default:
+ remote = tcpip.LinkAddress(zeroMAC[:])
+ }
+
+ pkt := tcpip.PacketBuffer{
+ Data: buffer.View(data).ToVectorisedView(),
+ }
+ if ethHdr != nil {
+ pkt.LinkHeader = buffer.View(ethHdr)
+ }
+ endpoint.InjectLinkAddr(protocol, remote, pkt)
+ return dataLen, nil
+}
+
+// Read reads one outgoing packet from the network interface.
+func (d *Device) Read() ([]byte, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return nil, syserror.EBADFD
+ }
+
+ for {
+ info, ok := endpoint.Read()
+ if !ok {
+ return nil, syserror.ErrWouldBlock
+ }
+
+ v, ok := d.encodePkt(&info)
+ if !ok {
+ // Ignore unsupported packet.
+ continue
+ }
+ return v, nil
+ }
+}
+
+// encodePkt encodes packet for fd side.
+func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
+ var vv buffer.VectorisedView
+
+ // Packet information.
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ hdr := make(PacketInfoHeader, PacketInfoHeaderSize)
+ hdr.Encode(&PacketInfoFields{
+ Protocol: info.Proto,
+ })
+ vv.AppendView(buffer.View(hdr))
+ }
+
+ // If the packet does not already have link layer header, and the route
+ // does not exist, we can't compute it. This is possibly a raw packet, tun
+ // device doesn't support this at the moment.
+ if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" {
+ return nil, false
+ }
+
+ // Ethernet header (TAP only).
+ if d.hasFlags(linux.IFF_TAP) {
+ // Add ethernet header if not provided.
+ if info.Pkt.LinkHeader == nil {
+ hdr := &header.EthernetFields{
+ SrcAddr: info.Route.LocalLinkAddress,
+ DstAddr: info.Route.RemoteLinkAddress,
+ Type: info.Proto,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = d.endpoint.LinkAddress()
+ }
+
+ eth := make(header.Ethernet, header.EthernetMinimumSize)
+ eth.Encode(hdr)
+ vv.AppendView(buffer.View(eth))
+ } else {
+ vv.AppendView(info.Pkt.LinkHeader)
+ }
+ }
+
+ // Append upper headers.
+ vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):]))
+ // Append data payload.
+ vv.Append(info.Pkt.Data)
+
+ return vv.ToView(), true
+}
+
+// Name returns the name of the attached network interface. Empty string if
+// unattached.
+func (d *Device) Name() string {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ if d.endpoint != nil {
+ return d.endpoint.name
+ }
+ return ""
+}
+
+// Flags returns the flags set for d. Zero value if unset.
+func (d *Device) Flags() uint16 {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ return d.flags
+}
+
+func (d *Device) hasFlags(flags uint16) bool {
+ return d.flags&flags == flags
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn != 0 {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint != nil && endpoint.NumQueued() == 0 {
+ mask &= ^waiter.EventIn
+ }
+ }
+ return mask & (waiter.EventIn | waiter.EventOut)
+}
+
+// WriteNotify implements channel.Notification.WriteNotify.
+func (d *Device) WriteNotify() {
+ d.Notify(waiter.EventIn)
+}
+
+// tunEndpoint is the link endpoint for the NIC created by the tun device.
+//
+// It is ref-counted as multiple opening files can attach to the same NIC.
+// The last owner is responsible for deleting the NIC.
+type tunEndpoint struct {
+ *channel.Endpoint
+
+ refs.AtomicRefCount
+
+ stack *stack.Stack
+ nicID tcpip.NICID
+ name string
+}
+
+// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
+func (e *tunEndpoint) DecRef() {
+ e.DecRefWithDestructor(func() {
+ e.stack.RemoveNIC(e.nicID)
+ })
+}
diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go
new file mode 100644
index 000000000..89d9d91a9
--- /dev/null
+++ b/pkg/tcpip/link/tun/protocol.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // PacketInfoHeaderSize is the size of the packet information header.
+ PacketInfoHeaderSize = 4
+
+ offsetFlags = 0
+ offsetProtocol = 2
+)
+
+// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is
+// not set.
+type PacketInfoFields struct {
+ Flags uint16
+ Protocol tcpip.NetworkProtocolNumber
+}
+
+// PacketInfoHeader is the wire representation of the packet information sent if
+// IFF_NO_PI flag is not set.
+type PacketInfoHeader []byte
+
+// Encode encodes f into h.
+func (h PacketInfoHeader) Encode(f *PacketInfoFields) {
+ binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags)
+ binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol))
+}
+
+// Flags returns the flag field in h.
+func (h PacketInfoHeader) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[offsetFlags:])
+}
+
+// Protocol returns the protocol field in h.
+func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:]))
+}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 4da13c5df..e9fcc89a8 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
}, nil
}
-// LinkAddressProtocol implements stack.LinkAddressResolver.
+// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
-// LinkAddressRequest implements stack.LinkAddressResolver.
+// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
RemoteLinkAddress: broadcastMAC,
@@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
}
-// ResolveStaticAddress implements stack.LinkAddressResolver.
+// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return broadcastMAC, true
@@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
return tcpip.LinkAddress([]byte(nil)), false
}
-// SetOption implements NetworkProtocol.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.NetworkProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements NetworkProtocol.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.NetworkProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
// NewProtocol returns an ARP network protocol.
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 6597e6781..4f1742938 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 180a480fd..9aef5234b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 045409bda..f651871ce 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -1148,22 +1148,27 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
return true
}
-// cleanupHostOnlyState cleans up any state that is only useful for hosts.
+// cleanupState cleans up ndp's state.
//
-// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a
-// host to a router. This function will invalidate all discovered on-link
-// prefixes, discovered routers, and auto-generated addresses as routers do not
-// normally process Router Advertisements to discover default routers and
-// on-link prefixes, and auto-generate addresses via SLAAC.
+// If hostOnly is true, then only host-specific state will be cleaned up.
+//
+// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
+// transitioning from a host to a router. This function will invalidate all
+// discovered on-link prefixes, discovered routers, and auto-generated
+// addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address will not be
+// invalidated as routers are also expected to generate a link-local address.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupHostOnlyState() {
+func (ndp *ndpState) cleanupState(hostOnly bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
linkLocalAddrs := 0
for addr := range ndp.autoGenAddresses {
// RFC 4862 section 5 states that routers are also expected to generate a
- // link-local address so we do not invalidate them.
- if linkLocalSubnet.Contains(addr) {
+ // link-local address so we do not invalidate them if we are cleaning up
+ // host-only state.
+ if hostOnly && linkLocalSubnet.Contains(addr) {
linkLocalAddrs++
continue
}
@@ -1230,7 +1235,7 @@ func (ndp *ndpState) startSolicitingRouters() {
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
pkt := header.ICMPv6(hdr.Prepend(payloadSize))
pkt.SetType(header.ICMPv6RouterSolicit)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 1f6f77439..6e9306d09 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -267,6 +267,17 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s
}
}
+// channelLinkWithHeaderLength is a channel.Endpoint with a configurable
+// header length.
+type channelLinkWithHeaderLength struct {
+ *channel.Endpoint
+ headerLength uint16
+}
+
+func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
+ return l.headerLength
+}
+
// Check e to make sure that the event is for addr on nic with ID 1, and the
// resolved flag set to resolved with the specified err.
func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
@@ -323,21 +334,46 @@ func TestDADDisabled(t *testing.T) {
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+// This tests also validates the NDP NS packet that is transmitted.
func TestDADResolve(t *testing.T) {
const nicID = 1
tests := []struct {
name string
+ linkHeaderLen uint16
dupAddrDetectTransmits uint8
retransTimer time.Duration
expectedRetransmitTimer time.Duration
}{
- {"1:1s:1s", 1, time.Second, time.Second},
- {"2:1s:1s", 2, time.Second, time.Second},
- {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
+ {
+ name: "1:1s:1s",
+ dupAddrDetectTransmits: 1,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "2:1s:1s",
+ linkHeaderLen: 1,
+ dupAddrDetectTransmits: 2,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "1:2s:2s",
+ linkHeaderLen: 2,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 2 * time.Second,
+ expectedRetransmitTimer: 2 * time.Second,
+ },
// 0s is an invalid RetransmitTimer timer and will be fixed to
// the default RetransmitTimer value of 1s.
- {"1:0s:1s", 1, 0, time.Second},
+ {
+ name: "1:0s:1s",
+ linkHeaderLen: 3,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 0,
+ expectedRetransmitTimer: time.Second,
+ },
}
for _, test := range tests {
@@ -356,10 +392,13 @@ func TestDADResolve(t *testing.T) {
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
- e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
+ if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -445,6 +484,10 @@ func TestDADResolve(t *testing.T) {
checker.NDPNSTargetAddress(addr1),
checker.NDPNSOptions(nil),
))
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
})
}
@@ -592,70 +635,94 @@ func TestDADFail(t *testing.T) {
}
}
-// TestDADStop tests to make sure that the DAD process stops when an address is
-// removed.
func TestDADStop(t *testing.T) {
const nicID = 1
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- ndpConfigs := stack.NDPConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 2,
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
+ tests := []struct {
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests to make sure that DAD stops when an address is removed.
+ {
+ name: "Remove address",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err)
+ }
+ },
+ },
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ // Tests to make sure that DAD stops when the NIC is disabled.
+ {
+ name: "Disable NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
- // Remove the address. This should stop DAD.
- if err := s.RemoveAddress(nicID, addr1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
- }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
- // Wait for DAD to fail (since the address was removed during DAD).
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ test.stopFn(t, s)
+
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
- // Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
- t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Errorf("got NeighborSolicit = %d, want <= 1", got)
+ }
+ })
}
}
@@ -2886,17 +2953,16 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
}
-// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers
-// and prefixes, and non-linklocal auto-generated addresses are invalidated when
-// a NIC becomes a router.
-func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
+// TestCleanupNDPState tests that all discovered routers and prefixes, and
+// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestCleanupNDPState(t *testing.T) {
t.Parallel()
const (
- lifetimeSeconds = 5
- maxEvents = 4
- nicID1 = 1
- nicID2 = 2
+ lifetimeSeconds = 5
+ maxRouterAndPrefixEvents = 4
+ nicID1 = 1
+ nicID2 = 2
)
prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2912,254 +2978,308 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
PrefixLen: 64,
}
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents),
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
+ tests := []struct {
+ name string
+ cleanupFn func(t *testing.T, s *stack.Stack)
+ keepAutoGenLinkLocal bool
+ maxAutoGenAddrEvents int
+ }{
+ // A NIC should still keep its auto-generated link-local address when
+ // becoming a router.
+ {
+ name: "Forwarding Enable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
+ keepAutoGenLinkLocal: true,
+ maxAutoGenAddrEvents: 4,
},
- NDPDisp: &ndpDisp,
- })
- expectRouterEvent := func() (bool, ndpRouterEvent) {
- select {
- case e := <-ndpDisp.routerC:
- return true, e
- default:
- }
+ // A NIC should cleanup all NDP state when it is disabled.
+ {
+ name: "NIC Disable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- return false, ndpRouterEvent{}
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ },
}
- expectPrefixEvent := func() (bool, ndpPrefixEvent) {
- select {
- case e := <-ndpDisp.prefixC:
- return true, e
- default:
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
- return false, ndpPrefixEvent{}
- }
+ expectRouterEvent := func() (bool, ndpRouterEvent) {
+ select {
+ case e := <-ndpDisp.routerC:
+ return true, e
+ default:
+ }
- expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
- select {
- case e := <-ndpDisp.autoGenAddrC:
- return true, e
- default:
- }
+ return false, ndpRouterEvent{}
+ }
- return false, ndpAutoGenAddrEvent{}
- }
+ expectPrefixEvent := func() (bool, ndpPrefixEvent) {
+ select {
+ case e := <-ndpDisp.prefixC:
+ return true, e
+ default:
+ }
- e1 := channel.New(0, 1280, linkAddr1)
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
- }
- // We have other tests that make sure we receive the *correct* events
- // on normal discovery of routers/prefixes, and auto-generated
- // addresses. Here we just make sure we get an event and let other tests
- // handle the correctness check.
- expectAutoGenAddrEvent()
+ return false, ndpPrefixEvent{}
+ }
- e2 := channel.New(0, 1280, linkAddr2)
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
- }
- expectAutoGenAddrEvent()
+ expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ return true, e
+ default:
+ }
- // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
- // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
- // llAddr4) to discover multiple routers and prefixes, and auto-gen
- // multiple addresses.
+ return false, ndpAutoGenAddrEvent{}
+ }
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
- }
+ e1 := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
+ }
+ // We have other tests that make sure we receive the *correct* events
+ // on normal discovery of routers/prefixes, and auto-generated
+ // addresses. Here we just make sure we get an event and let other tests
+ // handle the correctness check.
+ expectAutoGenAddrEvent()
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
- }
+ e2 := channel.New(0, 1280, linkAddr2)
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ }
+ expectAutoGenAddrEvent()
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
- }
+ // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
+ // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
+ // llAddr4) to discover multiple routers and prefixes, and auto-gen
+ // multiple addresses.
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
+ }
- // We should have the auto-generated addresses added.
- nicinfo := s.NICInfo()
- nic1Addrs := nicinfo[nicID1].ProtocolAddresses
- nic2Addrs := nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
+ }
- // We can't proceed any further if we already failed the test (missing
- // some discovery/auto-generated address events or addresses).
- if t.Failed() {
- t.FailNow()
- }
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
+ }
- s.SetForwarding(true)
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
+ }
- // Collect invalidation events after becoming a router
- gotRouterEvents := make(map[ndpRouterEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectRouterEvent()
- if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotRouterEvents[e]++
- }
- gotPrefixEvents := make(map[ndpPrefixEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectPrefixEvent()
- if !ok {
- t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotPrefixEvents[e]++
- }
- gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectAutoGenAddrEvent()
- if !ok {
- t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotAutoGenAddrEvents[e]++
- }
+ // We should have the auto-generated addresses added.
+ nicinfo := s.NICInfo()
+ nic1Addrs := nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs := nicinfo[nicID2].ProtocolAddresses
+ if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
- // No need to proceed any further if we already failed the test (missing
- // some invalidation events).
- if t.Failed() {
- t.FailNow()
- }
+ // We can't proceed any further if we already failed the test (missing
+ // some discovery/auto-generated address events or addresses).
+ if t.Failed() {
+ t.FailNow()
+ }
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
- }
- expectedPrefixEvents := map[ndpPrefixEvent]int{
- {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
- t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
- }
- expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
- {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
- }
- if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
- t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
- }
+ test.cleanupFn(t, s)
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ // Collect invalidation events after having NDP state cleaned up.
+ gotRouterEvents := make(map[ndpRouterEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectRouterEvent()
+ if !ok {
+ t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotRouterEvents[e]++
+ }
+ gotPrefixEvents := make(map[ndpPrefixEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectPrefixEvent()
+ if !ok {
+ t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotPrefixEvents[e]++
+ }
+ gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
+ for i := 0; i < test.maxAutoGenAddrEvents; i++ {
+ ok, e := expectAutoGenAddrEvent()
+ if !ok {
+ t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i)
+ break
+ }
+ gotAutoGenAddrEvents[e]++
+ }
- // Should not get any more events (invalidation timers should have been
- // cancelled when we transitioned into a router).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
- select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
- default:
- }
- select {
- case <-ndpDisp.prefixC:
- t.Error("unexpected prefix event")
- default:
- }
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Error("unexpected auto-generated address event")
- default:
+ // No need to proceed any further if we already failed the test (missing
+ // some invalidation events).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ expectedRouterEvents := map[ndpRouterEvent]int{
+ {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
+ t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ }
+ expectedPrefixEvents := map[ndpPrefixEvent]int{
+ {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
+ t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
+ }
+ expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
+ {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
+ }
+
+ if !test.keepAutoGenLinkLocal {
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1
+ }
+
+ if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
+ t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
+ }
+
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ }
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+
+ // Should not get any more events (invalidation timers should have been
+ // cancelled when the NDP state was cleaned up).
+ time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ select {
+ case <-ndpDisp.routerC:
+ t.Error("unexpected router event")
+ default:
+ }
+ select {
+ case <-ndpDisp.prefixC:
+ t.Error("unexpected prefix event")
+ default:
+ }
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Error("unexpected auto-generated address event")
+ default:
+ }
+ })
}
}
@@ -3259,8 +3379,11 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
func TestRouterSolicitation(t *testing.T) {
t.Parallel()
+ const nicID = 1
+
tests := []struct {
name string
+ linkHeaderLen uint16
maxRtrSolicit uint8
rtrSolicitInt time.Duration
effectiveRtrSolicitInt time.Duration
@@ -3277,6 +3400,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS with delay",
+ linkHeaderLen: 1,
maxRtrSolicit: 2,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3285,6 +3409,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Single RS without delay",
+ linkHeaderLen: 2,
maxRtrSolicit: 1,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3293,6 +3418,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS without delay and invalid zero interval",
+ linkHeaderLen: 3,
maxRtrSolicit: 2,
rtrSolicitInt: 0,
effectiveRtrSolicitInt: 4 * time.Second,
@@ -3330,8 +3456,11 @@ func TestRouterSolicitation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
- e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
waitForPkt := func(timeout time.Duration) {
t.Helper()
ctx, _ := context.WithTimeout(context.Background(), timeout)
@@ -3357,6 +3486,10 @@ func TestRouterSolicitation(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPRS(),
)
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
@@ -3373,8 +3506,8 @@ func TestRouterSolicitation(t *testing.T) {
MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
// Make sure each RS got sent at the right
@@ -3406,77 +3539,130 @@ func TestRouterSolicitation(t *testing.T) {
})
}
-// TestStopStartSolicitingRouters tests that when forwarding is enabled or
-// disabled, router solicitations are stopped or started, respecitively.
func TestStopStartSolicitingRouters(t *testing.T) {
t.Parallel()
+ const nicID = 1
const interval = 500 * time.Millisecond
const delay = time.Second
const maxRtrSolicitations = 3
- e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
+ tests := []struct {
+ name string
+ startFn func(t *testing.T, s *stack.Stack)
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests that when forwarding is enabled or disabled, router solicitations
+ // are stopped or started, respectively.
+ {
+ name: "Forwarding enabled and disabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(false)
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
- // Enable forwarding which should stop router solicitations.
- s.SetForwarding(true)
- ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- // A single RS may have been sent before forwarding was enabled.
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok = e.ReadContext(ctx); ok {
- t.Fatal("Should not have sent more than one RS message")
- }
- }
+ // Tests that when a NIC is enabled or disabled, router solicitations
+ // are started or stopped, respectively.
+ {
+ name: "NIC disabled and enabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Enabling forwarding again should do nothing.
- s.SetForwarding(true)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
- }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Disable forwarding which should start router solicitations.
- s.SetForwarding(false)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- // Disabling forwarding again should do nothing.
- s.SetForwarding(false)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Stop soliciting routers.
+ test.stopFn(t, s)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ // A single RS may have been sent before forwarding was enabled.
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok = e.ReadContext(ctx); ok {
+ t.Fatal("should not have sent more than one RS message")
+ }
+ }
+
+ // Stopping router solicitations after it has already been stopped should
+ // do nothing.
+ test.stopFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
+ }
+
+ // Start soliciting routers.
+ test.startFn(t, s)
+ waitForPkt(delay + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ }
+
+ // Starting router solicitations after it has already completed should do
+ // nothing.
+ test.startFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after finishing router solicitations")
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a75dc0322..271e55be0 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -28,6 +28,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/iptables"
)
+var ipv4BroadcastAddr = tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 8 * header.IPv4AddressSize,
+ },
+}
+
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
@@ -133,11 +141,77 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
}
+ nic.linkEP.Attach(nic)
+
return nic
}
-// enable enables the NIC. enable will attach the link to its LinkEndpoint and
-// join the IPv6 All-Nodes Multicast address (ff02::1).
+// enabled returns true if n is enabled.
+func (n *NIC) enabled() bool {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ return enabled
+}
+
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if !enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if !n.mu.enabled {
+ return nil
+ }
+
+ // TODO(b/147015577): Should Routes that are currently bound to n be
+ // invalidated? Currently, Routes will continue to work when a NIC is enabled
+ // again, and applications may not know that the underlying NIC was ever
+ // disabled.
+
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ n.mu.ndp.stopSolicitingRouters()
+ n.mu.ndp.cleanupState(false /* hostOnly */)
+
+ // Stop DAD for all the unicast IPv6 endpoints that are in the
+ // permanentTentative state.
+ for _, r := range n.mu.endpoints {
+ if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
+ n.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+ }
+
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ // The address may have already been removed.
+ if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ n.mu.enabled = false
+ return nil
+}
+
+// enable enables n.
+//
+// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast
+// address (ff02::1), start DAD for permanent addresses, and start soliciting
+// routers if the stack is not operating as a router. If the stack is also
+// configured to auto-generate a link-local address, one will be generated.
func (n *NIC) enable() *tcpip.Error {
n.mu.RLock()
enabled := n.mu.enabled
@@ -155,14 +229,9 @@ func (n *NIC) enable() *tcpip.Error {
n.mu.enabled = true
- n.attachLinkEndpoint()
-
// Create an endpoint to receive broadcast packets on this interface.
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
- }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
+ if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
}
@@ -184,6 +253,14 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
// state.
//
@@ -201,10 +278,6 @@ func (n *NIC) enable() *tcpip.Error {
}
}
- if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
- return err
- }
-
// Do not auto-generate an IPv6 link-local address for loopback devices.
if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
// The valid and preferred lifetime is infinite for the auto-generated
@@ -226,6 +299,33 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+// remove detaches NIC from the link endpoint, and marks existing referenced
+// network endpoints expired. This guarantees no packets between this NIC and
+// the network stack.
+func (n *NIC) remove() *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ // Remove permanent and permanentTentative addresses, so no packet goes out.
+ var errs []*tcpip.Error
+ for nid, ref := range n.mu.endpoints {
+ switch ref.getKind() {
+ case permanentTentative, permanent:
+ if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
+ errs = append(errs, err)
+ }
+ }
+ }
+ if len(errs) > 0 {
+ return errs[0]
+ }
+
+ return nil
+}
+
// becomeIPv6Router transitions n into an IPv6 router.
//
// When transitioning into an IPv6 router, host-only state (NDP discovered
@@ -235,7 +335,7 @@ func (n *NIC) becomeIPv6Router() {
n.mu.Lock()
defer n.mu.Unlock()
- n.mu.ndp.cleanupHostOnlyState()
+ n.mu.ndp.cleanupState(true /* hostOnly */)
n.mu.ndp.stopSolicitingRouters()
}
@@ -250,12 +350,6 @@ func (n *NIC) becomeIPv6Host() {
n.mu.ndp.startSolicitingRouters()
}
-// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
-// to start delivering packets.
-func (n *NIC) attachLinkEndpoint() {
- n.linkEP.Attach(n)
-}
-
// setPromiscuousMode enables or disables promiscuous mode.
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
@@ -713,6 +807,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
case permanentExpired, temporary:
continue
}
+
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -1010,6 +1105,15 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
return nil
}
+// isInGroup returns true if n has joined the multicast group addr.
+func (n *NIC) isInGroup(addr tcpip.Address) bool {
+ n.mu.RLock()
+ joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
+ n.mu.RUnlock()
+
+ return joins != 0
+}
+
func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
@@ -1237,6 +1341,11 @@ func (n *NIC) Stack() *Stack {
return n.stack
}
+// LinkEndpoint returns the link endpoint of n.
+func (n *NIC) LinkEndpoint() LinkEndpoint {
+ return n.linkEP
+}
+
// isAddrTentative returns true if addr is tentative on n.
//
// Note that if addr is not associated with n, then this function will return
@@ -1423,7 +1532,7 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
//
// r's NIC must be read locked.
func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- return r.getKind() != permanentExpired || r.nic.mu.spoofing
+ return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ec91f60dd..f9fd8f18f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -74,10 +74,11 @@ type TransportEndpoint interface {
// HandleControlPacket takes ownership of pkt.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
- // Close puts the endpoint in a closed state and frees all resources
- // associated with it. This cleanup may happen asynchronously. Wait can
- // be used to block on this asynchronous cleanup.
- Close()
+ // Abort initiates an expedited endpoint teardown. It puts the endpoint
+ // in a closed state and frees all resources associated with it. This
+ // cleanup may happen asynchronously. Wait can be used to block on this
+ // asynchronous cleanup.
+ Abort()
// Wait waits for any worker goroutines owned by the endpoint to stop.
//
@@ -160,6 +161,13 @@ type TransportProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -277,7 +285,7 @@ type NetworkProtocol interface {
// DefaultPrefixLen returns the protocol's default prefix length.
DefaultPrefixLen() int
- // ParsePorts returns the source and destination addresses stored in a
+ // ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
@@ -293,6 +301,13 @@ type NetworkProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// NetworkDispatcher contains the methods used by the network stack to deliver
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6eac16e16..ebb6c5e3b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -881,6 +881,8 @@ type NICOptions struct {
// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
// NICOptions. See the documentation on type NICOptions for details on how
// NICs can be configured.
+//
+// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -900,7 +902,6 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
n := newNIC(s, id, opts.Name, ep, opts.Context)
-
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -910,34 +911,88 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
-// `LinkEndpoint.Attach` to start delivering packets to it.
+// LinkEndpoint.Attach to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
+// GetNICByName gets the NIC specified by name.
+func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, nic := range s.nics {
+ if nic.Name() == name {
+ return nic, true
+ }
+ }
+ return nil, false
+}
+
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
+ nic, ok := s.nics[id]
+ if !ok {
return tcpip.ErrUnknownNICID
}
return nic.enable()
}
+// DisableNIC disables the given NIC.
+func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.disable()
+}
+
// CheckNIC checks if a NIC is usable.
func (s *Stack) CheckNIC(id tcpip.NICID) bool {
s.mu.RLock()
+ defer s.mu.RUnlock()
+
nic, ok := s.nics[id]
- s.mu.RUnlock()
- if ok {
- return nic.linkEP.IsAttached()
+ if !ok {
+ return false
}
- return false
+
+ return nic.enabled()
+}
+
+// RemoveNIC removes NIC and all related routes from the network stack.
+func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ delete(s.nics, id)
+
+ // Remove routes in-place. n tracks the number of routes written.
+ n := 0
+ for i, r := range s.routeTable {
+ if r.NIC != id {
+ // Keep this route.
+ if i > n {
+ s.routeTable[n] = r
+ }
+ n++
+ }
+ }
+ s.routeTable = s.routeTable[:n]
+
+ return nic.remove()
}
// NICAddressRanges returns a map of NICIDs to their associated subnets.
@@ -989,7 +1044,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
- Running: nic.linkEP.IsAttached(),
+ Running: nic.enabled(),
Promiscuous: nic.isPromiscuousMode(),
Loopback: nic.isLoopback(),
}
@@ -1151,7 +1206,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
- if nic, ok := s.nics[id]; ok {
+ if nic, ok := s.nics[id]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
@@ -1161,7 +1216,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
- if nic, ok := s.nics[route.NIC]; ok {
+ if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
@@ -1391,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
// Endpoints created or modified during this call may not get closed.
func (s *Stack) Close() {
for _, e := range s.RegisteredEndpoints() {
- e.Close()
+ e.Abort()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Close()
+ }
+ for _, p := range s.networkProtocols {
+ p.Close()
}
}
@@ -1409,6 +1470,12 @@ func (s *Stack) Wait() {
for _, e := range s.CleanupEndpoints() {
e.Wait()
}
+ for _, p := range s.transportProtocols {
+ p.proto.Wait()
+ }
+ for _, p := range s.networkProtocols {
+ p.Wait()
+ }
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1614,6 +1681,18 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
return tcpip.ErrUnknownNICID
}
+// IsInGroup returns true if the NIC with ID nicID has joined the multicast
+// group multicastAddr.
+func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.isInGroup(multicastAddr), nil
+ }
+ return false, tcpip.ErrUnknownNICID
+}
+
// IPTables returns the stack's iptables.
func (s *Stack) IPTables() iptables.IPTables {
s.tablesMu.RLock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 7ba604442..edf6bec52 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -234,10 +235,33 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements TransportProtocol.Close.
+func (*fakeNetworkProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeNetworkProtocol) Wait() {}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
+// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
+// that LinkEndpoint.Attach was called.
+type linkEPWithMockedAttach struct {
+ stack.LinkEndpoint
+ attached bool
+}
+
+// Attach implements stack.LinkEndpoint.Attach.
+func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
+ l.LinkEndpoint.Attach(d)
+ l.attached = true
+}
+
+func (l *linkEPWithMockedAttach) isAttached() bool {
+ return l.attached
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
@@ -509,6 +533,296 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr
}
}
+// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to
+// a NetworkDispatcher when the NIC is created.
+func TestAttachToLinkEndpointImmediately(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ nicOpts stack.NICOptions
+ }{
+ {
+ name: "Create enabled NIC",
+ nicOpts: stack.NICOptions{Disabled: false},
+ },
+ {
+ name: "Create disabled NIC",
+ nicOpts: stack.NICOptions{Disabled: true},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+
+ if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
+ }
+ if !e.isAttached() {
+ t.Fatalf("link endpoint not attached to a network disatcher")
+ }
+ })
+ }
+}
+
+func TestDisableUnknownNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ }
+}
+
+func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := loopback.New()
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ checkNIC := func(enabled bool) {
+ t.Helper()
+
+ allNICInfo := s.NICInfo()
+ nicInfo, ok := allNICInfo[nicID]
+ if !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ } else if nicInfo.Flags.Running != enabled {
+ t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled)
+ }
+
+ if got := s.CheckNIC(nicID); got != enabled {
+ t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled)
+ }
+ }
+
+ // NIC should initially report itself as disabled.
+ checkNIC(false)
+
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(true)
+
+ // If the NIC is not reporting a correct enabled status, we cannot trust the
+ // next check so end the test here.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(false)
+}
+
+func TestRoutesWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Disabling NIC1 should result in no routes to odd addresses. Routes to even
+ // addresses should continue to be available as NIC2 is still enabled.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ nic1Dst := tcpip.Address("\x05")
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ nic2Dst := tcpip.Address("\x06")
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Disabling NIC2 should result in no routes to even addresses. No route
+ // should be available to any address as routes to odd addresses were made
+ // unavailable by disabling NIC1 above.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ // Enabling NIC1 should make routes to odd addresses available again. Routes
+ // to even addresses should continue to be unavailable as NIC2 is still
+ // disabled.
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+}
+
+func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ nic1Dst := tcpip.Address("\x05")
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
+
+ nic2Dst := tcpip.Address("\x06")
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
+
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC1 should fail.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC2 should fail.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ // Writes with Routes that use the re-enabled NIC1 should succeed.
+ // TODO(b/147015577): Should we instead completely invalidate all Routes that
+ // were bound to a disabled NIC at some point?
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+}
+
func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
@@ -2173,13 +2487,29 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
e := channel.New(0, 1280, test.linkAddr)
s := stack.New(opts)
- nicOpts := stack.NICOptions{Name: test.nicName}
+ nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true}
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
}
- var expectedMainAddr tcpip.AddressWithPrefix
+ // A new disabled NIC should not have any address, even if auto generation
+ // was enabled.
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+ // Enabling the NIC should attempt auto-generation of a link-local
+ // address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ var expectedMainAddr tcpip.AddressWithPrefix
if test.shouldGen {
expectedMainAddr = tcpip.AddressWithPrefix{
Address: test.expectedAddr,
@@ -2609,6 +2939,111 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
}
+func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should add the IPv4 broadcast address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 1 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
+ }
+ want := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
+ if allNICAddrs[0] != want {
+ t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
+ }
+
+ // Disabling the NIC should remove the IPv4 broadcast address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+}
+
+func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
+ // not been enabled yet.
+ isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+}
+
// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
// was disabled have DAD performed on them when the NIC is enabled.
func TestDoDADWhenNICEnabled(t *testing.T) {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index d686e6eb8..778c0a4d6 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// Close implements stack.TransportEndpoint.Close.
-func (ep *multiPortEndpoint) Close() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Close()
- }
-}
-
-// Wait implements stack.TransportEndpoint.Wait.
-func (ep *multiPortEndpoint) Wait() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Wait()
- }
-}
-
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 869c69a6d..5d1da2f8b 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
+func (f *fakeTransportEndpoint) Abort() {
+ f.Close()
+}
+
func (f *fakeTransportEndpoint) Close() {
f.route.Release()
}
@@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
-func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Abort implements TransportProtocol.Abort.
+func (*fakeTransportProtocol) Abort() {}
+
+// Close implements tcpip.Endpoint.Close.
+func (*fakeTransportProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeTransportProtocol) Wait() {}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 9ca39ce40..3dc5d87d6 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -323,11 +323,11 @@ type ControlMessages struct {
// TOS is the IPv4 type of service of the associated packet.
TOS uint8
- // HasTClass indicates whether Tclass is valid/set.
+ // HasTClass indicates whether TClass is valid/set.
HasTClass bool
- // Tclass is the IPv6 traffic class of the associated packet.
- TClass int32
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
// HasIPPacketInfo indicates whether PacketInfo is set.
HasIPPacketInfo bool
@@ -341,9 +341,15 @@ type ControlMessages struct {
// networking stack.
type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources
- // associated with it.
+ // associated with it. Close initiates the teardown process, the
+ // Endpoint may not be fully closed when Close returns.
Close()
+ // Abort initiates an expedited endpoint teardown. As compared to
+ // Close, Abort prioritizes closing the Endpoint quickly over cleanly.
+ // Abort is best effort; implementing Abort with Close is acceptable.
+ Abort()
+
// Read reads data from the endpoint and optionally returns the sender.
//
// This method does not block if there is no data pending. It will also
@@ -502,9 +508,13 @@ type WriteOptions struct {
type SockOptBool int
const (
+ // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the
+ // IPV6_TCLASS ancillary message is passed with incoming packets.
+ ReceiveTClassOption SockOptBool = iota
+
// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
// ancillary message is passed with incoming packets.
- ReceiveTOSOption SockOptBool = iota
+ ReceiveTOSOption
// V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
@@ -514,6 +524,9 @@ const (
// if more inforamtion is provided with incoming packets such
// as interface index and address.
ReceiveIPPacketInfoOption
+
+ // TODO(b/146901447): convert existing bool socket options to be handled via
+ // Get/SetSockOptBool
)
// SockOptInt represents socket options which values have the int type.
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 48764b978..2f98a996f 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -25,6 +25,8 @@ import (
)
// StdClock implements Clock with the time package.
+//
+// +stateify savable
type StdClock struct{}
var _ Clock = (*StdClock)(nil)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 42afb3f5b..426da1ee6 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 9ce500e80..113d92901 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index fc5bc69fa..5722815e9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
return ep, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (ep *endpoint) Close() {
ep.mu.Lock()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ee9c4c58b..2ef5fac76 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (e *endpoint) Close() {
e.mu.Lock()
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 08afb7c17..13e383ffc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.Close()
+ // Wake up any waiters. This is strictly not required normally
+ // as a socket that was never accepted can't really have any
+ // registered waiters except when stack.Wait() is called which
+ // waits for all registered endpoints to stop and expects an
+ // EventHUp.
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
@@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Unlock()
// Notify waiters that the endpoint is shutdown.
- e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
s := sleep.Sleeper{}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5c5397823..7730e6445 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.snd.updateMaxPayloadSize(mtu, count)
}
- if n&notifyReset != 0 {
+ if n&notifyReset != 0 || n&notifyAbort != 0 {
return tcpip.ErrConnectionAborted
}
@@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 {
+ if n&notifyClose != 0 || n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index e18012ac0..d792b07d6 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -68,17 +68,28 @@ func (q *epQueue) empty() bool {
type processor struct {
epQ epQueue
newEndpointWaker sleep.Waker
+ closeWaker sleep.Waker
id int
+ wg sync.WaitGroup
}
func newProcessor(id int) *processor {
p := &processor{
id: id,
}
+ p.wg.Add(1)
go p.handleSegments()
return p
}
+func (p *processor) close() {
+ p.closeWaker.Assert()
+}
+
+func (p *processor) wait() {
+ p.wg.Wait()
+}
+
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
@@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) {
func (p *processor) handleSegments() {
const newEndpointWaker = 1
+ const closeWaker = 2
s := sleep.Sleeper{}
s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ s.AddWaker(&p.closeWaker, closeWaker)
defer s.Done()
for {
- s.Fetch(true)
+ id, ok := s.Fetch(true)
+ if ok && id == closeWaker {
+ p.wg.Done()
+ return
+ }
for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
if ep.segmentQueue.empty() {
continue
@@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher {
}
}
+func (d *dispatcher) close() {
+ for _, p := range d.processors {
+ p.close()
+ }
+}
+
+func (d *dispatcher) wait() {
+ for _, p := range d.processors {
+ p.wait()
+ }
+}
+
func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index f2be0e651..f1ad19dac 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -121,6 +121,8 @@ const (
notifyDrain
notifyReset
notifyResetByPeer
+ // notifyAbort is a request for an expedited teardown.
+ notifyAbort
notifyKeepaliveChanged
notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a
@@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
}
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ // The abort notification is not processed synchronously, so no
+ // synchronization is needed.
+ //
+ // If the endpoint becomes connected after this check, we still close
+ // the endpoint. This worst case results in a slower abort.
+ //
+ // If the endpoint disconnected after the check, nothing needs to be
+ // done, so sending a notification which will potentially be ignored is
+ // fine.
+ if e.EndpointState().connected() {
+ e.notifyProtocolGoroutine(notifyAbort)
+ return
+ }
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources associated
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
@@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() {
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e)
- if !e.workerRunning {
+ switch e.EndpointState() {
+ // Sockets in StateSynRecv state(passive connections) are closed when
+ // the handshake fails or if the listening socket is closed while
+ // handshake was in progress. In such cases the handshake goroutine
+ // is already gone by the time Close is called and we need to cleanup
+ // here.
+ case StateInitial, StateBound, StateSynRecv:
e.cleanupLocked()
- } else {
+ e.setEndpointState(StateClose)
+ case StateError, StateClose:
+ // do nothing.
+ default:
e.workerCleanup = true
e.notifyProtocolGoroutine(notifyClose)
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 958c06fa7..73098d904 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -194,7 +194,7 @@ func replyWithReset(s *segment) {
sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
-// SetOption implements TransportProtocol.SetOption.
+// SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
case SACKEnabled:
@@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
}
-// Option implements TransportProtocol.Option.
+// Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
@@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements stack.TransportProtocol.Close.
+func (p *protocol) Close() {
+ p.dispatcher.close()
+}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (p *protocol) Wait() {
+ p.dispatcher.wait()
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index cc118c993..5b2b16afa 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -543,8 +543,9 @@ func TestCurrentConnectedIncrement(t *testing.T) {
),
)
- // Wait for the TIME-WAIT state to transition to CLOSED.
- time.Sleep(1 * time.Second)
+ // Wait for a little more than the TIME-WAIT duration for the socket to
+ // transition to CLOSED state.
+ time.Sleep(1200 * time.Millisecond)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 3fe91cac2..1c6a600b8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -32,7 +32,8 @@ type udpPacket struct {
packetInfo tcpip.IPPacketInfo
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- tos uint8
+ // tos stores either the receiveTOS or receiveTClass value.
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -119,6 +120,10 @@ type endpoint struct {
// as ancillary data to ControlMessages on Read.
receiveTOS bool
+ // receiveTClass determines if the incoming IPv6 TClass header field is
+ // passed as ancillary data to ControlMessages on Read.
+ receiveTClass bool
+
// receiveIPPacketInfo determines if the packet info is returned by Read.
receiveIPPacketInfo bool
@@ -181,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -258,13 +268,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
}
e.mu.RLock()
receiveTOS := e.receiveTOS
+ receiveTClass := e.receiveTClass
receiveIPPacketInfo := e.receiveIPPacketInfo
e.mu.RUnlock()
if receiveTOS {
cm.HasTOS = true
cm.TOS = p.tos
}
-
+ if receiveTClass {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
if receiveIPPacketInfo {
cm.HasIPPacketInfo = true
cm.PacketInfo = p.packetInfo
@@ -490,6 +505,17 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrNotSupported
+ }
+
+ e.mu.Lock()
+ e.receiveTClass = v
+ e.mu.Unlock()
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -709,6 +735,17 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrNotSupported
+ }
+
+ e.mu.RLock()
+ v := e.receiveTClass
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -1273,6 +1310,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
packet.packetInfo.LocalAddr = r.LocalAddress
packet.packetInfo.DestinationAddr = r.RemoteAddress
packet.packetInfo.NIC = r.NICID()
+ case header.IPv6ProtocolNumber:
+ packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
packet.timestamp = e.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 259c3072a..8df089d22 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index f0ff3fe71..34b7c2360 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -409,6 +409,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
@@ -1336,7 +1337,7 @@ func TestSetTTL(t *testing.T) {
}
}
-func TestTOSV4(t *testing.T) {
+func TestSetTOS(t *testing.T) {
for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1347,23 +1348,23 @@ func TestTOSV4(t *testing.T) {
const tos = testTOS
var v tcpip.IPv4TOSOption
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
}
if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err)
}
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
if want := tcpip.IPv4TOSOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
}
testWrite(c, flow, checker.TOS(tos, 0))
@@ -1371,7 +1372,7 @@ func TestTOSV4(t *testing.T) {
}
}
-func TestTOSV6(t *testing.T) {
+func TestSetTClass(t *testing.T) {
for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1379,71 +1380,92 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = testTOS
+ const tClass = testTOS
var v tcpip.IPv6TrafficClassOption
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt failed: %s", err)
+ if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil {
+ c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err)
}
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
- if want := tcpip.IPv6TrafficClassOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if want := tcpip.IPv6TrafficClassOption(tClass); v != want {
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
}
- testWrite(c, flow, checker.TOS(tos, 0))
+ // The header getter for TClass is called TOS, so use that checker.
+ testWrite(c, flow, checker.TOS(tClass, 0))
})
}
}
-func TestReceiveTOSV4(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, broadcast} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
+func TestReceiveTosTClass(t *testing.T) {
+ testCases := []struct {
+ name string
+ getReceiveOption tcpip.SockOptBool
+ tests []testFlow
+ }{
+ {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
+ {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ }
+ for _, testCase := range testCases {
+ for _, flow := range testCase.tests {
+ t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- c.createEndpointForFlow(flow)
+ c.createEndpointForFlow(flow)
+ option := testCase.getReceiveOption
+ name := testCase.name
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
- }
- // Test for expected default value.
- if v != false {
- c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false)
- }
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
+ }
- want := true
- if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil {
- c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err)
- }
+ want := true
+ if err := c.ep.SetSockOptBool(option, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
+ }
- got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
- }
- if got != want {
- c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want)
- }
+ got, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ }
- // Verify that the correct received TOS is handed through as
- // ancillary data to the ControlMessages struct.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
- testRead(c, flow, checker.ReceiveTOS(testTOS))
- })
+ if got != want {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
+ }
+
+ // Verify that the correct received TOS or TClass is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+ switch option {
+ case tcpip.ReceiveTClassOption:
+ testRead(c, flow, checker.ReceiveTClass(testTOS))
+ case tcpip.ReceiveTOSOption:
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ default:
+ t.Fatalf("unknown test variant: %s", name)
+ }
+ })
+ }
}
}
diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD
index ff8b9e91a..6c9ada9c7 100644
--- a/pkg/usermem/BUILD
+++ b/pkg/usermem/BUILD
@@ -25,7 +25,6 @@ go_library(
"bytes_io_unsafe.go",
"usermem.go",
"usermem_arm64.go",
- "usermem_unsafe.go",
"usermem_x86.go",
],
visibility = ["//:sandbox"],
@@ -33,6 +32,7 @@ go_library(
"//pkg/atomicbitops",
"//pkg/binary",
"//pkg/context",
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/safemem",
"//pkg/syserror",
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 71fd4e155..d2f4403b0 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -251,7 +252,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
}
end, ok := addr.AddLength(uint64(readlen))
if !ok {
- return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
+ return gohacks.StringFromImmutableBytes(buf[:done]), syserror.EFAULT
}
// Shorten the read to avoid crossing page boundaries, since faulting
// in a page unnecessarily is expensive. This also ensures that partial
@@ -272,16 +273,16 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
// Look for the terminating zero byte, which may have occurred before
// hitting err.
if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
+ return gohacks.StringFromImmutableBytes(buf[:done+i]), nil
}
done += n
if err != nil {
- return stringFromImmutableBytes(buf[:done]), err
+ return gohacks.StringFromImmutableBytes(buf[:done]), err
}
addr = end
}
- return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG
+ return gohacks.StringFromImmutableBytes(buf), syserror.ENAMETOOLONG
}
// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The