summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD4
-rw-r--r--pkg/abi/linux/aio.go3
-rw-r--r--pkg/abi/linux/bpf.go1
-rw-r--r--pkg/abi/linux/capability.go4
-rw-r--r--pkg/abi/linux/fcntl.go4
-rw-r--r--pkg/abi/linux/ioctl.go6
-rw-r--r--pkg/abi/linux/ipc.go2
-rw-r--r--pkg/abi/linux/linux.go9
-rw-r--r--pkg/abi/linux/netfilter.go22
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go22
-rw-r--r--pkg/abi/linux/netlink.go2
-rw-r--r--pkg/abi/linux/poll.go2
-rw-r--r--pkg/abi/linux/rusage.go2
-rw-r--r--pkg/abi/linux/seccomp.go23
-rw-r--r--pkg/abi/linux/sem.go4
-rw-r--r--pkg/abi/linux/shm.go6
-rw-r--r--pkg/abi/linux/signal.go2
-rw-r--r--pkg/abi/linux/socket.go13
-rw-r--r--pkg/abi/linux/time.go16
-rw-r--r--pkg/abi/linux/utsname.go2
-rw-r--r--pkg/amutex/BUILD5
-rw-r--r--pkg/amutex/amutex.go30
-rw-r--r--pkg/bpf/decoder.go13
-rw-r--r--pkg/bpf/decoder_test.go4
-rw-r--r--pkg/bpf/program_builder.go23
-rw-r--r--pkg/bpf/program_builder_test.go42
-rw-r--r--pkg/context/BUILD1
-rw-r--r--pkg/context/context.go57
-rw-r--r--pkg/lisafs/README.md363
-rw-r--r--pkg/marshal/BUILD17
-rw-r--r--pkg/marshal/marshal.go184
-rw-r--r--pkg/marshal/marshal_impl_util.go78
-rw-r--r--pkg/marshal/primitive/BUILD18
-rw-r--r--pkg/marshal/primitive/primitive.go247
-rw-r--r--pkg/safemem/BUILD4
-rw-r--r--pkg/seccomp/seccomp.go177
-rw-r--r--pkg/seccomp/seccomp_rules.go75
-rw-r--r--pkg/seccomp/seccomp_test.go603
-rw-r--r--pkg/seccomp/seccomp_test_victim.go2
-rw-r--r--pkg/sentry/arch/BUILD3
-rw-r--r--pkg/sentry/arch/arch.go7
-rw-r--r--pkg/sentry/arch/arch_amd64.go12
-rw-r--r--pkg/sentry/arch/arch_arm64.go12
-rw-r--r--pkg/sentry/arch/signal_act.go2
-rw-r--r--pkg/sentry/arch/signal_stack.go2
-rw-r--r--pkg/sentry/fs/host/BUILD2
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go4
-rw-r--r--pkg/sentry/fs/host/tty.go2
-rw-r--r--pkg/sentry/fs/proc/task.go2
-rw-r--r--pkg/sentry/fs/tty/BUILD2
-rw-r--r--pkg/sentry/fs/tty/master.go2
-rw-r--r--pkg/sentry/fs/tty/queue.go2
-rw-r--r--pkg/sentry/fs/tty/replica.go2
-rw-r--r--pkg/sentry/fs/tty/terminal.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD4
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/queue.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/replica.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/terminal.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD4
-rw-r--r--pkg/sentry/fsimpl/fuse/connection.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go8
-rw-r--r--pkg/sentry/fsimpl/host/BUILD2
-rw-r--r--pkg/sentry/fsimpl/host/socket_unsafe.go4
-rw-r--r--pkg/sentry/fsimpl/host/tty.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go19
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go78
-rw-r--r--pkg/sentry/fsimpl/overlay/non_directory.go102
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go34
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go13
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go2
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD2
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go19
-rw-r--r--pkg/sentry/kernel/BUILD4
-rw-r--r--pkg/sentry/kernel/auth/BUILD1
-rw-r--r--pkg/sentry/kernel/auth/id.go4
-rw-r--r--pkg/sentry/kernel/fd_table.go143
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go60
-rw-r--r--pkg/sentry/kernel/ptrace.go23
-rw-r--r--pkg/sentry/kernel/ptrace_amd64.go2
-rw-r--r--pkg/sentry/kernel/syscalls.go10
-rw-r--r--pkg/sentry/kernel/task_clone.go6
-rw-r--r--pkg/sentry/kernel/task_exit.go3
-rw-r--r--pkg/sentry/kernel/task_futex.go7
-rw-r--r--pkg/sentry/kernel/task_run.go2
-rw-r--r--pkg/sentry/kernel/task_syscall.go7
-rw-r--r--pkg/sentry/kernel/task_usermem.go43
-rw-r--r--pkg/sentry/kernel/threads.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go10
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s4
-rw-r--r--pkg/sentry/socket/BUILD2
-rw-r--r--pkg/sentry/socket/hostinet/BUILD4
-rw-r--r--pkg/sentry/socket/hostinet/socket.go4
-rw-r--r--pkg/sentry/socket/netlink/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/socket.go4
-rw-r--r--pkg/sentry/socket/netstack/BUILD4
-rw-r--r--pkg/sentry/socket/netstack/netstack.go4
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go4
-rw-r--r--pkg/sentry/socket/socket.go2
-rw-r--r--pkg/sentry/socket/unix/BUILD2
-rw-r--r--pkg/sentry/socket/unix/unix.go2
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go2
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/epoll.go4
-rw-r--r--pkg/sentry/strace/socket.go23
-rw-r--r--pkg/sentry/strace/strace.go32
-rw-r--r--pkg/sentry/syscalls/linux/BUILD5
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go33
-rw-r--r--pkg/sentry/syscalls/linux/sys_capability.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go26
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go14
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go31
-rw-r--r--pkg/sentry/syscalls/linux/sys_identity.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_rlimit.go19
-rw-r--r--pkg/sentry/syscalls/linux/sys_rusage.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_sched.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_seccomp.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go65
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go11
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_sysinfo.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go15
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_timer.go101
-rw-r--r--pkg/sentry/syscalls/linux/sys_timerfd.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls_amd64.go13
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go27
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go49
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go6
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go9
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go24
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go65
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go17
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/timerfd.go6
-rw-r--r--pkg/sentry/vfs/epoll.go16
-rw-r--r--pkg/syserror/syserror_test.go20
-rw-r--r--pkg/tcpip/header/ipv4.go10
-rw-r--r--pkg/tcpip/header/ipv6.go4
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go18
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go500
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go10
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go137
-rw-r--r--pkg/tcpip/network/testutil/BUILD17
-rw-r--r--pkg/tcpip/network/testutil/testutil.go92
-rw-r--r--pkg/tcpip/stack/route.go23
-rw-r--r--pkg/tcpip/stack/stack.go4
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go18
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go120
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go7
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go68
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go32
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go39
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go6
168 files changed, 3597 insertions, 1102 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index b5c5cc20b..cdcaa8c73 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -74,9 +74,9 @@ go_library(
"//pkg/abi",
"//pkg/binary",
"//pkg/bits",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go
index 86ee3f8b5..5fc099892 100644
--- a/pkg/abi/linux/aio.go
+++ b/pkg/abi/linux/aio.go
@@ -42,6 +42,8 @@ const (
//
// The priority field is currently ignored in the implementation below. Also
// note that the IOCB_FLAG_RESFD feature is not supported.
+//
+// +marshal
type IOCallback struct {
Data uint64
Key uint32
@@ -64,6 +66,7 @@ type IOCallback struct {
// IOEvent describes an I/O result.
//
+// +marshal
// +stateify savable
type IOEvent struct {
Data uint64
diff --git a/pkg/abi/linux/bpf.go b/pkg/abi/linux/bpf.go
index aa3d3ce70..9422fcf69 100644
--- a/pkg/abi/linux/bpf.go
+++ b/pkg/abi/linux/bpf.go
@@ -16,6 +16,7 @@ package linux
// BPFInstruction is a raw BPF virtual machine instruction.
//
+// +marshal slice:BPFInstructionSlice
// +stateify savable
type BPFInstruction struct {
// OpCode is the operation to execute.
diff --git a/pkg/abi/linux/capability.go b/pkg/abi/linux/capability.go
index 965f74663..afd16cc27 100644
--- a/pkg/abi/linux/capability.go
+++ b/pkg/abi/linux/capability.go
@@ -177,12 +177,16 @@ const (
)
// CapUserHeader is equivalent to Linux's cap_user_header_t.
+//
+// +marshal
type CapUserHeader struct {
Version uint32
Pid int32
}
// CapUserData is equivalent to Linux's cap_user_data_t.
+//
+// +marshal slice:CapUserDataSlice
type CapUserData struct {
Effective uint32
Permitted uint32
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index 9242e80a5..cc3571fad 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -45,6 +45,8 @@ const (
)
// Flock is the lock structure for F_SETLK.
+//
+// +marshal
type Flock struct {
Type int16
Whence int16
@@ -63,6 +65,8 @@ const (
)
// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX.
+//
+// +marshal
type FOwnerEx struct {
Type int32
PID int32
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index a4fe7501d..3356a2b4a 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -113,6 +113,12 @@ const (
_IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS
)
+// Constants from uapi/linux/fs.h.
+const (
+ FS_IOC_GETFLAGS = 2147771905
+ FS_VERITY_FL = 1048576
+)
+
// Constants from uapi/linux/fsverity.h.
const (
FS_IOC_ENABLE_VERITY = 1082156677
diff --git a/pkg/abi/linux/ipc.go b/pkg/abi/linux/ipc.go
index 22acd2d43..c6e65df62 100644
--- a/pkg/abi/linux/ipc.go
+++ b/pkg/abi/linux/ipc.go
@@ -37,6 +37,8 @@ const IPC_PRIVATE = 0
// features like 32-bit UIDs.
// IPCPerm is equivalent to struct ipc64_perm.
+//
+// +marshal
type IPCPerm struct {
Key uint32
UID uint32
diff --git a/pkg/abi/linux/linux.go b/pkg/abi/linux/linux.go
index 281acdbde..3b4abece1 100644
--- a/pkg/abi/linux/linux.go
+++ b/pkg/abi/linux/linux.go
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package linux contains the constants and types needed to interface with a Linux kernel.
+// Package linux contains the constants and types needed to interface with a
+// Linux kernel.
package linux
// NumSoftIRQ is the number of software IRQs, exposed via /proc/stat.
@@ -21,6 +22,8 @@ package linux
const NumSoftIRQ = 10
// Sysinfo is the structure provided by sysinfo on linux versions > 2.3.48.
+//
+// +marshal
type Sysinfo struct {
Uptime int64
Loads [3]uint64
@@ -34,6 +37,6 @@ type Sysinfo struct {
_ [6]byte // Pad Procs to 64bits.
TotalHigh uint64
FreeHigh uint64
- Unit uint32
- /* The _f field in the glibc version of Sysinfo has size 0 on AMD64 */
+ Unit uint32 `marshal:"unaligned"` // Struct ends mid-64-bit-word.
+ // The _f field in the glibc version of Sysinfo has size 0 on AMD64.
}
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 91e35366f..1c5b34711 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -17,9 +17,9 @@ package linux
import (
"io"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// This file contains structures required to support netfilter, specifically
@@ -450,9 +450,9 @@ func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) {
}
// CopyIn implements marshal.Marshallable.CopyIn.
-func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
- length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+func (ke *KernelIPTGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+ buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
// Unmarshal unconditionally. If we had a short copy-in, this results in a
// partially unmarshalled struct.
ke.UnmarshalBytes(buf) // escapes: fallback.
@@ -460,21 +460,21 @@ func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int
}
// CopyOut implements marshal.Marshallable.CopyOut.
-func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+func (ke *KernelIPTGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
// Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
// back to MarshalBytes.
- return task.CopyOutBytes(addr, ke.marshalAll(task))
+ return cc.CopyOutBytes(addr, ke.marshalAll(cc))
}
// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+func (ke *KernelIPTGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
// Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
// back to MarshalBytes.
- return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+ return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit])
}
-func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte {
- buf := task.CopyScratchBuffer(ke.SizeBytes())
+func (ke *KernelIPTGetEntries) marshalAll(cc marshal.CopyContext) []byte {
+ buf := cc.CopyScratchBuffer(ke.SizeBytes())
ke.MarshalBytes(buf)
return buf
}
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index f6117024c..a137940b6 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -17,9 +17,9 @@ package linux
import (
"io"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// This file contains structures required to support IPv6 netfilter and
@@ -128,9 +128,9 @@ func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) {
}
// CopyIn implements marshal.Marshallable.CopyIn.
-func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
- length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+func (ke *KernelIP6TGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+ buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := cc.CopyInBytes(addr, buf) // escapes: okay.
// Unmarshal unconditionally. If we had a short copy-in, this results
// in a partially unmarshalled struct.
ke.UnmarshalBytes(buf) // escapes: fallback.
@@ -138,21 +138,21 @@ func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (in
}
// CopyOut implements marshal.Marshallable.CopyOut.
-func (ke *KernelIP6TGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+func (ke *KernelIP6TGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
// Type KernelIP6TGetEntries doesn't have a packed layout in memory,
// fall back to MarshalBytes.
- return task.CopyOutBytes(addr, ke.marshalAll(task))
+ return cc.CopyOutBytes(addr, ke.marshalAll(cc))
}
// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (ke *KernelIP6TGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+func (ke *KernelIP6TGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
// Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall
// back to MarshalBytes.
- return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+ return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit])
}
-func (ke *KernelIP6TGetEntries) marshalAll(task marshal.Task) []byte {
- buf := task.CopyScratchBuffer(ke.SizeBytes())
+func (ke *KernelIP6TGetEntries) marshalAll(cc marshal.CopyContext) []byte {
+ buf := cc.CopyScratchBuffer(ke.SizeBytes())
ke.MarshalBytes(buf)
return buf
}
diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go
index 0ba086c76..b41f94a69 100644
--- a/pkg/abi/linux/netlink.go
+++ b/pkg/abi/linux/netlink.go
@@ -40,6 +40,8 @@ const (
)
// SockAddrNetlink is struct sockaddr_nl, from uapi/linux/netlink.h.
+//
+// +marshal
type SockAddrNetlink struct {
Family uint16
_ uint16
diff --git a/pkg/abi/linux/poll.go b/pkg/abi/linux/poll.go
index c04d26e4c..3443a5768 100644
--- a/pkg/abi/linux/poll.go
+++ b/pkg/abi/linux/poll.go
@@ -15,6 +15,8 @@
package linux
// PollFD is struct pollfd, used by poll(2)/ppoll(2), from uapi/asm-generic/poll.h.
+//
+// +marshal slice:PollFDSlice
type PollFD struct {
FD int32
Events int16
diff --git a/pkg/abi/linux/rusage.go b/pkg/abi/linux/rusage.go
index d8302dc85..e29d0ac7e 100644
--- a/pkg/abi/linux/rusage.go
+++ b/pkg/abi/linux/rusage.go
@@ -26,6 +26,8 @@ const (
)
// Rusage represents the Linux struct rusage.
+//
+// +marshal
type Rusage struct {
UTime Timeval
STime Timeval
diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go
index d0607e256..b07cafe12 100644
--- a/pkg/abi/linux/seccomp.go
+++ b/pkg/abi/linux/seccomp.go
@@ -34,11 +34,11 @@ type BPFAction uint32
const (
SECCOMP_RET_KILL_PROCESS BPFAction = 0x80000000
- SECCOMP_RET_KILL_THREAD = 0x00000000
- SECCOMP_RET_TRAP = 0x00030000
- SECCOMP_RET_ERRNO = 0x00050000
- SECCOMP_RET_TRACE = 0x7ff00000
- SECCOMP_RET_ALLOW = 0x7fff0000
+ SECCOMP_RET_KILL_THREAD BPFAction = 0x00000000
+ SECCOMP_RET_TRAP BPFAction = 0x00030000
+ SECCOMP_RET_ERRNO BPFAction = 0x00050000
+ SECCOMP_RET_TRACE BPFAction = 0x7ff00000
+ SECCOMP_RET_ALLOW BPFAction = 0x7fff0000
)
func (a BPFAction) String() string {
@@ -64,6 +64,19 @@ func (a BPFAction) Data() uint16 {
return uint16(a & SECCOMP_RET_DATA)
}
+// WithReturnCode sets the lower 16 bits of the SECCOMP_RET_ERRNO or
+// SECCOMP_RET_TRACE actions to the provided return code, overwriting the previous
+// action, and returns a new BPFAction. If not SECCOMP_RET_ERRNO or
+// SECCOMP_RET_TRACE then this panics.
+func (a BPFAction) WithReturnCode(code uint16) BPFAction {
+ // mask out the previous return value
+ baseAction := a & SECCOMP_RET_ACTION_FULL
+ if baseAction == SECCOMP_RET_ERRNO || baseAction == SECCOMP_RET_TRACE {
+ return BPFAction(uint32(baseAction) | uint32(code))
+ }
+ panic("WithReturnCode only valid for SECCOMP_RET_ERRNO and SECCOMP_RET_TRACE")
+}
+
// SockFprog is sock_fprog taken from <linux/filter.h>.
type SockFprog struct {
Len uint16
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
index de422c519..487a626cc 100644
--- a/pkg/abi/linux/sem.go
+++ b/pkg/abi/linux/sem.go
@@ -35,6 +35,8 @@ const (
const SEM_UNDO = 0x1000
// SemidDS is equivalent to struct semid64_ds.
+//
+// +marshal
type SemidDS struct {
SemPerm IPCPerm
SemOTime TimeT
@@ -45,6 +47,8 @@ type SemidDS struct {
}
// Sembuf is equivalent to struct sembuf.
+//
+// +marshal slice:SembufSlice
type Sembuf struct {
SemNum uint16
SemOp int16
diff --git a/pkg/abi/linux/shm.go b/pkg/abi/linux/shm.go
index e45aadb10..274b1e847 100644
--- a/pkg/abi/linux/shm.go
+++ b/pkg/abi/linux/shm.go
@@ -51,6 +51,8 @@ const (
// ShmidDS is equivalent to struct shmid64_ds. Source:
// include/uapi/asm-generic/shmbuf.h
+//
+// +marshal
type ShmidDS struct {
ShmPerm IPCPerm
ShmSegsz uint64
@@ -66,6 +68,8 @@ type ShmidDS struct {
}
// ShmParams is equivalent to struct shminfo. Source: include/uapi/linux/shm.h
+//
+// +marshal
type ShmParams struct {
ShmMax uint64
ShmMin uint64
@@ -75,6 +79,8 @@ type ShmParams struct {
}
// ShmInfo is equivalent to struct shm_info. Source: include/uapi/linux/shm.h
+//
+// +marshal
type ShmInfo struct {
UsedIDs int32 // Number of currently existing segments.
_ [4]byte
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index 1c330e763..6ca57ffbb 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -214,6 +214,8 @@ const (
)
// Sigevent represents struct sigevent.
+//
+// +marshal
type Sigevent struct {
Value uint64 // union sigval {int, void*}
Signo int32
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index e37c8727d..d156d41e4 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -14,7 +14,10 @@
package linux
-import "gvisor.dev/gvisor/pkg/binary"
+import (
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
+)
// Address families, from linux/socket.h.
const (
@@ -265,6 +268,8 @@ type InetMulticastRequestWithNIC struct {
type Inet6Addr [16]byte
// SockAddrInet6 is struct sockaddr_in6, from uapi/linux/in6.h.
+//
+// +marshal
type SockAddrInet6 struct {
Family uint16
Port uint16
@@ -274,6 +279,8 @@ type SockAddrInet6 struct {
}
// SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h.
+//
+// +marshal
type SockAddrLink struct {
Family uint16
Protocol uint16
@@ -290,6 +297,8 @@ type SockAddrLink struct {
const UnixPathMax = 108
// SockAddrUnix is struct sockaddr_un, from uapi/linux/un.h.
+//
+// +marshal
type SockAddrUnix struct {
Family uint16
Path [UnixPathMax]int8
@@ -299,6 +308,8 @@ type SockAddrUnix struct {
// equivalent to struct sockaddr. SockAddr ensures that a well-defined set of
// types can be used as socket addresses.
type SockAddr interface {
+ marshal.Marshallable
+
// implementsSockAddr exists purely to allow a type to indicate that they
// implement this interface. This method is a no-op and shouldn't be called.
implementsSockAddr()
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index e6860ed49..206f5af7e 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -93,6 +93,8 @@ const (
const maxSecInDuration = math.MaxInt64 / int64(time.Second)
// TimeT represents time_t in <time.h>. It represents time in seconds.
+//
+// +marshal
type TimeT int64
// NsecToTimeT translates nanoseconds to TimeT (seconds).
@@ -102,7 +104,7 @@ func NsecToTimeT(nsec int64) TimeT {
// Timespec represents struct timespec in <time.h>.
//
-// +marshal
+// +marshal slice:TimespecSlice
type Timespec struct {
Sec int64
Nsec int64
@@ -158,7 +160,7 @@ const SizeOfTimeval = 16
// Timeval represents struct timeval in <time.h>.
//
-// +marshal
+// +marshal slice:TimevalSlice
type Timeval struct {
Sec int64
Usec int64
@@ -196,6 +198,8 @@ func DurationToTimeval(dur time.Duration) Timeval {
}
// Itimerspec represents struct itimerspec in <time.h>.
+//
+// +marshal
type Itimerspec struct {
Interval Timespec
Value Timespec
@@ -206,12 +210,16 @@ type Itimerspec struct {
// struct timeval it_interval; /* next value */
// struct timeval it_value; /* current value */
// };
+//
+// +marshal
type ItimerVal struct {
Interval Timeval
Value Timeval
}
// ClockT represents type clock_t.
+//
+// +marshal
type ClockT int64
// ClockTFromDuration converts time.Duration to clock_t.
@@ -220,6 +228,8 @@ func ClockTFromDuration(d time.Duration) ClockT {
}
// Tms represents struct tms, used by times(2).
+//
+// +marshal
type Tms struct {
UTime ClockT
STime ClockT
@@ -229,6 +239,8 @@ type Tms struct {
// TimerID represents type timer_t, which identifies a POSIX per-process
// interval timer.
+//
+// +marshal
type TimerID int32
// StatxTimestamp represents struct statx_timestamp.
diff --git a/pkg/abi/linux/utsname.go b/pkg/abi/linux/utsname.go
index 60f220a67..cb7c95437 100644
--- a/pkg/abi/linux/utsname.go
+++ b/pkg/abi/linux/utsname.go
@@ -26,6 +26,8 @@ const (
)
// UtsName represents struct utsname, the struct returned by uname(2).
+//
+// +marshal
type UtsName struct {
Sysname [UTSLen + 1]byte
Nodename [UTSLen + 1]byte
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index ffc918846..bd3a5cce9 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -6,7 +6,10 @@ go_library(
name = "amutex",
srcs = ["amutex.go"],
visibility = ["//:sandbox"],
- deps = ["//pkg/syserror"],
+ deps = [
+ "//pkg/context",
+ "//pkg/syserror",
+ ],
)
go_test(
diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go
index a078a31db..d7acc1d9f 100644
--- a/pkg/amutex/amutex.go
+++ b/pkg/amutex/amutex.go
@@ -19,41 +19,17 @@ package amutex
import (
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/syserror"
)
// Sleeper must be implemented by users of the abortable mutex to allow for
// cancellation of waits.
-type Sleeper interface {
- // SleepStart is called by the AbortableMutex.Lock() function when the
- // mutex is contended and the goroutine is about to sleep.
- //
- // A channel can be returned that causes the sleep to be canceled if
- // it's readable. If no cancellation is desired, nil can be returned.
- SleepStart() <-chan struct{}
-
- // SleepFinish is called by AbortableMutex.Lock() once a contended mutex
- // is acquired or the wait is aborted.
- SleepFinish(success bool)
-
- // Interrupted returns true if the wait is aborted.
- Interrupted() bool
-}
+type Sleeper = context.ChannelSleeper
// NoopSleeper is a stateless no-op implementation of Sleeper for anonymous
// embedding in other types that do not support cancelation.
-type NoopSleeper struct{}
-
-// SleepStart implements Sleeper.SleepStart.
-func (NoopSleeper) SleepStart() <-chan struct{} {
- return nil
-}
-
-// SleepFinish implements Sleeper.SleepFinish.
-func (NoopSleeper) SleepFinish(success bool) {}
-
-// Interrupted implements Sleeper.Interrupted.
-func (NoopSleeper) Interrupted() bool { return false }
+type NoopSleeper = context.Context
// Block blocks until either receiving from ch succeeds (in which case it
// returns nil) or sleeper is interrupted (in which case it returns
diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go
index c8ee0c3b1..069d0395d 100644
--- a/pkg/bpf/decoder.go
+++ b/pkg/bpf/decoder.go
@@ -21,10 +21,15 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// DecodeProgram translates an array of BPF instructions into text format.
-func DecodeProgram(program []linux.BPFInstruction) (string, error) {
+// DecodeProgram translates a compiled BPF program into text format.
+func DecodeProgram(p Program) (string, error) {
+ return DecodeInstructions(p.instructions)
+}
+
+// DecodeInstructions translates an array of BPF instructions into text format.
+func DecodeInstructions(instns []linux.BPFInstruction) (string, error) {
var ret bytes.Buffer
- for line, s := range program {
+ for line, s := range instns {
ret.WriteString(fmt.Sprintf("%v: ", line))
if err := decode(s, line, &ret); err != nil {
return "", err
@@ -34,7 +39,7 @@ func DecodeProgram(program []linux.BPFInstruction) (string, error) {
return ret.String(), nil
}
-// Decode translates BPF instruction into text format.
+// Decode translates a single BPF instruction into text format.
func Decode(inst linux.BPFInstruction) (string, error) {
var ret bytes.Buffer
err := decode(inst, -1, &ret)
diff --git a/pkg/bpf/decoder_test.go b/pkg/bpf/decoder_test.go
index 6a023f0c0..bb971ce21 100644
--- a/pkg/bpf/decoder_test.go
+++ b/pkg/bpf/decoder_test.go
@@ -93,7 +93,7 @@ func TestDecode(t *testing.T) {
}
}
-func TestDecodeProgram(t *testing.T) {
+func TestDecodeInstructions(t *testing.T) {
for _, test := range []struct {
name string
program []linux.BPFInstruction
@@ -126,7 +126,7 @@ func TestDecodeProgram(t *testing.T) {
program: []linux.BPFInstruction{Stmt(Ld+Abs+W, 10), Stmt(Ld+Len+Mem, 0)},
fail: true},
} {
- got, err := DecodeProgram(test.program)
+ got, err := DecodeInstructions(test.program)
if test.fail {
if err == nil {
t.Errorf("%s: Decode(...) failed, expected: 'error', got: %q", test.name, got)
diff --git a/pkg/bpf/program_builder.go b/pkg/bpf/program_builder.go
index 7992044d0..caaf99c83 100644
--- a/pkg/bpf/program_builder.go
+++ b/pkg/bpf/program_builder.go
@@ -32,13 +32,21 @@ type ProgramBuilder struct {
// Maps label names to label objects.
labels map[string]*label
+ // unusableLabels are labels that are added before being referenced in a
+ // jump. Any labels added this way cannot be referenced later in order to
+ // avoid backwards references.
+ unusableLabels map[string]bool
+
// Array of BPF instructions that makes up the program.
instructions []linux.BPFInstruction
}
// NewProgramBuilder creates a new ProgramBuilder instance.
func NewProgramBuilder() *ProgramBuilder {
- return &ProgramBuilder{labels: map[string]*label{}}
+ return &ProgramBuilder{
+ labels: map[string]*label{},
+ unusableLabels: map[string]bool{},
+ }
}
// label contains information to resolve a label to an offset.
@@ -108,9 +116,12 @@ func (b *ProgramBuilder) AddJumpLabels(code uint16, k uint32, jtLabel, jfLabel s
func (b *ProgramBuilder) AddLabel(name string) error {
l, ok := b.labels[name]
if !ok {
- // This is done to catch jump backwards cases, but it's not strictly wrong
- // to have unused labels.
- return fmt.Errorf("Adding a label that hasn't been used is not allowed: %v", name)
+ if _, ok = b.unusableLabels[name]; ok {
+ return fmt.Errorf("label %q already set", name)
+ }
+ // Mark the label as unusable. This is done to catch backwards jumps.
+ b.unusableLabels[name] = true
+ return nil
}
if l.target != -1 {
return fmt.Errorf("label %q target already set: %v", name, l.target)
@@ -141,6 +152,10 @@ func (b *ProgramBuilder) addLabelSource(labelName string, t jmpType) {
func (b *ProgramBuilder) resolveLabels() error {
for key, v := range b.labels {
+ if _, ok := b.unusableLabels[key]; ok {
+ return fmt.Errorf("backwards reference detected for label: %q", key)
+ }
+
if v.target == -1 {
return fmt.Errorf("label target not set: %v", key)
}
diff --git a/pkg/bpf/program_builder_test.go b/pkg/bpf/program_builder_test.go
index 92ca5f4c3..37f684f25 100644
--- a/pkg/bpf/program_builder_test.go
+++ b/pkg/bpf/program_builder_test.go
@@ -26,16 +26,16 @@ func validate(p *ProgramBuilder, expected []linux.BPFInstruction) error {
if err != nil {
return fmt.Errorf("Instructions() failed: %v", err)
}
- got, err := DecodeProgram(instructions)
+ got, err := DecodeInstructions(instructions)
if err != nil {
- return fmt.Errorf("DecodeProgram('instructions') failed: %v", err)
+ return fmt.Errorf("DecodeInstructions('instructions') failed: %v", err)
}
- expectedDecoded, err := DecodeProgram(expected)
+ expectedDecoded, err := DecodeInstructions(expected)
if err != nil {
- return fmt.Errorf("DecodeProgram('expected') failed: %v", err)
+ return fmt.Errorf("DecodeInstructions('expected') failed: %v", err)
}
if got != expectedDecoded {
- return fmt.Errorf("DecodeProgram() failed, expected: %q, got: %q", expectedDecoded, got)
+ return fmt.Errorf("DecodeInstructions() failed, expected: %q, got: %q", expectedDecoded, got)
}
return nil
}
@@ -124,10 +124,38 @@ func TestProgramBuilderLabelWithNoInstruction(t *testing.T) {
}
}
+// TestProgramBuilderUnusedLabel tests that adding an unused label doesn't
+// cause program generation to fail.
func TestProgramBuilderUnusedLabel(t *testing.T) {
p := NewProgramBuilder()
- if err := p.AddLabel("unused"); err == nil {
- t.Errorf("AddLabel(unused) should have failed")
+ p.AddStmt(Ld+Abs+W, 10)
+ p.AddJump(Jmp+Ja, 10, 0, 0)
+
+ expected := []linux.BPFInstruction{
+ Stmt(Ld+Abs+W, 10),
+ Jump(Jmp+Ja, 10, 0, 0),
+ }
+
+ if err := p.AddLabel("unused"); err != nil {
+ t.Errorf("AddLabel(unused) should have succeeded")
+ }
+
+ if err := validate(p, expected); err != nil {
+ t.Errorf("Validate() failed: %v", err)
+ }
+}
+
+// TestProgramBuilderBackwardsReference tests that including a backwards
+// reference to a label in a program causes a failure.
+func TestProgramBuilderBackwardsReference(t *testing.T) {
+ p := NewProgramBuilder()
+ if err := p.AddLabel("bw_label"); err != nil {
+ t.Errorf("failed to add label")
+ }
+ p.AddStmt(Ld+Abs+W, 10)
+ p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "bw_label", 0)
+ if _, err := p.Instructions(); err == nil {
+ t.Errorf("Instructions() should have failed")
}
}
diff --git a/pkg/context/BUILD b/pkg/context/BUILD
index 239f31149..f33e23bf7 100644
--- a/pkg/context/BUILD
+++ b/pkg/context/BUILD
@@ -7,7 +7,6 @@ go_library(
srcs = ["context.go"],
visibility = ["//:sandbox"],
deps = [
- "//pkg/amutex",
"//pkg/log",
],
)
diff --git a/pkg/context/context.go b/pkg/context/context.go
index 5319b6d8d..2613bc752 100644
--- a/pkg/context/context.go
+++ b/pkg/context/context.go
@@ -26,7 +26,6 @@ import (
"context"
"time"
- "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/log"
)
@@ -68,9 +67,10 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) {
// In both cases, values extracted from the Context should be used instead.
type Context interface {
log.Logger
- amutex.Sleeper
context.Context
+ ChannelSleeper
+
// UninterruptibleSleepStart indicates the beginning of an uninterruptible
// sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate
// is true and the Context represents a Task, the Task's AddressSpace is
@@ -85,29 +85,60 @@ type Context interface {
UninterruptibleSleepFinish(activate bool)
}
-// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep
-// methods for anonymous embedding in other types that do not implement sleeps.
-type NoopSleeper struct {
- amutex.NoopSleeper
+// A ChannelSleeper represents a goroutine that may sleep interruptibly, where
+// interruption is indicated by a channel becoming readable.
+type ChannelSleeper interface {
+ // SleepStart is called before going to sleep interruptibly. If SleepStart
+ // returns a non-nil channel and that channel becomes ready for receiving
+ // while the goroutine is sleeping, the goroutine should be woken, and
+ // SleepFinish(false) should be called. Otherwise, SleepFinish(true) should
+ // be called after the goroutine stops sleeping.
+ SleepStart() <-chan struct{}
+
+ // SleepFinish is called after an interruptibly-sleeping goroutine stops
+ // sleeping, as documented by SleepStart.
+ SleepFinish(success bool)
+
+ // Interrupted returns true if the channel returned by SleepStart is
+ // ready for receiving.
+ Interrupted() bool
+}
+
+// NoopSleeper is a noop implementation of ChannelSleeper and
+// Context.UninterruptibleSleep* methods for anonymous embedding in other types
+// that do not implement special behavior around sleeps.
+type NoopSleeper struct{}
+
+// SleepStart implements ChannelSleeper.SleepStart.
+func (NoopSleeper) SleepStart() <-chan struct{} {
+ return nil
+}
+
+// SleepFinish implements ChannelSleeper.SleepFinish.
+func (NoopSleeper) SleepFinish(success bool) {}
+
+// Interrupted implements ChannelSleeper.Interrupted.
+func (NoopSleeper) Interrupted() bool {
+ return false
}
-// UninterruptibleSleepStart does nothing.
-func (NoopSleeper) UninterruptibleSleepStart(bool) {}
+// UninterruptibleSleepStart implements Context.UninterruptibleSleepStart.
+func (NoopSleeper) UninterruptibleSleepStart(deactivate bool) {}
-// UninterruptibleSleepFinish does nothing.
-func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
+// UninterruptibleSleepFinish implements Context.UninterruptibleSleepFinish.
+func (NoopSleeper) UninterruptibleSleepFinish(activate bool) {}
-// Deadline returns zero values, meaning no deadline.
+// Deadline implements context.Context.Deadline.
func (NoopSleeper) Deadline() (time.Time, bool) {
return time.Time{}, false
}
-// Done returns nil.
+// Done implements context.Context.Done.
func (NoopSleeper) Done() <-chan struct{} {
return nil
}
-// Err returns nil.
+// Err returns context.Context.Err.
func (NoopSleeper) Err() error {
return nil
}
diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md
new file mode 100644
index 000000000..51d0d40e5
--- /dev/null
+++ b/pkg/lisafs/README.md
@@ -0,0 +1,363 @@
+# Replacing 9P
+
+## Background
+
+The Linux filesystem model consists of the following key aspects (modulo mounts,
+which are outside the scope of this discussion):
+
+- A `struct inode` represents a "filesystem object", such as a directory or a
+ regular file. "Filesystem object" is most precisely defined by the practical
+ properties of an inode, such as an immutable type (regular file, directory,
+ symbolic link, etc.) and its independence from the path originally used to
+ obtain it.
+
+- A `struct dentry` represents a node in a filesystem tree. Semantically, each
+ dentry is immutably associated with an inode representing the filesystem
+ object at that position. (Linux implements optimizations involving reuse of
+ unreferenced dentries, which allows their associated inodes to change, but
+ this is outside the scope of this discussion.)
+
+- A `struct file` represents an open file description (hereafter FD) and is
+ needed to perform I/O. Each FD is immutably associated with the dentry
+ through which it was opened.
+
+The current gVisor virtual filesystem implementation (hereafter VFS1) closely
+imitates the Linux design:
+
+- `struct inode` => `fs.Inode`
+
+- `struct dentry` => `fs.Dirent`
+
+- `struct file` => `fs.File`
+
+gVisor accesses most external filesystems through a variant of the 9P2000.L
+protocol, including extensions for performance (`walkgetattr`) and for features
+not supported by vanilla 9P2000.L (`flushf`, `lconnect`). The 9P protocol family
+is inode-based; 9P fids represent a file (equivalently "file system object"),
+and the protocol is structured around alternatively obtaining fids to represent
+files (with `walk` and, in gVisor, `walkgetattr`) and performing operations on
+those fids.
+
+In the sections below, a **shared** filesystem is a filesystem that is *mutably*
+accessible by multiple concurrent clients, such that a **non-shared** filesystem
+is a filesystem that is either read-only or accessible by only a single client.
+
+## Problems
+
+### Serialization of Path Component RPCs
+
+Broadly speaking, VFS1 traverses each path component in a pathname, alternating
+between verifying that each traversed dentry represents an inode that represents
+a searchable directory and moving to the next dentry in the path.
+
+In the context of a remote filesystem, the structure of this traversal means
+that - modulo caching - a path involving N components requires at least N-1
+*sequential* RPCs to obtain metadata for intermediate directories, incurring
+significant latency. (In vanilla 9P2000.L, 2(N-1) RPCs are required: N-1 `walk`
+and N-1 `getattr`. We added the `walkgetattr` RPC to reduce this overhead.) On
+non-shared filesystems, this overhead is primarily significant during
+application startup; caching mitigates much of this overhead at steady state. On
+shared filesystems, where correct caching requires revalidation (requiring RPCs
+for each revalidated directory anyway), this overhead is consistently ruinous.
+
+### Inefficient RPCs
+
+9P is not exceptionally economical with RPCs in general. In addition to the
+issue described above:
+
+- Opening an existing file in 9P involves at least 2 RPCs: `walk` to produce
+ an unopened fid representing the file, and `lopen` to open the fid.
+
+- Creating a file also involves at least 2 RPCs: `walk` to produce an unopened
+ fid representing the parent directory, and `lcreate` to create the file and
+ convert the fid to an open fid representing the created file. In practice,
+ both the Linux and gVisor 9P clients expect to have an unopened fid for the
+ created file (necessitating an additional `walk`), as well as attributes for
+ the created file (necessitating an additional `getattr`), for a total of 4
+ RPCs. (In a shared filesystem, where whether a file already exists can
+ change between RPCs, a correct implementation of `open(O_CREAT)` would have
+ to alternate between these two paths (plus `clunk`ing the temporary fid
+ between alternations, since the nature of the `fid` differs between the two
+ paths). Neither Linux nor gVisor implement the required alternation, so
+ `open(O_CREAT)` without `O_EXCL` can spuriously fail with `EEXIST` on both.)
+
+- Closing (`clunk`ing) a fid requires an RPC. VFS1 issues this RPC
+ asynchronously in an attempt to reduce critical path latency, but scheduling
+ overhead makes this not clearly advantageous in practice.
+
+- `read` and `readdir` can return partial reads without a way to indicate EOF,
+ necessitating an additional final read to detect EOF.
+
+- Operations that affect filesystem state do not consistently return updated
+ filesystem state. In gVisor, the client implementation attempts to handle
+ this by tracking what it thinks updated state "should" be; this is complex,
+ and especially brittle for timestamps (which are often not arbitrarily
+ settable). In Linux, the client implemtation invalidates cached metadata
+ whenever it performs such an operation, and reloads it when a dentry
+ corresponding to an inode with no valid cached metadata is revalidated; this
+ is simple, but necessitates an additional `getattr`.
+
+### Dentry/Inode Ambiguity
+
+As noted above, 9P's documentation tends to imply that unopened fids represent
+an inode. In practice, most filesystem APIs present very limited interfaces for
+working with inodes at best, such that the interpretation of unopened fids
+varies:
+
+- Linux's 9P client associates unopened fids with (dentry, uid) pairs. When
+ caching is enabled, it also associates each inode with the first fid opened
+ writably that references that inode, in order to support page cache
+ writeback.
+
+- gVisor's 9P client associates unopened fids with inodes, and also caches
+ opened fids in inodes in a manner similar to Linux.
+
+- The runsc fsgofer associates unopened fids with both "dentries" (host
+ filesystem paths) and "inodes" (host file descriptors); which is used
+ depends on the operation invoked on the fid.
+
+For non-shared filesystems, this confusion has resulted in correctness issues
+that are (in gVisor) currently handled by a number of coarse-grained locks that
+serialize renames with all other filesystem operations. For shared filesystems,
+this means inconsistent behavior in the presence of concurrent mutation.
+
+## Design
+
+Almost all Linux filesystem syscalls describe filesystem resources in one of two
+ways:
+
+- Path-based: A filesystem position is described by a combination of a
+ starting position and a sequence of path components relative to that
+ position, where the starting position is one of:
+
+ - The VFS root (defined by mount namespace and chroot), for absolute paths
+
+ - The VFS position of an existing FD, for relative paths passed to `*at`
+ syscalls (e.g. `statat`)
+
+ - The current working directory, for relative paths passed to non-`*at`
+ syscalls and `*at` syscalls with `AT_FDCWD`
+
+- File-description-based: A filesystem object is described by an existing FD,
+ passed to a `f*` syscall (e.g. `fstat`).
+
+Many of our issues with 9P arise from its (and VFS') interposition of a model
+based on inodes between the filesystem syscall API and filesystem
+implementations. We propose to replace 9P with a protocol that does not feature
+inodes at all, and instead closely follows the filesystem syscall API by
+featuring only path-based and FD-based operations, with minimal deviations as
+necessary to ameliorate deficiencies in the syscall interface (see below). This
+approach addresses the issues described above:
+
+- Even on shared filesystems, most application filesystem syscalls are
+ translated to a single RPC (possibly excepting special cases described
+ below), which is a logical lower bound.
+
+- The behavior of application syscalls on shared filesystems is
+ straightforwardly predictable: path-based syscalls are translated to
+ path-based RPCs, which will re-lookup the file at that path, and FD-based
+ syscalls are translated to FD-based RPCs, which use an existing open file
+ without performing another lookup. (This is at least true on gofers that
+ proxy the host local filesystem; other filesystems that lack support for
+ e.g. certain operations on FDs may have different behavior, but this
+ divergence is at least still predictable and inherent to the underlying
+ filesystem implementation.)
+
+Note that this approach is only feasible in gVisor's next-generation virtual
+filesystem (VFS2), which does not assume the existence of inodes and allows the
+remote filesystem client to translate whole path-based syscalls into RPCs. Thus
+one of the unavoidable tradeoffs associated with such a protocol vs. 9P is the
+inability to construct a Linux client that is performance-competitive with
+gVisor.
+
+### File Permissions
+
+Many filesystem operations are side-effectual, such that file permissions must
+be checked before such operations take effect. The simplest approach to file
+permission checking is for the sentry to obtain permissions from the remote
+filesystem, then apply permission checks in the sentry before performing the
+application-requested operation. However, this requires an additional RPC per
+application syscall (which can't be mitigated by caching on shared filesystems).
+Alternatively, we may delegate file permission checking to gofers. In general,
+file permission checks depend on the following properties of the accessor:
+
+- Filesystem UID/GID
+
+- Supplementary GIDs
+
+- Effective capabilities in the accessor's user namespace (i.e. the accessor's
+ effective capability set)
+
+- All UIDs and GIDs mapped in the accessor's user namespace (which determine
+ if the accessor's capabilities apply to accessed files)
+
+We may choose to delay implementation of file permission checking delegation,
+although this is potentially costly since it doubles the number of required RPCs
+for most operations on shared filesystems. We may also consider compromise
+options, such as only delegating file permission checks for accessors in the
+root user namespace.
+
+### Symbolic Links
+
+gVisor usually interprets symbolic link targets in its VFS rather than on the
+filesystem containing the symbolic link; thus e.g. a symlink to
+"/proc/self/maps" on a remote filesystem resolves to said file in the sentry's
+procfs rather than the host's. This implies that:
+
+- Remote filesystem servers that proxy filesystems supporting symlinks must
+ check if each path component is a symlink during path traversal.
+
+- Absolute symlinks require that the sentry restart the operation at its
+ contextual VFS root (which is task-specific and may not be on a remote
+ filesystem at all), so if a remote filesystem server encounters an absolute
+ symlink during path traversal on behalf of a path-based operation, it must
+ terminate path traversal and return the symlink target.
+
+- Relative symlinks begin target resolution in the parent directory of the
+ symlink, so in theory most relative symlinks can be handled automatically
+ during the path traversal that encounters the symlink, provided that said
+ traversal is supplied with the number of remaining symlinks before `ELOOP`.
+ However, the new path traversed by the symlink target may cross VFS mount
+ boundaries, such that it's only safe for remote filesystem servers to
+ speculatively follow relative symlinks for side-effect-free operations such
+ as `stat` (where the sentry can simply ignore results that are inapplicable
+ due to crossing mount boundaries). We may choose to delay implementation of
+ this feature, at the cost of an additional RPC per relative symlink (note
+ that even if the symlink target crosses a mount boundary, the sentry will
+ need to `stat` the path to the mount boundary to confirm that each traversed
+ component is an accessible directory); until it is implemented, relative
+ symlinks may be handled like absolute symlinks, by terminating path
+ traversal and returning the symlink target.
+
+The possibility of symlinks (and the possibility of a compromised sentry) means
+that the sentry may issue RPCs with paths that, in the absence of symlinks,
+would traverse beyond the root of the remote filesystem. For example, the sentry
+may issue an RPC with a path like "/foo/../..", on the premise that if "/foo" is
+a symlink then the resulting path may be elsewhere on the remote filesystem. To
+handle this, path traversal must also track its current depth below the remote
+filesystem root, and terminate path traversal if it would ascend beyond this
+point.
+
+### Path Traversal
+
+Since path-based VFS operations will translate to path-based RPCs, filesystem
+servers will need to handle path traversal. From the perspective of a given
+filesystem implementation in the server, there are two basic approaches to path
+traversal:
+
+- Inode-walk: For each path component, obtain a handle to the underlying
+ filesystem object (e.g. with `open(O_PATH)`), check if that object is a
+ symlink (as described above) and that that object is accessible by the
+ caller (e.g. with `fstat()`), then continue to the next path component (e.g.
+ with `openat()`). This ensures that the checked filesystem object is the one
+ used to obtain the next object in the traversal, which is intuitively
+ appealing. However, while this approach works for host local filesystems, it
+ requires features that are not widely supported by other filesystems.
+
+- Path-walk: For each path component, use a path-based operation to determine
+ if the filesystem object currently referred to by that path component is a
+ symlink / is accessible. This is highly portable, but suffers from quadratic
+ behavior (at the level of the underlying filesystem implementation, the
+ first path component will be traversed a number of times equal to the number
+ of path components in the path).
+
+The implementation should support either option by delegating path traversal to
+filesystem implementations within the server (like VFS and the remote filesystem
+protocol itself), as inode-walking is still safe, efficient, amenable to FD
+caching, and implementable on non-shared host local filesystems (a sufficiently
+common case as to be worth considering in the design).
+
+Both approaches are susceptible to race conditions that may permit sandboxed
+filesystem escapes:
+
+- Under inode-walk, a malicious application may cause a directory to be moved
+ (with `rename`) during path traversal, such that the filesystem
+ implementation incorrectly determines whether subsequent inodes are located
+ in paths that should be visible to sandboxed applications.
+
+- Under path-walk, a malicious application may cause a non-symlink file to be
+ replaced with a symlink during path traversal, such that following path
+ operations will incorrectly follow the symlink.
+
+Both race conditions can, to some extent, be mitigated in filesystem server
+implementations by synchronizing path traversal with the hazardous operations in
+question. However, shared filesystems are frequently used to share data between
+sandboxed and unsandboxed applications in a controlled way, and in some cases a
+malicious sandboxed application may be able to take advantage of a hazardous
+filesystem operation performed by an unsandboxed application. In some cases,
+filesystem features may be available to ensure safety even in such cases (e.g.
+[the new openat2() syscall](https://man7.org/linux/man-pages/man2/openat2.2.html)),
+but it is not clear how to solve this problem in general. (Note that this issue
+is not specific to our design; rather, it is a fundamental limitation of
+filesystem sandboxing.)
+
+### Filesystem Multiplexing
+
+A given sentry may need to access multiple distinct remote filesystems (e.g.
+different volumes for a given container). In many cases, there is no advantage
+to serving these filesystems from distinct filesystem servers, or accessing them
+through distinct connections (factors such as maximum RPC concurrency should be
+based on available host resources). Therefore, the protocol should support
+multiplexing of distinct filesystem trees within a single session. 9P supports
+this by allowing multiple calls to the `attach` RPC to produce fids representing
+distinct filesystem trees, but this is somewhat clunky; we propose a much
+simpler mechanism wherein each message that conveys a path also conveys a
+numeric filesystem ID that identifies a filesystem tree.
+
+## Alternatives Considered
+
+### Additional Extensions to 9P
+
+There are at least three conceptual aspects to 9P:
+
+- Wire format: messages with a 4-byte little-endian size prefix, strings with
+ a 2-byte little-endian size prefix, etc. Whether the wire format is worth
+ retaining is unclear; in particular, it's unclear that the 9P wire format
+ has a significant advantage over protobufs, which are substantially easier
+ to extend. Note that the official Go protobuf implementation is widely known
+ to suffer from a significant number of performance deficiencies, so if we
+ choose to switch to protobuf, we may need to use an alternative toolchain
+ such as `gogo/protobuf` (which is also widely used in the Go ecosystem, e.g.
+ by Kubernetes).
+
+- Filesystem model: fids, qids, etc. Discarding this is one of the motivations
+ for this proposal.
+
+- RPCs: Twalk, Tlopen, etc. In addition to previously-described
+ inefficiencies, most of these are dependent on the filesystem model and
+ therefore must be discarded.
+
+### FUSE
+
+The FUSE (Filesystem in Userspace) protocol is frequently used to provide
+arbitrary userspace filesystem implementations to a host Linux kernel.
+Unfortunately, FUSE is also inode-based, and therefore doesn't address any of
+the problems we have with 9P.
+
+### virtio-fs
+
+virtio-fs is an ongoing project aimed at improving Linux VM filesystem
+performance when accessing Linux host filesystems (vs. virtio-9p). In brief, it
+is based on:
+
+- Using a FUSE client in the guest that communicates over virtio with a FUSE
+ server in the host.
+
+- Using DAX to map the host page cache into the guest.
+
+- Using a file metadata table in shared memory to avoid VM exits for metadata
+ updates.
+
+None of these improvements seem applicable to gVisor:
+
+- As explained above, FUSE is still inode-based, so it is still susceptible to
+ most of the problems we have with 9P.
+
+- Our use of host file descriptors already allows us to leverage the host page
+ cache for file contents.
+
+- Our need for shared filesystem coherence is usually based on a user
+ requirement that an out-of-sandbox filesystem mutation is guaranteed to be
+ visible by all subsequent observations from within the sandbox, or vice
+ versa; it's not clear that this can be guaranteed without a synchronous
+ signaling mechanism like an RPC.
diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD
new file mode 100644
index 000000000..4aec98218
--- /dev/null
+++ b/pkg/marshal/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "marshal",
+ srcs = [
+ "marshal.go",
+ "marshal_impl_util.go",
+ ],
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go
new file mode 100644
index 000000000..d8cb44b40
--- /dev/null
+++ b/pkg/marshal/marshal.go
@@ -0,0 +1,184 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal defines the Marshallable interface for
+// serialize/deserializing go data structures to/from memory, according to the
+// Linux ABI.
+//
+// Implementations of this interface are typically automatically generated by
+// tools/go_marshal. See the go_marshal README for details.
+package marshal
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// CopyContext defines the memory operations required to marshal to and from
+// user memory. Typically, kernel.Task is used to provide implementations for
+// these operations.
+type CopyContext interface {
+ // CopyScratchBuffer provides a task goroutine-local scratch buffer. See
+ // kernel.CopyScratchBuffer.
+ CopyScratchBuffer(size int) []byte
+
+ // CopyOutBytes writes the contents of b to the task's memory. See
+ // kernel.CopyOutBytes.
+ CopyOutBytes(addr usermem.Addr, b []byte) (int, error)
+
+ // CopyInBytes reads the contents of the task's memory to b. See
+ // kernel.CopyInBytes.
+ CopyInBytes(addr usermem.Addr, b []byte) (int, error)
+}
+
+// Marshallable represents operations on a type that can be marshalled to and
+// from memory.
+//
+// go-marshal automatically generates implementations for this interface for
+// types marked as '+marshal'.
+type Marshallable interface {
+ io.WriterTo
+
+ // SizeBytes is the size of the memory representation of a type in
+ // marshalled form.
+ //
+ // SizeBytes must handle a nil receiver. Practically, this means SizeBytes
+ // cannot deference any fields on the object implementing it (but will
+ // likely make use of the type of these fields).
+ SizeBytes() int
+
+ // MarshalBytes serializes a copy of a type to dst.
+ // Precondition: dst must be at least SizeBytes() in length.
+ MarshalBytes(dst []byte)
+
+ // UnmarshalBytes deserializes a type from src.
+ // Precondition: src must be at least SizeBytes() in length.
+ UnmarshalBytes(src []byte)
+
+ // Packed returns true if the marshalled size of the type is the same as the
+ // size it occupies in memory. This happens when the type has no fields
+ // starting at unaligned addresses (should always be true by default for ABI
+ // structs, verified by automatically generated tests when using
+ // go_marshal), and has no fields marked `marshal:"unaligned"`.
+ //
+ // Packed must return the same result for all possible values of the type
+ // implementing it. Violating this constraint implies the type doesn't have
+ // a static memory layout, and will lead to memory corruption.
+ // Go-marshal-generated code reuses the result of Packed for multiple values
+ // of the same type.
+ Packed() bool
+
+ // MarshalUnsafe serializes a type by bulk copying its in-memory
+ // representation to the dst buffer. This is only safe to do when the type
+ // has no implicit padding, see Marshallable.Packed. When Packed would
+ // return false, MarshalUnsafe should fall back to the safer but slower
+ // MarshalBytes.
+ // Precondition: dst must be at least SizeBytes() in length.
+ MarshalUnsafe(dst []byte)
+
+ // UnmarshalUnsafe deserializes a type by directly copying to the underlying
+ // memory allocated for the object by the runtime.
+ //
+ // This allows much faster unmarshalling of types which have no implicit
+ // padding, see Marshallable.Packed. When Packed would return false,
+ // UnmarshalUnsafe should fall back to the safer but slower unmarshal
+ // mechanism implemented in UnmarshalBytes.
+ // Precondition: src must be at least SizeBytes() in length.
+ UnmarshalUnsafe(src []byte)
+
+ // CopyIn deserializes a Marshallable type from a task's memory. This may
+ // only be called from a task goroutine. This is more efficient than calling
+ // UnmarshalUnsafe on Marshallable.Packed types, as the type being
+ // marshalled does not escape. The implementation should avoid creating
+ // extra copies in memory by directly deserializing to the object's
+ // underlying memory.
+ //
+ // If the copy-in from the task memory is only partially successful, CopyIn
+ // should still attempt to deserialize as much data as possible. See comment
+ // for UnmarshalBytes.
+ CopyIn(cc CopyContext, addr usermem.Addr) (int, error)
+
+ // CopyOut serializes a Marshallable type to a task's memory. This may only
+ // be called from a task goroutine. This is more efficient than calling
+ // MarshalUnsafe on Marshallable.Packed types, as the type being serialized
+ // does not escape. The implementation should avoid creating extra copies in
+ // memory by directly serializing from the object's underlying memory.
+ //
+ // The copy-out to the task memory may be partially successful, in which
+ // case CopyOut returns how much data was serialized. See comment for
+ // MarshalBytes for implications.
+ CopyOut(cc CopyContext, addr usermem.Addr) (int, error)
+
+ // CopyOutN is like CopyOut, but explicitly requests a partial
+ // copy-out. Note that this may yield unexpected results for non-packed
+ // types and the caller may only want to allow this for packed types. See
+ // comment on MarshalBytes.
+ //
+ // The limit must be less than or equal to SizeBytes().
+ CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error)
+}
+
+// go-marshal generates additional functions for a type based on additional
+// clauses to the +marshal directive. They are documented below.
+//
+// Slice API
+// =========
+//
+// Adding a "slice" clause to the +marshal directive for structs or newtypes on
+// primitives like this:
+//
+// // +marshal slice:FooSlice
+// type Foo struct { ... }
+//
+// Generates four additional functions for marshalling slices of Foos like this:
+//
+// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It
+// // might be more efficient that repeatedly calling Foo.MarshalUnsafe
+// // over a []Foo in a loop if the type is Packed.
+// // Preconditions: dst must be at least len(src)*Foo.SizeBytes() in length.
+// func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... }
+//
+// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It
+// // might be more efficient that repeatedly calling Foo.UnmarshalUnsafe
+// // over a []Foo in a loop if the type is Packed.
+// // Preconditions: src must be at least len(dst)*Foo.SizeBytes() in length.
+// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... }
+//
+// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory.
+// func CopyFooSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Foo) (int, error) { ... }
+//
+// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory.
+// func CopyFooSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []Foo) (int, error) { ... }
+//
+// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where
+// %s is the first argument to the slice clause. This directive is not supported
+// for newtypes on arrays.
+//
+// The slice clause also takes an optional second argument, which must be the
+// value "inner":
+//
+// // +marshal slice:Int32Slice:inner
+// type Int32 int32
+//
+// This is only valid on newtypes on primitives, and causes the generated
+// functions to accept slices of the inner type instead:
+//
+// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []int32) (int, error) { ... }
+//
+// Without "inner", they would instead be:
+//
+// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Int32) (int, error) { ... }
+//
+// This may help avoid a cast depending on how the generated functions are used.
diff --git a/pkg/marshal/marshal_impl_util.go b/pkg/marshal/marshal_impl_util.go
new file mode 100644
index 000000000..69f33321e
--- /dev/null
+++ b/pkg/marshal/marshal_impl_util.go
@@ -0,0 +1,78 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package marshal
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// StubMarshallable implements the Marshallable interface.
+// StubMarshallable is a convenient embeddable type for satisfying the
+// marshallable interface, but provides no actual implementation. It is
+// useful when the marshallable interface needs to be implemented manually,
+// but the caller doesn't require the full marshallable interface.
+type StubMarshallable struct{}
+
+// WriteTo implements Marshallable.WriteTo.
+func (StubMarshallable) WriteTo(w io.Writer) (n int64, err error) {
+ panic("Please implement your own WriteTo function")
+}
+
+// SizeBytes implements Marshallable.SizeBytes.
+func (StubMarshallable) SizeBytes() int {
+ panic("Please implement your own SizeBytes function")
+}
+
+// MarshalBytes implements Marshallable.MarshalBytes.
+func (StubMarshallable) MarshalBytes(dst []byte) {
+ panic("Please implement your own MarshalBytes function")
+}
+
+// UnmarshalBytes implements Marshallable.UnmarshalBytes.
+func (StubMarshallable) UnmarshalBytes(src []byte) {
+ panic("Please implement your own UnMarshalBytes function")
+}
+
+// Packed implements Marshallable.Packed.
+func (StubMarshallable) Packed() bool {
+ panic("Please implement your own Packed function")
+}
+
+// MarshalUnsafe implements Marshallable.MarshalUnsafe.
+func (StubMarshallable) MarshalUnsafe(dst []byte) {
+ panic("Please implement your own MarshalUnsafe function")
+}
+
+// UnmarshalUnsafe implements Marshallable.UnmarshalUnsafe.
+func (StubMarshallable) UnmarshalUnsafe(src []byte) {
+ panic("Please implement your own UnmarshalUnsafe function")
+}
+
+// CopyIn implements Marshallable.CopyIn.
+func (StubMarshallable) CopyIn(cc CopyContext, addr usermem.Addr) (int, error) {
+ panic("Please implement your own CopyIn function")
+}
+
+// CopyOut implements Marshallable.CopyOut.
+func (StubMarshallable) CopyOut(cc CopyContext, addr usermem.Addr) (int, error) {
+ panic("Please implement your own CopyOut function")
+}
+
+// CopyOutN implements Marshallable.CopyOutN.
+func (StubMarshallable) CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) {
+ panic("Please implement your own CopyOutN function")
+}
diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD
new file mode 100644
index 000000000..06741e6d1
--- /dev/null
+++ b/pkg/marshal/primitive/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "primitive",
+ srcs = [
+ "primitive.go",
+ ],
+ marshal = True,
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/marshal",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go
new file mode 100644
index 000000000..dfdae5d60
--- /dev/null
+++ b/pkg/marshal/primitive/primitive.go
@@ -0,0 +1,247 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package primitive defines marshal.Marshallable implementations for primitive
+// types.
+package primitive
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Int8 is a marshal.Marshallable implementation for int8.
+//
+// +marshal slice:Int8Slice:inner
+type Int8 int8
+
+// Uint8 is a marshal.Marshallable implementation for uint8.
+//
+// +marshal slice:Uint8Slice:inner
+type Uint8 uint8
+
+// Int16 is a marshal.Marshallable implementation for int16.
+//
+// +marshal slice:Int16Slice:inner
+type Int16 int16
+
+// Uint16 is a marshal.Marshallable implementation for uint16.
+//
+// +marshal slice:Uint16Slice:inner
+type Uint16 uint16
+
+// Int32 is a marshal.Marshallable implementation for int32.
+//
+// +marshal slice:Int32Slice:inner
+type Int32 int32
+
+// Uint32 is a marshal.Marshallable implementation for uint32.
+//
+// +marshal slice:Uint32Slice:inner
+type Uint32 uint32
+
+// Int64 is a marshal.Marshallable implementation for int64.
+//
+// +marshal slice:Int64Slice:inner
+type Int64 int64
+
+// Uint64 is a marshal.Marshallable implementation for uint64.
+//
+// +marshal slice:Uint64Slice:inner
+type Uint64 uint64
+
+// ByteSlice is a marshal.Marshallable implementation for []byte.
+// This is a convenience wrapper around a dynamically sized type, and can't be
+// embedded in other marshallable types because it breaks assumptions made by
+// go-marshal internals. It violates the "no dynamically-sized types"
+// constraint of the go-marshal library.
+type ByteSlice []byte
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (b *ByteSlice) SizeBytes() int {
+ return len(*b)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (b *ByteSlice) MarshalBytes(dst []byte) {
+ copy(dst, *b)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (b *ByteSlice) UnmarshalBytes(src []byte) {
+ copy(*b, src)
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (b *ByteSlice) Packed() bool {
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (b *ByteSlice) MarshalUnsafe(dst []byte) {
+ b.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (b *ByteSlice) UnmarshalUnsafe(src []byte) {
+ b.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (b *ByteSlice) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+ return cc.CopyInBytes(addr, *b)
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (b *ByteSlice) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
+ return cc.CopyOutBytes(addr, *b)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (b *ByteSlice) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
+ return cc.CopyOutBytes(addr, (*b)[:limit])
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) {
+ n, err := w.Write(*b)
+ return int64(n), err
+}
+
+var _ marshal.Marshallable = (*ByteSlice)(nil)
+
+// Below, we define some convenience functions for marshalling primitive types
+// using the newtypes above, without requiring superfluous casts.
+
+// 16-bit integers
+
+// CopyInt16In is a convenient wrapper for copying in an int16 from the task's
+// memory.
+func CopyInt16In(cc marshal.CopyContext, addr usermem.Addr, dst *int16) (int, error) {
+ var buf Int16
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int16(buf)
+ return n, nil
+}
+
+// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's
+// memory.
+func CopyInt16Out(cc marshal.CopyContext, addr usermem.Addr, src int16) (int, error) {
+ srcP := Int16(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's
+// memory.
+func CopyUint16In(cc marshal.CopyContext, addr usermem.Addr, dst *uint16) (int, error) {
+ var buf Uint16
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint16(buf)
+ return n, nil
+}
+
+// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's
+// memory.
+func CopyUint16Out(cc marshal.CopyContext, addr usermem.Addr, src uint16) (int, error) {
+ srcP := Uint16(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// 32-bit integers
+
+// CopyInt32In is a convenient wrapper for copying in an int32 from the task's
+// memory.
+func CopyInt32In(cc marshal.CopyContext, addr usermem.Addr, dst *int32) (int, error) {
+ var buf Int32
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int32(buf)
+ return n, nil
+}
+
+// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's
+// memory.
+func CopyInt32Out(cc marshal.CopyContext, addr usermem.Addr, src int32) (int, error) {
+ srcP := Int32(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's
+// memory.
+func CopyUint32In(cc marshal.CopyContext, addr usermem.Addr, dst *uint32) (int, error) {
+ var buf Uint32
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint32(buf)
+ return n, nil
+}
+
+// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's
+// memory.
+func CopyUint32Out(cc marshal.CopyContext, addr usermem.Addr, src uint32) (int, error) {
+ srcP := Uint32(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// 64-bit integers
+
+// CopyInt64In is a convenient wrapper for copying in an int64 from the task's
+// memory.
+func CopyInt64In(cc marshal.CopyContext, addr usermem.Addr, dst *int64) (int, error) {
+ var buf Int64
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int64(buf)
+ return n, nil
+}
+
+// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's
+// memory.
+func CopyInt64Out(cc marshal.CopyContext, addr usermem.Addr, src int64) (int, error) {
+ srcP := Int64(src)
+ return srcP.CopyOut(cc, addr)
+}
+
+// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's
+// memory.
+func CopyUint64In(cc marshal.CopyContext, addr usermem.Addr, dst *uint64) (int, error) {
+ var buf Uint64
+ n, err := buf.CopyIn(cc, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint64(buf)
+ return n, nil
+}
+
+// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's
+// memory.
+func CopyUint64Out(cc marshal.CopyContext, addr usermem.Addr, src uint64) (int, error) {
+ srcP := Uint64(src)
+ return srcP.CopyOut(cc, addr)
+}
diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD
index ce30382ab..68ed074f8 100644
--- a/pkg/safemem/BUILD
+++ b/pkg/safemem/BUILD
@@ -11,9 +11,7 @@ go_library(
"seq_unsafe.go",
],
visibility = ["//:sandbox"],
- deps = [
- "//pkg/safecopy",
- ],
+ deps = ["//pkg/safecopy"],
)
go_test(
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index 55fd6967e..752e2dc32 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package seccomp provides basic seccomp filters for x86_64 (little endian).
+// Package seccomp provides generation of basic seccomp filters. Currently,
+// only little endian systems are supported.
package seccomp
import (
@@ -64,9 +65,9 @@ func Install(rules SyscallRules) error {
Rules: rules,
Action: linux.SECCOMP_RET_ALLOW,
},
- }, defaultAction)
+ }, defaultAction, defaultAction)
if log.IsLogging(log.Debug) {
- programStr, errDecode := bpf.DecodeProgram(instrs)
+ programStr, errDecode := bpf.DecodeInstructions(instrs)
if errDecode != nil {
programStr = fmt.Sprintf("Error: %v\n%s", errDecode, programStr)
}
@@ -117,7 +118,7 @@ var SyscallName = func(sysno uintptr) string {
// BuildProgram builds a BPF program from the given map of actions to matching
// SyscallRules. The single generated program covers all provided RuleSets.
-func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFInstruction, error) {
+func BuildProgram(rules []RuleSet, defaultAction, badArchAction linux.BPFAction) ([]linux.BPFInstruction, error) {
program := bpf.NewProgramBuilder()
// Be paranoid and check that syscall is done in the expected architecture.
@@ -128,7 +129,7 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn
// defaultLabel is at the bottom of the program. The size of program
// may exceeds 255 lines, which is the limit of a condition jump.
program.AddJump(bpf.Jmp|bpf.Jeq|bpf.K, LINUX_AUDIT_ARCH, skipOneInst, 0)
- program.AddDirectJumpLabel(defaultLabel)
+ program.AddStmt(bpf.Ret|bpf.K, uint32(badArchAction))
if err := buildIndex(rules, program); err != nil {
return nil, err
}
@@ -144,6 +145,11 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn
// buildIndex builds a BST to quickly search through all syscalls.
func buildIndex(rules []RuleSet, program *bpf.ProgramBuilder) error {
+ // Do nothing if rules is empty.
+ if len(rules) == 0 {
+ return nil
+ }
+
// Build a list of all application system calls, across all given rule
// sets. We have a simple BST, but may dispatch individual matchers
// with different actions. The matchers are evaluated linearly.
@@ -216,42 +222,163 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc
labelled := false
for i, arg := range rule {
if arg != nil {
+ // Break out early if using MatchAny since no further
+ // instructions are required.
+ if _, ok := arg.(MatchAny); ok {
+ continue
+ }
+
+ // Determine the data offset for low and high bits of input.
+ dataOffsetLow := seccompDataOffsetArgLow(i)
+ dataOffsetHigh := seccompDataOffsetArgHigh(i)
+ if i == RuleIP {
+ dataOffsetLow = seccompDataOffsetIPLow
+ dataOffsetHigh = seccompDataOffsetIPHigh
+ }
+
+ // Add the conditional operation. Input values to the BPF
+ // program are 64bit values. However, comparisons in BPF can
+ // only be done on 32bit values. This means that we need to do
+ // multiple BPF comparisons in order to do one logical 64bit
+ // comparison.
switch a := arg.(type) {
- case AllowAny:
- case AllowValue:
- dataOffsetLow := seccompDataOffsetArgLow(i)
- dataOffsetHigh := seccompDataOffsetArgHigh(i)
- if i == RuleIP {
- dataOffsetLow = seccompDataOffsetIPLow
- dataOffsetHigh = seccompDataOffsetIPHigh
- }
+ case EqualTo:
+ // EqualTo checks that both the higher and lower 32bits are equal.
high, low := uint32(a>>32), uint32(a)
- // assert arg_low == low
+
+ // Assert that the lower 32bits are equal.
+ // arg_low == low ? continue : violation
p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
- // assert arg_high == high
+
+ // Assert that the lower 32bits are also equal.
+ // arg_high == high ? continue/success : violation
p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
labelled = true
+ case NotEqual:
+ // NotEqual checks that either the higher or lower 32bits
+ // are *not* equal.
+ high, low := uint32(a>>32), uint32(a)
+ labelGood := fmt.Sprintf("ne%v", i)
+
+ // Check if the higher 32bits are (not) equal.
+ // arg_low == low ? continue : success
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+
+ // Assert that the lower 32bits are not equal (assuming
+ // higher bits are equal).
+ // arg_high == high ? violation : continue/success
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
case GreaterThan:
- dataOffsetLow := seccompDataOffsetArgLow(i)
- dataOffsetHigh := seccompDataOffsetArgHigh(i)
- if i == RuleIP {
- dataOffsetLow = seccompDataOffsetIPLow
- dataOffsetHigh = seccompDataOffsetIPHigh
- }
- labelGood := fmt.Sprintf("gt%v", i)
+ // GreaterThan checks that the higher 32bits is greater
+ // *or* that the higher 32bits are equal and the lower
+ // 32bits are greater.
high, low := uint32(a>>32), uint32(a)
- // assert arg_high < high
+ labelGood := fmt.Sprintf("gt%v", i)
+
+ // Assert the higher 32bits are greater than or equal.
+ // arg_high >= high ? continue : violation (arg_high < high)
p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
- // arg_high > high
+
+ // Assert that the lower 32bits are greater.
+ // arg_high == high ? continue : success (arg_high > high)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
- // arg_low < low
+ // arg_low > low ? continue/success : violation (arg_high == high and arg_low <= low)
p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
labelled = true
+ case GreaterThanOrEqual:
+ // GreaterThanOrEqual checks that the higher 32bits is
+ // greater *or* that the higher 32bits are equal and the
+ // lower 32bits are greater than or equal.
+ high, low := uint32(a>>32), uint32(a)
+ labelGood := fmt.Sprintf("ge%v", i)
+
+ // Assert the higher 32bits are greater than or equal.
+ // arg_high >= high ? continue : violation (arg_high < high)
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ // arg_high == high ? continue : success (arg_high > high)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+
+ // Assert that the lower 32bits are greater (assuming the
+ // higher bits are equal).
+ // arg_low >= low ? continue/success : violation (arg_high == high and arg_low < low)
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
+ case LessThan:
+ // LessThan checks that the higher 32bits is less *or* that
+ // the higher 32bits are equal and the lower 32bits are
+ // less.
+ high, low := uint32(a>>32), uint32(a)
+ labelGood := fmt.Sprintf("lt%v", i)
+
+ // Assert the higher 32bits are less than or equal.
+ // arg_high > high ? violation : continue
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ // arg_high == high ? continue : success (arg_high < high)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+
+ // Assert that the lower 32bits are less (assuming the
+ // higher bits are equal).
+ // arg_low >= low ? violation : continue
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jge|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
+ case LessThanOrEqual:
+ // LessThan checks that the higher 32bits is less *or* that
+ // the higher 32bits are equal and the lower 32bits are
+ // less than or equal.
+ high, low := uint32(a>>32), uint32(a)
+ labelGood := fmt.Sprintf("le%v", i)
+
+ // Assert the higher 32bits are less than or equal.
+ // assert arg_high > high ? violation : continue
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ // arg_high == high ? continue : success
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+
+ // Assert the lower bits are less than or equal (assuming
+ // the higher bits are equal).
+ // arg_low > low ? violation : success
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0)
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
+ case maskedEqual:
+ // MaskedEqual checks that the bitwise AND of the value and
+ // mask are equal for both the higher and lower 32bits.
+ high, low := uint32(a.value>>32), uint32(a.value)
+ maskHigh, maskLow := uint32(a.mask>>32), uint32(a.mask)
+
+ // Assert that the lower 32bits are equal when masked.
+ // A <- arg_low.
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ // A <- arg_low & maskLow
+ p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskLow)
+ // Assert that arg_low & maskLow == low.
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+
+ // Assert that the higher 32bits are equal when masked.
+ // A <- arg_high
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ // A <- arg_high & maskHigh
+ p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskHigh)
+ // Assert that arg_high & maskHigh == high.
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ labelled = true
default:
return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a))
}
diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go
index a52dc1b4e..daf165bbf 100644
--- a/pkg/seccomp/seccomp_rules.go
+++ b/pkg/seccomp/seccomp_rules.go
@@ -39,28 +39,79 @@ func seccompDataOffsetArgHigh(i int) uint32 {
return seccompDataOffsetArgLow(i) + 4
}
-// AllowAny is marker to indicate any value will be accepted.
-type AllowAny struct{}
+// MatchAny is marker to indicate any value will be accepted.
+type MatchAny struct{}
-func (a AllowAny) String() (s string) {
+func (a MatchAny) String() (s string) {
return "*"
}
-// AllowValue specifies a value that needs to be strictly matched.
-type AllowValue uintptr
+// EqualTo specifies a value that needs to be strictly matched.
+type EqualTo uintptr
+
+func (a EqualTo) String() (s string) {
+ return fmt.Sprintf("== %#x", uintptr(a))
+}
+
+// NotEqual specifies a value that is strictly not equal.
+type NotEqual uintptr
+
+func (a NotEqual) String() (s string) {
+ return fmt.Sprintf("!= %#x", uintptr(a))
+}
// GreaterThan specifies a value that needs to be strictly smaller.
type GreaterThan uintptr
-func (a AllowValue) String() (s string) {
- return fmt.Sprintf("%#x ", uintptr(a))
+func (a GreaterThan) String() (s string) {
+ return fmt.Sprintf("> %#x", uintptr(a))
+}
+
+// GreaterThanOrEqual specifies a value that needs to be smaller or equal.
+type GreaterThanOrEqual uintptr
+
+func (a GreaterThanOrEqual) String() (s string) {
+ return fmt.Sprintf(">= %#x", uintptr(a))
+}
+
+// LessThan specifies a value that needs to be strictly greater.
+type LessThan uintptr
+
+func (a LessThan) String() (s string) {
+ return fmt.Sprintf("< %#x", uintptr(a))
+}
+
+// LessThanOrEqual specifies a value that needs to be greater or equal.
+type LessThanOrEqual uintptr
+
+func (a LessThanOrEqual) String() (s string) {
+ return fmt.Sprintf("<= %#x", uintptr(a))
+}
+
+type maskedEqual struct {
+ mask uintptr
+ value uintptr
+}
+
+func (a maskedEqual) String() (s string) {
+ return fmt.Sprintf("& %#x == %#x", a.mask, a.value)
+}
+
+// MaskedEqual specifies a value that matches the input after the input is
+// masked (bitwise &) against the given mask. Can be used to verify that input
+// only includes certain approved flags.
+func MaskedEqual(mask, value uintptr) interface{} {
+ return maskedEqual{
+ mask: mask,
+ value: value,
+ }
}
// Rule stores the allowed syscall arguments.
//
// For example:
// rule := Rule {
-// AllowValue(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0
+// EqualTo(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0
// }
type Rule [7]interface{} // 6 arguments + RIP
@@ -89,12 +140,12 @@ func (r Rule) String() (s string) {
// rules := SyscallRules{
// syscall.SYS_FUTEX: []Rule{
// {
-// AllowAny{},
-// AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
+// MatchAny{},
+// EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
// }, // OR
// {
-// AllowAny{},
-// AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
+// MatchAny{},
+// EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
// },
// },
// syscall.SYS_GETPID: []Rule{},
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index 5238df8bd..23f30678d 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -76,11 +76,14 @@ func TestBasic(t *testing.T) {
}
for _, test := range []struct {
+ name string
ruleSets []RuleSet
defaultAction linux.BPFAction
+ badArchAction linux.BPFAction
specs []spec
}{
{
+ name: "Single syscall",
ruleSets: []RuleSet{
{
Rules: SyscallRules{1: {}},
@@ -88,26 +91,28 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Single syscall allowed",
+ desc: "syscall allowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Single syscall disallowed",
+ desc: "syscall disallowed",
data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
},
{
+ name: "Multiple rulesets",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- AllowValue(0x1),
+ EqualTo(0x1),
},
},
},
@@ -122,30 +127,32 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_KILL_THREAD,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Multiple rulesets allowed (1a)",
+ desc: "allowed (1a)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Multiple rulesets allowed (1b)",
+ desc: "allowed (1b)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple rulesets allowed (2)",
+ desc: "syscall 1 matched 2nd rule",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple rulesets allowed (2)",
+ desc: "no match",
data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_KILL_THREAD,
},
},
},
{
+ name: "Multiple syscalls",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
@@ -157,50 +164,52 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Multiple syscalls allowed (1)",
+ desc: "allowed (1)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Multiple syscalls allowed (3)",
+ desc: "allowed (3)",
data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Multiple syscalls allowed (5)",
+ desc: "allowed (5)",
data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Multiple syscalls disallowed (0)",
+ desc: "disallowed (0)",
data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple syscalls disallowed (2)",
+ desc: "disallowed (2)",
data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple syscalls disallowed (4)",
+ desc: "disallowed (4)",
data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple syscalls disallowed (6)",
+ desc: "disallowed (6)",
data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Multiple syscalls disallowed (100)",
+ desc: "disallowed (100)",
data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
},
{
+ name: "Wrong architecture",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
@@ -210,15 +219,17 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Wrong architecture",
+ desc: "arch (123)",
data: seccompData{nr: 1, arch: 123},
- want: linux.SECCOMP_RET_TRAP,
+ want: linux.SECCOMP_RET_KILL_THREAD,
},
},
},
{
+ name: "Syscall disallowed",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
@@ -228,22 +239,24 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Syscall disallowed, action trap",
+ desc: "action trap",
data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
},
{
+ name: "Syscall arguments",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- AllowAny{},
- AllowValue(0xf),
+ MatchAny{},
+ EqualTo(0xf),
},
},
},
@@ -251,29 +264,31 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Syscall argument allowed",
+ desc: "allowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Syscall argument disallowed",
+ desc: "disallowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}},
want: linux.SECCOMP_RET_TRAP,
},
},
},
{
+ name: "Multiple arguments",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- AllowValue(0xf),
+ EqualTo(0xf),
},
{
- AllowValue(0xe),
+ EqualTo(0xe),
},
},
},
@@ -281,28 +296,30 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "Syscall argument allowed, two rules",
+ desc: "match first rule",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "Syscall argument allowed, two rules",
+ desc: "match 2nd rule",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}},
want: linux.SECCOMP_RET_ALLOW,
},
},
},
{
+ name: "EqualTo",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- AllowValue(0),
- AllowValue(math.MaxUint64 - 1),
- AllowValue(math.MaxUint32),
+ EqualTo(0),
+ EqualTo(math.MaxUint64 - 1),
+ EqualTo(math.MaxUint32),
},
},
},
@@ -310,9 +327,10 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "64bit syscall argument allowed",
+ desc: "argument allowed (all match)",
data: seccompData{
nr: 1,
arch: LINUX_AUDIT_ARCH,
@@ -321,7 +339,7 @@ func TestBasic(t *testing.T) {
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "64bit syscall argument disallowed",
+ desc: "argument disallowed (one mismatch)",
data: seccompData{
nr: 1,
arch: LINUX_AUDIT_ARCH,
@@ -330,7 +348,7 @@ func TestBasic(t *testing.T) {
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "64bit syscall argument disallowed",
+ desc: "argument disallowed (multiple mismatch)",
data: seccompData{
nr: 1,
arch: LINUX_AUDIT_ARCH,
@@ -341,6 +359,103 @@ func TestBasic(t *testing.T) {
},
},
{
+ name: "NotEqual",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ NotEqual(0x7aabbccdd),
+ NotEqual(math.MaxUint64 - 1),
+ NotEqual(math.MaxUint32),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
+ },
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (one equal)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (all equal)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "GreaterThan",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ // 4294967298
+ // Both upper 32 bits and lower 32 bits are non-zero.
+ // 00000000000000000000000000000010
+ // 00000000000000000000000000000010
+ GreaterThan(0x00000002_00000002),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "high 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits equal",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000003}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "GreaterThan (multi)",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
@@ -355,46 +470,145 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "GreaterThan: Syscall argument allowed",
+ desc: "arg allowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "GreaterThan: Syscall argument disallowed (equal)",
+ desc: "arg disallowed (first arg equal)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "Syscall argument disallowed (smaller)",
+ desc: "arg disallowed (first arg smaller)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "GreaterThan2: Syscall argument allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xfbcd000d}},
+ desc: "arg disallowed (second arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (second arg smaller)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "GreaterThanOrEqual",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ // 4294967298
+ // Both upper 32 bits and lower 32 bits are non-zero.
+ // 00000000000000000000000000000010
+ // 00000000000000000000000000000010
+ GreaterThanOrEqual(0x00000002_00000002),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "high 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "GreaterThan2: Syscall argument disallowed (equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ desc: "high 32bits equal, low 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits equal",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "GreaterThanOrEqual (multi)",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ GreaterThanOrEqual(0xf),
+ GreaterThanOrEqual(0xabcd000d),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed (both greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg allowed (first arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (first arg smaller)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
- desc: "GreaterThan2: Syscall argument disallowed (smaller)",
+ desc: "arg allowed (second arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (second arg smaller)",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
+ {
+ desc: "arg disallowed (both arg smaller)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xa000ffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
},
},
{
+ name: "LessThan",
ruleSets: []RuleSet{
{
Rules: SyscallRules{
1: []Rule{
{
- RuleIP: AllowValue(0x7aabbccdd),
+ // 4294967298
+ // Both upper 32 bits and lower 32 bits are non-zero.
+ // 00000000000000000000000000000010
+ // 00000000000000000000000000000010
+ LessThan(0x00000002_00000002),
},
},
},
@@ -402,40 +616,307 @@ func TestBasic(t *testing.T) {
},
},
defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
specs: []spec{
{
- desc: "IP: Syscall instruction pointer allowed",
+ desc: "high 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits equal",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ },
+ {
+ name: "LessThan (multi)",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ LessThan(0x1),
+ LessThan(0xabcd000d),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (first arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (first arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (second arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (second arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (both arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "LessThanOrEqual",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ // 4294967298
+ // Both upper 32 bits and lower 32 bits are non-zero.
+ // 00000000000000000000000000000010
+ // 00000000000000000000000000000010
+ LessThanOrEqual(0x00000002_00000002),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "high 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits greater",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "high 32bits equal, low 32bits equal",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits equal, low 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "high 32bits less",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ },
+
+ {
+ name: "LessThanOrEqual (multi)",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ LessThanOrEqual(0x1),
+ LessThanOrEqual(0xabcd000d),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg allowed (first arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (first arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg allowed (second arg equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (second arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (both arg greater)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "MaskedEqual",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ // x & 00000001 00000011 (0x103) == 00000000 00000001 (0x1)
+ // Input x must have lowest order bit set and
+ // must *not* have 8th or second lowest order bit set.
+ MaskedEqual(0x103, 0x1),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "arg allowed (low order mandatory bit)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000000 00000001
+ args: [6]uint64{0x1},
+ },
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg allowed (low order optional bit)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000000 00000101
+ args: [6]uint64{0x5},
+ },
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "arg disallowed (lowest order bit not set)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000000 00000010
+ args: [6]uint64{0x2},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (second lowest order bit set)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000000 00000011
+ args: [6]uint64{0x3},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "arg disallowed (8th bit set)",
+ data: seccompData{
+ nr: 1,
+ arch: LINUX_AUDIT_ARCH,
+ // 00000000 00000000 00000001 00000000
+ args: [6]uint64{0x100},
+ },
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ name: "Instruction Pointer",
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ RuleIP: EqualTo(0x7aabbccdd),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ badArchAction: linux.SECCOMP_RET_KILL_THREAD,
+ specs: []spec{
+ {
+ desc: "allowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd},
want: linux.SECCOMP_RET_ALLOW,
},
{
- desc: "IP: Syscall instruction pointer disallowed",
+ desc: "disallowed",
data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344},
want: linux.SECCOMP_RET_TRAP,
},
},
},
} {
- instrs, err := BuildProgram(test.ruleSets, test.defaultAction)
- if err != nil {
- t.Errorf("%s: buildProgram() got error: %v", test.specs[0].desc, err)
- continue
- }
- p, err := bpf.Compile(instrs)
- if err != nil {
- t.Errorf("%s: bpf.Compile() got error: %v", test.specs[0].desc, err)
- continue
- }
- for _, spec := range test.specs {
- got, err := bpf.Exec(p, spec.data.asInput())
+ t.Run(test.name, func(t *testing.T) {
+ instrs, err := BuildProgram(test.ruleSets, test.defaultAction, test.badArchAction)
if err != nil {
- t.Errorf("%s: bpf.Exec() got error: %v", spec.desc, err)
- continue
+ t.Fatalf("BuildProgram() got error: %v", err)
+ }
+ p, err := bpf.Compile(instrs)
+ if err != nil {
+ t.Fatalf("bpf.Compile() got error: %v", err)
}
- if got != uint32(spec.want) {
- t.Errorf("%s: bpd.Exec() = %d, want: %d", spec.desc, got, spec.want)
+ for _, spec := range test.specs {
+ got, err := bpf.Exec(p, spec.data.asInput())
+ if err != nil {
+ t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err)
+ }
+ if got != uint32(spec.want) {
+ // Include a decoded version of the program in output for debugging purposes.
+ decoded, _ := bpf.DecodeInstructions(instrs)
+ t.Fatalf("%s: got: %d, want: %d\nBPF Program\n%s", spec.desc, got, spec.want, decoded)
+ }
}
- }
+ })
}
}
@@ -457,7 +938,7 @@ func TestRandom(t *testing.T) {
Rules: syscallRules,
Action: linux.SECCOMP_RET_ALLOW,
},
- }, linux.SECCOMP_RET_TRAP)
+ }, linux.SECCOMP_RET_TRAP, linux.SECCOMP_RET_KILL_THREAD)
if err != nil {
t.Fatalf("buildProgram() got error: %v", err)
}
diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go
index fe157f539..7f33e0d9e 100644
--- a/pkg/seccomp/seccomp_test_victim.go
+++ b/pkg/seccomp/seccomp_test_victim.go
@@ -100,7 +100,7 @@ func main() {
if !die {
syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{
{
- seccomp.AllowValue(10),
+ seccomp.EqualTo(10),
},
}
}
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 901e0f320..99e2b3389 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -33,11 +33,12 @@ go_library(
"//pkg/context",
"//pkg/cpuid",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/limits",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index a903d031c..d75d665ae 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -72,12 +73,12 @@ type Context interface {
// with return values of varying sizes (for example ARCH_GETFS). This
// is a simple utility function to convert to the native size in these
// cases, and then we can CopyOut.
- Native(val uintptr) interface{}
+ Native(val uintptr) marshal.Marshallable
// Value converts a native type back to a generic value.
// Once a value has been converted to native via the above call -- it
// can be converted back here.
- Value(val interface{}) uintptr
+ Value(val marshal.Marshallable) uintptr
// Width returns the number of bytes for a native value.
Width() uint
@@ -205,7 +206,7 @@ type Context interface {
// equivalent of arch_ptrace():
// PtracePeekUser implements ptrace(PTRACE_PEEKUSR).
- PtracePeekUser(addr uintptr) (interface{}, error)
+ PtracePeekUser(addr uintptr) (marshal.Marshallable, error)
// PtracePokeUser implements ptrace(PTRACE_POKEUSR).
PtracePokeUser(addr, data uintptr) error
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 1c3e3c14c..c7d3a206d 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -23,6 +23,8 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -179,14 +181,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
}
// Native returns the native type for the given val.
-func (c *context64) Native(val uintptr) interface{} {
- v := uint64(val)
+func (c *context64) Native(val uintptr) marshal.Marshallable {
+ v := primitive.Uint64(val)
return &v
}
// Value returns the generic val for the given native type.
-func (c *context64) Value(val interface{}) uintptr {
- return uintptr(*val.(*uint64))
+func (c *context64) Value(val marshal.Marshallable) uintptr {
+ return uintptr(*val.(*primitive.Uint64))
}
// Width returns the byte width of this architecture.
@@ -293,7 +295,7 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
const userStructSize = 928
// PtracePeekUser implements Context.PtracePeekUser.
-func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) {
if addr&7 != 0 || addr >= userStructSize {
return nil, syscall.EIO
}
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
index 550741d8c..680d23a9f 100644
--- a/pkg/sentry/arch/arch_arm64.go
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -22,6 +22,8 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -163,14 +165,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
}
// Native returns the native type for the given val.
-func (c *context64) Native(val uintptr) interface{} {
- v := uint64(val)
+func (c *context64) Native(val uintptr) marshal.Marshallable {
+ v := primitive.Uint64(val)
return &v
}
// Value returns the generic val for the given native type.
-func (c *context64) Value(val interface{}) uintptr {
- return uintptr(*val.(*uint64))
+func (c *context64) Value(val marshal.Marshallable) uintptr {
+ return uintptr(*val.(*primitive.Uint64))
}
// Width returns the byte width of this architecture.
@@ -274,7 +276,7 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
}
// PtracePeekUser implements Context.PtracePeekUser.
-func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) {
// TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
return c.Native(0), nil
}
diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go
index 32173aa20..d3e2324a8 100644
--- a/pkg/sentry/arch/signal_act.go
+++ b/pkg/sentry/arch/signal_act.go
@@ -14,7 +14,7 @@
package arch
-import "gvisor.dev/gvisor/tools/go_marshal/marshal"
+import "gvisor.dev/gvisor/pkg/marshal"
// Special values for SignalAct.Handler.
const (
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
index 0fa738a1d..a1eae98f9 100644
--- a/pkg/sentry/arch/signal_stack.go
+++ b/pkg/sentry/arch/signal_stack.go
@@ -17,8 +17,8 @@
package arch
import (
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
const (
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 42a6c41c2..1368014c4 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/fdnotifier",
"//pkg/iovec",
"//pkg/log",
+ "//pkg/marshal/primitive",
"//pkg/refs",
"//pkg/safemem",
"//pkg/secio",
@@ -55,7 +56,6 @@ go_library(
"//pkg/unet",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index 5d4f312cf..c8231e0aa 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -65,10 +65,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (
controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
if n > length {
- return length, n, msg.Controllen, controlTrunc, err
+ return length, n, msg.Controllen, controlTrunc, nil
}
- return n, n, msg.Controllen, controlTrunc, err
+ return n, n, msg.Controllen, controlTrunc, nil
}
// fdWriteVec sends from bufs to fd.
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 87d56a51d..1183727ab 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -17,6 +17,7 @@ package host
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -24,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 9cf7f2a62..103bfc600 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -604,7 +604,7 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) (
var vss, rss, data uint64
s.t.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.Size()
+ fds = fdTable.CurrentMaxFDs()
}
if mm := t.MemoryManager(); mm != nil {
vss = mm.VirtualMemorySize()
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index fdd5a40d5..e6d0eb359 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -17,6 +17,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/marshal/primitive",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -31,7 +32,6 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index bebf90ffa..b91184b1b 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -17,6 +17,7 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -25,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go
index e070a1b71..79975d812 100644
--- a/pkg/sentry/fs/tty/queue.go
+++ b/pkg/sentry/fs/tty/queue.go
@@ -17,6 +17,7 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -24,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
diff --git a/pkg/sentry/fs/tty/replica.go b/pkg/sentry/fs/tty/replica.go
index cb6cd6864..385d230fb 100644
--- a/pkg/sentry/fs/tty/replica.go
+++ b/pkg/sentry/fs/tty/replica.go
@@ -17,6 +17,7 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -24,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
index c9dbf1f3b..4f431d74d 100644
--- a/pkg/sentry/fs/tty/terminal.go
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -17,10 +17,10 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index ac48ab34b..48e13613a 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -30,6 +30,8 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -43,8 +45,6 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index d07e1ded8..83d790b38 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -17,6 +17,7 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
@@ -27,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// masterInode is the inode for the master end of the Terminal.
diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go
index ca36b66e9..55bff3e60 100644
--- a/pkg/sentry/fsimpl/devpts/queue.go
+++ b/pkg/sentry/fsimpl/devpts/queue.go
@@ -17,6 +17,7 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -24,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// waitBufMaxBytes is the maximum size of a wait buffer. It is based on
diff --git a/pkg/sentry/fsimpl/devpts/replica.go b/pkg/sentry/fsimpl/devpts/replica.go
index 1f99f4b4d..58f6c1d3a 100644
--- a/pkg/sentry/fsimpl/devpts/replica.go
+++ b/pkg/sentry/fsimpl/devpts/replica.go
@@ -17,6 +17,7 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
@@ -26,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// replicaInode is the inode for the replica end of the Terminal.
diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go
index 731955d62..510bd6d89 100644
--- a/pkg/sentry/fsimpl/devpts/terminal.go
+++ b/pkg/sentry/fsimpl/devpts/terminal.go
@@ -17,9 +17,9 @@ package devpts
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// Terminal is a pseudoterminal.
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 53a4f3012..999c16bfd 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -42,6 +42,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/log",
+ "//pkg/marshal",
"//pkg/refs",
"//pkg/sentry/fsimpl/devtmpfs",
"//pkg/sentry/fsimpl/kernfs",
@@ -52,7 +53,6 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
"@org_golang_x_sys//unix:go_default_library",
],
)
@@ -64,6 +64,7 @@ go_test(
library = ":fuse",
deps = [
"//pkg/abi/linux",
+ "//pkg/marshal",
"//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
@@ -71,6 +72,5 @@ go_test(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go
index 6df2728ab..eb1d1a2b7 100644
--- a/pkg/sentry/fsimpl/fuse/connection.go
+++ b/pkg/sentry/fsimpl/fuse/connection.go
@@ -24,12 +24,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// maxActiveRequestsDefault is the default setting controlling the upper bound
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index aedc2fa39..c18921dd8 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -28,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// echoTestOpcode is the Opcode used during testing. The server used in tests
@@ -408,17 +408,17 @@ func (t *testPayload) UnmarshalUnsafe(src []byte) {
}
// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+func (t *testPayload) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {
panic("not implemented")
}
// CopyOut implements marshal.Marshallable.CopyOut.
-func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+func (t *testPayload) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
panic("not implemented")
}
// CopyIn implements marshal.Marshallable.CopyIn.
-func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+func (t *testPayload) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {
panic("not implemented")
}
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index be1c88c82..56bcf9bdb 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -49,6 +49,7 @@ go_library(
"//pkg/fspath",
"//pkg/iovec",
"//pkg/log",
+ "//pkg/marshal/primitive",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -72,7 +73,6 @@ go_library(
"//pkg/unet",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go
index 35ded24bc..c0bf45f08 100644
--- a/pkg/sentry/fsimpl/host/socket_unsafe.go
+++ b/pkg/sentry/fsimpl/host/socket_unsafe.go
@@ -63,10 +63,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (
controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
if n > length {
- return length, n, msg.Controllen, controlTrunc, err
+ return length, n, msg.Controllen, controlTrunc, nil
}
- return n, n, msg.Controllen, controlTrunc, err
+ return n, n, msg.Controllen, controlTrunc, nil
}
// fdWriteVec sends from bufs to fd.
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
index 7a9be4b97..97cefa350 100644
--- a/pkg/sentry/fsimpl/host/tty.go
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -17,6 +17,7 @@ package host
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -25,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// TTYFileDescription implements vfs.FileDescriptionImpl for a host file
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index c0b863ba4..74408e322 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -259,9 +259,19 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
if opts.Stat.Mask == 0 {
return nil
}
- if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
+
+ // Note that not all fields are modifiable. For example, the file type and
+ // inode numbers are immutable after node creation. Setting the size is often
+ // allowed by kernfs files but does not do anything. If some other behavior is
+ // needed, the embedder should consider extending SetStat.
+ //
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
+ if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 {
return syserror.EPERM
}
+ if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() {
+ return syserror.EISDIR
+ }
if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
return err
}
@@ -284,13 +294,6 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut
atomic.StoreUint32(&a.gid, stat.GID)
}
- // Note that not all fields are modifiable. For example, the file type and
- // inode numbers are immutable after node creation.
-
- // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
- // Also, STATX_SIZE will need some special handling, because read-only static
- // files should return EIO for truncate operations.
-
return nil
}
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index c589b4746..360b77ef6 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -81,6 +82,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
Start: d.parent.upperVD,
Path: fspath.Parse(d.name),
}
+ // Used during copy-up of memory-mapped regular files.
+ var mmapOpts *memmap.MMapOpts
cleanupUndoCopyUp := func() {
var err error
if ftype == linux.S_IFDIR {
@@ -136,6 +139,25 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
break
}
}
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ if d.wrappedMappable != nil {
+ // We may have memory mappings of the file on the lower layer.
+ // Switch to mapping the file on the upper layer instead.
+ mmapOpts = &memmap.MMapOpts{
+ Perms: usermem.ReadWrite,
+ MaxPerms: usermem.ReadWrite,
+ }
+ if err := newFD.ConfigureMMap(ctx, mmapOpts); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ if mmapOpts.MappingIdentity != nil {
+ mmapOpts.MappingIdentity.DecRef(ctx)
+ }
+ // Don't actually switch Mappables until the end of copy-up; see
+ // below for why.
+ }
if err := newFD.SetStat(ctx, vfs.SetStatOptions{
Stat: linux.Statx{
Mask: linux.STATX_UID | linux.STATX_GID,
@@ -265,6 +287,62 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
atomic.StoreUint64(&d.ino, upperStat.Ino)
}
+ if mmapOpts != nil && mmapOpts.Mappable != nil {
+ // Note that if mmapOpts != nil, then d.mapsMu is locked for writing
+ // (from the S_IFREG path above).
+
+ // Propagate mappings of d to the new Mappable. Remember which mappings
+ // we added so we can remove them on failure.
+ upperMappable := mmapOpts.Mappable
+ allAdded := make(map[memmap.MappableRange]memmap.MappingsOfRange)
+ for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ added := make(memmap.MappingsOfRange)
+ for m := range seg.Value() {
+ if err := upperMappable.AddMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable); err != nil {
+ for m := range added {
+ upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable)
+ }
+ for mr, mappings := range allAdded {
+ for m := range mappings {
+ upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, mr.Start, m.Writable)
+ }
+ }
+ return err
+ }
+ added[m] = struct{}{}
+ }
+ allAdded[seg.Range()] = added
+ }
+
+ // Switch to the new Mappable. We do this at the end of copy-up
+ // because:
+ //
+ // - We need to switch Mappables (by changing d.wrappedMappable) before
+ // invalidating Translations from the old Mappable (to pick up
+ // Translations from the new one).
+ //
+ // - We need to lock d.dataMu while changing d.wrappedMappable, but
+ // must invalidate Translations with d.dataMu unlocked (due to lock
+ // ordering).
+ //
+ // - Consequently, once we unlock d.dataMu, other threads may
+ // immediately observe the new (copied-up) Mappable, which we want to
+ // delay until copy-up is guaranteed to succeed.
+ d.dataMu.Lock()
+ lowerMappable := d.wrappedMappable
+ d.wrappedMappable = upperMappable
+ d.dataMu.Unlock()
+ d.lowerMappings.InvalidateAll(memmap.InvalidateOpts{})
+
+ // Remove mappings from the old Mappable.
+ for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ for m := range seg.Value() {
+ lowerMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable)
+ }
+ }
+ d.lowerMappings.RemoveAll()
+ }
+
atomic.StoreUint32(&d.copiedUp, 1)
return nil
}
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go
index 268b32537..74cfd3799 100644
--- a/pkg/sentry/fsimpl/overlay/non_directory.go
+++ b/pkg/sentry/fsimpl/overlay/non_directory.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -256,10 +257,105 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error {
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
- wrappedFD, err := fd.getCurrentFD(ctx)
+ if err := fd.ensureMappable(ctx, opts); err != nil {
+ return err
+ }
+ return vfs.GenericConfigureMMap(&fd.vfsfd, fd.dentry(), opts)
+}
+
+// ensureMappable ensures that fd.dentry().wrappedMappable is not nil.
+func (fd *nonDirectoryFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error {
+ d := fd.dentry()
+
+ // Fast path if we already have a Mappable for the current top layer.
+ if atomic.LoadUint32(&d.isMappable) != 0 {
+ return nil
+ }
+
+ // Only permit mmap of regular files, since other file types may have
+ // unpredictable behavior when mmapped (e.g. /dev/zero).
+ if atomic.LoadUint32(&d.mode)&linux.S_IFMT != linux.S_IFREG {
+ return syserror.ENODEV
+ }
+
+ // Get a Mappable for the current top layer.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ d.copyMu.RLock()
+ defer d.copyMu.RUnlock()
+ if atomic.LoadUint32(&d.isMappable) != 0 {
+ return nil
+ }
+ wrappedFD, err := fd.currentFDLocked(ctx)
if err != nil {
return err
}
- defer wrappedFD.DecRef(ctx)
- return wrappedFD.ConfigureMMap(ctx, opts)
+ if err := wrappedFD.ConfigureMMap(ctx, opts); err != nil {
+ return err
+ }
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef(ctx)
+ opts.MappingIdentity = nil
+ }
+ // Use this Mappable for all mappings of this layer (unless we raced with
+ // another call to ensureMappable).
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ if d.wrappedMappable == nil {
+ d.wrappedMappable = opts.Mappable
+ atomic.StoreUint32(&d.isMappable, 1)
+ }
+ return nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ if err := d.wrappedMappable.AddMapping(ctx, ms, ar, offset, writable); err != nil {
+ return err
+ }
+ if !d.isCopiedUp() {
+ d.lowerMappings.AddMapping(ms, ar, offset, writable)
+ }
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.wrappedMappable.RemoveMapping(ctx, ms, ar, offset, writable)
+ if !d.isCopiedUp() {
+ d.lowerMappings.RemoveMapping(ms, ar, offset, writable)
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ if err := d.wrappedMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil {
+ return err
+ }
+ if !d.isCopiedUp() {
+ d.lowerMappings.AddMapping(ms, dstAR, offset, writable)
+ }
+ return nil
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ d.dataMu.RLock()
+ defer d.dataMu.RUnlock()
+ return d.wrappedMappable.Translate(ctx, required, optional, at)
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (d *dentry) InvalidateUnsavable(ctx context.Context) error {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ return d.wrappedMappable.InvalidateUnsavable(ctx)
}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index 9a8f7010e..b2efe5f80 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -22,6 +22,10 @@
// filesystem.renameMu
// dentry.dirMu
// dentry.copyMu
+// *** "memmap.Mappable locks" below this point
+// dentry.mapsMu
+// *** "memmap.Mappable locks taken by Translate" below this point
+// dentry.dataMu
//
// Locking dentry.dirMu in multiple dentries requires that parent dentries are
// locked before child dentries, and that filesystem.renameMu is locked to
@@ -37,6 +41,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -419,6 +424,35 @@ type dentry struct {
devMinor uint32
ino uint64
+ // If this dentry represents a regular file, then:
+ //
+ // - mapsMu is used to synchronize between copy-up and memmap.Mappable
+ // methods on dentry preceding mm.MemoryManager.activeMu in the lock order.
+ //
+ // - dataMu is used to synchronize between copy-up and
+ // dentry.(memmap.Mappable).Translate.
+ //
+ // - lowerMappings tracks memory mappings of the file. lowerMappings is
+ // used to invalidate mappings of the lower layer when the file is copied
+ // up to ensure that they remain coherent with subsequent writes to the
+ // file. (Note that, as of this writing, Linux overlayfs does not do this;
+ // this feature is a gVisor extension.) lowerMappings is protected by
+ // mapsMu.
+ //
+ // - If this dentry is copied-up, then wrappedMappable is the Mappable
+ // obtained from a call to the current top layer's
+ // FileDescription.ConfigureMMap(). Once wrappedMappable becomes non-nil
+ // (from a call to nonDirectoryFD.ensureMappable()), it cannot become nil.
+ // wrappedMappable is protected by mapsMu and dataMu.
+ //
+ // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is
+ // accessed using atomic memory operations.
+ mapsMu sync.Mutex
+ lowerMappings memmap.MappingSet
+ dataMu sync.RWMutex
+ wrappedMappable memmap.Mappable
+ isMappable uint32
+
locks vfs.FileLocks
}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 3f0d78461..94ec2ff69 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -86,14 +86,19 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, off
Name: strconv.FormatUint(uint64(fd), 10),
Type: typ,
Ino: i.fs.NextIno(),
- NextOff: offset + 1,
+ NextOff: int64(fd) + 3,
}
if err := cb.Handle(dirent); err != nil {
- return offset, err
+ // Getdents should iterate correctly despite mutation
+ // of fds, so we return the next fd to serialize plus
+ // 2 (which accounts for the "." and ".." tracked by
+ // kernfs) as the offset.
+ return int64(fd) + 2, err
}
- offset++
}
- return offset, nil
+ // We serialized them all. Next offset should be higher than last
+ // serialized fd.
+ return int64(fds[len(fds)-1]) + 3, nil
}
// fdDirInode represents the inode for /proc/[pid]/fd directory.
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 356036b9b..ce87b0d47 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -543,7 +543,7 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
var vss, rss, data uint64
s.task.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.Size()
+ fds = fdTable.CurrentMaxFDs()
}
if mm := t.MemoryManager(); mm != nil {
vss = mm.VirtualMemorySize()
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index d28450e53..bc8e38431 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -13,9 +13,11 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/marshal/primitive",
"//pkg/merkletree",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 0bac8e938..249cc1341 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -28,10 +28,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/merkletree"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -589,11 +591,28 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg
return 0, nil
}
+func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ f := int32(0)
+
+ // All enabled files should store a root hash. This flag is not settable
+ // via FS_IOC_SETFLAGS.
+ if len(fd.d.rootHash) != 0 {
+ f |= linux.FS_VERITY_FL
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ addr := args[2].Pointer()
+ _, err := primitive.CopyInt32Out(t, addr, f)
+ return 0, err
+}
+
// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch cmd := args[1].Uint(); cmd {
case linux.FS_IOC_ENABLE_VERITY:
return fd.enableVerity(ctx, uio, args)
+ case linux.FS_IOC_GETFLAGS:
+ return fd.getFlags(ctx, uio, args)
default:
return fd.lowerFD.Ioctl(ctx, uio, args)
}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index d436daab4..a43c549f1 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -197,6 +197,7 @@ go_library(
"gvisor.dev/gvisor/pkg/sentry/device",
"gvisor.dev/gvisor/pkg/tcpip",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
":uncaught_signal_go_proto",
@@ -212,6 +213,8 @@ go_library(
"//pkg/eventchannel",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/refs",
"//pkg/refs_vfs2",
@@ -261,7 +264,6 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 2bc49483a..869e49ebc 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -57,6 +57,7 @@ go_library(
"id_map_set.go",
"user_namespace.go",
],
+ marshal = True,
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/kernel/auth/id.go b/pkg/sentry/kernel/auth/id.go
index 0a58ba17c..4c32ee703 100644
--- a/pkg/sentry/kernel/auth/id.go
+++ b/pkg/sentry/kernel/auth/id.go
@@ -19,9 +19,13 @@ import (
)
// UID is a user ID in an unspecified user namespace.
+//
+// +marshal
type UID uint32
// GID is a group ID in an unspecified user namespace.
+//
+// +marshal slice:GIDSlice
type GID uint32
// In the root user namespace, user/group IDs have a 1-to-1 relationship with
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 89223fa36..0ec7344cd 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -111,8 +111,11 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
ctx := context.Background()
f.init() // Initialize table.
+ f.used = 0
for fd, d := range m {
- f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags)
+ if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil {
+ panic("VFS1 or VFS2 files set")
+ }
// Note that we do _not_ need to acquire a extra table reference here. The
// table reference will already be accounted for in the file, so we drop the
@@ -186,12 +189,6 @@ func (f *FDTable) DecRef(ctx context.Context) {
})
}
-// Size returns the number of file descriptor slots currently allocated.
-func (f *FDTable) Size() int {
- size := atomic.LoadInt32(&f.used)
- return int(size)
-}
-
// forEach iterates over all non-nil files in sorted order.
//
// It is the caller's responsibility to acquire an appropriate lock.
@@ -278,7 +275,6 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
}
f.mu.Lock()
- defer f.mu.Unlock()
// From f.next to find available fd.
if fd < f.next {
@@ -288,15 +284,25 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
// Install all entries.
for i := fd; i < end && len(fds) < len(files); i++ {
if d, _, _ := f.get(i); d == nil {
- f.set(ctx, i, files[len(fds)], flags) // Set the descriptor.
- fds = append(fds, i) // Record the file descriptor.
+ // Set the descriptor.
+ f.set(ctx, i, files[len(fds)], flags)
+ fds = append(fds, i) // Record the file descriptor.
}
}
// Failure? Unwind existing FDs.
if len(fds) < len(files) {
for _, i := range fds {
- f.set(ctx, i, nil, FDFlags{}) // Zap entry.
+ f.set(ctx, i, nil, FDFlags{})
+ }
+ f.mu.Unlock()
+
+ // Drop the reference taken by the call to f.set() that
+ // originally installed the file. Don't call f.drop()
+ // (generating inotify events, etc.) since the file should
+ // appear to have never been inserted into f.
+ for _, file := range files[:len(fds)] {
+ file.DecRef(ctx)
}
return nil, syscall.EMFILE
}
@@ -306,6 +312,7 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
f.next = fds[len(fds)-1] + 1
}
+ f.mu.Unlock()
return fds, nil
}
@@ -333,7 +340,6 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
}
f.mu.Lock()
- defer f.mu.Unlock()
// From f.next to find available fd.
if fd < f.next {
@@ -343,15 +349,25 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
// Install all entries.
for i := fd; i < end && len(fds) < len(files); i++ {
if d, _, _ := f.getVFS2(i); d == nil {
- f.setVFS2(ctx, i, files[len(fds)], flags) // Set the descriptor.
- fds = append(fds, i) // Record the file descriptor.
+ // Set the descriptor.
+ f.setVFS2(ctx, i, files[len(fds)], flags)
+ fds = append(fds, i) // Record the file descriptor.
}
}
// Failure? Unwind existing FDs.
if len(fds) < len(files) {
for _, i := range fds {
- f.setVFS2(ctx, i, nil, FDFlags{}) // Zap entry.
+ f.setVFS2(ctx, i, nil, FDFlags{})
+ }
+ f.mu.Unlock()
+
+ // Drop the reference taken by the call to f.setVFS2() that
+ // originally installed the file. Don't call f.dropVFS2()
+ // (generating inotify events, etc.) since the file should
+ // appear to have never been inserted into f.
+ for _, file := range files[:len(fds)] {
+ file.DecRef(ctx)
}
return nil, syscall.EMFILE
}
@@ -361,6 +377,7 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
f.next = fds[len(fds)-1] + 1
}
+ f.mu.Unlock()
return fds, nil
}
@@ -412,34 +429,49 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc
// reference for that FD, the ref count for that existing reference is
// decremented.
func (f *FDTable) NewFDAt(ctx context.Context, fd int32, file *fs.File, flags FDFlags) error {
- return f.newFDAt(ctx, fd, file, nil, flags)
+ df, _, err := f.newFDAt(ctx, fd, file, nil, flags)
+ if err != nil {
+ return err
+ }
+ if df != nil {
+ f.drop(ctx, df)
+ }
+ return nil
}
// NewFDAtVFS2 sets the file reference for the given FD. If there is an active
// reference for that FD, the ref count for that existing reference is
// decremented.
func (f *FDTable) NewFDAtVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) error {
- return f.newFDAt(ctx, fd, nil, file, flags)
+ _, dfVFS2, err := f.newFDAt(ctx, fd, nil, file, flags)
+ if err != nil {
+ return err
+ }
+ if dfVFS2 != nil {
+ f.dropVFS2(ctx, dfVFS2)
+ }
+ return nil
}
-func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) error {
+func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription, error) {
if fd < 0 {
// Don't accept negative FDs.
- return syscall.EBADF
+ return nil, nil, syscall.EBADF
}
// Check the limit for the provided file.
if limitSet := limits.FromContext(ctx); limitSet != nil {
if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur {
- return syscall.EMFILE
+ return nil, nil, syscall.EMFILE
}
}
// Install the entry.
f.mu.Lock()
defer f.mu.Unlock()
- f.setAll(ctx, fd, file, fileVFS2, flags)
- return nil
+
+ df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags)
+ return df, dfVFS2, nil
}
// SetFlags sets the flags for the given file descriptor.
@@ -550,30 +582,6 @@ func (f *FDTable) GetFDs(ctx context.Context) []int32 {
return fds
}
-// GetRefs returns a stable slice of references to all files and bumps the
-// reference count on each. The caller must use DecRef on each reference when
-// they're done using the slice.
-func (f *FDTable) GetRefs(ctx context.Context) []*fs.File {
- files := make([]*fs.File, 0, f.Size())
- f.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
- file.IncRef() // Acquire a reference for caller.
- files = append(files, file)
- })
- return files
-}
-
-// GetRefsVFS2 returns a stable slice of references to all files and bumps the
-// reference count on each. The caller must use DecRef on each reference when
-// they're done using the slice.
-func (f *FDTable) GetRefsVFS2(ctx context.Context) []*vfs.FileDescription {
- files := make([]*vfs.FileDescription, 0, f.Size())
- f.forEach(ctx, func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) {
- file.IncRef() // Acquire a reference for caller.
- files = append(files, file)
- })
- return files
-}
-
// Fork returns an independent FDTable.
func (f *FDTable) Fork(ctx context.Context) *FDTable {
clone := f.k.NewFDTable()
@@ -581,11 +589,8 @@ func (f *FDTable) Fork(ctx context.Context) *FDTable {
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
// The set function here will acquire an appropriate table
// reference for the clone. We don't need anything else.
- switch {
- case file != nil:
- clone.set(ctx, fd, file, flags)
- case fileVFS2 != nil:
- clone.setVFS2(ctx, fd, fileVFS2, flags)
+ if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil {
+ panic("VFS1 or VFS2 files set")
}
})
return clone
@@ -600,7 +605,6 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc
}
f.mu.Lock()
- defer f.mu.Unlock()
// Update current available position.
if fd < f.next {
@@ -616,24 +620,51 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc
case orig2 != nil:
orig2.IncRef()
}
+
if orig != nil || orig2 != nil {
- f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry.
+ orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry.
+ }
+ f.mu.Unlock()
+
+ if orig != nil {
+ f.drop(ctx, orig)
}
+ if orig2 != nil {
+ f.dropVFS2(ctx, orig2)
+ }
+
return orig, orig2
}
// RemoveIf removes all FDs where cond is true.
func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) {
- f.mu.Lock()
- defer f.mu.Unlock()
+ // TODO(gvisor.dev/issue/1624): Remove fs.File slice.
+ var files []*fs.File
+ var filesVFS2 []*vfs.FileDescription
+ f.mu.Lock()
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
if cond(file, fileVFS2, flags) {
- f.set(ctx, fd, nil, FDFlags{}) // Clear from table.
+ df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table.
+ if df != nil {
+ files = append(files, df)
+ }
+ if dfVFS2 != nil {
+ filesVFS2 = append(filesVFS2, dfVFS2)
+ }
// Update current available position.
if fd < f.next {
f.next = fd
}
}
})
+ f.mu.Unlock()
+
+ for _, file := range files {
+ f.drop(ctx, file)
+ }
+
+ for _, file := range filesVFS2 {
+ f.dropVFS2(ctx, file)
+ }
}
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index 555b14f8e..da79e6627 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -79,33 +79,37 @@ func (f *FDTable) getAll(fd int32) (*fs.File, *vfs.FileDescription, FDFlags, boo
return d.file, d.fileVFS2, d.flags, true
}
-// set sets an entry.
-//
-// This handles accounting changes, as well as acquiring and releasing the
-// reference needed by the table iff the file is different.
+// CurrentMaxFDs returns the number of file descriptors that may be stored in f
+// without reallocation.
+func (f *FDTable) CurrentMaxFDs() int {
+ slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
+ return len(slice)
+}
+
+// set sets an entry for VFS1, refer to setAll().
//
// Precondition: mu must be held.
-func (f *FDTable) set(ctx context.Context, fd int32, file *fs.File, flags FDFlags) {
- f.setAll(ctx, fd, file, nil, flags)
+func (f *FDTable) set(ctx context.Context, fd int32, file *fs.File, flags FDFlags) *fs.File {
+ dropFile, _ := f.setAll(ctx, fd, file, nil, flags)
+ return dropFile
}
-// setVFS2 sets an entry.
-//
-// This handles accounting changes, as well as acquiring and releasing the
-// reference needed by the table iff the file is different.
+// setVFS2 sets an entry for VFS2, refer to setAll().
//
// Precondition: mu must be held.
-func (f *FDTable) setVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) {
- f.setAll(ctx, fd, nil, file, flags)
+func (f *FDTable) setVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) *vfs.FileDescription {
+ _, dropFile := f.setAll(ctx, fd, nil, file, flags)
+ return dropFile
}
-// setAll sets an entry.
-//
-// This handles accounting changes, as well as acquiring and releasing the
-// reference needed by the table iff the file is different.
+// setAll sets the file description referred to by fd to file/fileVFS2. If
+// file/fileVFS2 are non-nil, it takes a reference on them. If setAll replaces
+// an existing file description, it returns it with the FDTable's reference
+// transferred to the caller, which must call f.drop/dropVFS2() on the returned
+// file after unlocking f.mu.
//
// Precondition: mu must be held.
-func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription) {
if file != nil && fileVFS2 != nil {
panic("VFS1 and VFS2 files set")
}
@@ -148,25 +152,25 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2
}
}
- // Drop the table reference.
+ // Adjust used.
+ switch {
+ case orig == nil && desc != nil:
+ atomic.AddInt32(&f.used, 1)
+ case orig != nil && desc == nil:
+ atomic.AddInt32(&f.used, -1)
+ }
+
if orig != nil {
switch {
case orig.file != nil:
if desc == nil || desc.file != orig.file {
- f.drop(ctx, orig.file)
+ return orig.file, nil
}
case orig.fileVFS2 != nil:
if desc == nil || desc.fileVFS2 != orig.fileVFS2 {
- f.dropVFS2(ctx, orig.fileVFS2)
+ return nil, orig.fileVFS2
}
}
}
-
- // Adjust used.
- switch {
- case orig == nil && desc != nil:
- atomic.AddInt32(&f.used, 1)
- case orig != nil && desc == nil:
- atomic.AddInt32(&f.used, -1)
- }
+ return nil, nil
}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 50df179c3..1145faf13 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
@@ -999,18 +1000,15 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
// at the address specified by the data parameter, and the return value
// is the error flag." - ptrace(2)
word := t.Arch().Native(0)
- if _, err := usermem.CopyObjectIn(t, target.MemoryManager(), addr, word, usermem.IOOpts{
- IgnorePermissions: true,
- }); err != nil {
+ if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil {
return err
}
- _, err := t.CopyOut(data, word)
+ _, err := word.CopyOut(t, data)
return err
case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA:
- _, err := usermem.CopyObjectOut(t, target.MemoryManager(), addr, t.Arch().Native(uintptr(data)), usermem.IOOpts{
- IgnorePermissions: true,
- })
+ word := t.Arch().Native(uintptr(data))
+ _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr)
return err
case linux.PTRACE_GETREGSET:
@@ -1078,12 +1076,12 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
if target.ptraceSiginfo == nil {
return syserror.EINVAL
}
- _, err := t.CopyOut(data, target.ptraceSiginfo)
+ _, err := target.ptraceSiginfo.CopyOut(t, data)
return err
case linux.PTRACE_SETSIGINFO:
var info arch.SignalInfo
- if _, err := t.CopyIn(data, &info); err != nil {
+ if _, err := info.CopyIn(t, data); err != nil {
return err
}
t.tg.pidns.owner.mu.RLock()
@@ -1098,7 +1096,8 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
if addr != linux.SignalSetSize {
return syserror.EINVAL
}
- _, err := t.CopyOut(data, target.SignalMask())
+ mask := target.SignalMask()
+ _, err := mask.CopyOut(t, data)
return err
case linux.PTRACE_SETSIGMASK:
@@ -1106,7 +1105,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
return syserror.EINVAL
}
var mask linux.SignalSet
- if _, err := t.CopyIn(data, &mask); err != nil {
+ if _, err := mask.CopyIn(t, data); err != nil {
return err
}
// The target's task goroutine is stopped, so this is safe:
@@ -1121,7 +1120,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
case linux.PTRACE_GETEVENTMSG:
t.tg.pidns.owner.mu.RLock()
defer t.tg.pidns.owner.mu.RUnlock()
- _, err := t.CopyOut(usermem.Addr(data), target.ptraceEventMsg)
+ _, err := primitive.CopyUint64Out(t, usermem.Addr(data), target.ptraceEventMsg)
return err
// PEEKSIGINFO is unimplemented but seems to have no users anywhere.
diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go
index cef1276ec..609ad3941 100644
--- a/pkg/sentry/kernel/ptrace_amd64.go
+++ b/pkg/sentry/kernel/ptrace_amd64.go
@@ -30,7 +30,7 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro
if err != nil {
return err
}
- _, err = t.CopyOut(data, n)
+ _, err = n.CopyOut(t, data)
return err
case linux.PTRACE_POKEUSR: // aka PTRACE_POKEUSER
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 413111faf..332bdb8e8 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -348,6 +348,16 @@ func (s *SyscallTable) LookupName(sysno uintptr) string {
return fmt.Sprintf("sys_%d", sysno) // Unlikely.
}
+// LookupNo looks up a syscall number by name.
+func (s *SyscallTable) LookupNo(name string) (uintptr, error) {
+ for i, syscall := range s.Table {
+ if syscall.Name == name {
+ return uintptr(i), nil
+ }
+ }
+ return 0, fmt.Errorf("syscall %q not found", name)
+}
+
// LookupEmulate looks up an emulation syscall number.
func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) {
sysno, ok := s.Emulate[addr]
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 9d7a9128f..fce1064a7 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -341,12 +341,12 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
nt.SetClearTID(opts.ChildTID)
}
if opts.ChildSetTID {
- // Can't use Task.CopyOut, which assumes AddressSpaceActive.
- usermem.CopyObjectOut(t, nt.MemoryManager(), opts.ChildTID, nt.ThreadID(), usermem.IOOpts{})
+ ctid := nt.ThreadID()
+ ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID)
}
ntid := t.tg.pidns.IDOfTask(nt)
if opts.ParentSetTID {
- t.CopyOut(opts.ParentTID, ntid)
+ ntid.CopyOut(t, opts.ParentTID)
}
kind := ptraceCloneKindClone
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index b76f7f503..b400a8b41 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -248,7 +248,8 @@ func (*runExitMain) execute(t *Task) taskRunState {
signaled := t.tg.exiting && t.tg.exitStatus.Signaled()
t.tg.signalHandlers.mu.Unlock()
if !signaled {
- if _, err := t.CopyOut(t.cleartid, ThreadID(0)); err == nil {
+ zero := ThreadID(0)
+ if _, err := zero.CopyOut(t, t.cleartid); err == nil {
t.Futex().Wake(t, t.cleartid, false, ^uint32(0), 1)
}
// If the CopyOut fails, there's nothing we can do.
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
index 4b535c949..c80391475 100644
--- a/pkg/sentry/kernel/task_futex.go
+++ b/pkg/sentry/kernel/task_futex.go
@@ -16,6 +16,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -87,7 +88,7 @@ func (t *Task) exitRobustList() {
return
}
- next := rl.List
+ next := primitive.Uint64(rl.List)
done := 0
var pendingLockAddr usermem.Addr
if rl.ListOpPending != 0 {
@@ -99,12 +100,12 @@ func (t *Task) exitRobustList() {
// We traverse to the next element of the list before we
// actually wake anything. This prevents the race where waking
// this futex causes a modification of the list.
- thisLockAddr := usermem.Addr(next + rl.FutexOffset)
+ thisLockAddr := usermem.Addr(uint64(next) + rl.FutexOffset)
// Try to decode the next element in the list before waking the
// current futex. But don't check the error until after we've
// woken the current futex. Linux does it in this order too
- _, nextErr := t.CopyIn(usermem.Addr(next), &next)
+ _, nextErr := next.CopyIn(t, usermem.Addr(next))
// Wakeup the current futex if it's not pending.
if thisLockAddr != pendingLockAddr {
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index aa3a573c0..8dc3fec90 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -141,7 +141,7 @@ func (*runApp) handleCPUIDInstruction(t *Task) error {
region := trace.StartRegion(t.traceContext, cpuidRegion)
expected := arch.CPUIDInstruction[:]
found := make([]byte, len(expected))
- _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
+ _, err := t.CopyInBytes(usermem.Addr(t.Arch().IP()), found)
if err == nil && bytes.Equal(expected, found) {
// Skip the cpuid instruction.
t.Arch().CPUIDEmulate(t)
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index 2dbf86547..0141459e7 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -287,7 +288,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState {
// Grab the caller up front, to make sure there's a sensible stack.
caller := t.Arch().Native(uintptr(0))
- if _, err := t.CopyIn(usermem.Addr(t.Arch().Stack()), caller); err != nil {
+ if _, err := caller.CopyIn(t, usermem.Addr(t.Arch().Stack())); err != nil {
t.Debugf("vsyscall %d: error reading return address from stack: %v", sysno, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
@@ -323,7 +324,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState {
type runVsyscallAfterPtraceEventSeccomp struct {
addr usermem.Addr
sysno uintptr
- caller interface{}
+ caller marshal.Marshallable
}
func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState {
@@ -346,7 +347,7 @@ func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState {
return t.doVsyscallInvoke(sysno, t.Arch().SyscallArgs(), r.caller)
}
-func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller interface{}) taskRunState {
+func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller marshal.Marshallable) taskRunState {
rval, ctrl, err := t.executeSyscall(sysno, args)
if ctrl != nil {
t.Debugf("vsyscall %d, caller %x: syscall control: %v", sysno, t.Arch().Value(caller), ctrl)
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index 0cb86e390..ce134bf54 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -18,6 +18,7 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -43,17 +44,6 @@ func (t *Task) Deactivate() {
}
}
-// CopyIn copies a fixed-size value or slice of fixed-size values in from the
-// task's memory. The copy will fail with syscall.EFAULT if it traverses user
-// memory that is unmapped or not readable by the user.
-//
-// This Task's AddressSpace must be active.
-func (t *Task) CopyIn(addr usermem.Addr, dst interface{}) (int, error) {
- return usermem.CopyObjectIn(t, t.MemoryManager(), addr, dst, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-}
-
// CopyInBytes is a fast version of CopyIn if the caller can serialize the
// data without reflection and pass in a byte slice.
//
@@ -64,17 +54,6 @@ func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
})
}
-// CopyOut copies a fixed-size value or slice of fixed-size values out to the
-// task's memory. The copy will fail with syscall.EFAULT if it traverses user
-// memory that is unmapped or not writeable by the user.
-//
-// This Task's AddressSpace must be active.
-func (t *Task) CopyOut(addr usermem.Addr, src interface{}) (int, error) {
- return usermem.CopyObjectOut(t, t.MemoryManager(), addr, src, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-}
-
// CopyOutBytes is a fast version of CopyOut if the caller can serialize the
// data without reflection and pass in a byte slice.
//
@@ -114,7 +93,7 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([
var v []string
for {
argAddr := t.Arch().Native(0)
- if _, err := t.CopyIn(addr, argAddr); err != nil {
+ if _, err := argAddr.CopyIn(t, addr); err != nil {
return v, err
}
if t.Arch().Value(argAddr) == 0 {
@@ -302,29 +281,29 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp
}, nil
}
-// CopyContextWithOpts wraps a task to allow copying memory to and from the
-// task memory with user specified usermem.IOOpts.
-type CopyContextWithOpts struct {
+// copyContext implements marshal.CopyContext. It wraps a task to allow copying
+// memory to and from the task memory with custom usermem.IOOpts.
+type copyContext struct {
*Task
opts usermem.IOOpts
}
-// AsCopyContextWithOpts wraps the task and returns it as CopyContextWithOpts.
-func (t *Task) AsCopyContextWithOpts(opts usermem.IOOpts) *CopyContextWithOpts {
- return &CopyContextWithOpts{t, opts}
+// AsCopyContext wraps the task and returns it as CopyContext.
+func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext {
+ return &copyContext{t, opts}
}
// CopyInString copies a string in from the task's memory.
-func (t *CopyContextWithOpts) CopyInString(addr usermem.Addr, maxLen int) (string, error) {
+func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) {
return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts)
}
// CopyInBytes copies task memory into dst from an IO context.
-func (t *CopyContextWithOpts) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
return t.MemoryManager().CopyIn(t, addr, dst, t.opts)
}
// CopyOutBytes copies src into task memoryfrom an IO context.
-func (t *CopyContextWithOpts) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
return t.MemoryManager().CopyOut(t, addr, src, t.opts)
}
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 872e1a82d..5ae5906e8 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -36,6 +36,8 @@ import (
const TasksLimit = (1 << 16)
// ThreadID is a generic thread identifier.
+//
+// +marshal
type ThreadID int32
// String returns a decimal representation of the ThreadID.
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 84b699f0d..020bbda79 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -201,7 +201,7 @@ func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFActi
seccomp.RuleSet{
Rules: seccomp.SyscallRules{
syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+ {seccomp.EqualTo(linux.ARCH_SET_CPUID), seccomp.EqualTo(0)},
},
},
Action: linux.SECCOMP_RET_ALLOW,
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index 2ce528601..8548853da 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -80,9 +80,9 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
Rules: seccomp.SyscallRules{
syscall.SYS_CLONE: []seccomp.Rule{
// Allow creation of new subprocesses (used by the master).
- {seccomp.AllowValue(syscall.CLONE_FILES | syscall.SIGKILL)},
+ {seccomp.EqualTo(syscall.CLONE_FILES | syscall.SIGKILL)},
// Allow creation of new threads within a single address space (used by addresss spaces).
- {seccomp.AllowValue(
+ {seccomp.EqualTo(
syscall.CLONE_FILES |
syscall.CLONE_FS |
syscall.CLONE_SIGHAND |
@@ -97,14 +97,14 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// For the stub prctl dance (all).
syscall.SYS_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(syscall.PR_SET_PDEATHSIG), seccomp.AllowValue(syscall.SIGKILL)},
+ {seccomp.EqualTo(syscall.PR_SET_PDEATHSIG), seccomp.EqualTo(syscall.SIGKILL)},
},
syscall.SYS_GETPPID: {},
// For the stub to stop itself (all).
syscall.SYS_GETPID: {},
syscall.SYS_KILL: []seccomp.Rule{
- {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SIGSTOP)},
+ {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SIGSTOP)},
},
// Injected to support the address space operations.
@@ -115,7 +115,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
})
}
rules = appendArchSeccompRules(rules, defaultAction)
- instrs, err := seccomp.BuildProgram(rules, defaultAction)
+ instrs, err := seccomp.BuildProgram(rules, defaultAction, defaultAction)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 5125cd4fb..5f63cbd45 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -27,7 +27,9 @@
// ERET returns using the ELR and SPSR for the current exception level.
#define ERET() \
- WORD $0xd69f03e0
+ WORD $0xd69f03e0; \
+ DSB $7; \
+ ISB $15;
// RSV_REG is a register that holds el1 information temporarily.
#define RSV_REG R18_PLATFORM
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index c0fd3425b..a3f775d15 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/marshal",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
@@ -20,6 +21,5 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index e76e498de..632e33452 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -21,6 +21,8 @@ go_library(
"//pkg/context",
"//pkg/fdnotifier",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
@@ -40,8 +42,6 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 242e6bf76..7d3c4a01c 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -24,6 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -36,8 +38,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const (
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 0546801bf..1f926aa91 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -16,6 +16,8 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
@@ -36,8 +38,6 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 68a9b9a96..5ddcd4be5 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -38,8 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const sizeOfInt32 int = 4
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 1fb777a6c..fae3b6783 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -22,6 +22,8 @@ go_library(
"//pkg/binary",
"//pkg/context",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -51,8 +53,6 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 91790834b..2e568bc3d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -40,6 +40,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -62,8 +64,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 0f342e655..c0212ad76 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -18,6 +18,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
@@ -29,8 +31,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// SocketVFS2 encapsulates all the state needed to represent a network stack
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 04b259d27..fd31479e5 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -35,7 +36,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// ControlMessages represents the union of unix control messages and tcpip
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cb953e4dc..a89583dad 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -29,6 +29,7 @@ go_library(
"//pkg/context",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/marshal",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -49,6 +50,5 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 616530eb6..917055cea 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -39,7 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketOperations is a Unix socket. It is similar to a netstack socket,
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index e25c7e84a..3688f22d2 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
@@ -32,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketVFS2 implements socket.SocketVFS2 (and by extension,
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 88d5db9fc..a920180d3 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -28,6 +28,7 @@ go_library(
"//pkg/binary",
"//pkg/bits",
"//pkg/eventchannel",
+ "//pkg/marshal/primitive",
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
index 5d51a7792..ae3b998c8 100644
--- a/pkg/sentry/strace/epoll.go
+++ b/pkg/sentry/strace/epoll.go
@@ -26,7 +26,7 @@ import (
func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
var e linux.EpollEvent
- if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ if _, err := e.CopyIn(t, eventAddr); err != nil {
return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
}
var sb strings.Builder
@@ -41,7 +41,7 @@ func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes ui
addr := eventsAddr
for i := uint64(0); i < numEvents; i++ {
var e linux.EpollEvent
- if _, err := t.CopyIn(addr, &e); err != nil {
+ if _, err := e.CopyIn(t, addr); err != nil {
fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err)
continue
}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 08e97e6c4..cc5f70cd4 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
@@ -166,7 +167,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
}
buf := make([]byte, length)
- if _, err := t.CopyIn(addr, &buf); err != nil {
+ if _, err := t.CopyInBytes(addr, buf); err != nil {
return fmt.Sprintf("%#x (error decoding control: %v)", addr, err)
}
@@ -302,7 +303,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint64) string {
var msg slinux.MessageHeader64
- if err := slinux.CopyInMessageHeader64(t, addr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding msghdr: %v)", addr, err)
}
s := fmt.Sprintf(
@@ -380,9 +381,9 @@ func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) str
func copySockLen(t *kernel.Task, addr usermem.Addr) (uint32, error) {
// socklen_t is 32-bits.
- var l uint32
- _, err := t.CopyIn(addr, &l)
- return l, err
+ var l primitive.Uint32
+ _, err := l.CopyIn(t, addr)
+ return uint32(l), err
}
func sockLenPointer(t *kernel.Task, addr usermem.Addr) string {
@@ -436,22 +437,22 @@ func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, o
func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string {
switch optLen {
case 1:
- var v uint8
- _, err := t.CopyIn(optVal, &v)
+ var v primitive.Uint8
+ _, err := v.CopyIn(t, optVal)
if err != nil {
return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
}
return fmt.Sprintf("%#x {value=%v}", optVal, v)
case 2:
- var v uint16
- _, err := t.CopyIn(optVal, &v)
+ var v primitive.Uint16
+ _, err := v.CopyIn(t, optVal)
if err != nil {
return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
}
return fmt.Sprintf("%#x {value=%v}", optVal, v)
case 4:
- var v uint32
- _, err := t.CopyIn(optVal, &v)
+ var v primitive.Uint32
+ _, err := v.CopyIn(t, optVal)
if err != nil {
return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 87b239730..52281ccc2 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -21,13 +21,13 @@ import (
"fmt"
"strconv"
"strings"
- "syscall"
"time"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -91,7 +91,7 @@ func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, ma
}
b := make([]byte, size)
- amt, err := t.CopyIn(ar.Start, b)
+ amt, err := t.CopyInBytes(ar.Start, b)
if err != nil {
iovs[i] = fmt.Sprintf("{base=%#x, len=%d, %q..., error decoding string: %v}", ar.Start, ar.Length(), b[:amt], err)
continue
@@ -118,7 +118,7 @@ func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) st
}
b := make([]byte, size)
- amt, err := t.CopyIn(addr, b)
+ amt, err := t.CopyInBytes(addr, b)
if err != nil {
return fmt.Sprintf("%#x (error decoding string: %s)", addr, err)
}
@@ -199,7 +199,7 @@ func fdVFS2(t *kernel.Task, fd int32) string {
func fdpair(t *kernel.Task, addr usermem.Addr) string {
var fds [2]int32
- _, err := t.CopyIn(addr, &fds)
+ _, err := primitive.CopyInt32SliceIn(t, addr, fds[:])
if err != nil {
return fmt.Sprintf("%#x (error decoding fds: %s)", addr, err)
}
@@ -209,7 +209,7 @@ func fdpair(t *kernel.Task, addr usermem.Addr) string {
func uname(t *kernel.Task, addr usermem.Addr) string {
var u linux.UtsName
- if _, err := t.CopyIn(addr, &u); err != nil {
+ if _, err := u.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding utsname: %s)", addr, err)
}
@@ -222,7 +222,7 @@ func utimensTimespec(t *kernel.Task, addr usermem.Addr) string {
}
var tim linux.Timespec
- if _, err := t.CopyIn(addr, &tim); err != nil {
+ if _, err := tim.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err)
}
@@ -244,7 +244,7 @@ func timespec(t *kernel.Task, addr usermem.Addr) string {
}
var tim linux.Timespec
- if _, err := t.CopyIn(addr, &tim); err != nil {
+ if _, err := tim.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err)
}
return fmt.Sprintf("%#x {sec=%v nsec=%v}", addr, tim.Sec, tim.Nsec)
@@ -256,7 +256,7 @@ func timeval(t *kernel.Task, addr usermem.Addr) string {
}
var tim linux.Timeval
- if _, err := t.CopyIn(addr, &tim); err != nil {
+ if _, err := tim.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding timeval: %s)", addr, err)
}
@@ -268,8 +268,8 @@ func utimbuf(t *kernel.Task, addr usermem.Addr) string {
return "null"
}
- var utim syscall.Utimbuf
- if _, err := t.CopyIn(addr, &utim); err != nil {
+ var utim linux.Utime
+ if _, err := utim.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding utimbuf: %s)", addr, err)
}
@@ -282,7 +282,7 @@ func stat(t *kernel.Task, addr usermem.Addr) string {
}
var stat linux.Stat
- if _, err := t.CopyIn(addr, &stat); err != nil {
+ if _, err := stat.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding stat: %s)", addr, err)
}
return fmt.Sprintf("%#x {dev=%d, ino=%d, mode=%s, nlink=%d, uid=%d, gid=%d, rdev=%d, size=%d, blksize=%d, blocks=%d, atime=%s, mtime=%s, ctime=%s}", addr, stat.Dev, stat.Ino, linux.FileMode(stat.Mode), stat.Nlink, stat.UID, stat.GID, stat.Rdev, stat.Size, stat.Blksize, stat.Blocks, time.Unix(stat.ATime.Sec, stat.ATime.Nsec), time.Unix(stat.MTime.Sec, stat.MTime.Nsec), time.Unix(stat.CTime.Sec, stat.CTime.Nsec))
@@ -330,7 +330,7 @@ func rusage(t *kernel.Task, addr usermem.Addr) string {
}
var ru linux.Rusage
- if _, err := t.CopyIn(addr, &ru); err != nil {
+ if _, err := ru.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding rusage: %s)", addr, err)
}
return fmt.Sprintf("%#x %+v", addr, ru)
@@ -342,7 +342,7 @@ func capHeader(t *kernel.Task, addr usermem.Addr) string {
}
var hdr linux.CapUserHeader
- if _, err := t.CopyIn(addr, &hdr); err != nil {
+ if _, err := hdr.CopyIn(t, addr); err != nil {
return fmt.Sprintf("%#x (error decoding header: %s)", addr, err)
}
@@ -367,7 +367,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string {
}
var hdr linux.CapUserHeader
- if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyIn(t, hdrAddr); err != nil {
return fmt.Sprintf("%#x (error decoding header: %v)", dataAddr, err)
}
@@ -376,7 +376,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string {
switch hdr.Version {
case linux.LINUX_CAPABILITY_VERSION_1:
var data linux.CapUserData
- if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ if _, err := data.CopyIn(t, dataAddr); err != nil {
return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err)
}
p = uint64(data.Permitted)
@@ -384,7 +384,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string {
e = uint64(data.Effective)
case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
var data [2]linux.CapUserData
- if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil {
return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err)
}
p = uint64(data[0].Permitted) | (uint64(data[1].Permitted) << 32)
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 4a9b04fd0..75752b2e6 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -56,6 +56,7 @@ go_library(
"sys_xattr.go",
"timespec.go",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
"//pkg/abi",
@@ -64,6 +65,8 @@ go_library(
"//pkg/bpf",
"//pkg/context",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/rand",
"//pkg/safemem",
@@ -99,7 +102,5 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index da6bd85e1..5f26697d2 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -138,7 +138,7 @@ var AMD64 = &kernel.SyscallTable{
83: syscalls.Supported("mkdir", Mkdir),
84: syscalls.Supported("rmdir", Rmdir),
85: syscalls.Supported("creat", Creat),
- 86: syscalls.Supported("link", Link),
+ 86: syscalls.PartiallySupported("link", Link, "Limited support with Gofer. Link count and linked files may get out of sync because gVisor is not aware of external hardlinks.", nil),
87: syscalls.Supported("unlink", Unlink),
88: syscalls.Supported("symlink", Symlink),
89: syscalls.Supported("readlink", Readlink),
@@ -317,7 +317,7 @@ var AMD64 = &kernel.SyscallTable{
262: syscalls.Supported("fstatat", Fstatat),
263: syscalls.Supported("unlinkat", Unlinkat),
264: syscalls.Supported("renameat", Renameat),
- 265: syscalls.Supported("linkat", Linkat),
+ 265: syscalls.PartiallySupported("linkat", Linkat, "See link(2).", nil),
266: syscalls.Supported("symlinkat", Symlinkat),
267: syscalls.Supported("readlinkat", Readlinkat),
268: syscalls.Supported("fchmodat", Fchmodat),
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index e9d64dec5..0bf313a13 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -17,6 +17,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -36,7 +37,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
//
// The context pointer _must_ be zero initially.
var idIn uint64
- if _, err := t.CopyIn(idAddr, &idIn); err != nil {
+ if _, err := primitive.CopyUint64In(t, idAddr, &idIn); err != nil {
return 0, nil, err
}
if idIn != 0 {
@@ -49,7 +50,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
// Copy out the new ID.
- if _, err := t.CopyOut(idAddr, &id); err != nil {
+ if _, err := primitive.CopyUint64Out(t, idAddr, id); err != nil {
t.MemoryManager().DestroyAIOContext(t, id)
return 0, nil, err
}
@@ -142,7 +143,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
ev := v.(*linux.IOEvent)
// Copy out the result.
- if _, err := t.CopyOut(eventsAddr, ev); err != nil {
+ if _, err := ev.CopyOut(t, eventsAddr); err != nil {
if count > 0 {
return uintptr(count), nil, nil
}
@@ -338,21 +339,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
for i := int32(0); i < nrEvents; i++ {
- // Copy in the address.
- cbAddrNative := t.Arch().Native(0)
- if _, err := t.CopyIn(addr, cbAddrNative); err != nil {
- if i > 0 {
- // Some successful.
- return uintptr(i), nil, nil
+ // Copy in the callback address.
+ var cbAddr usermem.Addr
+ switch t.Arch().Width() {
+ case 8:
+ var cbAddrP primitive.Uint64
+ if _, err := cbAddrP.CopyIn(t, addr); err != nil {
+ if i > 0 {
+ // Some successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
}
- // Nothing done.
- return 0, nil, err
+ cbAddr = usermem.Addr(cbAddrP)
+ default:
+ return 0, nil, syserror.ENOSYS
}
// Copy in this callback.
var cb linux.IOCallback
- cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
- if _, err := t.CopyIn(cbAddr, &cb); err != nil {
+ if _, err := cb.CopyIn(t, cbAddr); err != nil {
if i > 0 {
// Some have been successful.
diff --git a/pkg/sentry/syscalls/linux/sys_capability.go b/pkg/sentry/syscalls/linux/sys_capability.go
index adf5ea5f2..d3b85e11b 100644
--- a/pkg/sentry/syscalls/linux/sys_capability.go
+++ b/pkg/sentry/syscalls/linux/sys_capability.go
@@ -45,7 +45,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
dataAddr := args[1].Pointer()
var hdr linux.CapUserHeader
- if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyIn(t, hdrAddr); err != nil {
return 0, nil, err
}
// hdr.Pid doesn't need to be valid if this capget() is a "version probe"
@@ -65,7 +65,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
Permitted: uint32(p),
Inheritable: uint32(i),
}
- _, err = t.CopyOut(dataAddr, &data)
+ _, err = data.CopyOut(t, dataAddr)
return 0, nil, err
case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3:
@@ -88,12 +88,12 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
Inheritable: uint32(i >> 32),
},
}
- _, err = t.CopyOut(dataAddr, &data)
+ _, err = linux.CopyCapUserDataSliceOut(t, dataAddr, data[:])
return 0, nil, err
default:
hdr.Version = linux.HighestCapabilityVersion
- if _, err := t.CopyOut(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyOut(t, hdrAddr); err != nil {
return 0, nil, err
}
if dataAddr != 0 {
@@ -109,7 +109,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
dataAddr := args[1].Pointer()
var hdr linux.CapUserHeader
- if _, err := t.CopyIn(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyIn(t, hdrAddr); err != nil {
return 0, nil, err
}
switch hdr.Version {
@@ -118,7 +118,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.EPERM
}
var data linux.CapUserData
- if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ if _, err := data.CopyIn(t, dataAddr); err != nil {
return 0, nil, err
}
p := auth.CapabilitySet(data.Permitted) & auth.AllCapabilities
@@ -131,7 +131,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.EPERM
}
var data [2]linux.CapUserData
- if _, err := t.CopyIn(dataAddr, &data); err != nil {
+ if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil {
return 0, nil, err
}
p := (auth.CapabilitySet(data[0].Permitted) | (auth.CapabilitySet(data[1].Permitted) << 32)) & auth.AllCapabilities
@@ -141,7 +141,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
default:
hdr.Version = linux.HighestCapabilityVersion
- if _, err := t.CopyOut(hdrAddr, &hdr); err != nil {
+ if _, err := hdr.CopyOut(t, hdrAddr); err != nil {
return 0, nil, err
}
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 07c77e442..98331eb3c 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
@@ -613,7 +614,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FIONBIO:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
flags := file.Flags()
@@ -627,7 +628,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FIOASYNC:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
flags := file.Flags()
@@ -641,15 +642,14 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FIOSETOWN, linux.SIOCSPGRP:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
fSetOwn(t, file, set)
return 0, nil, nil
case linux.FIOGETOWN, linux.SIOCGPGRP:
- who := fGetOwn(t, file)
- _, err := t.CopyOut(args[2].Pointer(), &who)
+ _, err := primitive.CopyInt32Out(t, args[2].Pointer(), fGetOwn(t, file))
return 0, nil, err
default:
@@ -694,7 +694,7 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
// Top it off with a terminator.
- _, err = t.CopyOut(addr+usermem.Addr(bytes), []byte("\x00"))
+ _, err = t.CopyOutBytes(addr+usermem.Addr(bytes), []byte("\x00"))
return uintptr(bytes + 1), nil, err
}
@@ -962,7 +962,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Copy in the lock request.
flockAddr := args[2].Pointer()
var flock linux.Flock
- if _, err := t.CopyIn(flockAddr, &flock); err != nil {
+ if _, err := flock.CopyIn(t, flockAddr); err != nil {
return 0, nil, err
}
@@ -1052,12 +1052,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_GETOWN_EX:
addr := args[2].Pointer()
owner := fGetOwnEx(t, file)
- _, err := t.CopyOut(addr, &owner)
+ _, err := owner.CopyOut(t, addr)
return 0, nil, err
case linux.F_SETOWN_EX:
addr := args[2].Pointer()
var owner linux.FOwnerEx
- _, err := t.CopyIn(addr, &owner)
+ _, err := owner.CopyIn(t, addr)
if err != nil {
return 0, nil, err
}
@@ -1922,7 +1922,7 @@ func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ts := defaultSetToSystemTimeSpec()
if timesAddr != 0 {
var times linux.Utime
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := times.CopyIn(t, timesAddr); err != nil {
return 0, nil, err
}
ts = fs.TimeSpec{
@@ -1942,7 +1942,7 @@ func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
ts := defaultSetToSystemTimeSpec()
if timesAddr != 0 {
var times [2]linux.Timeval
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil {
return 0, nil, err
}
ts = fs.TimeSpec{
@@ -1970,7 +1970,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
ts := defaultSetToSystemTimeSpec()
if timesAddr != 0 {
var times [2]linux.Timespec
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil {
return 0, nil, err
}
if !timespecIsValid(times[0]) || !timespecIsValid(times[1]) {
@@ -2004,7 +2004,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
ts := defaultSetToSystemTimeSpec()
if timesAddr != 0 {
var times [2]linux.Timeval
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil {
return 0, nil, err
}
if times[0].Usec >= 1e6 || times[0].Usec < 0 ||
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index 12b2fa690..f39ce0639 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -306,8 +306,8 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
// Despite the syscall using the name 'pid' for this variable, it is
// very much a tid.
tid := args[0].Int()
- head := args[1].Pointer()
- size := args[2].Pointer()
+ headAddr := args[1].Pointer()
+ sizeAddr := args[2].Pointer()
if tid < 0 {
return 0, nil, syserror.EINVAL
@@ -321,12 +321,16 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
}
// Copy out head pointer.
- if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil {
+ head := t.Arch().Native(uintptr(ot.GetRobustList()))
+ if _, err := head.CopyOut(t, headAddr); err != nil {
return 0, nil, err
}
- // Copy out size, which is a constant.
- if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil {
+ // Copy out size, which is a constant. Note that while size isn't
+ // an address, it is defined as the arch-dependent size_t, so it
+ // needs to be converted to a native-sized int.
+ size := t.Arch().Native(uintptr(linux.SizeOfRobustListHead))
+ if _, err := size.CopyOut(t, sizeAddr); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index 59004cefe..b25f7d881 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -19,7 +19,6 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -93,19 +92,23 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir
}
}
-// oldDirentHdr is a fixed sized header matching the fixed size
-// fields found in the old linux dirent struct.
+// oldDirentHdr is a fixed sized header matching the fixed size fields found in
+// the old linux dirent struct.
+//
+// +marshal
type oldDirentHdr struct {
Ino uint64
Off uint64
- Reclen uint16
+ Reclen uint16 `marshal:"unaligned"` // Struct ends mid-word.
}
-// direntHdr is a fixed sized header matching the fixed size
-// fields found in the new linux dirent struct.
+// direntHdr is a fixed sized header matching the fixed size fields found in the
+// new linux dirent struct.
+//
+// +marshal
type direntHdr struct {
OldHdr oldDirentHdr
- Typ uint8
+ Typ uint8 `marshal:"unaligned"` // Struct ends mid-word.
}
// dirent contains the data pointed to by a new linux dirent struct.
@@ -134,20 +137,20 @@ func newDirent(width uint, name string, attr fs.DentAttr, offset uint64) *dirent
// the old linux dirent format.
func smallestDirent(a arch.Context) uint {
d := dirent{}
- return uint(binary.Size(d.Hdr.OldHdr)) + a.Width() + 1
+ return uint(d.Hdr.OldHdr.SizeBytes()) + a.Width() + 1
}
// smallestDirent64 returns the size of the smallest possible dirent using
// the new linux dirent format.
func smallestDirent64(a arch.Context) uint {
d := dirent{}
- return uint(binary.Size(d.Hdr)) + a.Width()
+ return uint(d.Hdr.SizeBytes()) + a.Width()
}
// padRec pads the name field until the rec length is a multiple of the width,
// which must be a power of 2. It returns the padded rec length.
func (d *dirent) padRec(width int) uint16 {
- a := int(binary.Size(d.Hdr)) + len(d.Name)
+ a := d.Hdr.SizeBytes() + len(d.Name)
r := (a + width) &^ (width - 1)
padding := r - a
d.Name = append(d.Name, make([]byte, padding)...)
@@ -157,7 +160,7 @@ func (d *dirent) padRec(width int) uint16 {
// Serialize64 serializes a Dirent struct to a byte slice, keeping the new
// linux dirent format. Returns the number of bytes serialized or an error.
func (d *dirent) Serialize64(w io.Writer) (int, error) {
- n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr))
+ n1, err := d.Hdr.WriteTo(w)
if err != nil {
return 0, err
}
@@ -165,14 +168,14 @@ func (d *dirent) Serialize64(w io.Writer) (int, error) {
if err != nil {
return 0, err
}
- return n1 + n2, nil
+ return int(n1) + n2, nil
}
// Serialize serializes a Dirent struct to a byte slice, using the old linux
// dirent format.
// Returns the number of bytes serialized or an error.
func (d *dirent) Serialize(w io.Writer) (int, error) {
- n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr.OldHdr))
+ n1, err := d.Hdr.OldHdr.WriteTo(w)
if err != nil {
return 0, err
}
@@ -184,7 +187,7 @@ func (d *dirent) Serialize(w io.Writer) (int, error) {
if err != nil {
return 0, err
}
- return n1 + n2 + n3, nil
+ return int(n1) + n2 + n3, nil
}
// direntSerializer implements fs.InodeOperationsInfoSerializer, serializing dirents to an
diff --git a/pkg/sentry/syscalls/linux/sys_identity.go b/pkg/sentry/syscalls/linux/sys_identity.go
index 715ac45e6..a29d307e5 100644
--- a/pkg/sentry/syscalls/linux/sys_identity.go
+++ b/pkg/sentry/syscalls/linux/sys_identity.go
@@ -49,13 +49,13 @@ func Getresuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
ruid := c.RealKUID.In(c.UserNamespace).OrOverflow()
euid := c.EffectiveKUID.In(c.UserNamespace).OrOverflow()
suid := c.SavedKUID.In(c.UserNamespace).OrOverflow()
- if _, err := t.CopyOut(ruidAddr, ruid); err != nil {
+ if _, err := ruid.CopyOut(t, ruidAddr); err != nil {
return 0, nil, err
}
- if _, err := t.CopyOut(euidAddr, euid); err != nil {
+ if _, err := euid.CopyOut(t, euidAddr); err != nil {
return 0, nil, err
}
- if _, err := t.CopyOut(suidAddr, suid); err != nil {
+ if _, err := suid.CopyOut(t, suidAddr); err != nil {
return 0, nil, err
}
return 0, nil, nil
@@ -84,13 +84,13 @@ func Getresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
rgid := c.RealKGID.In(c.UserNamespace).OrOverflow()
egid := c.EffectiveKGID.In(c.UserNamespace).OrOverflow()
sgid := c.SavedKGID.In(c.UserNamespace).OrOverflow()
- if _, err := t.CopyOut(rgidAddr, rgid); err != nil {
+ if _, err := rgid.CopyOut(t, rgidAddr); err != nil {
return 0, nil, err
}
- if _, err := t.CopyOut(egidAddr, egid); err != nil {
+ if _, err := egid.CopyOut(t, egidAddr); err != nil {
return 0, nil, err
}
- if _, err := t.CopyOut(sgidAddr, sgid); err != nil {
+ if _, err := sgid.CopyOut(t, sgidAddr); err != nil {
return 0, nil, err
}
return 0, nil, nil
@@ -157,7 +157,7 @@ func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
for i, kgid := range kgids {
gids[i] = kgid.In(t.UserNamespace()).OrOverflow()
}
- if _, err := t.CopyOut(args[1].Pointer(), gids); err != nil {
+ if _, err := auth.CopyGIDSliceOut(t, args[1].Pointer(), gids); err != nil {
return 0, nil, err
}
return uintptr(len(gids)), nil, nil
@@ -173,7 +173,7 @@ func Setgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, t.SetExtraGIDs(nil)
}
gids := make([]auth.GID, size)
- if _, err := t.CopyIn(args[1].Pointer(), &gids); err != nil {
+ if _, err := auth.CopyGIDSliceIn(t, args[1].Pointer(), gids); err != nil {
return 0, nil, err
}
return 0, nil, t.SetExtraGIDs(gids)
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index d0109baa4..8ab062bca 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -239,7 +239,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return 0, nil, syserror.ENOMEM
}
resident := bytes.Repeat([]byte{1}, int(mapped/usermem.PageSize))
- _, err := t.CopyOut(vec, resident)
+ _, err := t.CopyOutBytes(vec, resident)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index c55beb39b..849a47476 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -46,7 +47,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
return 0, err
}
- if _, err := t.CopyOut(addr, fds); err != nil {
+ if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil {
for _, fd := range fds {
if file, _ := t.FDTable().Remove(t, fd); file != nil {
file.DecRef(t)
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
index 789e2ed5b..254f4c9f9 100644
--- a/pkg/sentry/syscalls/linux/sys_poll.go
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -162,7 +162,7 @@ func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD
pfd := make([]linux.PollFD, nfds)
if nfds > 0 {
- if _, err := t.CopyIn(addr, &pfd); err != nil {
+ if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil {
return nil, err
}
}
@@ -189,7 +189,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration)
// The poll entries are copied out regardless of whether
// any are set or not. This aligns with the Linux behavior.
if nfds > 0 && err == nil {
- if _, err := t.CopyOut(addr, pfd); err != nil {
+ if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil {
return remainingTimeout, 0, err
}
}
@@ -202,7 +202,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy
set := make([]byte, nBytes)
if addr != 0 {
- if _, err := t.CopyIn(addr, &set); err != nil {
+ if _, err := t.CopyInBytes(addr, set); err != nil {
return nil, err
}
// If we only use part of the last byte, mask out the extraneous bits.
@@ -329,19 +329,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
// Copy updated vectors back.
if readFDs != 0 {
- if _, err := t.CopyOut(readFDs, r); err != nil {
+ if _, err := t.CopyOutBytes(readFDs, r); err != nil {
return 0, err
}
}
if writeFDs != 0 {
- if _, err := t.CopyOut(writeFDs, w); err != nil {
+ if _, err := t.CopyOutBytes(writeFDs, w); err != nil {
return 0, err
}
}
if exceptFDs != 0 {
- if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ if _, err := t.CopyOutBytes(exceptFDs, e); err != nil {
return 0, err
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index 64a725296..a892d2c62 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
@@ -43,7 +44,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, nil
case linux.PR_GET_PDEATHSIG:
- _, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal()))
+ _, err := primitive.CopyInt32Out(t, args[1].Pointer(), int32(t.ParentDeathSignal()))
return 0, nil, err
case linux.PR_GET_DUMPABLE:
@@ -110,7 +111,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
buf[len] = 0
len++
}
- _, err := t.CopyOut(addr, buf[:len])
+ _, err := t.CopyOutBytes(addr, buf[:len])
if err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go
index d5d5b6959..309c183a3 100644
--- a/pkg/sentry/syscalls/linux/sys_rlimit.go
+++ b/pkg/sentry/syscalls/linux/sys_rlimit.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/limits"
@@ -26,17 +27,13 @@ import (
// rlimit describes an implementation of 'struct rlimit', which may vary from
// system-to-system.
type rlimit interface {
+ marshal.Marshallable
+
// toLimit converts an rlimit to a limits.Limit.
toLimit() *limits.Limit
// fromLimit converts a limits.Limit to an rlimit.
fromLimit(lim limits.Limit)
-
- // copyIn copies an rlimit from the untrusted app to the kernel.
- copyIn(t *kernel.Task, addr usermem.Addr) error
-
- // copyOut copies an rlimit from the kernel to the untrusted app.
- copyOut(t *kernel.Task, addr usermem.Addr) error
}
// newRlimit returns the appropriate rlimit type for 'struct rlimit' on this system.
@@ -50,6 +47,7 @@ func newRlimit(t *kernel.Task) (rlimit, error) {
}
}
+// +marshal
type rlimit64 struct {
Cur uint64
Max uint64
@@ -70,12 +68,12 @@ func (r *rlimit64) fromLimit(lim limits.Limit) {
}
func (r *rlimit64) copyIn(t *kernel.Task, addr usermem.Addr) error {
- _, err := t.CopyIn(addr, r)
+ _, err := r.CopyIn(t, addr)
return err
}
func (r *rlimit64) copyOut(t *kernel.Task, addr usermem.Addr) error {
- _, err := t.CopyOut(addr, *r)
+ _, err := r.CopyOut(t, addr)
return err
}
@@ -140,7 +138,8 @@ func Getrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, err
}
rlim.fromLimit(lim)
- return 0, nil, rlim.copyOut(t, addr)
+ _, err = rlim.CopyOut(t, addr)
+ return 0, nil, err
}
// Setrlimit implements linux syscall setrlimit(2).
@@ -155,7 +154,7 @@ func Setrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if err != nil {
return 0, nil, err
}
- if err := rlim.copyIn(t, addr); err != nil {
+ if _, err := rlim.CopyIn(t, addr); err != nil {
return 0, nil, syserror.EFAULT
}
_, err = prlimit64(t, resource, rlim.toLimit())
diff --git a/pkg/sentry/syscalls/linux/sys_rusage.go b/pkg/sentry/syscalls/linux/sys_rusage.go
index 1674c7445..ac5c98a54 100644
--- a/pkg/sentry/syscalls/linux/sys_rusage.go
+++ b/pkg/sentry/syscalls/linux/sys_rusage.go
@@ -80,7 +80,7 @@ func Getrusage(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
ru := getrusage(t, which)
- _, err := t.CopyOut(addr, &ru)
+ _, err := ru.CopyOut(t, addr)
return 0, nil, err
}
@@ -104,7 +104,7 @@ func Times(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
CUTime: linux.ClockTFromDuration(cs2.UserTime),
CSTime: linux.ClockTFromDuration(cs2.SysTime),
}
- if _, err := t.CopyOut(addr, &r); err != nil {
+ if _, err := r.CopyOut(t, addr); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go
index 99f6993f5..cd6f4dd94 100644
--- a/pkg/sentry/syscalls/linux/sys_sched.go
+++ b/pkg/sentry/syscalls/linux/sys_sched.go
@@ -27,6 +27,8 @@ const (
)
// SchedParam replicates struct sched_param in sched.h.
+//
+// +marshal
type SchedParam struct {
schedPriority int64
}
@@ -45,7 +47,7 @@ func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.ESRCH
}
r := SchedParam{schedPriority: onlyPriority}
- if _, err := t.CopyOut(param, r); err != nil {
+ if _, err := r.CopyOut(t, param); err != nil {
return 0, nil, err
}
@@ -79,7 +81,7 @@ func SchedSetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ke
return 0, nil, syserror.ESRCH
}
var r SchedParam
- if _, err := t.CopyIn(param, &r); err != nil {
+ if _, err := r.CopyIn(t, param); err != nil {
return 0, nil, syserror.EINVAL
}
if r.schedPriority != onlyPriority {
diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go
index 5b7a66f4d..4fdb4463c 100644
--- a/pkg/sentry/syscalls/linux/sys_seccomp.go
+++ b/pkg/sentry/syscalls/linux/sys_seccomp.go
@@ -24,6 +24,8 @@ import (
)
// userSockFprog is equivalent to Linux's struct sock_fprog on amd64.
+//
+// +marshal
type userSockFprog struct {
// Len is the length of the filter in BPF instructions.
Len uint16
@@ -33,7 +35,7 @@ type userSockFprog struct {
// Filter is a user pointer to the struct sock_filter array that makes up
// the filter program. Filter is a uint64 rather than a usermem.Addr
// because usermem.Addr is actually uintptr, which is not a fixed-size
- // type, and encoding/binary.Read objects to this.
+ // type.
Filter uint64
}
@@ -54,11 +56,11 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error {
}
var fprog userSockFprog
- if _, err := t.CopyIn(addr, &fprog); err != nil {
+ if _, err := fprog.CopyIn(t, addr); err != nil {
return err
}
filter := make([]linux.BPFInstruction, int(fprog.Len))
- if _, err := t.CopyIn(usermem.Addr(fprog.Filter), &filter); err != nil {
+ if _, err := linux.CopyBPFInstructionSliceIn(t, usermem.Addr(fprog.Filter), filter); err != nil {
return err
}
compiledFilter, err := bpf.Compile(filter)
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index 5f54f2456..47dadb800 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -18,6 +18,7 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -66,7 +67,7 @@ func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
ops := make([]linux.Sembuf, nsops)
- if _, err := t.CopyIn(sembufAddr, ops); err != nil {
+ if _, err := linux.CopySembufSliceIn(t, sembufAddr, ops); err != nil {
return 0, nil, err
}
@@ -116,8 +117,8 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
case linux.IPC_SET:
arg := args[3].Pointer()
- s := linux.SemidDS{}
- if _, err := t.CopyIn(arg, &s); err != nil {
+ var s linux.SemidDS
+ if _, err := s.CopyIn(t, arg); err != nil {
return 0, nil, err
}
@@ -188,7 +189,7 @@ func setValAll(t *kernel.Task, id int32, array usermem.Addr) error {
return syserror.EINVAL
}
vals := make([]uint16, set.Size())
- if _, err := t.CopyIn(array, vals); err != nil {
+ if _, err := primitive.CopyUint16SliceIn(t, array, vals); err != nil {
return err
}
creds := auth.CredentialsFromContext(t)
@@ -217,7 +218,7 @@ func getValAll(t *kernel.Task, id int32, array usermem.Addr) error {
if err != nil {
return err
}
- _, err = t.CopyOut(array, vals)
+ _, err = primitive.CopyUint16SliceOut(t, array, vals)
return err
}
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index f0ae8fa8e..584064143 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -112,18 +112,18 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
stat, err := segment.IPCStat(t)
if err == nil {
- _, err = t.CopyOut(buf, stat)
+ _, err = stat.CopyOut(t, buf)
}
return 0, nil, err
case linux.IPC_INFO:
params := r.IPCInfo()
- _, err := t.CopyOut(buf, params)
+ _, err := params.CopyOut(t, buf)
return 0, nil, err
case linux.SHM_INFO:
info := r.ShmInfo()
- _, err := t.CopyOut(buf, info)
+ _, err := info.CopyOut(t, buf)
return 0, nil, err
}
@@ -137,11 +137,10 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
switch cmd {
case linux.IPC_SET:
var ds linux.ShmidDS
- _, err = t.CopyIn(buf, &ds)
- if err != nil {
+ if _, err = ds.CopyIn(t, buf); err != nil {
return 0, nil, err
}
- err = segment.Set(t, &ds)
+ err := segment.Set(t, &ds)
return 0, nil, err
case linux.IPC_RMID:
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index e4528d095..9feaca0da 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -19,6 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -29,8 +31,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// LINT.IfChange
@@ -67,10 +67,10 @@ const flagsOffset = 48
const sizeOfInt32 = 4
// messageHeader64Len is the length of a MessageHeader64 struct.
-var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes())
// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
-var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes())
// baseRecvFlags are the flags that are accepted across recvmsg(2),
// recvmmsg(2), and recvfrom(2).
@@ -78,6 +78,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT |
// MessageHeader64 is the 64-bit representation of the msghdr struct used in
// the recvmsg and sendmsg syscalls.
+//
+// +marshal
type MessageHeader64 struct {
// Name is the optional pointer to a network address buffer.
Name uint64
@@ -106,30 +108,14 @@ type MessageHeader64 struct {
// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
// the recvmmsg and sendmmsg syscalls.
+//
+// +marshal
type multipleMessageHeader64 struct {
msgHdr MessageHeader64
msgLen uint32
_ int32
}
-// CopyInMessageHeader64 copies a message header from user to kernel memory.
-func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
- b := t.CopyScratchBuffer(52)
- if _, err := t.CopyInBytes(addr, b); err != nil {
- return err
- }
-
- msg.Name = usermem.ByteOrder.Uint64(b[0:])
- msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
- msg.Iov = usermem.ByteOrder.Uint64(b[16:])
- msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
- msg.Control = usermem.ByteOrder.Uint64(b[32:])
- msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
- msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
-
- return nil
-}
-
// CaptureAddress allocates memory for and copies a socket address structure
// from the untrusted address space range.
func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
@@ -148,10 +134,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte,
// writeAddress writes a sockaddr structure and its length to an output buffer
// in the unstrusted address space range. If the address is bigger than the
// buffer, it is truncated.
-func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
// Get the buffer length.
var bufLen uint32
- if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil {
return err
}
@@ -160,7 +146,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user
}
// Write the length unconditionally.
- if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil {
return err
}
@@ -173,7 +159,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user
}
// Copy as much of the address as will fit in the buffer.
- encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ encodedAddr := t.CopyScratchBuffer(addr.SizeBytes())
+ addr.MarshalUnsafe(encodedAddr)
if bufLen > uint32(len(encodedAddr)) {
bufLen = uint32(len(encodedAddr))
}
@@ -247,7 +234,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
// Copy the file descriptors out.
- if _, err := t.CopyOut(socks, fds); err != nil {
+ if _, err := primitive.CopyInt32SliceOut(t, socks, fds); err != nil {
for _, fd := range fds {
if file, _ := t.FDTable().Remove(t, fd); file != nil {
file.DecRef(t)
@@ -456,8 +443,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
// Read the length. Reject negative values.
- optLen := int32(0)
- if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ var optLen int32
+ if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil {
return 0, nil, err
}
if optLen < 0 {
@@ -471,7 +458,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
vLen := int32(binary.Size(v))
- if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
return 0, nil, err
}
@@ -733,7 +720,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if !ok {
return 0, nil, syserror.EFAULT
}
- if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
break
}
count++
@@ -748,7 +735,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
// Capture the message header and io vectors.
var msg MessageHeader64
- if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, msgPtr); err != nil {
return 0, err
}
@@ -780,7 +767,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
if int(msg.Flags) != mflags {
// Copy out the flags to the caller.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
}
@@ -817,17 +804,17 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
}
// Copy the control data to the caller.
- if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
return 0, err
}
if len(controlData) > 0 {
- if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
// Copy out the flags to the caller.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
@@ -996,7 +983,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if !ok {
return 0, nil, syserror.EFAULT
}
- if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
break
}
count++
@@ -1011,7 +998,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr usermem.Addr, flags int32) (uintptr, error) {
// Capture the message header.
var msg MessageHeader64
- if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, msgPtr); err != nil {
return 0, err
}
@@ -1022,7 +1009,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
return 0, syserror.ENOBUFS
}
controlData = make([]byte, msg.ControlLen)
- if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index c69941feb..46616c961 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -16,6 +16,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -141,7 +142,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Copy in the offset.
var offset int64
- if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
+ if _, err := primitive.CopyInt64In(t, offsetAddr, &offset); err != nil {
return 0, nil, err
}
@@ -149,11 +150,11 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
SrcOffset: true,
- SrcStart: offset,
+ SrcStart: int64(offset),
}, outFile.Flags().NonBlocking)
// Copy out the new offset.
- if _, err := t.CopyOut(offsetAddr, n+offset); err != nil {
+ if _, err := primitive.CopyInt64Out(t, offsetAddr, offset+n); err != nil {
return 0, nil, err
}
} else {
@@ -228,7 +229,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
var offset int64
- if _, err := t.CopyIn(outOffset, &offset); err != nil {
+ if _, err := primitive.CopyInt64In(t, outOffset, &offset); err != nil {
return 0, nil, err
}
@@ -246,7 +247,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
var offset int64
- if _, err := t.CopyIn(inOffset, &offset); err != nil {
+ if _, err := primitive.CopyInt64In(t, inOffset, &offset); err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index a5826f2dd..cda29a8b5 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -221,7 +221,7 @@ func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr
DevMajor: uint32(devMajor),
DevMinor: devMinor,
}
- _, err := t.CopyOut(statxAddr, &s)
+ _, err := s.CopyOut(t, statxAddr)
return err
}
@@ -283,7 +283,7 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
FragmentSize: d.Inode.StableAttr.BlockSize,
// Leave other fields 0 like simple_statfs does.
}
- _, err = t.CopyOut(addr, &statfs)
+ _, err = statfs.CopyOut(t, addr)
return err
}
diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go
index 297de052a..674d341b6 100644
--- a/pkg/sentry/syscalls/linux/sys_sysinfo.go
+++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go
@@ -43,6 +43,6 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
FreeRAM: memFree,
Unit: 1,
}
- _, err := t.CopyOut(addr, si)
+ _, err := si.CopyOut(t, addr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 101096038..39ca9ea97 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -19,6 +19,7 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
@@ -311,13 +312,13 @@ func wait4(t *kernel.Task, pid int, statusAddr usermem.Addr, options int, rusage
return 0, err
}
if statusAddr != 0 {
- if _, err := t.CopyOut(statusAddr, wr.Status); err != nil {
+ if _, err := primitive.CopyUint32Out(t, statusAddr, wr.Status); err != nil {
return 0, err
}
}
if rusageAddr != 0 {
ru := getrusage(wr.Task, linux.RUSAGE_BOTH)
- if _, err := t.CopyOut(rusageAddr, &ru); err != nil {
+ if _, err := ru.CopyOut(t, rusageAddr); err != nil {
return 0, err
}
}
@@ -395,14 +396,14 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// as well.
if infop != 0 {
var si arch.SignalInfo
- _, err = t.CopyOut(infop, &si)
+ _, err = si.CopyOut(t, infop)
}
}
return 0, nil, err
}
if rusageAddr != 0 {
ru := getrusage(wr.Task, linux.RUSAGE_BOTH)
- if _, err := t.CopyOut(rusageAddr, &ru); err != nil {
+ if _, err := ru.CopyOut(t, rusageAddr); err != nil {
return 0, nil, err
}
}
@@ -441,7 +442,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
default:
t.Warningf("waitid got incomprehensible wait status %d", s)
}
- _, err = t.CopyOut(infop, &si)
+ _, err = si.CopyOut(t, infop)
return 0, nil, err
}
@@ -558,9 +559,7 @@ func Getcpu(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// third argument to this system call is nowadays unused.
if cpu != 0 {
- buf := t.CopyScratchBuffer(4)
- usermem.ByteOrder.PutUint32(buf, uint32(t.CPU()))
- if _, err := t.CopyOutBytes(cpu, buf); err != nil {
+ if _, err := primitive.CopyInt32Out(t, cpu, t.CPU()); err != nil {
return 0, nil, err
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index a2a24a027..c5054d2f1 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -19,6 +19,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -168,7 +169,7 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(r), nil, nil
}
- if _, err := t.CopyOut(addr, r); err != nil {
+ if _, err := r.CopyOut(t, addr); err != nil {
return 0, nil, err
}
return uintptr(r), nil, nil
@@ -334,8 +335,8 @@ func Gettimeofday(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
// Ask the time package for the timezone.
_, offset := time.Now().Zone()
// This int32 array mimics linux's struct timezone.
- timezone := [2]int32{-int32(offset) / 60, 0}
- _, err := t.CopyOut(tz, timezone)
+ timezone := []int32{-int32(offset) / 60, 0}
+ _, err := primitive.CopyInt32SliceOut(t, tz, timezone)
return 0, nil, err
}
return 0, nil, nil
diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go
index a4c400f87..45eef4feb 100644
--- a/pkg/sentry/syscalls/linux/sys_timer.go
+++ b/pkg/sentry/syscalls/linux/sys_timer.go
@@ -21,81 +21,63 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
const nsecPerSec = int64(time.Second)
-// copyItimerValIn copies an ItimerVal from the untrusted app range to the
-// kernel. The ItimerVal may be either 32 or 64 bits.
-// A NULL address is allowed because because Linux allows
-// setitimer(which, NULL, &old_value) which disables the timer.
-// There is a KERN_WARN message saying this misfeature will be removed.
-// However, that hasn't happened as of 3.19, so we continue to support it.
-func copyItimerValIn(t *kernel.Task, addr usermem.Addr) (linux.ItimerVal, error) {
- if addr == usermem.Addr(0) {
- return linux.ItimerVal{}, nil
- }
-
- switch t.Arch().Width() {
- case 8:
- // Native size, just copy directly.
- var itv linux.ItimerVal
- if _, err := t.CopyIn(addr, &itv); err != nil {
- return linux.ItimerVal{}, err
- }
-
- return itv, nil
- default:
- return linux.ItimerVal{}, syserror.ENOSYS
- }
-}
-
-// copyItimerValOut copies an ItimerVal to the untrusted app range.
-// The ItimerVal may be either 32 or 64 bits.
-// A NULL address is allowed, in which case no copy takes place
-func copyItimerValOut(t *kernel.Task, addr usermem.Addr, itv *linux.ItimerVal) error {
- if addr == usermem.Addr(0) {
- return nil
- }
-
- switch t.Arch().Width() {
- case 8:
- // Native size, just copy directly.
- _, err := t.CopyOut(addr, itv)
- return err
- default:
- return syserror.ENOSYS
- }
-}
-
// Getitimer implements linux syscall getitimer(2).
func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ if t.Arch().Width() != 8 {
+ // Definition of linux.ItimerVal assumes 64-bit architecture.
+ return 0, nil, syserror.ENOSYS
+ }
+
timerID := args[0].Int()
- val := args[1].Pointer()
+ addr := args[1].Pointer()
olditv, err := t.Getitimer(timerID)
if err != nil {
return 0, nil, err
}
- return 0, nil, copyItimerValOut(t, val, &olditv)
+ // A NULL address is allowed, in which case no copy out takes place.
+ if addr == 0 {
+ return 0, nil, nil
+ }
+ _, err = olditv.CopyOut(t, addr)
+ return 0, nil, err
}
// Setitimer implements linux syscall setitimer(2).
func Setitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- timerID := args[0].Int()
- newVal := args[1].Pointer()
- oldVal := args[2].Pointer()
+ if t.Arch().Width() != 8 {
+ // Definition of linux.ItimerVal assumes 64-bit architecture.
+ return 0, nil, syserror.ENOSYS
+ }
- newitv, err := copyItimerValIn(t, newVal)
- if err != nil {
- return 0, nil, err
+ timerID := args[0].Int()
+ newAddr := args[1].Pointer()
+ oldAddr := args[2].Pointer()
+
+ var newitv linux.ItimerVal
+ // A NULL address is allowed because because Linux allows
+ // setitimer(which, NULL, &old_value) which disables the timer. There is a
+ // KERN_WARN message saying this misfeature will be removed. However, that
+ // hasn't happened as of 3.19, so we continue to support it.
+ if newAddr != 0 {
+ if _, err := newitv.CopyIn(t, newAddr); err != nil {
+ return 0, nil, err
+ }
}
olditv, err := t.Setitimer(timerID, newitv)
if err != nil {
return 0, nil, err
}
- return 0, nil, copyItimerValOut(t, oldVal, &olditv)
+ // A NULL address is allowed, in which case no copy out takes place.
+ if oldAddr == 0 {
+ return 0, nil, nil
+ }
+ _, err = olditv.CopyOut(t, oldAddr)
+ return 0, nil, err
}
// Alarm implements linux syscall alarm(2).
@@ -131,7 +113,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
var sev *linux.Sigevent
if sevp != 0 {
sev = &linux.Sigevent{}
- if _, err = t.CopyIn(sevp, sev); err != nil {
+ if _, err = sev.CopyIn(t, sevp); err != nil {
return 0, nil, err
}
}
@@ -141,7 +123,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, err
}
- if _, err := t.CopyOut(timerIDp, &id); err != nil {
+ if _, err := id.CopyOut(t, timerIDp); err != nil {
t.IntervalTimerDelete(id)
return 0, nil, err
}
@@ -157,7 +139,7 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
oldValAddr := args[3].Pointer()
var newVal linux.Itimerspec
- if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ if _, err := newVal.CopyIn(t, newValAddr); err != nil {
return 0, nil, err
}
oldVal, err := t.IntervalTimerSettime(timerID, newVal, flags&linux.TIMER_ABSTIME != 0)
@@ -165,9 +147,8 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
return 0, nil, err
}
if oldValAddr != 0 {
- if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
- return 0, nil, err
- }
+ _, err = oldVal.CopyOut(t, oldValAddr)
+ return 0, nil, err
}
return 0, nil, nil
}
@@ -181,7 +162,7 @@ func TimerGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
if err != nil {
return 0, nil, err
}
- _, err = t.CopyOut(curValAddr, &curVal)
+ _, err = curVal.CopyOut(t, curValAddr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go
index 34b03e4ee..cadd9d348 100644
--- a/pkg/sentry/syscalls/linux/sys_timerfd.go
+++ b/pkg/sentry/syscalls/linux/sys_timerfd.go
@@ -81,7 +81,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
}
var newVal linux.Itimerspec
- if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ if _, err := newVal.CopyIn(t, newValAddr); err != nil {
return 0, nil, err
}
newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tf.Clock())
@@ -91,7 +91,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tm, oldS := tf.SetTime(newS)
if oldValAddr != 0 {
oldVal := ktime.ItimerspecFromSetting(tm, oldS)
- if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ if _, err := oldVal.CopyOut(t, oldValAddr); err != nil {
return 0, nil, err
}
}
@@ -116,6 +116,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tm, s := tf.GetTime()
curVal := ktime.ItimerspecFromSetting(tm, s)
- _, err := t.CopyOut(curValAddr, &curVal)
+ _, err := curVal.CopyOut(t, curValAddr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
index b3eb96a1c..6ddd30d5c 100644
--- a/pkg/sentry/syscalls/linux/sys_tls_amd64.go
+++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
@@ -18,6 +18,7 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -30,17 +31,19 @@ func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
case linux.ARCH_GET_FS:
addr := args[1].Pointer()
fsbase := t.Arch().TLS()
- _, err := t.CopyOut(addr, uint64(fsbase))
- if err != nil {
- return 0, nil, err
+ switch t.Arch().Width() {
+ case 8:
+ if _, err := primitive.CopyUint64Out(t, addr, uint64(fsbase)); err != nil {
+ return 0, nil, err
+ }
+ default:
+ return 0, nil, syserror.ENOSYS
}
-
case linux.ARCH_SET_FS:
fsbase := args[1].Uint64()
if !t.Arch().SetTLS(uintptr(fsbase)) {
return 0, nil, syserror.EPERM
}
-
case linux.ARCH_GET_GS, linux.ARCH_SET_GS:
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index e9d702e8e..66c5974f5 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -46,7 +46,7 @@ func Uname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Copy out the result.
va := args[0].Pointer()
- _, err := t.CopyOut(va, u)
+ _, err := u.CopyOut(t, va)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 0030dee39..9ee766552 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -45,6 +45,8 @@ go_library(
"//pkg/fspath",
"//pkg/gohacks",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsbridge",
@@ -73,7 +75,5 @@ go_library(
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index 42559bf69..6d0a38330 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -17,6 +17,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -38,21 +39,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
for i := int32(0); i < nrEvents; i++ {
- // Copy in the address.
- cbAddrNative := t.Arch().Native(0)
- if _, err := t.CopyIn(addr, cbAddrNative); err != nil {
- if i > 0 {
- // Some successful.
- return uintptr(i), nil, nil
+ // Copy in the callback address.
+ var cbAddr usermem.Addr
+ switch t.Arch().Width() {
+ case 8:
+ var cbAddrP primitive.Uint64
+ if _, err := cbAddrP.CopyIn(t, addr); err != nil {
+ if i > 0 {
+ // Some successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
}
- // Nothing done.
- return 0, nil, err
+ cbAddr = usermem.Addr(cbAddrP)
+ default:
+ return 0, nil, syserror.ENOSYS
}
// Copy in this callback.
var cb linux.IOCallback
- cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
- if _, err := t.CopyIn(cbAddr, &cb); err != nil {
+ if _, err := cb.CopyIn(t, cbAddr); err != nil {
if i > 0 {
// Some have been successful.
return uintptr(i), nil, nil
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
index c62f03509..d0cbb77eb 100644
--- a/pkg/sentry/syscalls/linux/vfs2/epoll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -24,7 +24,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -141,50 +140,26 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EINVAL
}
- // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent,
- // maxEvents), so that the buffer can be allocated on the stack.
+ // Allocate space for a few events on the stack for the common case in
+ // which we don't have too many events.
var (
- events [16]linux.EpollEvent
- total int
+ eventsArr [16]linux.EpollEvent
ch chan struct{}
haveDeadline bool
deadline ktime.Time
)
for {
- batchEvents := len(events)
- if batchEvents > maxEvents {
- batchEvents = maxEvents
- }
- n := ep.ReadEvents(events[:batchEvents])
- maxEvents -= n
- if n != 0 {
- // Copy what we read out.
- copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events[:n])
+ events := ep.ReadEvents(eventsArr[:0], maxEvents)
+ if len(events) != 0 {
+ copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events)
copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
- eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
- total += copiedEvents
- if err != nil {
- if total != 0 {
- return uintptr(total), nil, nil
- }
- return 0, nil, err
- }
- // If we've filled the application's event buffer, we're done.
- if maxEvents == 0 {
- return uintptr(total), nil, nil
- }
- // Loop if we read a full batch, under the expectation that there
- // may be more events to read.
- if n == batchEvents {
- continue
+ if copiedEvents != 0 {
+ return uintptr(copiedEvents), nil, nil
}
+ return 0, nil, err
}
- // We get here if n != batchEvents. If we read any number of events
- // (just now, or in a previous iteration of this loop), or if timeout
- // is 0 (such that epoll_wait should be non-blocking), return the
- // events we've read so far to the application.
- if total != 0 || timeout == 0 {
- return uintptr(total), nil, nil
+ if timeout == 0 {
+ return 0, nil, nil
}
// In the first iteration of this loop, register with the epoll
// instance for readability events, but then immediately continue the
@@ -207,8 +182,6 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if err == syserror.ETIMEDOUT {
err = nil
}
- // total must be 0 since otherwise we would have returned
- // above.
return 0, nil, err
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index fdd8f88c5..d8b8d9783 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -181,11 +181,11 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if !hasOwner {
return 0, nil, nil
}
- _, err := t.CopyOut(args[2].Pointer(), &owner)
+ _, err := owner.CopyOut(t, args[2].Pointer())
return 0, nil, err
case linux.F_SETOWN_EX:
var owner linux.FOwnerEx
- _, err := t.CopyIn(args[2].Pointer(), &owner)
+ _, err := owner.CopyIn(t, args[2].Pointer())
if err != nil {
return 0, nil, err
}
@@ -286,7 +286,7 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip
// Copy in the lock request.
flockAddr := args[2].Pointer()
var flock linux.Flock
- if _, err := t.CopyIn(flockAddr, &flock); err != nil {
+ if _, err := flock.CopyIn(t, flockAddr); err != nil {
return err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index baa8a49af..2806c3f6f 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserror"
@@ -47,7 +48,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FIONBIO:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
flags := file.StatusFlags()
@@ -60,7 +61,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FIOASYNC:
var set int32
- if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
flags := file.StatusFlags()
@@ -82,12 +83,12 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
who = owner.PID
}
}
- _, err := t.CopyOut(args[2].Pointer(), &who)
+ _, err := primitive.CopyInt32Out(t, args[2].Pointer(), who)
return 0, nil, err
case linux.FIOSETOWN, linux.SIOCSPGRP:
var who int32
- if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil {
+ if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &who); err != nil {
return 0, nil, err
}
ownerType := int32(linux.F_OWNER_PID)
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
index 3aa6d939d..ee38fdca0 100644
--- a/pkg/sentry/syscalls/linux/vfs2/pipe.go
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -16,6 +16,7 @@ package vfs2
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -51,7 +52,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
if err != nil {
return err
}
- if _, err := t.CopyOut(addr, fds); err != nil {
+ if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil {
for _, fd := range fds {
if _, file := t.FDTable().Remove(t, fd); file != nil {
file.DecRef(t)
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
index 79ad64039..c22e4ce54 100644
--- a/pkg/sentry/syscalls/linux/vfs2/poll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -165,7 +165,7 @@ func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD
pfd := make([]linux.PollFD, nfds)
if nfds > 0 {
- if _, err := t.CopyIn(addr, &pfd); err != nil {
+ if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil {
return nil, err
}
}
@@ -192,7 +192,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration)
// The poll entries are copied out regardless of whether
// any are set or not. This aligns with the Linux behavior.
if nfds > 0 && err == nil {
- if _, err := t.CopyOut(addr, pfd); err != nil {
+ if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil {
return remainingTimeout, 0, err
}
}
@@ -205,7 +205,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy
set := make([]byte, nBytes)
if addr != 0 {
- if _, err := t.CopyIn(addr, &set); err != nil {
+ if _, err := t.CopyInBytes(addr, set); err != nil {
return nil, err
}
// If we only use part of the last byte, mask out the extraneous bits.
@@ -332,19 +332,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
// Copy updated vectors back.
if readFDs != 0 {
- if _, err := t.CopyOut(readFDs, r); err != nil {
+ if _, err := t.CopyOutBytes(readFDs, r); err != nil {
return 0, err
}
}
if writeFDs != 0 {
- if _, err := t.CopyOut(writeFDs, w); err != nil {
+ if _, err := t.CopyOutBytes(writeFDs, w); err != nil {
return 0, err
}
}
if exceptFDs != 0 {
- if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ if _, err := t.CopyOutBytes(exceptFDs, e); err != nil {
return 0, err
}
}
@@ -497,6 +497,12 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return n, nil, err
}
+// +marshal
+type sigSetWithSize struct {
+ sigsetAddr uint64
+ sizeofSigset uint64
+}
+
// Pselect implements linux syscall pselect(2).
func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
nfds := int(args[0].Int()) // select(2) uses an int.
@@ -538,12 +544,6 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
return n, nil, err
}
-// +marshal
-type sigSetWithSize struct {
- sigsetAddr uint64
- sizeofSigset uint64
-}
-
// copyTimespecInToDuration copies a Timespec from the untrusted app range,
// validates it and converts it to a Duration.
//
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
index 5e6eb13ba..1ee37e5a8 100644
--- a/pkg/sentry/syscalls/linux/vfs2/setstat.go
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -346,7 +346,7 @@ func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opt
return nil
}
var times [2]linux.Timeval
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil {
return err
}
if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 {
@@ -410,7 +410,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op
return nil
}
var times [2]linux.Timespec
- if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil {
return err
}
if times[0].Nsec != linux.UTIME_OMIT {
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index a15dad29f..bfae6b7e9 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -19,6 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -30,8 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// minListenBacklog is the minimum reasonable backlog for listening sockets.
@@ -66,10 +66,10 @@ const flagsOffset = 48
const sizeOfInt32 = 4
// messageHeader64Len is the length of a MessageHeader64 struct.
-var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes())
// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
-var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes())
// baseRecvFlags are the flags that are accepted across recvmsg(2),
// recvmmsg(2), and recvfrom(2).
@@ -77,6 +77,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT |
// MessageHeader64 is the 64-bit representation of the msghdr struct used in
// the recvmsg and sendmsg syscalls.
+//
+// +marshal
type MessageHeader64 struct {
// Name is the optional pointer to a network address buffer.
Name uint64
@@ -105,30 +107,14 @@ type MessageHeader64 struct {
// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
// the recvmmsg and sendmmsg syscalls.
+//
+// +marshal
type multipleMessageHeader64 struct {
msgHdr MessageHeader64
msgLen uint32
_ int32
}
-// CopyInMessageHeader64 copies a message header from user to kernel memory.
-func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
- b := t.CopyScratchBuffer(52)
- if _, err := t.CopyInBytes(addr, b); err != nil {
- return err
- }
-
- msg.Name = usermem.ByteOrder.Uint64(b[0:])
- msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
- msg.Iov = usermem.ByteOrder.Uint64(b[16:])
- msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
- msg.Control = usermem.ByteOrder.Uint64(b[32:])
- msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
- msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
-
- return nil
-}
-
// CaptureAddress allocates memory for and copies a socket address structure
// from the untrusted address space range.
func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
@@ -147,10 +133,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte,
// writeAddress writes a sockaddr structure and its length to an output buffer
// in the unstrusted address space range. If the address is bigger than the
// buffer, it is truncated.
-func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
// Get the buffer length.
var bufLen uint32
- if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil {
return err
}
@@ -159,7 +145,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user
}
// Write the length unconditionally.
- if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil {
return err
}
@@ -172,7 +158,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user
}
// Copy as much of the address as will fit in the buffer.
- encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ encodedAddr := t.CopyScratchBuffer(addr.SizeBytes())
+ addr.MarshalUnsafe(encodedAddr)
if bufLen > uint32(len(encodedAddr)) {
bufLen = uint32(len(encodedAddr))
}
@@ -250,7 +237,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, err
}
- if _, err := t.CopyOut(addr, fds); err != nil {
+ if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil {
for _, fd := range fds {
if _, file := t.FDTable().Remove(t, fd); file != nil {
file.DecRef(t)
@@ -459,8 +446,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
// Read the length. Reject negative values.
- optLen := int32(0)
- if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ var optLen int32
+ if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil {
return 0, nil, err
}
if optLen < 0 {
@@ -474,7 +461,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
vLen := int32(binary.Size(v))
- if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil {
return 0, nil, err
}
@@ -736,7 +723,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if !ok {
return 0, nil, syserror.EFAULT
}
- if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
break
}
count++
@@ -751,7 +738,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
// Capture the message header and io vectors.
var msg MessageHeader64
- if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, msgPtr); err != nil {
return 0, err
}
@@ -783,7 +770,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
if int(msg.Flags) != mflags {
// Copy out the flags to the caller.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
}
@@ -820,17 +807,17 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
}
// Copy the control data to the caller.
- if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
return 0, err
}
if len(controlData) > 0 {
- if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
// Copy out the flags to the caller.
- if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
return 0, err
}
@@ -999,7 +986,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if !ok {
return 0, nil, syserror.EFAULT
}
- if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
break
}
count++
@@ -1014,7 +1001,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) {
// Capture the message header.
var msg MessageHeader64
- if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ if _, err := msg.CopyIn(t, msgPtr); err != nil {
return 0, err
}
@@ -1025,7 +1012,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
return 0, syserror.ENOBUFS
}
controlData = make([]byte, msg.ControlLen)
- if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil {
return 0, err
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 5543cfac2..f55d74cd2 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
@@ -89,7 +90,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if inFile.Options().DenyPRead {
return 0, nil, syserror.EINVAL
}
- if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil {
+ if _, err := primitive.CopyInt64In(t, inOffsetPtr, &inOffset); err != nil {
return 0, nil, err
}
if inOffset < 0 {
@@ -104,7 +105,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if outFile.Options().DenyPWrite {
return 0, nil, syserror.EINVAL
}
- if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil {
+ if _, err := primitive.CopyInt64In(t, outOffsetPtr, &outOffset); err != nil {
return 0, nil, err
}
if outOffset < 0 {
@@ -160,12 +161,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Copy updated offsets out.
if inOffsetPtr != 0 {
- if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil {
+ if _, err := primitive.CopyInt64Out(t, inOffsetPtr, inOffset); err != nil {
return 0, nil, err
}
}
if outOffsetPtr != 0 {
- if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil {
+ if _, err := primitive.CopyInt64Out(t, outOffsetPtr, outOffset); err != nil {
return 0, nil, err
}
}
@@ -303,9 +304,12 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if inFile.Options().DenyPRead {
return 0, nil, syserror.ESPIPE
}
- if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
+ var offsetP primitive.Int64
+ if _, err := offsetP.CopyIn(t, offsetAddr); err != nil {
return 0, nil, err
}
+ offset = int64(offsetP)
+
if offset < 0 {
return 0, nil, syserror.EINVAL
}
@@ -422,7 +426,8 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if offsetAddr != 0 {
// Copy out the new offset.
- if _, err := t.CopyOut(offsetAddr, offset); err != nil {
+ offsetP := primitive.Uint64(offset)
+ if _, err := offsetP.CopyOut(t, offsetAddr); err != nil {
return 0, nil, err
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
index 7a26890ef..250870c03 100644
--- a/pkg/sentry/syscalls/linux/vfs2/timerfd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
@@ -87,7 +87,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
}
var newVal linux.Itimerspec
- if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ if _, err := newVal.CopyIn(t, newValAddr); err != nil {
return 0, nil, err
}
newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock())
@@ -97,7 +97,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tm, oldS := tfd.SetTime(newS)
if oldValAddr != 0 {
oldVal := ktime.ItimerspecFromSetting(tm, oldS)
- if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ if _, err := oldVal.CopyOut(t, oldValAddr); err != nil {
return 0, nil, err
}
}
@@ -122,6 +122,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
tm, s := tfd.GetTime()
curVal := ktime.ItimerspecFromSetting(tm, s)
- _, err := t.CopyOut(curValAddr, &curVal)
+ _, err := curVal.CopyOut(t, curValAddr)
return 0, nil, err
}
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 1b5af9f73..754e76aec 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -331,11 +331,9 @@ func (ep *EpollInstance) removeLocked(epi *epollInterest) {
ep.mu.Unlock()
}
-// ReadEvents reads up to len(events) ready events into events and returns the
-// number of events read.
-//
-// Preconditions: len(events) != 0.
-func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int {
+// ReadEvents appends up to maxReady events to events and returns the updated
+// slice of events.
+func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent, maxEvents int) []linux.EpollEvent {
i := 0
// Hot path: avoid defer.
ep.mu.Lock()
@@ -368,16 +366,16 @@ func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int {
requeue.PushBack(epi)
}
// Report ievents.
- events[i] = linux.EpollEvent{
+ events = append(events, linux.EpollEvent{
Events: ievents.ToLinux(),
Data: epi.userData,
- }
+ })
i++
- if i == len(events) {
+ if i == maxEvents {
break
}
}
ep.ready.PushBackList(&requeue)
ep.mu.Unlock()
- return i
+ return events
}
diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go
index 29719752e..7036467c4 100644
--- a/pkg/syserror/syserror_test.go
+++ b/pkg/syserror/syserror_test.go
@@ -24,27 +24,20 @@ import (
var globalError error
-func returnErrnoAsError() error {
- return syscall.EINVAL
-}
-
-func returnError() error {
- return syserror.EINVAL
-}
-
-func BenchmarkReturnErrnoAsError(b *testing.B) {
+func BenchmarkAssignErrno(b *testing.B) {
for i := b.N; i > 0; i-- {
- returnErrnoAsError()
+ globalError = syscall.EINVAL
}
}
-func BenchmarkReturnError(b *testing.B) {
+func BenchmarkAssignError(b *testing.B) {
for i := b.N; i > 0; i-- {
- returnError()
+ globalError = syserror.EINVAL
}
}
func BenchmarkCompareErrno(b *testing.B) {
+ globalError = syscall.EAGAIN
j := 0
for i := b.N; i > 0; i-- {
if globalError == syscall.EINVAL {
@@ -54,6 +47,7 @@ func BenchmarkCompareErrno(b *testing.B) {
}
func BenchmarkCompareError(b *testing.B) {
+ globalError = syserror.EAGAIN
j := 0
for i := b.N; i > 0; i-- {
if globalError == syserror.EINVAL {
@@ -63,6 +57,7 @@ func BenchmarkCompareError(b *testing.B) {
}
func BenchmarkSwitchErrno(b *testing.B) {
+ globalError = syscall.EPERM
j := 0
for i := b.N; i > 0; i-- {
switch globalError {
@@ -77,6 +72,7 @@ func BenchmarkSwitchErrno(b *testing.B) {
}
func BenchmarkSwitchError(b *testing.B) {
+ globalError = syserror.EPERM
j := 0
for i := b.N; i > 0; i-- {
switch globalError {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 680eafd16..e8816c3f4 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -88,6 +88,16 @@ const (
// units, the header cannot exceed 15*4 = 60 bytes.
IPv4MaximumHeaderSize = 60
+ // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload.
+ //
+ // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4
+ // header size). But RFC 791 section 3.2 discusses the design of the IPv4
+ // fragment "allows 2**13 = 8192 fragments of 8 octets each for a total of
+ // 65,536 octets. Note that this is consistent with the the datagram total
+ // length field (of course, the header is counted in the total length and not
+ // in the fragments)."
+ IPv4MaximumPayloadSize = 65536
+
// MinIPFragmentPayloadSize is the minimum number of payload bytes that
// the first fragment must carry when an IPv4 packet is fragmented.
MinIPFragmentPayloadSize = 8
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index ea3823898..0761a1807 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -74,6 +74,10 @@ const (
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
IPv6AddressSize = 16
+ // IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per
+ // RFC 8200 Section 4.5.
+ IPv6MaximumPayloadSize = 65535
+
// IPv6ProtocolNumber is IPv6's network protocol number.
IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 9339d637f..98bdd29db 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "math"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -55,6 +56,10 @@ const (
// UDPMinimumSize is the minimum size of a valid UDP packet.
UDPMinimumSize = 8
+ // UDPMaximumSize is the maximum size of a valid UDP packet. The length field
+ // in the UDP header is 16 bits as per RFC 768.
+ UDPMaximumSize = math.MaxUint16
+
// UDPProtocolNumber is UDP's transport protocol number.
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index d142b4ffa..c82593e71 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -30,6 +30,7 @@ go_test(
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index fa4ae2012..f4394749d 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -404,11 +404,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
// The packet is a fragment, let's try to reassemble it.
- last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < h.FragmentOffset() {
+ start := h.FragmentOffset()
+ // Drop the fragment if the size of the reassembled payload would exceed the
+ // maximum payload size.
+ //
+ // Note that this addition doesn't overflow even on 32bit architecture
+ // because pkt.Data.Size() should not exceed 65535 (the max IP datagram
+ // size). Otherwise the packet would've been rejected as invalid before
+ // reaching here.
+ if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
@@ -425,8 +429,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ID: uint32(h.ID()),
Protocol: proto,
},
- h.FragmentOffset(),
- last,
+ start,
+ start+uint16(pkt.Data.Size())-1,
h.More(),
proto,
pkt.Data,
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 197e3bc51..5e50558e8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,8 +17,6 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
- "fmt"
- "math/rand"
"testing"
"github.com/google/go-cmp/cmp"
@@ -28,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -92,31 +91,6 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// makeRandPkt generates a randomize packet. hdrLength indicates how much
-// data should already be in the header before WritePacket. extraLength
-// indicates how much extra space should be in the header. The payload is made
-// from many Views of the sizes listed in viewSizes.
-func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer {
- var views []buffer.View
- totalLength := 0
- for _, s := range viewSizes {
- newView := buffer.NewView(s)
- rand.Read(newView)
- views = append(views, newView)
- totalLength += s
- }
-
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: hdrLength + extraLength,
- Data: buffer.NewVectorisedView(totalLength, views),
- })
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil {
- panic(fmt.Sprintf("rand.Read: %s", err))
- }
- return pkt
-}
-
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
@@ -186,63 +160,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
}
-type errorChannel struct {
- *channel.Endpoint
- Ch chan *stack.PacketBuffer
- packetCollectorErrors []*tcpip.Error
-}
-
-// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
-// will return successive errors from packetCollectorErrors until the list is
-// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
- return &errorChannel{
- Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan *stack.PacketBuffer, size),
- packetCollectorErrors: packetCollectorErrors,
- }
-}
-
-// Drain removes all outbound packets from the channel and counts them.
-func (e *errorChannel) Drain() int {
- c := 0
- for {
- select {
- case <-e.Ch:
- c++
- default:
- return c
- }
- }
-}
-
-// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- select {
- case e.Ch <- pkt:
- default:
- }
-
- nextError := (*tcpip.Error)(nil)
- if len(e.packetCollectorErrors) > 0 {
- nextError = e.packetCollectorErrors[0]
- e.packetCollectorErrors = e.packetCollectorErrors[1:]
- }
- return nextError
-}
-
-type context struct {
+type testRoute struct {
stack.Route
- linkEP *errorChannel
+
+ linkEP *testutil.TestEndpoint
}
-func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
+func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute {
// Make the packet and write it.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
})
- ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- s.CreateNIC(1, ep)
+ testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors)
+ s.CreateNIC(1, testEP)
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
@@ -262,9 +192,12 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
if err != nil {
t.Fatalf("s.FindRoute got %v, want %v", err, nil)
}
- return context{
+ t.Cleanup(func() {
+ testEP.Close()
+ })
+ return testRoute{
Route: r,
- linkEP: ep,
+ linkEP: testEP,
}
}
@@ -274,13 +207,13 @@ func TestFragmentation(t *testing.T) {
manyPayloadViewsSizes[i] = 7
}
fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- hdrLength int
- extraLength int
- payloadViewsSizes []int
- expectedFrags int
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transportHeaderLength int
+ extraHeaderReserveLength int
+ payloadViewsSizes []int
+ expectedFrags int
}{
{"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
{"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
@@ -295,10 +228,10 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
source := pkt.Clone()
- c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
@@ -307,24 +240,13 @@ func TestFragmentation(t *testing.T) {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []*stack.PacketBuffer
- L:
- for {
- select {
- case pi := <-c.linkEP.Ch:
- results = append(results, pi)
- default:
- break L
- }
- }
-
- if got, want := len(results), ft.expectedFrags; got != want {
- t.Errorf("len(result) got %d, want %d", got, want)
+ if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want {
+ t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
}
- if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want)
}
- compareFragments(t, results, source, ft.mtu)
+ compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu)
})
}
}
@@ -335,21 +257,21 @@ func TestFragmentationErrors(t *testing.T) {
fragTests := []struct {
description string
mtu uint32
- hdrLength int
+ transportHeaderLength int
payloadViewsSizes []int
packetCollectorErrors []*tcpip.Error
}{
{"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
{"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
{"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
- {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
- c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: 42,
TOS: stack.DefaultTOS,
@@ -364,7 +286,7 @@ func TestFragmentationErrors(t *testing.T) {
if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
t.Errorf("err got %v, want %v", got, want)
}
- if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
+ if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
}
})
@@ -372,115 +294,308 @@ func TestFragmentationErrors(t *testing.T) {
}
func TestInvalidFragments(t *testing.T) {
+ const (
+ nicID = 1
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ addr1 = "\x0a\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x02"
+ tos = 0
+ ident = 1
+ ttl = 48
+ protocol = 6
+ )
+
+ payloadGen := func(payloadLen int) []byte {
+ payload := make([]byte, payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = 0x30
+ }
+ return payload
+ }
+
+ type fragmentData struct {
+ ipv4fields header.IPv4Fields
+ payload []byte
+ autoChecksum bool // if true, the Checksum field will be overwritten.
+ }
+
// These packets have both IHL and TotalLength set to 0.
- testCases := []struct {
+ tests := []struct {
name string
- packets [][]byte
+ fragments []fragmentData
wantMalformedIPPackets uint64
wantMalformedFragments uint64
}{
{
- "ihl_totallen_zero_valid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- },
- 1,
- 0,
- },
- {
- "ihl_totallen_zero_invalid_frag_offset",
- [][]byte{
- {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL and TotalLength zero, FragmentOffset non-zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments,
+ FragmentOffset: 59776,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
},
- 1,
- 0,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
},
{
- // Total Length of 37(20 bytes IP header + 17 bytes of
- // payload)
- // Frag Offset of 0x1ffe = 8190*8 = 65520
- // Leading to the fragment end to be past 65535.
- "ihl_totallen_valid_invalid_frag_offset_1",
- [][]byte{
- {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL and TotalLength zero, FragmentOffset zero",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: 0,
+ TOS: tos,
+ TotalLength: 0,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(12),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 0,
},
- // The following 3 tests were found by running a fuzzer and were
- // triggering a panic in the IPv4 reassembler code.
{
- "ihl_less_than_ipv4_minimum_size_1",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ // Payload 17 octets and Fragment offset 65520
+ // Leading to the fragment end to be past 65536.
+ name: "fragment ends past 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 17,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(17),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
{
- "ihl_less_than_ipv4_minimum_size_2",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ // Payload 16 octets and fragment offset 65520
+ // Leading to the fragment end to be exactly 65536.
+ name: "fragment ends exactly at 65536",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 65520,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(16),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 0,
+ wantMalformedFragments: 0,
},
{
- "ihl_less_than_ipv4_minimum_size_3",
- [][]byte{
- {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "IHL less than IPv4 minimum size",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 1944,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 12,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize - 12,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 2,
- 0,
+ wantMalformedIPPackets: 2,
+ wantMalformedFragments: 0,
},
{
- "fragment_with_short_total_len_extra_payload",
- [][]byte{
- {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
- {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30},
+ name: "fragment with short TotalLength and extra payload",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 28,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 28816,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize + 4,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 4,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(28),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
{
- "multiple_fragments_with_more_fragments_set_to_false",
- [][]byte{
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ name: "multiple fragments with More Fragments flag set to false",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 128,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 8,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: payloadGen(8),
+ autoChecksum: true,
+ },
},
- 1,
- 1,
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
},
}
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- const nicID tcpip.NICID = 42
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{
ipv4.NewProtocol(),
},
})
+ e := channel.New(0, 1500, linkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+
+ for _, f := range test.fragments {
+ pktSize := header.IPv4MinimumSize + len(f.payload)
+ hdr := buffer.NewPrependable(pktSize)
+
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&f.ipv4fields)
+ copy(ip[header.IPv4MinimumSize:], f.payload)
- var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
- var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
- ep := channel.New(10, 1500, linkAddr)
- s.CreateNIC(nicID, sniffer.New(ep))
+ if f.autoChecksum {
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ }
- for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
+ vv := hdr.View().ToVectorisedView()
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
}))
}
- if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want)
}
- if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want {
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want)
}
})
@@ -534,6 +649,9 @@ func TestReceiveFragments(t *testing.T) {
// the fragment block size of 8 (RFC 791 section 3.1 page 14).
ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
+ // Used to test the max reassembled payload length (65,535 octets).
+ ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2)
+ udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:]
type fragmentData struct {
srcAddr tcpip.Address
@@ -827,6 +945,28 @@ func TestReceiveFragments(t *testing.T) {
},
expectedPayloads: nil,
},
+ {
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload4Addr1ToAddr2[:65512],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 65512,
+ payload: ipv4Payload4Addr1ToAddr2[65512:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
+ },
}
for _, test := range tests {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index af3cd91c6..e821a8bff 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -311,12 +311,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// The packet is a fragment, let's try to reassemble it.
start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
- last := start + uint16(fragmentPayloadLen) - 1
- // Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and pkt.Data.size() causes a
- // wrap around resulting in last being less than the offset.
- if last < start {
+ // Drop the fragment if the size of the reassembled payload would exceed
+ // the maximum payload size.
+ if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
@@ -333,7 +331,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ID: extHdr.ID(),
},
start,
- last,
+ start+uint16(fragmentPayloadLen)-1,
extHdr.More(),
uint8(rawPayload.Identifier),
rawPayload.Buf,
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 54787198f..5f9822c49 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
@@ -687,6 +688,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
// Used to test cases where the fragment blocks are not a multiple of
// the fragment block size of 8 (RFC 8200 section 4.5).
udpPayload3Length = 127
+ udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize
fragmentExtHdrLen = 8
// Note, not all routing extension headers will be 8 bytes but this test
// uses 8 byte routing extension headers for most sub tests.
@@ -731,6 +733,10 @@ func TestReceiveIPv6Fragments(t *testing.T) {
udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:]
ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2)
+ var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte
+ udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:]
+ ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2)
+
tests := []struct {
name string
expectedPayload []byte
@@ -1020,6 +1026,44 @@ func TestReceiveIPv6Fragments(t *testing.T) {
expectedPayloads: nil,
},
{
+ name: "Two fragments reassembled into a maximum UDP packet",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[:65520],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8190, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}),
+
+ ipv6Payload4Addr1ToAddr2[65520:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2},
+ },
+ {
name: "Two fragments with per-fragment routing header with zero segments left",
fragments: []fragmentData{
{
@@ -1572,3 +1616,96 @@ func TestReceiveIPv6Fragments(t *testing.T) {
})
}
}
+
+func TestInvalidIPv6Fragments(t *testing.T) {
+ const (
+ nicID = 1
+ fragmentExtHdrLen = 8
+ )
+
+ payloadGen := func(payloadLen int) []byte {
+ payload := make([]byte, payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = 0x30
+ }
+ return payload
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ wantMalformedIPPackets uint64
+ wantMalformedFragments uint64
+ }{
+ {
+ name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16,
+ []buffer.View{
+ // Fragment extension header.
+ // Fragment offset = 8190, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
+ ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8,
+ ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8,
+ 0, 0, 0, 1}),
+ // Payload length = 16
+ payloadGen(16),
+ },
+ ),
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ NewProtocol(),
+ },
+ })
+ e := channel.New(0, 1500, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+
+ // Serialize IPv6 fixed header.
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(f.data.Size()),
+ NextHeader: f.nextHdr,
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.Append(f.data)
+
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ }
+
+ if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
+ t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want)
+ }
+ if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
+ t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
new file mode 100644
index 000000000..e218563d0
--- /dev/null
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ srcs = [
+ "testutil.go",
+ ],
+ visibility = ["//pkg/tcpip/network/ipv4:__pkg__"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
new file mode 100644
index 000000000..bf5ce74be
--- /dev/null
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil defines types and functions used to test Network Layer
+// functionality such as IP fragmentation.
+package testutil
+
+import (
+ "fmt"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// TestEndpoint is an endpoint used for testing, it stores packets written to it
+// and can mock errors.
+type TestEndpoint struct {
+ *channel.Endpoint
+
+ // WrittenPackets is where we store packets written via WritePacket().
+ WrittenPackets []*stack.PacketBuffer
+
+ packetCollectorErrors []*tcpip.Error
+}
+
+// NewTestEndpoint creates a new TestEndpoint endpoint.
+//
+// packetCollectorErrors can be used to set error values and each call to
+// WritePacket will remove the first one from the slice and return it until
+// the slice is empty - at that point it will return nil every time.
+func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint {
+ return &TestEndpoint{
+ Endpoint: ep,
+ WrittenPackets: make([]*stack.PacketBuffer, 0),
+ packetCollectorErrors: packetCollectorErrors,
+ }
+}
+
+// WritePacket stores outbound packets and may return an error if one was
+// injected.
+func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.WrittenPackets = append(e.WrittenPackets, pkt)
+
+ if len(e.packetCollectorErrors) > 0 {
+ nextError := e.packetCollectorErrors[0]
+ e.packetCollectorErrors = e.packetCollectorErrors[1:]
+ return nextError
+ }
+
+ return nil
+}
+
+// MakeRandPkt generates a randomized packet. transportHeaderLength indicates
+// how many random bytes will be copied in the Transport Header.
+// extraHeaderReserveLength indicates how much extra space will be reserved for
+// the other headers. The payload is made from Views of the sizes listed in
+// viewSizes.
+func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
+ var views buffer.VectorisedView
+
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ if _, err := rand.Read(newView); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ views.AppendView(newView)
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
+ Data: views,
+ })
+ pkt.NetworkProtocolNumber = proto
+ if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ return pkt
+}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index c2eabde9e..2cbbf0de8 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -48,10 +48,6 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
-
- // directedBroadcast indicates whether this route is sending a directed
- // broadcast packet.
- directedBroadcast bool
}
// makeRoute initializes a new route. It takes ownership of the provided
@@ -303,24 +299,27 @@ func (r *Route) Stack() *Stack {
return r.ref.stack()
}
+func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
+ if addr == header.IPv4Broadcast {
+ return true
+ }
+
+ subnet := r.ref.addrWithPrefix().Subnet()
+ return subnet.IsBroadcast(addr)
+}
+
// IsOutboundBroadcast returns true if the route is for an outbound broadcast
// packet.
func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
- return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast
+ return r.isV4Broadcast(r.RemoteAddress)
}
// IsInboundBroadcast returns true if the route is for an inbound broadcast
// packet.
func (r *Route) IsInboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
- if r.LocalAddress == header.IPv4Broadcast {
- return true
- }
-
- addr := r.ref.addrWithPrefix()
- subnet := addr.Subnet()
- return subnet.IsBroadcast(r.LocalAddress)
+ return r.isV4Broadcast(r.LocalAddress)
}
// ReverseRoute returns new route with given source and destination address.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index def8b0b43..6a683545d 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1311,13 +1311,11 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
- r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr)
-
if len(route.Gateway) > 0 {
if needRoute {
r.NextHop = route.Gateway
}
- } else if r.directedBroadcast {
+ } else if subnet := ref.addrWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index b902c6ca9..0774b5382 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isMulticastOrBroadcast(id.LocalAddress) {
+ if isInboundMulticastOrBroadcast(r) {
mpep.handlePacketAll(r, id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
@@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a TCP packet with a non-unicast source or destination
// address, then do nothing further and instruct the caller to do the same.
- if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.
r.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isMulticastOrBroadcast(addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+func isInboundMulticastOrBroadcast(r *Route) bool {
+ return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
}
-func isUnicast(addr tcpip.Address) bool {
- return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+func isInboundUnicast(r *Route) bool {
+ return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r)
+}
+
+func isOutboundUnicast(r *Route) bool {
+ return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress)
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 52c27e045..659acbc7a 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -436,3 +437,122 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
})
}
}
+
+// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all
+// interested endpoints.
+func TestReuseAddrAndBroadcast(t *testing.T) {
+ const (
+ nicID = 1
+ localPort = 9000
+ loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
+ )
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+
+ tests := []struct {
+ name string
+ broadcastAddr tcpip.Address
+ }{
+ {
+ name: "Subnet directed broadcast",
+ broadcastAddr: loopbackBroadcast,
+ },
+ {
+ name: "IPv4 broadcast",
+ broadcastAddr: header.IPv4Broadcast,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ if err := s.CreateNIC(nicID, loopback.New()); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x7f\x00\x00\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ // We use the empty subnet instead of just the loopback subnet so we
+ // also have a route to the IPv4 Broadcast address.
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // We create endpoints that bind to both the wildcard address and the
+ // broadcast address to make sure both of these types of "broadcast
+ // interested" endpoints receive broadcast packets.
+ wq := waiter.Queue{}
+ var eps []tcpip.Endpoint
+ for _, bindWildcard := range []bool{false, true} {
+ // Create multiple endpoints for each type of "broadcast interested"
+ // endpoint so we can test that all endpoints receive the broadcast
+ // packet.
+ for i := 0; i < 2; i++ {
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err)
+ }
+
+ bindAddr := tcpip.FullAddress{Port: localPort}
+ if bindWildcard {
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ } else {
+ bindAddr.Addr = test.broadcastAddr
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err)
+ }
+ }
+
+ eps = append(eps, ep)
+ }
+ }
+
+ for i, wep := range eps {
+ writeOpts := tcpip.WriteOptions{
+ To: &tcpip.FullAddress{
+ Addr: test.broadcastAddr,
+ Port: localPort,
+ },
+ }
+ if n, _, err := wep.Write(data, writeOpts); err != nil {
+ t.Fatalf("eps[%d].Write(_, _): %s", i, err)
+ } else if want := int64(len(data)); n != want {
+ t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
+ }
+
+ for j, rep := range eps {
+ if gotPayload, _, err := rep.Read(nil); err != nil {
+ t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
+ } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index bc920a03b..cfd43b5e3 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -436,6 +436,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// Just silently drop any RST packets in TIME_WAIT. We do not support
// TIME_WAIT assasination as a result we confirm w/ fix 1 as described
// in https://tools.ietf.org/html/rfc1337#section-3.
+ //
+ // This behavior overrides RFC793 page 70 where we transition to CLOSED
+ // on receiving RST, which is also default Linux behavior.
+ // On Linux the RST can be ignored by setting sysctl net.ipv4.tcp_rfc1337.
+ //
+ // As we do not yet support PAWS, we are being conservative in ignoring
+ // RSTs by default.
if s.flagIsSet(header.TCPFlagRst) {
return false, false
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 0d13e1efd..b1e5f1b24 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5214,6 +5214,8 @@ func TestListenBacklogFull(t *testing.T) {
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+ subnet := context.StackAddrWithPrefix.Subnet()
+ subnetBroadcastAddr := subnet.Broadcast()
tests := []struct {
name string
@@ -5221,53 +5223,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) {
dstAddr tcpip.Address
}{
{
- "SourceUnspecified",
- header.IPv4Any,
- context.StackAddr,
+ name: "SourceUnspecified",
+ srcAddr: header.IPv4Any,
+ dstAddr: context.StackAddr,
},
{
- "SourceBroadcast",
- header.IPv4Broadcast,
- context.StackAddr,
+ name: "SourceBroadcast",
+ srcAddr: header.IPv4Broadcast,
+ dstAddr: context.StackAddr,
},
{
- "SourceOurMulticast",
- multicastAddr,
- context.StackAddr,
+ name: "SourceOurMulticast",
+ srcAddr: multicastAddr,
+ dstAddr: context.StackAddr,
},
{
- "SourceOtherMulticast",
- otherMulticastAddr,
- context.StackAddr,
+ name: "SourceOtherMulticast",
+ srcAddr: otherMulticastAddr,
+ dstAddr: context.StackAddr,
},
{
- "DestUnspecified",
- context.TestAddr,
- header.IPv4Any,
+ name: "DestUnspecified",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Any,
},
{
- "DestBroadcast",
- context.TestAddr,
- header.IPv4Broadcast,
+ name: "DestBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: header.IPv4Broadcast,
},
{
- "DestOurMulticast",
- context.TestAddr,
- multicastAddr,
+ name: "DestOurMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: multicastAddr,
},
{
- "DestOtherMulticast",
- context.TestAddr,
- otherMulticastAddr,
+ name: "DestOtherMulticast",
+ srcAddr: context.TestAddr,
+ dstAddr: otherMulticastAddr,
+ },
+ {
+ name: "SrcSubnetBroadcast",
+ srcAddr: subnetBroadcastAddr,
+ dstAddr: context.StackAddr,
+ },
+ {
+ name: "DestSubnetBroadcast",
+ srcAddr: context.TestAddr,
+ dstAddr: subnetBroadcastAddr,
},
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -5367,11 +5375,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) {
}
for _, test := range tests {
- test := test // capture range variable
-
t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index baf7df197..85e8c1c75 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -53,11 +53,11 @@ const (
TestPort = 4096
// StackV6Addr is the IPv6 address assigned to the stack.
- StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
// TestV6Addr is the source address for packets sent to the stack via
// the link layer endpoint.
- TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// StackV4MappedAddr is StackAddr as a mapped v6 address.
StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
@@ -73,6 +73,18 @@ const (
testInitialSequenceNumber = 789
)
+// StackAddrWithPrefix is StackAddr with its associated prefix length.
+var StackAddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackAddr,
+ PrefixLen: 24,
+}
+
+// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
+var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: StackV6Addr,
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+}
+
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
@@ -184,12 +196,20 @@ func New(t *testing.T, mtu uint32) *Context {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
}
s.SetRouteTable([]tcpip.Route{
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 2828b2c01..b572c39db 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -139,7 +139,7 @@ type endpoint struct {
// multicastMemberships that need to be remvoed when the endpoint is
// closed. Protected by the mu mutex.
- multicastMemberships []multicastMembership
+ multicastMemberships map[multicastMembership]struct{}
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -182,12 +182,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
// TTL=1.
//
// Linux defaults to TTL=1.
- multicastTTL: 1,
- multicastLoop: true,
- rcvBufSizeMax: 32 * 1024,
- sndBufSizeMax: 32 * 1024,
- state: StateInitial,
- uniqueID: s.UniqueID(),
+ multicastTTL: 1,
+ multicastLoop: true,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
+ multicastMemberships: make(map[multicastMembership]struct{}),
+ state: StateInitial,
+ uniqueID: s.UniqueID(),
}
// Override with stack defaults.
@@ -237,10 +238,10 @@ func (e *endpoint) Close() {
e.boundPortFlags = ports.Flags{}
}
- for _, mem := range e.multicastMemberships {
+ for mem := range e.multicastMemberships {
e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr)
}
- e.multicastMemberships = nil
+ e.multicastMemberships = make(map[multicastMembership]struct{})
// Close the receive list and drain it.
e.rcvMu.Lock()
@@ -752,17 +753,15 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- for _, mem := range e.multicastMemberships {
- if mem == memToInsert {
- return tcpip.ErrPortInUse
- }
+ if _, ok := e.multicastMemberships[memToInsert]; ok {
+ return tcpip.ErrPortInUse
}
if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
return err
}
- e.multicastMemberships = append(e.multicastMemberships, memToInsert)
+ e.multicastMemberships[memToInsert] = struct{}{}
case *tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
@@ -786,18 +785,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
- memToRemoveIndex := -1
e.mu.Lock()
defer e.mu.Unlock()
- for i, mem := range e.multicastMemberships {
- if mem == memToRemove {
- memToRemoveIndex = i
- break
- }
- }
- if memToRemoveIndex == -1 {
+ if _, ok := e.multicastMemberships[memToRemove]; !ok {
return tcpip.ErrBadLocalAddress
}
@@ -805,8 +797,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
return err
}
- e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
- e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ delete(e.multicastMemberships, memToRemove)
case *tcpip.BindToDeviceOption:
id := tcpip.NICID(*v)
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 851e6b635..858c99a45 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
- for _, m := range e.multicastMemberships {
+ for m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 0cbc045d8..d5881d183 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -2343,8 +2343,10 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
NIC: nicID1,
},
},
- remoteAddr: remNetSubnetBcast,
- requiresBroadcastOpt: true,
+ remoteAddr: remNetSubnetBcast,
+ // TODO(gvisor.dev/issue/3938): Once we support marking a route as
+ // broadcast, this test should require the broadcast option to be set.
+ requiresBroadcastOpt: false,
},
}