summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/arch_amd64.go23
-rw-r--r--pkg/abi/linux/epoll_amd64.go4
-rw-r--r--pkg/abi/linux/epoll_arm64.go4
-rw-r--r--pkg/abi/linux/mm.go17
-rw-r--r--pkg/abi/linux/seccomp.go7
-rw-r--r--pkg/flipcall/packet_window_allocator.go4
-rw-r--r--pkg/seccomp/seccomp_unsafe.go9
-rw-r--r--pkg/sentry/arch/signal_arm64.go30
-rw-r--r--pkg/sentry/fs/proc/task.go3
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go4
-rw-r--r--pkg/sentry/fs/tty/master.go4
-rw-r--r--pkg/sentry/fs/tty/queue.go4
-rw-r--r--pkg/sentry/fs/tty/slave.go4
-rw-r--r--pkg/sentry/fs/tty/terminal.go4
-rw-r--r--pkg/sentry/fs/user/BUILD34
-rw-r--r--pkg/sentry/fs/user/user.go237
-rw-r--r--pkg/sentry/fs/user/user_test.go198
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD43
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go207
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts_test.go56
-rw-r--r--pkg/sentry/fsimpl/devpts/line_discipline.go449
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go226
-rw-r--r--pkg/sentry/fsimpl/devpts/queue.go240
-rw-r--r--pkg/sentry/fsimpl/devpts/slave.go186
-rw-r--r--pkg/sentry/fsimpl/devpts/terminal.go124
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go13
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD2
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go3
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD2
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go147
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go333
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go177
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go2
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/handle_unsafe.go66
-rw-r--r--pkg/sentry/fsimpl/host/BUILD3
-rw-r--r--pkg/sentry/fsimpl/host/host.go31
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go9
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go13
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go2
-rw-r--r--pkg/sentry/hostfd/BUILD17
-rw-r--r--pkg/sentry/hostfd/hostfd.go84
-rw-r--r--pkg/sentry/hostfd/hostfd_unsafe.go107
-rw-r--r--pkg/sentry/kernel/epoll/BUILD1
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go20
-rw-r--r--pkg/sentry/kernel/task.go2
-rw-r--r--pkg/sentry/kernel/task_run.go1
-rw-r--r--pkg/sentry/kernel/task_sched.go4
-rw-r--r--pkg/sentry/kernel/task_signals.go13
-rw-r--r--pkg/sentry/loader/loader.go9
-rw-r--r--pkg/sentry/loader/vdso.go25
-rw-r--r--pkg/sentry/mm/mm.go3
-rw-r--r--pkg/sentry/mm/procfs.go4
-rw-r--r--pkg/sentry/mm/syscalls.go4
-rw-r--r--pkg/sentry/platform/kvm/context.go3
-rw-r--r--pkg/sentry/platform/kvm/kvm.go5
-rw-r--r--pkg/sentry/platform/platform.go21
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go13
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go2
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go5
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go5
-rw-r--r--pkg/sentry/syscalls/epoll.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go27
-rw-r--r--pkg/sentry/syscalls/linux/sys_mempolicy.go18
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go7
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go44
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go6
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go4
-rw-r--r--pkg/sentry/vfs/file_description.go15
-rw-r--r--pkg/sentry/vfs/filesystem.go3
-rw-r--r--pkg/sentry/vfs/options.go19
-rw-r--r--pkg/tcpip/buffer/view.go55
-rw-r--r--pkg/tcpip/buffer/view_test.go113
-rw-r--r--pkg/tcpip/link/loopback/loopback.go10
-rw-r--r--pkg/tcpip/link/rawfile/BUILD9
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_test.go46
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go2
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go65
-rw-r--r--pkg/tcpip/network/arp/arp.go5
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go20
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go12
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go74
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go3
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go6
-rw-r--r--pkg/tcpip/stack/forwarder_test.go13
-rw-r--r--pkg/tcpip/stack/iptables.go22
-rw-r--r--pkg/tcpip/stack/iptables_targets.go23
-rw-r--r--pkg/tcpip/stack/nic.go34
-rw-r--r--pkg/tcpip/stack/packet_buffer.go8
-rw-r--r--pkg/tcpip/stack/stack_test.go10
-rw-r--r--pkg/tcpip/stack/transport_test.go5
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go43
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go5
-rw-r--r--pkg/tcpip/transport/tcp/segment.go29
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/udp/protocol.go9
-rw-r--r--pkg/test/criutil/BUILD14
-rw-r--r--pkg/test/criutil/criutil.go306
-rw-r--r--pkg/test/dockerutil/BUILD14
-rw-r--r--pkg/test/dockerutil/dockerutil.go581
-rw-r--r--pkg/test/testutil/BUILD20
-rw-r--r--pkg/test/testutil/testutil.go550
-rw-r--r--pkg/test/testutil/testutil_runfiles.go75
111 files changed, 4992 insertions, 646 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 322d1ccc4..59b0e138a 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -10,6 +10,7 @@ go_library(
name = "linux",
srcs = [
"aio.go",
+ "arch_amd64.go",
"audit.go",
"bpf.go",
"capability.go",
diff --git a/pkg/abi/linux/arch_amd64.go b/pkg/abi/linux/arch_amd64.go
new file mode 100644
index 000000000..0be31e755
--- /dev/null
+++ b/pkg/abi/linux/arch_amd64.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// Start and end addresses of the vsyscall page.
+const (
+ VSyscallStartAddr uint64 = 0xffffffffff600000
+ VSyscallEndAddr uint64 = 0xffffffffff601000
+)
diff --git a/pkg/abi/linux/epoll_amd64.go b/pkg/abi/linux/epoll_amd64.go
index 34ff18009..7e74b1143 100644
--- a/pkg/abi/linux/epoll_amd64.go
+++ b/pkg/abi/linux/epoll_amd64.go
@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package linux
// EpollEvent is equivalent to struct epoll_event from epoll(2).
//
-// +marshal
+// +marshal slice:EpollEventSlice
type EpollEvent struct {
Events uint32
// Linux makes struct epoll_event::data a __u64. We represent it as
diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go
index f86c35329..a35939cc9 100644
--- a/pkg/abi/linux/epoll_arm64.go
+++ b/pkg/abi/linux/epoll_arm64.go
@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package linux
// EpollEvent is equivalent to struct epoll_event from epoll(2).
//
-// +marshal
+// +marshal slice:EpollEventSlice
type EpollEvent struct {
Events uint32
// Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding
diff --git a/pkg/abi/linux/mm.go b/pkg/abi/linux/mm.go
index cd043dac3..07cc1895e 100644
--- a/pkg/abi/linux/mm.go
+++ b/pkg/abi/linux/mm.go
@@ -90,14 +90,19 @@ const (
MS_SYNC = 1 << 2
)
+// NumaPolicy is the NUMA memory policy for a memory range. See numa(7).
+//
+// +marshal
+type NumaPolicy int32
+
// Policies for get_mempolicy(2)/set_mempolicy(2).
const (
- MPOL_DEFAULT = 0
- MPOL_PREFERRED = 1
- MPOL_BIND = 2
- MPOL_INTERLEAVE = 3
- MPOL_LOCAL = 4
- MPOL_MAX = 5
+ MPOL_DEFAULT NumaPolicy = 0
+ MPOL_PREFERRED NumaPolicy = 1
+ MPOL_BIND NumaPolicy = 2
+ MPOL_INTERLEAVE NumaPolicy = 3
+ MPOL_LOCAL NumaPolicy = 4
+ MPOL_MAX NumaPolicy = 5
)
// Flags for get_mempolicy(2).
diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go
index 4eeb5cd7a..d0607e256 100644
--- a/pkg/abi/linux/seccomp.go
+++ b/pkg/abi/linux/seccomp.go
@@ -63,3 +63,10 @@ func (a BPFAction) String() string {
func (a BPFAction) Data() uint16 {
return uint16(a & SECCOMP_RET_DATA)
}
+
+// SockFprog is sock_fprog taken from <linux/filter.h>.
+type SockFprog struct {
+ Len uint16
+ pad [6]byte
+ Filter *BPFInstruction
+}
diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window_allocator.go
index ccb918fab..af9cc3d21 100644
--- a/pkg/flipcall/packet_window_allocator.go
+++ b/pkg/flipcall/packet_window_allocator.go
@@ -134,7 +134,7 @@ func (pwa *PacketWindowAllocator) Allocate(size int) (PacketWindowDescriptor, er
start := pwa.nextAlloc
pwa.nextAlloc = end
return PacketWindowDescriptor{
- FD: pwa.fd,
+ FD: pwa.FD(),
Offset: start,
Length: size,
}, nil
@@ -158,7 +158,7 @@ func (pwa *PacketWindowAllocator) ensureFileSize(min int64) error {
}
newSize = newNewSize
}
- if err := syscall.Ftruncate(pwa.fd, newSize); err != nil {
+ if err := syscall.Ftruncate(pwa.FD(), newSize); err != nil {
return fmt.Errorf("ftruncate failed: %v", err)
}
pwa.fileSize = newSize
diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go
index be328db12..f7e986589 100644
--- a/pkg/seccomp/seccomp_unsafe.go
+++ b/pkg/seccomp/seccomp_unsafe.go
@@ -21,13 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// sockFprog is sock_fprog taken from <linux/filter.h>.
-type sockFprog struct {
- Len uint16
- pad [6]byte
- Filter *linux.BPFInstruction
-}
-
// SetFilter installs the given BPF program.
//
// This is safe to call from an afterFork context.
@@ -39,7 +32,7 @@ func SetFilter(instrs []linux.BPFInstruction) syscall.Errno {
return errno
}
- sockProg := sockFprog{
+ sockProg := linux.SockFprog{
Len: uint16(len(instrs)),
Filter: (*linux.BPFInstruction)(unsafe.Pointer(&instrs[0])),
}
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index 0c1db4b13..1cb1adf8c 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -98,9 +98,12 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
if ucSize < 0 {
panic("can't get size of UContext64")
}
- // st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128.
- frameSize := int(st.Arch.Width()) + ucSize + 128
- frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8
+
+ // frameSize = ucSize + sizeof(siginfo).
+ // sizeof(siginfo) == 128.
+ // R30 stores the restorer address.
+ frameSize := ucSize + 128
+ frameBottom := (sp - usermem.Addr(frameSize)) & ^usermem.Addr(15)
sp = frameBottom + usermem.Addr(frameSize)
st.Bottom = sp
@@ -130,12 +133,27 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt
c.Regs.Regs[0] = uint64(info.Signo)
c.Regs.Regs[1] = uint64(infoAddr)
c.Regs.Regs[2] = uint64(ucAddr)
-
+ c.Regs.Regs[30] = uint64(act.Restorer)
return nil
}
// SignalRestore implements Context.SignalRestore.
-// Only used on intel.
func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
- return 0, SignalStack{}, nil
+ // Copy out the stack frame.
+ var uc UContext64
+ if _, err := st.Pop(&uc); err != nil {
+ return 0, SignalStack{}, err
+ }
+ var info SignalInfo
+ if _, err := st.Pop(&info); err != nil {
+ return 0, SignalStack{}, err
+ }
+
+ // Restore registers.
+ c.Regs.Regs = uc.MContext.Regs
+ c.Regs.Pc = uc.MContext.Pc
+ c.Regs.Sp = uc.MContext.Sp
+ c.Regs.Pstate = uc.MContext.Pstate
+
+ return uc.Sigset, uc.Stack, nil
}
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 4d42eac83..4bbe90198 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -73,8 +73,7 @@ func checkTaskState(t *kernel.Task) error {
type taskDir struct {
ramfs.Dir
- t *kernel.Task
- pidns *kernel.PIDNamespace
+ t *kernel.Task
}
var _ fs.InodeOperations = (*taskDir)(nil)
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
index 12b1c6097..2e9dd2d55 100644
--- a/pkg/sentry/fs/tty/line_discipline.go
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -27,6 +27,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// canonMaxBytes is the number of bytes that fit into a single line of
// terminal input in canonical mode. This corresponds to N_TTY_BUF_SIZE
@@ -443,3 +445,5 @@ func (l *lineDiscipline) peek(b []byte) int {
}
return size
}
+
+// LINT.ThenChange(../../fsimpl/devpts/line_discipline.go)
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index f62da49bd..fe07fa929 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -26,6 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// masterInodeOperations are the fs.InodeOperations for the master end of the
// Terminal (ptmx file).
//
@@ -232,3 +234,5 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
unimpl.EmitUnimplementedEvent(ctx)
}
}
+
+// LINT.ThenChange(../../fsimpl/devpts/master.go)
diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go
index 1ca79c0b2..ceabb9b1e 100644
--- a/pkg/sentry/fs/tty/queue.go
+++ b/pkg/sentry/fs/tty/queue.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// waitBufMaxBytes is the maximum size of a wait buffer. It is based on
// TTYB_DEFAULT_MEM_LIMIT.
const waitBufMaxBytes = 131072
@@ -234,3 +236,5 @@ func (q *queue) waitBufAppend(b []byte) {
q.waitBuf = append(q.waitBuf, b)
q.waitBufLen += uint64(len(b))
}
+
+// LINT.ThenChange(../../fsimpl/devpts/queue.go)
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index 6a2dbc576..9871f6fc6 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// slaveInodeOperations are the fs.InodeOperations for the slave end of the
// Terminal (pts file).
//
@@ -172,3 +174,5 @@ func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem
return 0, syserror.ENOTTY
}
}
+
+// LINT.ThenChange(../../fsimpl/devpts/slave.go)
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
index 5883f26db..ddcccf4da 100644
--- a/pkg/sentry/fs/tty/terminal.go
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -23,6 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Terminal is a pseudoterminal.
//
// +stateify savable
@@ -126,3 +128,5 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
}
return tm.slaveKTTY
}
+
+// LINT.ThenChange(../../fsimpl/devpts/terminal.go)
diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD
new file mode 100644
index 000000000..f37f979f1
--- /dev/null
+++ b/pkg/sentry/fs/user/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "user",
+ srcs = ["user.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "user_test",
+ size = "small",
+ srcs = ["user_test.go"],
+ library = ":user",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/user/user.go b/pkg/sentry/fs/user/user.go
new file mode 100644
index 000000000..fe7f67c00
--- /dev/null
+++ b/pkg/sentry/fs/user/user.go
@@ -0,0 +1,237 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package user
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type fileReader struct {
+ // Ctx is the context for the file reader.
+ Ctx context.Context
+
+ // File is the file to read from.
+ File *fs.File
+}
+
+// Read implements io.Reader.Read.
+func (r *fileReader) Read(buf []byte) (int, error) {
+ n, err := r.File.Readv(r.Ctx, usermem.BytesIOSequence(buf))
+ return int(n), err
+}
+
+// getExecUserHome returns the home directory of the executing user read from
+// /etc/passwd as read from the container filesystem.
+func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.KUID) (string, error) {
+ // The default user home directory to return if no user matching the user
+ // if found in the /etc/passwd found in the image.
+ const defaultHome = "/"
+
+ // Open the /etc/passwd file from the dirent via the root mount namespace.
+ mnsRoot := rootMns.Root()
+ maxTraversals := uint(linux.MaxSymlinkTraversals)
+ dirent, err := rootMns.FindInode(ctx, mnsRoot, nil, "/etc/passwd", &maxTraversals)
+ if err != nil {
+ // NOTE: Ignore errors opening the passwd file. If the passwd file
+ // doesn't exist we will return the default home directory.
+ return defaultHome, nil
+ }
+ defer dirent.DecRef()
+
+ // Check read permissions on the file.
+ if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Read: true}); err != nil {
+ // NOTE: Ignore permissions errors here and return default root dir.
+ return defaultHome, nil
+ }
+
+ // Only open regular files. We don't open other files like named pipes as
+ // they may block and might present some attack surface to the container.
+ // Note that runc does not seem to do this kind of checking.
+ if !fs.IsRegular(dirent.Inode.StableAttr) {
+ return defaultHome, nil
+ }
+
+ f, err := dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true, Directory: false})
+ if err != nil {
+ return "", err
+ }
+ defer f.DecRef()
+
+ r := &fileReader{
+ Ctx: ctx,
+ File: f,
+ }
+
+ return findHomeInPasswd(uint32(uid), r, defaultHome)
+}
+
+type fileReaderVFS2 struct {
+ ctx context.Context
+ fd *vfs.FileDescription
+}
+
+func (r *fileReaderVFS2) Read(buf []byte) (int, error) {
+ n, err := r.fd.Read(r.ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ return int(n), err
+}
+
+func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.KUID) (string, error) {
+ const defaultHome = "/"
+
+ root := mns.Root()
+ defer root.DecRef()
+
+ creds := auth.CredentialsFromContext(ctx)
+
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse("/etc/passwd"),
+ }
+
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }
+
+ fd, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, target, opts)
+ if err != nil {
+ return defaultHome, nil
+ }
+ defer fd.DecRef()
+
+ r := &fileReaderVFS2{
+ ctx: ctx,
+ fd: fd,
+ }
+
+ homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome)
+ if err != nil {
+ return "", err
+ }
+
+ return homeDir, nil
+}
+
+// MaybeAddExecUserHome returns a new slice with the HOME enviroment variable
+// set if the slice does not already contain it, otherwise it returns the
+// original slice unmodified.
+func MaybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHome(ctx, mns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
+
+ return append(envv, "HOME="+homeDir), nil
+}
+
+// MaybeAddExecUserHomeVFS2 returns a new slice with the HOME enviroment
+// variable set if the slice does not already contain it, otherwise it returns
+// the original slice unmodified.
+func MaybeAddExecUserHomeVFS2(ctx context.Context, vmns *vfs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHomeVFS2(ctx, vmns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
+ return append(envv, "HOME="+homeDir), nil
+}
+
+// findHomeInPasswd parses a passwd file and returns the given user's home
+// directory. This function does it's best to replicate the runc's behavior.
+func findHomeInPasswd(uid uint32, passwd io.Reader, defaultHome string) (string, error) {
+ s := bufio.NewScanner(passwd)
+
+ for s.Scan() {
+ if err := s.Err(); err != nil {
+ return "", err
+ }
+
+ line := strings.TrimSpace(s.Text())
+ if line == "" {
+ continue
+ }
+
+ // Pull out part of passwd entry. Loosely parse the passwd entry as some
+ // passwd files could be poorly written and for compatibility with runc.
+ //
+ // Per 'man 5 passwd'
+ // /etc/passwd contains one line for each user account, with seven
+ // fields delimited by colons (“:”). These fields are:
+ //
+ // - login name
+ // - optional encrypted password
+ // - numerical user ID
+ // - numerical group ID
+ // - user name or comment field
+ // - user home directory
+ // - optional user command interpreter
+ parts := strings.Split(line, ":")
+
+ found := false
+ homeDir := ""
+ for i, p := range parts {
+ switch i {
+ case 2:
+ parsedUID, err := strconv.ParseUint(p, 10, 32)
+ if err == nil && parsedUID == uint64(uid) {
+ found = true
+ }
+ case 5:
+ homeDir = p
+ }
+ }
+ if found {
+ // NOTE: If the uid is present but the home directory is not
+ // present in the /etc/passwd entry we return an empty string. This
+ // is, for better or worse, what runc does.
+ return homeDir, nil
+ }
+ }
+
+ return defaultHome, nil
+}
diff --git a/pkg/sentry/fs/user/user_test.go b/pkg/sentry/fs/user/user_test.go
new file mode 100644
index 000000000..7d8e9ac7c
--- /dev/null
+++ b/pkg/sentry/fs/user/user_test.go
@@ -0,0 +1,198 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package user
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// createEtcPasswd creates /etc/passwd with the given contents and mode. If
+// mode is empty, then no file will be created. If mode is not a regular file
+// mode, then contents is ignored.
+func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode linux.FileMode) error {
+ if err := root.CreateDirectory(ctx, root, "etc", fs.FilePermsFromMode(0755)); err != nil {
+ return err
+ }
+ etc, err := root.Walk(ctx, root, "etc")
+ if err != nil {
+ return err
+ }
+ defer etc.DecRef()
+ switch mode.FileType() {
+ case 0:
+ // Don't create anything.
+ return nil
+ case linux.S_IFREG:
+ passwd, err := etc.Create(ctx, root, "passwd", fs.FileFlags{Write: true}, fs.FilePermsFromMode(mode))
+ if err != nil {
+ return err
+ }
+ defer passwd.DecRef()
+ if _, err := passwd.Writev(ctx, usermem.BytesIOSequence([]byte(contents))); err != nil {
+ return err
+ }
+ return nil
+ case linux.S_IFDIR:
+ return etc.CreateDirectory(ctx, root, "passwd", fs.FilePermsFromMode(mode))
+ case linux.S_IFIFO:
+ return etc.CreateFifo(ctx, root, "passwd", fs.FilePermsFromMode(mode))
+ default:
+ return fmt.Errorf("unknown file type %x", mode.FileType())
+ }
+}
+
+// TestGetExecUserHome tests the getExecUserHome function.
+func TestGetExecUserHome(t *testing.T) {
+ tests := map[string]struct {
+ uid auth.KUID
+ passwdContents string
+ passwdMode linux.FileMode
+ expected string
+ }{
+ "success": {
+ uid: 1000,
+ passwdContents: "adin::1000:1111::/home/adin:/bin/sh",
+ passwdMode: linux.S_IFREG | 0666,
+ expected: "/home/adin",
+ },
+ "no_perms": {
+ uid: 1000,
+ passwdContents: "adin::1000:1111::/home/adin:/bin/sh",
+ passwdMode: linux.S_IFREG,
+ expected: "/",
+ },
+ "no_passwd": {
+ uid: 1000,
+ expected: "/",
+ },
+ "directory": {
+ uid: 1000,
+ passwdMode: linux.S_IFDIR | 0666,
+ expected: "/",
+ },
+ // Currently we don't allow named pipes.
+ "named_pipe": {
+ uid: 1000,
+ passwdMode: linux.S_IFIFO | 0666,
+ expected: "/",
+ },
+ }
+
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ ctx := contexttest.Context(t)
+ msrc := fs.NewPseudoMountSource(ctx)
+ rootInode := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc)
+
+ mns, err := fs.NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ t.Fatalf("NewMountNamespace failed: %v", err)
+ }
+ defer mns.DecRef()
+ root := mns.Root()
+ defer root.DecRef()
+ ctx = fs.WithRoot(ctx, root)
+
+ if err := createEtcPasswd(ctx, root, tc.passwdContents, tc.passwdMode); err != nil {
+ t.Fatalf("createEtcPasswd failed: %v", err)
+ }
+
+ got, err := getExecUserHome(ctx, mns, tc.uid)
+ if err != nil {
+ t.Fatalf("failed to get user home: %v", err)
+ }
+
+ if got != tc.expected {
+ t.Fatalf("expected %v, got: %v", tc.expected, got)
+ }
+ })
+ }
+}
+
+// TestFindHomeInPasswd tests the findHomeInPasswd function's passwd file parsing.
+func TestFindHomeInPasswd(t *testing.T) {
+ tests := map[string]struct {
+ uid uint32
+ passwd string
+ expected string
+ def string
+ }{
+ "empty": {
+ uid: 1000,
+ passwd: "",
+ expected: "/",
+ def: "/",
+ },
+ "whitespace": {
+ uid: 1000,
+ passwd: " ",
+ expected: "/",
+ def: "/",
+ },
+ "full": {
+ uid: 1000,
+ passwd: "adin::1000:1111::/home/adin:/bin/sh",
+ expected: "/home/adin",
+ def: "/",
+ },
+ // For better or worse, this is how runc works.
+ "partial": {
+ uid: 1000,
+ passwd: "adin::1000:1111:",
+ expected: "",
+ def: "/",
+ },
+ "multiple": {
+ uid: 1001,
+ passwd: "adin::1000:1111::/home/adin:/bin/sh\nian::1001:1111::/home/ian:/bin/sh",
+ expected: "/home/ian",
+ def: "/",
+ },
+ "duplicate": {
+ uid: 1000,
+ passwd: "adin::1000:1111::/home/adin:/bin/sh\nian::1000:1111::/home/ian:/bin/sh",
+ expected: "/home/adin",
+ def: "/",
+ },
+ "empty_lines": {
+ uid: 1001,
+ passwd: "adin::1000:1111::/home/adin:/bin/sh\n\n\nian::1001:1111::/home/ian:/bin/sh",
+ expected: "/home/ian",
+ def: "/",
+ },
+ }
+
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ got, err := findHomeInPasswd(tc.uid, strings.NewReader(tc.passwd), tc.def)
+ if err != nil {
+ t.Fatalf("error parsing passwd: %v", err)
+ }
+ if tc.expected != got {
+ t.Fatalf("expected %v, got: %v", tc.expected, got)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
new file mode 100644
index 000000000..585764223
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "devpts",
+ srcs = [
+ "devpts.go",
+ "line_discipline.go",
+ "master.go",
+ "queue.go",
+ "slave.go",
+ "terminal.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "devpts_test",
+ size = "small",
+ srcs = ["devpts_test.go"],
+ library = ":devpts",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/contexttest",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
new file mode 100644
index 000000000..07a69b940
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -0,0 +1,207 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package devpts provides a filesystem implementation that behaves like
+// devpts.
+package devpts
+
+import (
+ "fmt"
+ "math"
+ "sort"
+ "strconv"
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the filesystem name.
+const Name = "devpts"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+var _ vfs.FilesystemType = (*FilesystemType)(nil)
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ // No data allowed.
+ if opts.Data != "" {
+ return nil, nil, syserror.EINVAL
+ }
+
+ fs, root := fstype.newFilesystem(vfsObj, creds)
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// newFilesystem creates a new devpts filesystem with root directory and ptmx
+// master inode. It returns the filesystem and root Dentry.
+func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*kernfs.Filesystem, *kernfs.Dentry) {
+ fs := &kernfs.Filesystem{}
+ fs.Init(vfsObj, fstype)
+
+ // Construct the root directory. This is always inode id 1.
+ root := &rootInode{
+ slaves: make(map[uint32]*slaveInode),
+ }
+ root.InodeAttrs.Init(creds, 1, linux.ModeDirectory|0555)
+ root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ root.dentry.Init(root)
+
+ // Construct the pts master inode and dentry. Linux always uses inode
+ // id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx.
+ master := &masterInode{
+ root: root,
+ }
+ master.InodeAttrs.Init(creds, 2, linux.ModeCharacterDevice|0666)
+ master.dentry.Init(master)
+
+ // Add the master as a child of the root.
+ links := root.OrderedChildren.Populate(&root.dentry, map[string]*kernfs.Dentry{
+ "ptmx": &master.dentry,
+ })
+ root.IncLinks(links)
+
+ return fs, &root.dentry
+}
+
+// rootInode is the root directory inode for the devpts mounts.
+type rootInode struct {
+ kernfs.AlwaysValid
+ kernfs.InodeAttrs
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeNotSymlink
+ kernfs.OrderedChildren
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // master is the master pty inode. Immutable.
+ master *masterInode
+
+ // root is the root directory inode for this filesystem. Immutable.
+ root *rootInode
+
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // slaves maps pty ids to slave inodes.
+ slaves map[uint32]*slaveInode
+
+ // nextIdx is the next pty index to use. Must be accessed atomically.
+ //
+ // TODO(b/29356795): reuse indices when ptys are closed.
+ nextIdx uint32
+}
+
+var _ kernfs.Inode = (*rootInode)(nil)
+
+// allocateTerminal creates a new Terminal and installs a pts node for it.
+func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ if i.nextIdx == math.MaxUint32 {
+ return nil, syserror.ENOMEM
+ }
+ idx := i.nextIdx
+ i.nextIdx++
+
+ // Sanity check that slave with idx does not exist.
+ if _, ok := i.slaves[idx]; ok {
+ panic(fmt.Sprintf("pty index collision; index %d already exists", idx))
+ }
+
+ // Create the new terminal and slave.
+ t := newTerminal(idx)
+ slave := &slaveInode{
+ root: i,
+ t: t,
+ }
+ // Linux always uses pty index + 3 as the inode id. See
+ // fs/devpts/inode.c:devpts_pty_new().
+ slave.InodeAttrs.Init(creds, uint64(idx+3), linux.ModeCharacterDevice|0600)
+ slave.dentry.Init(slave)
+ i.slaves[idx] = slave
+
+ return t, nil
+}
+
+// masterClose is called when the master end of t is closed.
+func (i *rootInode) masterClose(t *Terminal) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ // Sanity check that slave with idx exists.
+ if _, ok := i.slaves[t.n]; !ok {
+ panic(fmt.Sprintf("pty with index %d does not exist", t.n))
+ }
+ delete(i.slaves, t.n)
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *rootInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
+ return fd.VFSFileDescription(), nil
+}
+
+// Lookup implements kernfs.Inode.Lookup.
+func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ idx, err := strconv.ParseUint(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ if si, ok := i.slaves[uint32(idx)]; ok {
+ si.dentry.IncRef()
+ return si.dentry.VFSDentry(), nil
+
+ }
+ return nil, syserror.ENOENT
+}
+
+// IterDirents implements kernfs.Inode.IterDirents.
+func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ ids := make([]int, 0, len(i.slaves))
+ for id := range i.slaves {
+ ids = append(ids, int(id))
+ }
+ sort.Ints(ids)
+ for _, id := range ids[relOffset:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(id), 10),
+ Type: linux.DT_CHR,
+ Ino: i.slaves[uint32(id)].InodeAttrs.Ino(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
diff --git a/pkg/sentry/fsimpl/devpts/devpts_test.go b/pkg/sentry/fsimpl/devpts/devpts_test.go
new file mode 100644
index 000000000..b7c149047
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/devpts_test.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestSimpleMasterToSlave(t *testing.T) {
+ ld := newLineDiscipline(linux.DefaultSlaveTermios)
+ ctx := contexttest.Context(t)
+ inBytes := []byte("hello, tty\n")
+ src := usermem.BytesIOSequence(inBytes)
+ outBytes := make([]byte, 32)
+ dst := usermem.BytesIOSequence(outBytes)
+
+ // Write to the input queue.
+ nw, err := ld.inputQueueWrite(ctx, src)
+ if err != nil {
+ t.Fatalf("error writing to input queue: %v", err)
+ }
+ if nw != int64(len(inBytes)) {
+ t.Fatalf("wrote wrong length: got %d, want %d", nw, len(inBytes))
+ }
+
+ // Read from the input queue.
+ nr, err := ld.inputQueueRead(ctx, dst)
+ if err != nil {
+ t.Fatalf("error reading from input queue: %v", err)
+ }
+ if nr != int64(len(inBytes)) {
+ t.Fatalf("read wrong length: got %d, want %d", nr, len(inBytes))
+ }
+
+ outStr := string(outBytes[:nr])
+ inStr := string(inBytes)
+ if outStr != inStr {
+ t.Fatalf("written and read strings do not match: got %q, want %q", outStr, inStr)
+ }
+}
diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go
new file mode 100644
index 000000000..e201801d6
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/line_discipline.go
@@ -0,0 +1,449 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "bytes"
+ "unicode/utf8"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+const (
+ // canonMaxBytes is the number of bytes that fit into a single line of
+ // terminal input in canonical mode. This corresponds to N_TTY_BUF_SIZE
+ // in include/linux/tty.h.
+ canonMaxBytes = 4096
+
+ // nonCanonMaxBytes is the maximum number of bytes that can be read at
+ // a time in noncanonical mode.
+ nonCanonMaxBytes = canonMaxBytes - 1
+
+ spacesPerTab = 8
+)
+
+// lineDiscipline dictates how input and output are handled between the
+// pseudoterminal (pty) master and slave. It can be configured to alter I/O,
+// modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man
+// pages are good resources for how to affect the line discipline:
+//
+// * termios(3)
+// * tty_ioctl(4)
+//
+// This file corresponds most closely to drivers/tty/n_tty.c.
+//
+// lineDiscipline has a simple structure but supports a multitude of options
+// (see the above man pages). It consists of two queues of bytes: one from the
+// terminal master to slave (the input queue) and one from slave to master (the
+// output queue). When bytes are written to one end of the pty, the line
+// discipline reads the bytes, modifies them or takes special action if
+// required, and enqueues them to be read by the other end of the pty:
+//
+// input from terminal +-------------+ input to process (e.g. bash)
+// +------------------------>| input queue |---------------------------+
+// | (inputQueueWrite) +-------------+ (inputQueueRead) |
+// | |
+// | v
+// masterFD slaveFD
+// ^ |
+// | |
+// | output to terminal +--------------+ output from process |
+// +------------------------| output queue |<--------------------------+
+// (outputQueueRead) +--------------+ (outputQueueWrite)
+//
+// Lock order:
+// termiosMu
+// inQueue.mu
+// outQueue.mu
+//
+// +stateify savable
+type lineDiscipline struct {
+ // sizeMu protects size.
+ sizeMu sync.Mutex `state:"nosave"`
+
+ // size is the terminal size (width and height).
+ size linux.WindowSize
+
+ // inQueue is the input queue of the terminal.
+ inQueue queue
+
+ // outQueue is the output queue of the terminal.
+ outQueue queue
+
+ // termiosMu protects termios.
+ termiosMu sync.RWMutex `state:"nosave"`
+
+ // termios is the terminal configuration used by the lineDiscipline.
+ termios linux.KernelTermios
+
+ // column is the location in a row of the cursor. This is important for
+ // handling certain special characters like backspace.
+ column int
+
+ // masterWaiter is used to wait on the master end of the TTY.
+ masterWaiter waiter.Queue `state:"zerovalue"`
+
+ // slaveWaiter is used to wait on the slave end of the TTY.
+ slaveWaiter waiter.Queue `state:"zerovalue"`
+}
+
+func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
+ ld := lineDiscipline{termios: termios}
+ ld.inQueue.transformer = &inputQueueTransformer{}
+ ld.outQueue.transformer = &outputQueueTransformer{}
+ return &ld
+}
+
+// getTermios gets the linux.Termios for the tty.
+func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ // We must copy a Termios struct, not KernelTermios.
+ t := l.termios.ToTermios()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// setTermios sets a linux.Termios for the tty.
+func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.Lock()
+ defer l.termiosMu.Unlock()
+ oldCanonEnabled := l.termios.LEnabled(linux.ICANON)
+ // We must copy a Termios struct, not KernelTermios.
+ var t linux.Termios
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ l.termios.FromTermios(t)
+
+ // If canonical mode is turned off, move bytes from inQueue's wait
+ // buffer to its read buffer. Anything already in the read buffer is
+ // now readable.
+ if oldCanonEnabled && !l.termios.LEnabled(linux.ICANON) {
+ l.inQueue.mu.Lock()
+ l.inQueue.pushWaitBufLocked(l)
+ l.inQueue.readable = true
+ l.inQueue.mu.Unlock()
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+
+ return 0, err
+}
+
+func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) masterReadiness() waiter.EventMask {
+ // We don't have to lock a termios because the default master termios
+ // is immutable.
+ return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios)
+}
+
+func (l *lineDiscipline) slaveReadiness() waiter.EventMask {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios)
+}
+
+func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.inQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.inQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.inQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.outQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.outQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.outQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// transformer is a helper interface to make it easier to stateify queue.
+type transformer interface {
+ // transform functions require queue's mutex to be held.
+ transform(*lineDiscipline, *queue, []byte) int
+}
+
+// outputQueueTransformer implements transformer. It performs line discipline
+// transformations on the output queue.
+//
+// +stateify savable
+type outputQueueTransformer struct{}
+
+// transform does output processing for one end of the pty. See
+// drivers/tty/n_tty.c:do_output_char for an analogous kernel function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*outputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // transformOutput is effectively always in noncanonical mode, as the
+ // master termios never has ICANON set.
+
+ if !l.termios.OEnabled(linux.OPOST) {
+ q.readBuf = append(q.readBuf, buf...)
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return len(buf)
+ }
+
+ var ret int
+ for len(buf) > 0 {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ ret += size
+ buf = buf[size:]
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\n':
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ if l.termios.OEnabled(linux.ONLCR) {
+ q.readBuf = append(q.readBuf, '\r', '\n')
+ continue
+ }
+ case '\r':
+ if l.termios.OEnabled(linux.ONOCR) && l.column == 0 {
+ continue
+ }
+ if l.termios.OEnabled(linux.OCRNL) {
+ cBytes[0] = '\n'
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ break
+ }
+ l.column = 0
+ case '\t':
+ spaces := spacesPerTab - l.column%spacesPerTab
+ if l.termios.OutputFlags&linux.TABDLY == linux.XTABS {
+ l.column += spaces
+ q.readBuf = append(q.readBuf, bytes.Repeat([]byte{' '}, spacesPerTab)...)
+ continue
+ }
+ l.column += spaces
+ case '\b':
+ if l.column > 0 {
+ l.column--
+ }
+ default:
+ l.column++
+ }
+ q.readBuf = append(q.readBuf, cBytes...)
+ }
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return ret
+}
+
+// inputQueueTransformer implements transformer. It performs line discipline
+// transformations on the input queue.
+//
+// +stateify savable
+type inputQueueTransformer struct{}
+
+// transform does input processing for one end of the pty. Characters read are
+// transformed according to flags set in the termios struct. See
+// drivers/tty/n_tty.c:n_tty_receive_char_special for an analogous kernel
+// function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*inputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // If there's a line waiting to be read in canonical mode, don't write
+ // anything else to the read buffer.
+ if l.termios.LEnabled(linux.ICANON) && q.readable {
+ return 0
+ }
+
+ maxBytes := nonCanonMaxBytes
+ if l.termios.LEnabled(linux.ICANON) {
+ maxBytes = canonMaxBytes
+ }
+
+ var ret int
+ for len(buf) > 0 && len(q.readBuf) < canonMaxBytes {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\r':
+ if l.termios.IEnabled(linux.IGNCR) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+ if l.termios.IEnabled(linux.ICRNL) {
+ cBytes[0] = '\n'
+ }
+ case '\n':
+ if l.termios.IEnabled(linux.INLCR) {
+ cBytes[0] = '\r'
+ }
+ }
+
+ // In canonical mode, we discard non-terminating characters
+ // after the first 4095.
+ if l.shouldDiscard(q, cBytes) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+
+ // Stop if the buffer would be overfilled.
+ if len(q.readBuf)+size > maxBytes {
+ break
+ }
+ buf = buf[size:]
+ ret += size
+
+ // If we get EOF, make the buffer available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsEOF(cBytes[0]) {
+ q.readable = true
+ break
+ }
+
+ q.readBuf = append(q.readBuf, cBytes...)
+
+ // Anything written to the readBuf will have to be echoed.
+ if l.termios.LEnabled(linux.ECHO) {
+ l.outQueue.writeBytes(cBytes, l)
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+
+ // If we finish a line, make it available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsTerminating(cBytes) {
+ q.readable = true
+ break
+ }
+ }
+
+ // In noncanonical mode, everything is readable.
+ if !l.termios.LEnabled(linux.ICANON) && len(q.readBuf) > 0 {
+ q.readable = true
+ }
+
+ return ret
+}
+
+// shouldDiscard returns whether c should be discarded. In canonical mode, if
+// too many bytes are enqueued, we keep reading input and discarding it until
+// we find a terminating character. Signal/echo processing still occurs.
+//
+// Precondition:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (l *lineDiscipline) shouldDiscard(q *queue, cBytes []byte) bool {
+ return l.termios.LEnabled(linux.ICANON) && len(q.readBuf)+len(cBytes) >= canonMaxBytes && !l.termios.IsTerminating(cBytes)
+}
+
+// peek returns the size in bytes of the next character to process. As long as
+// b isn't empty, peek returns a value of at least 1.
+func (l *lineDiscipline) peek(b []byte) int {
+ size := 1
+ // If UTF-8 support is enabled, runes might be multiple bytes.
+ if l.termios.IEnabled(linux.IUTF8) {
+ _, size = utf8.DecodeRune(b)
+ }
+ return size
+}
+
+// LINT.ThenChange(../../fs/tty/line_discipline.go)
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
new file mode 100644
index 000000000..60340c28e
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -0,0 +1,226 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// masterInode is the inode for the master end of the Terminal.
+type masterInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // root is the devpts root inode.
+ root *rootInode
+}
+
+var _ kernfs.Inode = (*masterInode)(nil)
+
+// Open implements kernfs.Inode.Open.
+func (mi *masterInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ t, err := mi.root.allocateTerminal(rp.Credentials())
+ if err != nil {
+ return nil, err
+ }
+
+ mi.IncRef()
+ fd := &masterFileDescription{
+ inode: mi,
+ t: t,
+ }
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ mi.DecRef()
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (mi *masterInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := mi.InodeAttrs.Stat(vfsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ statx.Blksize = 1024
+ statx.RdevMajor = linux.TTYAUX_MAJOR
+ statx.RdevMinor = linux.PTMX_MINOR
+ return statx, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat
+func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask&linux.STATX_SIZE != 0 {
+ return syserror.EINVAL
+ }
+ return mi.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
+}
+
+type masterFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+
+ inode *masterInode
+ t *Terminal
+}
+
+var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil)
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (mfd *masterFileDescription) Release() {
+ mfd.inode.root.masterClose(mfd.t)
+ mfd.inode.DecRef()
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (mfd *masterFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ mfd.t.ld.masterWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (mfd *masterFileDescription) EventUnregister(e *waiter.Entry) {
+ mfd.t.ld.masterWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (mfd *masterFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mfd.t.ld.masterReadiness()
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (mfd *masterFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ return mfd.t.ld.outputQueueRead(ctx, dst)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (mfd *masterFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ return mfd.t.ld.inputQueueWrite(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (mfd *masterFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the output queue read buffer.
+ return 0, mfd.t.ld.outputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ // N.B. TCGETS on the master actually returns the configuration
+ // of the slave end.
+ return mfd.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ // N.B. TCSETS on the master actually affects the configuration
+ // of the slave end.
+ return mfd.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return mfd.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mfd.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCSPTLCK:
+ // TODO(b/29356795): Implement pty locking. For now just pretend we do.
+ return 0, nil
+ case linux.TIOCGWINSZ:
+ return 0, mfd.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, mfd.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, mfd.t.setControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, mfd.t.releaseControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return mfd.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return mfd.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (mfd *masterFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return mfd.inode.SetStat(ctx, fs, creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return mfd.inode.Stat(fs, opts)
+}
+
+// maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid.
+func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
+ switch cmd {
+ case linux.TCGETS,
+ linux.TCSETS,
+ linux.TCSETSW,
+ linux.TCSETSF,
+ linux.TIOCGWINSZ,
+ linux.TIOCSWINSZ,
+ linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ }
+}
+
+// LINT.ThenChange(../../fs/tty/master.go)
diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go
new file mode 100644
index 000000000..29a6be858
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/queue.go
@@ -0,0 +1,240 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// waitBufMaxBytes is the maximum size of a wait buffer. It is based on
+// TTYB_DEFAULT_MEM_LIMIT.
+const waitBufMaxBytes = 131072
+
+// queue represents one of the input or output queues between a pty master and
+// slave. Bytes written to a queue are added to the read buffer until it is
+// full, at which point they are written to the wait buffer. Bytes are
+// processed (i.e. undergo termios transformations) as they are added to the
+// read buffer. The read buffer is readable when its length is nonzero and
+// readable is true.
+//
+// +stateify savable
+type queue struct {
+ // mu protects everything in queue.
+ mu sync.Mutex `state:"nosave"`
+
+ // readBuf is buffer of data ready to be read when readable is true.
+ // This data has been processed.
+ readBuf []byte
+
+ // waitBuf contains data that can't fit into readBuf. It is put here
+ // until it can be loaded into the read buffer. waitBuf contains data
+ // that hasn't been processed.
+ waitBuf [][]byte
+ waitBufLen uint64
+
+ // readable indicates whether the read buffer can be read from. In
+ // canonical mode, there can be an unterminated line in the read buffer,
+ // so readable must be checked.
+ readable bool
+
+ // transform is the the queue's function for transforming bytes
+ // entering the queue. For example, transform might convert all '\r's
+ // entering the queue to '\n's.
+ transformer
+}
+
+// readReadiness returns whether q is ready to be read from.
+func (q *queue) readReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if len(q.readBuf) > 0 && q.readable {
+ return waiter.EventIn
+ }
+ return waiter.EventMask(0)
+}
+
+// writeReadiness returns whether q is ready to be written to.
+func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if q.waitBufLen < waitBufMaxBytes {
+ return waiter.EventOut
+ }
+ return waiter.EventMask(0)
+}
+
+// readableSize writes the number of readable bytes to userspace.
+func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ var size int32
+ if q.readable {
+ size = int32(len(q.readBuf))
+ }
+
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+
+}
+
+// read reads from q to userspace. It returns the number of bytes read as well
+// as whether the read caused more readable data to become available (whether
+// data was pushed from the wait buffer to the read buffer).
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipline) (int64, bool, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if !q.readable {
+ return 0, false, syserror.ErrWouldBlock
+ }
+
+ if dst.NumBytes() > canonMaxBytes {
+ dst = dst.TakeFirst(canonMaxBytes)
+ }
+
+ n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dst safemem.BlockSeq) (uint64, error) {
+ src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(q.readBuf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.readBuf = q.readBuf[n:]
+
+ // If we read everything, this queue is no longer readable.
+ if len(q.readBuf) == 0 {
+ q.readable = false
+ }
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, false, err
+ }
+
+ // Move data from the queue's wait buffer to its read buffer.
+ nPushed := q.pushWaitBufLocked(l)
+
+ return int64(n), nPushed > 0, nil
+}
+
+// write writes to q from userspace.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscipline) (int64, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Copy data into the wait buffer.
+ n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(src safemem.BlockSeq) (uint64, error) {
+ copyLen := src.NumBytes()
+ room := waitBufMaxBytes - q.waitBufLen
+ // If out of room, return EAGAIN.
+ if room == 0 && copyLen > 0 {
+ return 0, syserror.ErrWouldBlock
+ }
+ // Cap the size of the wait buffer.
+ if copyLen > room {
+ copyLen = room
+ src = src.TakeFirst64(room)
+ }
+ buf := make([]byte, copyLen)
+
+ // Copy the data into the wait buffer.
+ dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.waitBufAppend(buf)
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, err
+ }
+
+ // Push data from the wait to the read buffer.
+ q.pushWaitBufLocked(l)
+
+ return n, nil
+}
+
+// writeBytes writes to q from b.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) writeBytes(b []byte, l *lineDiscipline) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Write to the wait buffer.
+ q.waitBufAppend(b)
+ q.pushWaitBufLocked(l)
+}
+
+// pushWaitBufLocked fills the queue's read buffer with data from the wait
+// buffer.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be locked.
+func (q *queue) pushWaitBufLocked(l *lineDiscipline) int {
+ if q.waitBufLen == 0 {
+ return 0
+ }
+
+ // Move data from the wait to the read buffer.
+ var total int
+ var i int
+ for i = 0; i < len(q.waitBuf); i++ {
+ n := q.transform(l, q, q.waitBuf[i])
+ total += n
+ if n != len(q.waitBuf[i]) {
+ // The read buffer filled up without consuming the
+ // entire buffer.
+ q.waitBuf[i] = q.waitBuf[i][n:]
+ break
+ }
+ }
+
+ // Update wait buffer based on consumed data.
+ q.waitBuf = q.waitBuf[i:]
+ q.waitBufLen -= uint64(total)
+
+ return total
+}
+
+// Precondition: q.mu must be locked.
+func (q *queue) waitBufAppend(b []byte) {
+ q.waitBuf = append(q.waitBuf, b)
+ q.waitBufLen += uint64(len(b))
+}
+
+// LINT.ThenChange(../../fs/tty/queue.go)
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
new file mode 100644
index 000000000..e7e50d51e
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/slave.go
@@ -0,0 +1,186 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// slaveInode is the inode for the slave end of the Terminal.
+type slaveInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // root is the devpts root inode.
+ root *rootInode
+
+ // t is the connected Terminal.
+ t *Terminal
+}
+
+var _ kernfs.Inode = (*slaveInode)(nil)
+
+// Open implements kernfs.Inode.Open.
+func (si *slaveInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ si.IncRef()
+ fd := &slaveFileDescription{
+ inode: si,
+ }
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ si.DecRef()
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+
+}
+
+// Valid implements kernfs.Inode.Valid.
+func (si *slaveInode) Valid(context.Context) bool {
+ // Return valid if the slave still exists.
+ si.root.mu.Lock()
+ defer si.root.mu.Unlock()
+ _, ok := si.root.slaves[si.t.n]
+ return ok
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (si *slaveInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := si.InodeAttrs.Stat(vfsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ statx.Blksize = 1024
+ statx.RdevMajor = linux.UNIX98_PTY_SLAVE_MAJOR
+ statx.RdevMinor = si.t.n
+ return statx, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat
+func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask&linux.STATX_SIZE != 0 {
+ return syserror.EINVAL
+ }
+ return si.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
+}
+
+type slaveFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+
+ inode *slaveInode
+}
+
+var _ vfs.FileDescriptionImpl = (*slaveFileDescription)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (sfd *slaveFileDescription) Release() {
+ sfd.inode.DecRef()
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (sfd *slaveFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ sfd.inode.t.ld.slaveWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (sfd *slaveFileDescription) EventUnregister(e *waiter.Entry) {
+ sfd.inode.t.ld.slaveWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (sfd *slaveFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return sfd.inode.t.ld.slaveReadiness()
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (sfd *slaveFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ return sfd.inode.t.ld.inputQueueRead(ctx, dst)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ return sfd.inode.t.ld.outputQueueWrite(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescripionImpl.Ioctl.
+func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the input queue read buffer.
+ return 0, sfd.inode.t.ld.inputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ return sfd.inode.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ return sfd.inode.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return sfd.inode.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sfd.inode.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCGWINSZ:
+ return 0, sfd.inode.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, sfd.inode.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, sfd.inode.t.setControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, sfd.inode.t.releaseControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return sfd.inode.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return sfd.inode.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return sfd.inode.SetStat(ctx, fs, creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return sfd.inode.Stat(fs, opts)
+}
+
+// LINT.ThenChange(../../fs/tty/slave.go)
diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go
new file mode 100644
index 000000000..b44e673d8
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/terminal.go
@@ -0,0 +1,124 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChanges
+
+// Terminal is a pseudoterminal.
+//
+// +stateify savable
+type Terminal struct {
+ // n is the terminal index. It is immutable.
+ n uint32
+
+ // ld is the line discipline of the terminal. It is immutable.
+ ld *lineDiscipline
+
+ // masterKTTY contains the controlling process of the master end of
+ // this terminal. This field is immutable.
+ masterKTTY *kernel.TTY
+
+ // slaveKTTY contains the controlling process of the slave end of this
+ // terminal. This field is immutable.
+ slaveKTTY *kernel.TTY
+}
+
+func newTerminal(n uint32) *Terminal {
+ termios := linux.DefaultSlaveTermios
+ t := Terminal{
+ n: n,
+ ld: newLineDiscipline(termios),
+ masterKTTY: &kernel.TTY{Index: n},
+ slaveKTTY: &kernel.TTY{Index: n},
+ }
+ return &t
+}
+
+// setControllingTTY makes tm the controlling terminal of the calling thread
+// group.
+func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int())
+}
+
+// releaseControllingTTY removes tm as the controlling terminal of the calling
+// thread group.
+func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("releaseControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster))
+}
+
+// foregroundProcessGroup gets the process group ID of tm's foreground process.
+func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("foregroundProcessGroup must be called from a task context")
+ }
+
+ ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster))
+ if err != nil {
+ return 0, err
+ }
+
+ // Write it out to *arg.
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// foregroundProcessGroup sets tm's foreground process.
+func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setForegroundProcessGroup must be called from a task context")
+ }
+
+ // Read in the process group ID.
+ var pgid int32
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid))
+ return uintptr(ret), err
+}
+
+func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
+ if isMaster {
+ return tm.masterKTTY
+ }
+ return tm.slaveKTTY
+}
+
+// LINT.ThenChange(../../fs/tty/terminal.go)
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index 64f1b142c..142ee53b0 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -163,16 +163,25 @@ func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind v
func (a *Accessor) UserspaceInit(ctx context.Context) error {
actx := a.wrapContext(ctx)
- // systemd: src/shared/dev-setup.c:dev_setup()
+ // Initialize symlinks.
for _, symlink := range []struct {
source string
target string
}{
- // /proc/kcore is not implemented.
+ // systemd: src/shared/dev-setup.c:dev_setup()
{source: "fd", target: "/proc/self/fd"},
{source: "stdin", target: "/proc/self/fd/0"},
{source: "stdout", target: "/proc/self/fd/1"},
{source: "stderr", target: "/proc/self/fd/2"},
+ // /proc/kcore is not implemented.
+
+ // Linux implements /dev/ptmx as a device node, but advises
+ // container implementations to create /dev/ptmx as a symlink
+ // to pts/ptmx (Documentation/filesystems/devpts.txt). Systemd
+ // follows this advice (src/nspawn/nspawn.c:setup_pts()), while
+ // LXC tries to create a bind mount and falls back to a symlink
+ // (src/lxc/conf.c:lxc_setup_devpts()).
+ {source: "ptmx", target: "pts/ptmx"},
} {
if err := a.vfsObj.SymlinkAt(actx, a.creds, a.pathOperationAt(symlink.source), symlink.target); err != nil {
return fmt.Errorf("failed to create symlink %q => %q: %v", symlink.source, symlink.target, err)
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index a4947c480..ff861d0fe 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -93,8 +93,8 @@ go_test(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/syserror",
+ "//pkg/test/testutil",
"//pkg/usermem",
- "//runsc/testutil",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 29bb73765..64e9a579f 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -32,9 +32,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/pkg/usermem"
-
- "gvisor.dev/gvisor/runsc/testutil"
)
const (
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index acd061905..b9c4beee4 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -35,7 +35,6 @@ go_library(
"fstree.go",
"gofer.go",
"handle.go",
- "handle_unsafe.go",
"p9file.go",
"pagemath.go",
"regular_file.go",
@@ -53,6 +52,7 @@ go_library(
"//pkg/p9",
"//pkg/safemem",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/hostfd",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index d02691232..c67766ab2 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -21,8 +21,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func (d *dentry) isDir() bool {
@@ -41,15 +43,46 @@ func (d *dentry) cacheNewChildLocked(child *dentry, name string) {
d.children[name] = child
}
-// Preconditions: d.dirMu must be locked. d.isDir(). fs.opts.interop !=
-// InteropModeShared.
-func (d *dentry) cacheNegativeChildLocked(name string) {
+// Preconditions: d.dirMu must be locked. d.isDir().
+func (d *dentry) cacheNegativeLookupLocked(name string) {
+ // Don't cache negative lookups if InteropModeShared is in effect (since
+ // this makes remote lookup unavoidable), or if d.isSynthetic() (in which
+ // case the only files in the directory are those for which a dentry exists
+ // in d.children). Instead, just delete any previously-cached dentry.
+ if d.fs.opts.interop == InteropModeShared || d.isSynthetic() {
+ delete(d.children, name)
+ return
+ }
if d.children == nil {
d.children = make(map[string]*dentry)
}
d.children[name] = nil
}
+// createSyntheticDirectory creates a synthetic directory with the given name
+// in d.
+//
+// Preconditions: d.dirMu must be locked. d.isDir(). d does not already contain
+// a child with the given name.
+func (d *dentry) createSyntheticDirectoryLocked(name string, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) {
+ d2 := &dentry{
+ refs: 1, // held by d
+ fs: d.fs,
+ mode: uint32(mode) | linux.S_IFDIR,
+ uid: uint32(kuid),
+ gid: uint32(kgid),
+ blockSize: usermem.PageSize, // arbitrary
+ handle: handle{
+ fd: -1,
+ },
+ }
+ d2.pf.dentry = d2
+ d2.vfsd.Init(d2)
+
+ d.cacheNewChildLocked(d2, name)
+ d.syntheticChildren++
+}
+
type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
@@ -77,7 +110,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
fd.dirents = ds
}
- if d.fs.opts.interop != InteropModeShared {
+ if d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
@@ -108,10 +141,10 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
// filesystem.renameMu is needed for d.parent, and must be locked before
// dentry.dirMu.
d.fs.renameMu.RLock()
+ defer d.fs.renameMu.RUnlock()
d.dirMu.Lock()
defer d.dirMu.Unlock()
if d.dirents != nil {
- d.fs.renameMu.RUnlock()
return d.dirents, nil
}
@@ -132,51 +165,81 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
NextOff: 2,
},
}
- d.fs.renameMu.RUnlock()
- off := uint64(0)
- const count = 64 * 1024 // for consistency with the vfs1 client
- d.handleMu.RLock()
- defer d.handleMu.RUnlock()
- if !d.handleReadable {
- // This should not be possible because a readable handle should have
- // been opened when the calling directoryFD was opened.
- panic("gofer.dentry.getDirents called without a readable handle")
- }
- for {
- p9ds, err := d.handle.file.readdir(ctx, off, count)
- if err != nil {
- return nil, err
+ var realChildren map[string]struct{}
+ if !d.isSynthetic() {
+ if d.syntheticChildren != 0 && d.fs.opts.interop == InteropModeShared {
+ // Record the set of children d actually has so that we don't emit
+ // duplicate entries for synthetic children.
+ realChildren = make(map[string]struct{})
}
- if len(p9ds) == 0 {
- // Cache dirents for future directoryFDs if permitted.
- if d.fs.opts.interop != InteropModeShared {
- d.dirents = dirents
+ off := uint64(0)
+ const count = 64 * 1024 // for consistency with the vfs1 client
+ d.handleMu.RLock()
+ if !d.handleReadable {
+ // This should not be possible because a readable handle should
+ // have been opened when the calling directoryFD was opened.
+ d.handleMu.RUnlock()
+ panic("gofer.dentry.getDirents called without a readable handle")
+ }
+ for {
+ p9ds, err := d.handle.file.readdir(ctx, off, count)
+ if err != nil {
+ d.handleMu.RUnlock()
+ return nil, err
+ }
+ if len(p9ds) == 0 {
+ d.handleMu.RUnlock()
+ break
+ }
+ for _, p9d := range p9ds {
+ if p9d.Name == "." || p9d.Name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: p9d.Name,
+ Ino: p9d.QID.Path,
+ NextOff: int64(len(dirents) + 1),
+ }
+ // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
+ // DMSOCKET.
+ switch p9d.Type {
+ case p9.TypeSymlink:
+ dirent.Type = linux.DT_LNK
+ case p9.TypeDir:
+ dirent.Type = linux.DT_DIR
+ default:
+ dirent.Type = linux.DT_REG
+ }
+ dirents = append(dirents, dirent)
+ if realChildren != nil {
+ realChildren[p9d.Name] = struct{}{}
+ }
}
- return dirents, nil
+ off = p9ds[len(p9ds)-1].Offset
}
- for _, p9d := range p9ds {
- if p9d.Name == "." || p9d.Name == ".." {
+ }
+ // Emit entries for synthetic children.
+ if d.syntheticChildren != 0 {
+ for _, child := range d.children {
+ if child == nil || !child.isSynthetic() {
continue
}
- dirent := vfs.Dirent{
- Name: p9d.Name,
- Ino: p9d.QID.Path,
- NextOff: int64(len(dirents) + 1),
- }
- // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
- // DMSOCKET.
- switch p9d.Type {
- case p9.TypeSymlink:
- dirent.Type = linux.DT_LNK
- case p9.TypeDir:
- dirent.Type = linux.DT_DIR
- default:
- dirent.Type = linux.DT_REG
+ if _, ok := realChildren[child.name]; ok {
+ continue
}
- dirents = append(dirents, dirent)
+ dirents = append(dirents, vfs.Dirent{
+ Name: child.name,
+ Type: uint8(atomic.LoadUint32(&child.mode) >> 12),
+ Ino: child.ino,
+ NextOff: int64(len(dirents) + 1),
+ })
}
- off = p9ds[len(p9ds)-1].Offset
}
+ // Cache dirents for future directoryFDs if permitted.
+ if d.cachedMetadataAuthoritative() {
+ d.dirents = dirents
+ }
+ return dirents, nil
}
// Seek implements vfs.FileDescriptionImpl.Seek.
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 43e863c61..98ccb42fd 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -29,14 +29,16 @@ import (
// Sync implements vfs.FilesystemImpl.Sync.
func (fs *filesystem) Sync(ctx context.Context) error {
- // Snapshot current dentries and special files.
+ // Snapshot current syncable dentries and special files.
fs.syncMu.Lock()
- ds := make([]*dentry, 0, len(fs.dentries))
- for d := range fs.dentries {
+ ds := make([]*dentry, 0, len(fs.syncableDentries))
+ for d := range fs.syncableDentries {
+ d.IncRef()
ds = append(ds, d)
}
sffds := make([]*specialFileFD, 0, len(fs.specialFileFDs))
for sffd := range fs.specialFileFDs {
+ sffd.vfsfd.IncRef()
sffds = append(sffds, sffd)
}
fs.syncMu.Unlock()
@@ -47,9 +49,6 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// Sync regular files.
for _, d := range ds {
- if !d.TryIncRef() {
- continue
- }
err := d.syncSharedHandle(ctx)
d.DecRef()
if err != nil && retErr == nil {
@@ -60,9 +59,6 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// Sync special files, which may be writable but do not use dentry shared
// handles (so they won't be synced by the above).
for _, sffd := range sffds {
- if !sffd.vfsfd.TryIncRef() {
- continue
- }
err := sffd.Sync(ctx)
sffd.vfsfd.DecRef()
if err != nil && retErr == nil {
@@ -114,8 +110,8 @@ func putDentrySlice(ds *[]*dentry) {
// to *ds.
//
// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
-// !rp.Done(). If fs.opts.interop == InteropModeShared, then d's cached
-// metadata must be up to date.
+// !rp.Done(). If !d.cachedMetadataAuthoritative(), then d's cached metadata
+// must be up to date.
//
// Postconditions: The returned dentry's cached metadata is up to date.
func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
@@ -148,7 +144,7 @@ afterSymlink:
if err := rp.CheckMount(&d.parent.vfsd); err != nil {
return nil, err
}
- if fs.opts.interop == InteropModeShared && d != d.parent {
+ if d != d.parent && !d.cachedMetadataAuthoritative() {
_, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask())
if err != nil {
return nil, err
@@ -195,7 +191,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil
return nil, syserror.ENAMETOOLONG
}
child, ok := parent.children[name]
- if ok && fs.opts.interop != InteropModeShared {
+ if (ok && fs.opts.interop != InteropModeShared) || parent.isSynthetic() {
// Whether child is nil or not, it is cached information that is
// assumed to be correct.
return child, nil
@@ -206,7 +202,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil
return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds)
}
-// Preconditions: As for getChildLocked.
+// Preconditions: As for getChildLocked. !parent.isSynthetic().
func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) {
qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
if err != nil && err != syserror.ENOENT {
@@ -220,24 +216,41 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
child.updateFromP9Attrs(attrMask, &attr)
return child, nil
}
- // The file at this path has changed or no longer exists. Remove
- // the stale dentry from the tree, and re-evaluate its caching
- // status (i.e. if it has 0 references, drop it).
+ if file.isNil() && child.isSynthetic() {
+ // We have a synthetic file, and no remote file has arisen to
+ // replace it.
+ return child, nil
+ }
+ // The file at this path has changed or no longer exists. Mark the
+ // dentry invalidated, and re-evaluate its caching status (i.e. if it
+ // has 0 references, drop it). Wait to update parent.children until we
+ // know what to replace the existing dentry with (i.e. one of the
+ // returns below), to avoid a redundant map access.
vfsObj.InvalidateDentry(&child.vfsd)
+ if child.isSynthetic() {
+ // Normally we don't mark invalidated dentries as deleted since
+ // they may still exist (but at a different path), and also for
+ // consistency with Linux. However, synthetic files are guaranteed
+ // to become unreachable if their dentries are invalidated, so
+ // treat their invalidation as deletion.
+ child.setDeleted()
+ parent.syntheticChildren--
+ child.decRefLocked()
+ parent.dirents = nil
+ }
*ds = appendDentry(*ds, child)
}
if file.isNil() {
// No file exists at this path now. Cache the negative lookup if
// allowed.
- if fs.opts.interop != InteropModeShared {
- parent.cacheNegativeChildLocked(name)
- }
+ parent.cacheNegativeLookupLocked(name)
return nil, nil
}
// Create a new dentry representing the file.
child, err = fs.newDentry(ctx, file, qid, attrMask, &attr)
if err != nil {
file.close(ctx)
+ delete(parent.children, name)
return nil, err
}
parent.cacheNewChildLocked(child, name)
@@ -252,8 +265,9 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// rp.Start().Impl().(*dentry)). It does not check that the returned directory
// is searchable by the provider of rp.
//
-// Preconditions: fs.renameMu must be locked. !rp.Done(). If fs.opts.interop ==
-// InteropModeShared, then d's cached metadata must be up to date.
+// Preconditions: fs.renameMu must be locked. !rp.Done(). If
+// !d.cachedMetadataAuthoritative(), then d's cached metadata must be up to
+// date.
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
for !rp.Final() {
d.dirMu.Lock()
@@ -275,7 +289,7 @@ func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
// Preconditions: fs.renameMu must be locked.
func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
d := rp.Start().Impl().(*dentry)
- if fs.opts.interop == InteropModeShared {
+ if !d.cachedMetadataAuthoritative() {
// Get updated metadata for rp.Start() as required by fs.stepLocked().
if err := d.updateFromGetattr(ctx); err != nil {
return nil, err
@@ -297,16 +311,17 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
}
// doCreateAt checks that creating a file at rp is permitted, then invokes
-// create to do so.
+// createInRemoteDir (if the parent directory is a real remote directory) or
+// createInSyntheticDir (if the parent directory is synthetic) to do so.
//
// Preconditions: !rp.Done(). For the final path component in rp,
// !rp.ShouldFollowSymlink().
-func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string) error) error {
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string) error, createInSyntheticDir func(parent *dentry, name string) error) error {
var ds *[]*dentry
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(&ds)
start := rp.Start().Impl().(*dentry)
- if fs.opts.interop == InteropModeShared {
+ if !start.cachedMetadataAuthoritative() {
// Get updated metadata for start as required by
// fs.walkParentDirLocked().
if err := start.updateFromGetattr(ctx); err != nil {
@@ -340,6 +355,20 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
defer mnt.EndWrite()
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
+ if parent.isSynthetic() {
+ if child := parent.children[name]; child != nil {
+ return syserror.EEXIST
+ }
+ if createInSyntheticDir == nil {
+ return syserror.EPERM
+ }
+ if err := createInSyntheticDir(parent, name); err != nil {
+ return err
+ }
+ parent.touchCMtime()
+ parent.dirents = nil
+ return nil
+ }
if fs.opts.interop == InteropModeShared {
// The existence of a dentry at name would be inconclusive because the
// file it represents may have been deleted from the remote filesystem,
@@ -348,21 +377,21 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
// will fail with EEXIST like we would have. If the RPC succeeds, and a
// stale dentry exists, the dentry will fail revalidation next time
// it's used.
- return create(parent, name)
+ return createInRemoteDir(parent, name)
}
if child := parent.children[name]; child != nil {
return syserror.EEXIST
}
// No cached dentry exists; however, there might still be an existing file
// at name. As above, we attempt the file creation RPC anyway.
- if err := create(parent, name); err != nil {
+ if err := createInRemoteDir(parent, name); err != nil {
return err
}
+ if child, ok := parent.children[name]; ok && child == nil {
+ // Delete the now-stale negative dentry.
+ delete(parent.children, name)
+ }
parent.touchCMtime()
- // Either parent.children[name] doesn't exist (in which case this is a
- // no-op) or is nil (in which case this erases the now-stale information
- // that the file doesn't exist).
- delete(parent.children, name)
parent.dirents = nil
return nil
}
@@ -373,7 +402,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(&ds)
start := rp.Start().Impl().(*dentry)
- if fs.opts.interop == InteropModeShared {
+ if !start.cachedMetadataAuthoritative() {
// Get updated metadata for start as required by
// fs.walkParentDirLocked().
if err := start.updateFromGetattr(ctx); err != nil {
@@ -421,8 +450,10 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
// only revalidating the dentry if that fails (indicating that the existing
// dentry is a mount point).
if child != nil {
+ child.dirMu.Lock()
+ defer child.dirMu.Unlock()
if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
- if fs.opts.interop != InteropModeShared {
+ if parent.cachedMetadataAuthoritative() {
return err
}
child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
@@ -437,13 +468,37 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
}
}
flags := uint32(0)
+ // If a dentry exists, use it for best-effort checks on its deletability.
if dir {
- if child != nil && !child.isDir() {
- vfsObj.AbortDeleteDentry(&child.vfsd)
- return syserror.ENOTDIR
+ if child != nil {
+ // child must be an empty directory.
+ if child.syntheticChildren != 0 {
+ // This is definitely not an empty directory, irrespective of
+ // fs.opts.interop.
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.ENOTEMPTY
+ }
+ // If InteropModeShared is in effect and the first call to
+ // PrepareDeleteDentry above succeeded, then child wasn't
+ // revalidated (so we can't expect its file type to be correct) and
+ // individually revalidating its children (to confirm that they
+ // still exist) would be a waste of time.
+ if child.cachedMetadataAuthoritative() {
+ if !child.isDir() {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.ENOTDIR
+ }
+ for _, grandchild := range child.children {
+ if grandchild != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.ENOTEMPTY
+ }
+ }
+ }
}
flags = linux.AT_REMOVEDIR
} else {
+ // child must be a non-directory file.
if child != nil && child.isDir() {
vfsObj.AbortDeleteDentry(&child.vfsd)
return syserror.EISDIR
@@ -455,28 +510,36 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
return syserror.ENOTDIR
}
}
- err = parent.file.unlinkAt(ctx, name, flags)
- if err != nil {
- if child != nil {
- vfsObj.AbortDeleteDentry(&child.vfsd)
- }
- return err
- }
- if fs.opts.interop != InteropModeShared {
- parent.touchCMtime()
- if dir {
- parent.decLinks()
+ if parent.isSynthetic() {
+ if child == nil {
+ return syserror.ENOENT
}
- parent.cacheNegativeChildLocked(name)
- parent.dirents = nil
} else {
- delete(parent.children, name)
+ err = parent.file.unlinkAt(ctx, name, flags)
+ if err != nil {
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return err
+ }
}
if child != nil {
- child.setDeleted()
vfsObj.CommitDeleteDentry(&child.vfsd)
+ child.setDeleted()
+ if child.isSynthetic() {
+ parent.syntheticChildren--
+ child.decRefLocked()
+ }
ds = appendDentry(ds, child)
}
+ parent.cacheNegativeLookupLocked(name)
+ if parent.cachedMetadataAuthoritative() {
+ parent.dirents = nil
+ parent.touchCMtime()
+ if dir {
+ parent.decLinks()
+ }
+ }
return nil
}
@@ -554,7 +617,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa
fs.renameMu.RLock()
defer fs.renameMuRUnlockAndCheckCaching(&ds)
start := rp.Start().Impl().(*dentry)
- if fs.opts.interop == InteropModeShared {
+ if !start.cachedMetadataAuthoritative() {
// Get updated metadata for start as required by
// fs.walkParentDirLocked().
if err := start.updateFromGetattr(ctx); err != nil {
@@ -577,20 +640,32 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
}
// 9P2000.L supports hard links, but we don't.
return syserror.EPERM
- })
+ }, nil)
}
// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ creds := rp.Credentials()
return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string) error {
- creds := rp.Credentials()
if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil {
- return err
+ if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
+ return err
+ }
+ ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err)
+ parent.createSyntheticDirectoryLocked(name, opts.Mode, creds.EffectiveKUID, creds.EffectiveKGID)
}
if fs.opts.interop != InteropModeShared {
parent.incLinks()
}
return nil
+ }, func(parent *dentry, name string) error {
+ if !opts.ForSyntheticMountpoint {
+ // Can't create non-synthetic files in synthetic directories.
+ return syserror.EPERM
+ }
+ parent.createSyntheticDirectoryLocked(name, opts.Mode, creds.EffectiveKUID, creds.EffectiveKGID)
+ parent.incLinks()
+ return nil
})
}
@@ -600,7 +675,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
creds := rp.Credentials()
_, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
return err
- })
+ }, nil)
}
// OpenAt implements vfs.FilesystemImpl.OpenAt.
@@ -620,7 +695,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
defer fs.renameMuRUnlockAndCheckCaching(&ds)
start := rp.Start().Impl().(*dentry)
- if fs.opts.interop == InteropModeShared {
+ if !start.cachedMetadataAuthoritative() {
// Get updated metadata for start as required by fs.stepLocked().
if err := start.updateFromGetattr(ctx); err != nil {
return nil, err
@@ -643,6 +718,10 @@ afterTrailingSymlink:
parent.dirMu.Lock()
child, err := fs.stepLocked(ctx, rp, parent, &ds)
if err == syserror.ENOENT && mayCreate {
+ if parent.isSynthetic() {
+ parent.dirMu.Unlock()
+ return nil, syserror.EPERM
+ }
fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts)
parent.dirMu.Unlock()
return fd, err
@@ -702,8 +781,10 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
if opts.Flags&linux.O_DIRECT != 0 {
return nil, syserror.EINVAL
}
- if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, false /* write */, false /* trunc */); err != nil {
- return nil, err
+ if !d.isSynthetic() {
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, false /* write */, false /* trunc */); err != nil {
+ return nil, err
+ }
}
fd := &directoryFD{}
if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
@@ -733,6 +814,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
}
// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked.
+// !d.isSynthetic().
func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
@@ -811,7 +893,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
child.refs = 1
// Insert the dentry into the tree.
d.cacheNewChildLocked(child, name)
- if d.fs.opts.interop != InteropModeShared {
+ if d.cachedMetadataAuthoritative() {
d.touchCMtime()
d.dirents = nil
}
@@ -888,7 +970,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
defer mnt.EndWrite()
oldParent := oldParentVD.Dentry().Impl().(*dentry)
- if fs.opts.interop == InteropModeShared {
+ if !oldParent.cachedMetadataAuthoritative() {
if err := oldParent.updateFromGetattr(ctx); err != nil {
return err
}
@@ -933,35 +1015,22 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if newParent.isDeleted() {
return syserror.ENOENT
}
- replaced := newParent.children[newName]
- // This is similar to unlinkAt, except:
- //
- // - If a dentry exists for the file to be replaced, we revalidate it
- // unconditionally (instead of only if PrepareRenameDentry fails) for
- // simplicity.
- //
- // - If rp.MustBeDir(), then we need a dentry representing the replaced
- // file regardless to confirm that it's a directory.
- if replaced != nil || rp.MustBeDir() {
- replaced, err = fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds)
- if err != nil {
- return err
- }
- if replaced != nil {
- if replaced.isDir() {
- if !renamed.isDir() {
- return syserror.EISDIR
- }
- } else {
- if rp.MustBeDir() || renamed.isDir() {
- return syserror.ENOTDIR
- }
- }
- }
+ replaced, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds)
+ if err != nil {
+ return err
}
var replacedVFSD *vfs.Dentry
if replaced != nil {
replacedVFSD = &replaced.vfsd
+ if replaced.isDir() {
+ if !renamed.isDir() {
+ return syserror.EISDIR
+ }
+ } else {
+ if rp.MustBeDir() || renamed.isDir() {
+ return syserror.ENOTDIR
+ }
+ }
}
if oldParent == newParent && oldName == newName {
@@ -972,27 +1041,47 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil {
return err
}
- if err := renamed.file.rename(ctx, newParent.file, newName); err != nil {
- vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
- return err
+
+ // Update the remote filesystem.
+ if !renamed.isSynthetic() {
+ if err := renamed.file.rename(ctx, newParent.file, newName); err != nil {
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
+ } else if replaced != nil && !replaced.isSynthetic() {
+ // We are replacing an existing real file with a synthetic one, so we
+ // need to unlink the former.
+ flags := uint32(0)
+ if replaced.isDir() {
+ flags = linux.AT_REMOVEDIR
+ }
+ if err := newParent.file.unlinkAt(ctx, newName, flags); err != nil {
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
}
- if fs.opts.interop != InteropModeShared {
- oldParent.cacheNegativeChildLocked(oldName)
- oldParent.dirents = nil
- newParent.dirents = nil
- if renamed.isDir() {
- oldParent.decLinks()
- newParent.incLinks()
+
+ // Update the dentry tree.
+ vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD)
+ if replaced != nil {
+ replaced.setDeleted()
+ if replaced.isSynthetic() {
+ newParent.syntheticChildren--
+ replaced.decRefLocked()
}
- oldParent.touchCMtime()
- newParent.touchCMtime()
- renamed.touchCtime()
- } else {
- delete(oldParent.children, oldName)
+ ds = appendDentry(ds, replaced)
}
+ oldParent.cacheNegativeLookupLocked(oldName)
+ // We don't use newParent.cacheNewChildLocked() since we don't want to mess
+ // with reference counts and queue oldParent for checkCachingLocked if the
+ // parent isn't actually changing.
if oldParent != newParent {
- appendDentry(ds, oldParent)
+ ds = appendDentry(ds, oldParent)
newParent.IncRef()
+ if renamed.isSynthetic() {
+ oldParent.syntheticChildren--
+ newParent.syntheticChildren++
+ }
}
renamed.parent = newParent
renamed.name = newName
@@ -1000,11 +1089,25 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
newParent.children = make(map[string]*dentry)
}
newParent.children[newName] = renamed
- if replaced != nil {
- replaced.setDeleted()
- appendDentry(ds, replaced)
+
+ // Update metadata.
+ if renamed.cachedMetadataAuthoritative() {
+ renamed.touchCtime()
+ }
+ if oldParent.cachedMetadataAuthoritative() {
+ oldParent.dirents = nil
+ oldParent.touchCMtime()
+ if renamed.isDir() {
+ oldParent.decLinks()
+ }
+ }
+ if newParent.cachedMetadataAuthoritative() {
+ newParent.dirents = nil
+ newParent.touchCMtime()
+ if renamed.isDir() {
+ newParent.incLinks()
+ }
}
- vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, replacedVFSD)
return nil
}
@@ -1051,6 +1154,10 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
if err != nil {
return linux.Statfs{}, err
}
+ // If d is synthetic, invoke statfs on the first ancestor of d that isn't.
+ for d.isSynthetic() {
+ d = d.parent
+ }
fsstat, err := d.file.statFS(ctx)
if err != nil {
return linux.Statfs{}, err
@@ -1080,7 +1187,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
creds := rp.Credentials()
_, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
return err
- })
+ }, nil)
}
// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
@@ -1089,9 +1196,15 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
-//
-// TODO(gvisor.dev/issue/1476): Implement BoundEndpointAt.
func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath) (transport.BoundEndpoint, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // TODO(gvisor.dev/issue/1476): Implement BoundEndpointAt.
return nil, syserror.ECONNREFUSED
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 293df2545..8b4e91d17 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -27,8 +27,9 @@
// dentry.handleMu
// dentry.dataMu
//
-// Locking dentry.dirMu in multiple dentries requires holding
-// filesystem.renameMu for writing.
+// Locking dentry.dirMu in multiple dentries requires that either ancestor
+// dentries are locked before descendant dentries, or that filesystem.renameMu
+// is locked for writing.
package gofer
import (
@@ -102,11 +103,12 @@ type filesystem struct {
cachedDentries dentryList
cachedDentriesLen uint64
- // dentries contains all dentries in this filesystem. specialFileFDs
- // contains all open specialFileFDs. These fields are protected by syncMu.
- syncMu sync.Mutex
- dentries map[*dentry]struct{}
- specialFileFDs map[*specialFileFD]struct{}
+ // syncableDentries contains all dentries in this filesystem for which
+ // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs.
+ // These fields are protected by syncMu.
+ syncMu sync.Mutex
+ syncableDentries map[*dentry]struct{}
+ specialFileFDs map[*specialFileFD]struct{}
}
type filesystemOptions struct {
@@ -187,7 +189,8 @@ const (
// InteropModeShared is appropriate when there are users of the remote
// filesystem that may mutate its state other than the client.
//
- // - The client must verify cached filesystem state before using it.
+ // - The client must verify ("revalidate") cached filesystem state before
+ // using it.
//
// - Client changes to filesystem state must be sent to the remote
// filesystem synchronously.
@@ -376,14 +379,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Construct the filesystem object.
fs := &filesystem{
- mfp: mfp,
- opts: fsopts,
- uid: creds.EffectiveKUID,
- gid: creds.EffectiveKGID,
- client: client,
- clock: ktime.RealtimeClockFromContext(ctx),
- dentries: make(map[*dentry]struct{}),
- specialFileFDs: make(map[*specialFileFD]struct{}),
+ mfp: mfp,
+ opts: fsopts,
+ uid: creds.EffectiveKUID,
+ gid: creds.EffectiveKGID,
+ client: client,
+ clock: ktime.RealtimeClockFromContext(ctx),
+ syncableDentries: make(map[*dentry]struct{}),
+ specialFileFDs: make(map[*specialFileFD]struct{}),
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
@@ -409,7 +412,7 @@ func (fs *filesystem) Release() {
mf := fs.mfp.MemoryFile()
fs.syncMu.Lock()
- for d := range fs.dentries {
+ for d := range fs.syncableDentries {
d.handleMu.Lock()
d.dataMu.Lock()
if d.handleWritable {
@@ -444,9 +447,11 @@ type dentry struct {
vfsd vfs.Dentry
// refs is the reference count. Each dentry holds a reference on its
- // parent, even if disowned. refs is accessed using atomic memory
- // operations. When refs reaches 0, the dentry may be added to the cache or
- // destroyed. If refs==-1 the dentry has already been destroyed.
+ // parent, even if disowned. An additional reference is held on all
+ // synthetic dentries until they are unlinked or invalidated. When refs
+ // reaches 0, the dentry may be added to the cache or destroyed. If refs ==
+ // -1, the dentry has already been destroyed. refs is accessed using atomic
+ // memory operations.
refs int64
// fs is the owning filesystem. fs is immutable.
@@ -465,6 +470,12 @@ type dentry struct {
// We don't support hard links, so each dentry maps 1:1 to an inode.
// file is the unopened p9.File that backs this dentry. file is immutable.
+ //
+ // If file.isNil(), this dentry represents a synthetic file, i.e. a file
+ // that does not exist on the remote filesystem. As of this writing, this
+ // is only possible for a directory created with
+ // MkdirOptions.ForSyntheticMountpoint == true.
+ // TODO(gvisor.dev/issue/1476): Support synthetic sockets (and pipes).
file p9file
// If deleted is non-zero, the file represented by this dentry has been
@@ -484,15 +495,21 @@ type dentry struct {
// - Mappings of child filenames to dentries representing those children.
//
// - Mappings of child filenames that are known not to exist to nil
- // dentries (only if InteropModeShared is not in effect).
+ // dentries (only if InteropModeShared is not in effect and the directory
+ // is not synthetic).
//
// children is protected by dirMu.
children map[string]*dentry
- // If this dentry represents a directory, InteropModeShared is not in
- // effect, and dirents is not nil, it is a cache of all entries in the
- // directory, in the order they were returned by the server. dirents is
- // protected by dirMu.
+ // If this dentry represents a directory, syntheticChildren is the number
+ // of child dentries for which dentry.isSynthetic() == true.
+ // syntheticChildren is protected by dirMu.
+ syntheticChildren int
+
+ // If this dentry represents a directory,
+ // dentry.cachedMetadataAuthoritative() == true, and dirents is not nil, it
+ // is a cache of all entries in the directory, in the order they were
+ // returned by the server. dirents is protected by dirMu.
dirents []vfs.Dirent
// Cached metadata; protected by metadataMu and accessed using atomic
@@ -589,6 +606,8 @@ func dentryAttrMask() p9.AttrMask {
// initially has no references, but is not cached; it is the caller's
// responsibility to set the dentry's reference count and/or call
// dentry.checkCachingLocked() as appropriate.
+//
+// Preconditions: !file.isNil().
func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, mask p9.AttrMask, attr *p9.Attr) (*dentry, error) {
if !mask.Mode {
ctx.Warningf("can't create gofer.dentry without file type")
@@ -612,10 +631,10 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
},
}
d.pf.dentry = d
- if mask.UID {
+ if mask.UID && attr.UID != auth.NoID {
d.uid = uint32(attr.UID)
}
- if mask.GID {
+ if mask.GID && attr.GID != auth.NoID {
d.gid = uint32(attr.GID)
}
if mask.Size {
@@ -642,11 +661,19 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
d.vfsd.Init(d)
fs.syncMu.Lock()
- fs.dentries[d] = struct{}{}
+ fs.syncableDentries[d] = struct{}{}
fs.syncMu.Unlock()
return d, nil
}
+func (d *dentry) isSynthetic() bool {
+ return d.file.isNil()
+}
+
+func (d *dentry) cachedMetadataAuthoritative() bool {
+ return d.fs.opts.interop != InteropModeShared || d.isSynthetic()
+}
+
// updateFromP9Attrs is called to update d's metadata after an update from the
// remote filesystem.
func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
@@ -691,6 +718,7 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
d.metadataMu.Unlock()
}
+// Preconditions: !d.isSynthetic()
func (d *dentry) updateFromGetattr(ctx context.Context) error {
// Use d.handle.file, which represents a 9P fid that has been opened, in
// preference to d.file, which represents a 9P fid that has not. This may
@@ -758,7 +786,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
defer mnt.EndWrite()
setLocalAtime := false
setLocalMtime := false
- if d.fs.opts.interop != InteropModeShared {
+ if d.cachedMetadataAuthoritative() {
// Timestamp updates will be handled locally.
setLocalAtime = stat.Mask&linux.STATX_ATIME != 0
setLocalMtime = stat.Mask&linux.STATX_MTIME != 0
@@ -771,35 +799,37 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin
}
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
- if stat.Mask != 0 {
- if err := d.file.setAttr(ctx, p9.SetAttrMask{
- Permissions: stat.Mask&linux.STATX_MODE != 0,
- UID: stat.Mask&linux.STATX_UID != 0,
- GID: stat.Mask&linux.STATX_GID != 0,
- Size: stat.Mask&linux.STATX_SIZE != 0,
- ATime: stat.Mask&linux.STATX_ATIME != 0,
- MTime: stat.Mask&linux.STATX_MTIME != 0,
- ATimeNotSystemTime: stat.Atime.Nsec != linux.UTIME_NOW,
- MTimeNotSystemTime: stat.Mtime.Nsec != linux.UTIME_NOW,
- }, p9.SetAttr{
- Permissions: p9.FileMode(stat.Mode),
- UID: p9.UID(stat.UID),
- GID: p9.GID(stat.GID),
- Size: stat.Size,
- ATimeSeconds: uint64(stat.Atime.Sec),
- ATimeNanoSeconds: uint64(stat.Atime.Nsec),
- MTimeSeconds: uint64(stat.Mtime.Sec),
- MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
- }); err != nil {
- return err
+ if !d.isSynthetic() {
+ if stat.Mask != 0 {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
+ UID: stat.Mask&linux.STATX_UID != 0,
+ GID: stat.Mask&linux.STATX_GID != 0,
+ Size: stat.Mask&linux.STATX_SIZE != 0,
+ ATime: stat.Mask&linux.STATX_ATIME != 0,
+ MTime: stat.Mask&linux.STATX_MTIME != 0,
+ ATimeNotSystemTime: stat.Atime.Nsec != linux.UTIME_NOW,
+ MTimeNotSystemTime: stat.Mtime.Nsec != linux.UTIME_NOW,
+ }, p9.SetAttr{
+ Permissions: p9.FileMode(stat.Mode),
+ UID: p9.UID(stat.UID),
+ GID: p9.GID(stat.GID),
+ Size: stat.Size,
+ ATimeSeconds: uint64(stat.Atime.Sec),
+ ATimeNanoSeconds: uint64(stat.Atime.Nsec),
+ MTimeSeconds: uint64(stat.Mtime.Sec),
+ MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
+ }); err != nil {
+ return err
+ }
+ }
+ if d.fs.opts.interop == InteropModeShared {
+ // There's no point to updating d's metadata in this case since
+ // it'll be overwritten by revalidation before the next time it's
+ // used anyway. (InteropModeShared inhibits client caching of
+ // regular file data, so there's no cache to truncate either.)
+ return nil
}
- }
- if d.fs.opts.interop == InteropModeShared {
- // There's no point to updating d's metadata in this case since it'll
- // be overwritten by revalidation before the next time it's used
- // anyway. (InteropModeShared inhibits client caching of regular file
- // data, so there's no cache to truncate either.)
- return nil
}
now := d.fs.clock.Now().Nanoseconds()
if stat.Mask&linux.STATX_MODE != 0 {
@@ -897,6 +927,15 @@ func (d *dentry) DecRef() {
}
}
+// decRefLocked decrements d's reference count without calling
+// d.checkCachingLocked, even if d's reference count reaches 0; callers are
+// responsible for ensuring that d.checkCachingLocked will be called later.
+func (d *dentry) decRefLocked() {
+ if refs := atomic.AddInt64(&d.refs, -1); refs < 0 {
+ panic("gofer.dentry.decRefLocked() called without holding a reference")
+ }
+}
+
// checkCachingLocked should be called after d's reference count becomes 0 or it
// becomes disowned.
//
@@ -1013,11 +1052,11 @@ func (d *dentry) destroyLocked() {
if !d.file.isNil() {
d.file.close(ctx)
d.file = p9file{}
+ // Remove d from the set of syncable dentries.
+ d.fs.syncMu.Lock()
+ delete(d.fs.syncableDentries, d)
+ d.fs.syncMu.Unlock()
}
- // Remove d from the set of all dentries.
- d.fs.syncMu.Lock()
- delete(d.fs.dentries, d)
- d.fs.syncMu.Unlock()
// Drop the reference held by d on its parent without recursively locking
// d.fs.renameMu.
if d.parent != nil {
@@ -1040,6 +1079,9 @@ func (d *dentry) setDeleted() {
// We only support xattrs prefixed with "user." (see b/148380782). Currently,
// there is no need to expose any other xattrs through a gofer.
func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
+ if d.file.isNil() {
+ return nil, nil
+ }
xattrMap, err := d.file.listXattr(ctx, size)
if err != nil {
return nil, err
@@ -1054,6 +1096,9 @@ func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size ui
}
func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+ if d.file.isNil() {
+ return "", syserror.ENODATA
+ }
if err := d.checkPermissions(creds, vfs.MayRead); err != nil {
return "", err
}
@@ -1064,6 +1109,9 @@ func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vf
}
func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+ if d.file.isNil() {
+ return syserror.EPERM
+ }
if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
return err
}
@@ -1074,6 +1122,9 @@ func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vf
}
func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error {
+ if d.file.isNil() {
+ return syserror.EPERM
+ }
if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
return err
}
@@ -1083,7 +1134,7 @@ func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name
return d.file.removeXattr(ctx, name)
}
-// Preconditions: d.isRegularFile() || d.isDirectory().
+// Preconditions: !d.file.isNil(). d.isRegularFile() || d.isDirectory().
func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error {
// O_TRUNC unconditionally requires us to obtain a new handle (opened with
// O_TRUNC).
@@ -1213,7 +1264,7 @@ func (fd *fileDescription) dentry() *dentry {
func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
d := fd.dentry()
const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME)
- if d.fs.opts.interop == InteropModeShared && opts.Mask&(validMask) != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
+ if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
// TODO(jamieliu): Use specialFileFD.handle.file for the getattr if
// available?
if err := d.updateFromGetattr(ctx); err != nil {
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index 4041fb252..adff39490 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -24,7 +24,7 @@ import (
func TestDestroyIdempotent(t *testing.T) {
fs := filesystem{
- dentries: make(map[*dentry]struct{}),
+ syncableDentries: make(map[*dentry]struct{}),
opts: filesystemOptions{
// Test relies on no dentry being held in the cache.
maxCachedDentries: 0,
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
index cfe66f797..724a3f1f7 100644
--- a/pkg/sentry/fsimpl/gofer/handle.go
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
)
// handle represents a remote "open file descriptor", consisting of an opened
@@ -77,7 +78,7 @@ func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offs
}
if h.fd >= 0 {
ctx.UninterruptibleSleepStart(false)
- n, err := hostPreadv(h.fd, dsts, int64(offset))
+ n, err := hostfd.Preadv2(h.fd, dsts, int64(offset), 0 /* flags */)
ctx.UninterruptibleSleepFinish(false)
return n, err
}
@@ -103,7 +104,7 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o
}
if h.fd >= 0 {
ctx.UninterruptibleSleepStart(false)
- n, err := hostPwritev(h.fd, srcs, int64(offset))
+ n, err := hostfd.Pwritev2(h.fd, srcs, int64(offset), 0 /* flags */)
ctx.UninterruptibleSleepFinish(false)
return n, err
}
diff --git a/pkg/sentry/fsimpl/gofer/handle_unsafe.go b/pkg/sentry/fsimpl/gofer/handle_unsafe.go
deleted file mode 100644
index 19560ab26..000000000
--- a/pkg/sentry/fsimpl/gofer/handle_unsafe.go
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package gofer
-
-import (
- "syscall"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/safemem"
-)
-
-// Preconditions: !dsts.IsEmpty().
-func hostPreadv(fd int32, dsts safemem.BlockSeq, off int64) (uint64, error) {
- // No buffering is necessary regardless of safecopy; host syscalls will
- // return EFAULT if appropriate, instead of raising SIGBUS.
- if dsts.NumBlocks() == 1 {
- // Use pread() instead of preadv() to avoid iovec allocation and
- // copying.
- dst := dsts.Head()
- n, _, e := syscall.Syscall6(syscall.SYS_PREAD64, uintptr(fd), dst.Addr(), uintptr(dst.Len()), uintptr(off), 0, 0)
- if e != 0 {
- return 0, e
- }
- return uint64(n), nil
- }
- iovs := safemem.IovecsFromBlockSeq(dsts)
- n, _, e := syscall.Syscall6(syscall.SYS_PREADV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(off), 0, 0)
- if e != 0 {
- return 0, e
- }
- return uint64(n), nil
-}
-
-// Preconditions: !srcs.IsEmpty().
-func hostPwritev(fd int32, srcs safemem.BlockSeq, off int64) (uint64, error) {
- // No buffering is necessary regardless of safecopy; host syscalls will
- // return EFAULT if appropriate, instead of raising SIGBUS.
- if srcs.NumBlocks() == 1 {
- // Use pwrite() instead of pwritev() to avoid iovec allocation and
- // copying.
- src := srcs.Head()
- n, _, e := syscall.Syscall6(syscall.SYS_PWRITE64, uintptr(fd), src.Addr(), uintptr(src.Len()), uintptr(off), 0, 0)
- if e != 0 {
- return 0, e
- }
- return uint64(n), nil
- }
- iovs := safemem.IovecsFromBlockSeq(srcs)
- n, _, e := syscall.Syscall6(syscall.SYS_PWRITEV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(off), 0, 0)
- if e != 0 {
- return 0, e
- }
- return uint64(n), nil
-}
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index 82e1fb74b..44dd9f672 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -15,12 +15,11 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
- "//pkg/fd",
"//pkg/log",
"//pkg/refs",
- "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/hostfd",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index fe14476f1..ae94cfa6e 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -25,11 +25,10 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -492,19 +491,9 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off
if flags != 0 {
return 0, syserror.EOPNOTSUPP
}
-
- var reader safemem.Reader
- if offset == -1 {
- reader = safemem.FromIOReader{fd.NewReadWriter(hostFD)}
- } else {
- reader = safemem.FromVecReaderFunc{
- func(srcs [][]byte) (int64, error) {
- n, err := unix.Preadv(hostFD, srcs, offset)
- return int64(n), err
- },
- }
- }
+ reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
n, err := dst.CopyOutFrom(ctx, reader)
+ hostfd.PutReadWriterAt(reader)
return int64(n), err
}
@@ -542,19 +531,9 @@ func writeToHostFD(ctx context.Context, hostFD int, src usermem.IOSequence, offs
if flags != 0 {
return 0, syserror.EOPNOTSUPP
}
-
- var writer safemem.Writer
- if offset == -1 {
- writer = safemem.FromIOWriter{fd.NewReadWriter(hostFD)}
- } else {
- writer = safemem.FromVecWriterFunc{
- func(srcs [][]byte) (int64, error) {
- n, err := unix.Pwritev(hostFD, srcs, offset)
- return int64(n), err
- },
- }
- }
+ writer := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
n, err := src.CopyInTo(ctx, writer)
+ hostfd.PutReadWriterAt(writer)
return int64(n), err
}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 01c23d192..1d46dba25 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -246,8 +246,8 @@ func (fs *Filesystem) Sync(ctx context.Context) error {
// AccessAt implements vfs.Filesystem.Impl.AccessAt.
func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
fs.mu.RLock()
- defer fs.mu.RUnlock()
defer fs.processDeferredDecRefs()
+ defer fs.mu.RUnlock()
_, inode, err := fs.walkExistingLocked(ctx, rp)
if err != nil {
@@ -391,7 +391,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// O_NOFOLLOW have no effect here (they're handled by VFS by setting
// appropriate bits in rp), but are returned by
// FileDescriptionImpl.StatusFlags().
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK
ats := vfs.AccessTypesForOpenFlags(&opts)
// Do not create new file.
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 9f526359e..a946645f6 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -216,6 +216,11 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, ino uint64, mode linux.FileMo
atomic.StoreUint32(&a.nlink, nlink)
}
+// Ino returns the inode id.
+func (a *InodeAttrs) Ino() uint64 {
+ return atomic.LoadUint64(&a.ino)
+}
+
// Mode implements Inode.Mode.
func (a *InodeAttrs) Mode() linux.FileMode {
return linux.FileMode(atomic.LoadUint32(&a.mode))
@@ -359,8 +364,8 @@ func (o *OrderedChildren) Destroy() {
// cache. Populate returns the number of directories inserted, which the caller
// may use to update the link count for the parent directory.
//
-// Precondition: d.Impl() must be a kernfs Dentry. d must represent a directory
-// inode. children must not contain any conflicting entries already in o.
+// Precondition: d must represent a directory inode. children must not contain
+// any conflicting entries already in o.
func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 {
var links uint32
for name, child := range children {
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 2c6f8bdfc..f3173e197 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -111,17 +111,18 @@ func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error {
}
defer m.DecUsers(ctx)
- // Space for buffer with AT_NULL (0) terminator at the end.
auxv := m.Auxv()
+ // Space for buffer with AT_NULL (0) terminator at the end.
buf.Grow((len(auxv) + 1) * 16)
for _, e := range auxv {
- var tmp [8]byte
- usermem.ByteOrder.PutUint64(tmp[:], e.Key)
- buf.Write(tmp[:])
-
- usermem.ByteOrder.PutUint64(tmp[:], uint64(e.Value))
+ var tmp [16]byte
+ usermem.ByteOrder.PutUint64(tmp[:8], e.Key)
+ usermem.ByteOrder.PutUint64(tmp[8:], uint64(e.Value))
buf.Write(tmp[:])
}
+ var atNull [16]byte
+ buf.Write(atNull[:])
+
return nil
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 3d5dc463c..f08668ca2 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -39,7 +39,7 @@ func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
"shmmni": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMMNI)),
}),
"vm": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
- "mmap_min_addr": newDentry(root, inoGen.NextIno(), 0444, &mmapMinAddrData{}),
+ "mmap_min_addr": newDentry(root, inoGen.NextIno(), 0444, &mmapMinAddrData{k: k}),
"overcommit_memory": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0\n")),
}),
"net": newSysNetDir(root, inoGen, k),
diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD
new file mode 100644
index 000000000..364a78306
--- /dev/null
+++ b/pkg/sentry/hostfd/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "hostfd",
+ srcs = [
+ "hostfd.go",
+ "hostfd_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/safemem",
+ "//pkg/sync",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/hostfd/hostfd.go b/pkg/sentry/hostfd/hostfd.go
new file mode 100644
index 000000000..70dd9cafb
--- /dev/null
+++ b/pkg/sentry/hostfd/hostfd.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostfd provides efficient I/O with host file descriptors.
+package hostfd
+
+import (
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// ReadWriterAt implements safemem.Reader and safemem.Writer by reading from
+// and writing to a host file descriptor respectively. ReadWriterAts should be
+// obtained by calling GetReadWriterAt.
+//
+// Clients should usually prefer to use Preadv2 and Pwritev2 directly.
+type ReadWriterAt struct {
+ fd int32
+ offset int64
+ flags uint32
+}
+
+var rwpool = sync.Pool{
+ New: func() interface{} {
+ return &ReadWriterAt{}
+ },
+}
+
+// GetReadWriterAt returns a ReadWriterAt that reads from / writes to the given
+// host file descriptor, starting at the given offset and using the given
+// preadv2(2)/pwritev2(2) flags. If offset is -1, the host file descriptor's
+// offset is used instead. Users are responsible for ensuring that fd remains
+// valid for the lifetime of the returned ReadWriterAt, and must call
+// PutReadWriterAt when it is no longer needed.
+func GetReadWriterAt(fd int32, offset int64, flags uint32) *ReadWriterAt {
+ rw := rwpool.Get().(*ReadWriterAt)
+ *rw = ReadWriterAt{
+ fd: fd,
+ offset: offset,
+ flags: flags,
+ }
+ return rw
+}
+
+// PutReadWriterAt releases a ReadWriterAt returned by a previous call to
+// GetReadWriterAt that is no longer in use.
+func PutReadWriterAt(rw *ReadWriterAt) {
+ rwpool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *ReadWriterAt) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ n, err := Preadv2(rw.fd, dsts, rw.offset, rw.flags)
+ if rw.offset >= 0 {
+ rw.offset += int64(n)
+ }
+ return n, err
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *ReadWriterAt) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ n, err := Pwritev2(rw.fd, srcs, rw.offset, rw.flags)
+ if rw.offset >= 0 {
+ rw.offset += int64(n)
+ }
+ return n, err
+}
diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go
new file mode 100644
index 000000000..5e9e60fc4
--- /dev/null
+++ b/pkg/sentry/hostfd/hostfd_unsafe.go
@@ -0,0 +1,107 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostfd
+
+import (
+ "io"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// Preadv2 reads up to dsts.NumBytes() bytes from host file descriptor fd into
+// dsts. offset and flags are interpreted as for preadv2(2).
+//
+// Preconditions: !dsts.IsEmpty().
+func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint64, error) {
+ // No buffering is necessary regardless of safecopy; host syscalls will
+ // return EFAULT if appropriate, instead of raising SIGBUS.
+ var (
+ n uintptr
+ e syscall.Errno
+ )
+ // Avoid preadv2(2) if possible, since it's relatively new and thus least
+ // likely to be supported by the host kernel.
+ if flags == 0 {
+ if dsts.NumBlocks() == 1 {
+ // Use read() or pread() to avoid iovec allocation and copying.
+ dst := dsts.Head()
+ if offset == -1 {
+ n, _, e = syscall.Syscall(unix.SYS_READ, uintptr(fd), dst.Addr(), uintptr(dst.Len()))
+ } else {
+ n, _, e = syscall.Syscall6(unix.SYS_PREAD64, uintptr(fd), dst.Addr(), uintptr(dst.Len()), uintptr(offset), 0 /* pos_h */, 0 /* unused */)
+ }
+ } else {
+ iovs := safemem.IovecsFromBlockSeq(dsts)
+ if offset == -1 {
+ n, _, e = syscall.Syscall(unix.SYS_READV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)))
+ } else {
+ n, _, e = syscall.Syscall6(unix.SYS_PREADV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, 0 /* unused */)
+ }
+ }
+ } else {
+ iovs := safemem.IovecsFromBlockSeq(dsts)
+ n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
+ }
+ if e != 0 {
+ return 0, e
+ }
+ if n == 0 {
+ return 0, io.EOF
+ }
+ return uint64(n), nil
+}
+
+// Pwritev2 writes up to srcs.NumBytes() from srcs into host file descriptor
+// fd. offset and flags are interpreted as for pwritev2(2).
+//
+// Preconditions: !srcs.IsEmpty().
+func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint64, error) {
+ // No buffering is necessary regardless of safecopy; host syscalls will
+ // return EFAULT if appropriate, instead of raising SIGBUS.
+ var (
+ n uintptr
+ e syscall.Errno
+ )
+ // Avoid pwritev2(2) if possible, since it's relatively new and thus least
+ // likely to be supported by the host kernel.
+ if flags == 0 {
+ if srcs.NumBlocks() == 1 {
+ // Use write() or pwrite() to avoid iovec allocation and copying.
+ src := srcs.Head()
+ if offset == -1 {
+ n, _, e = syscall.Syscall(unix.SYS_WRITE, uintptr(fd), src.Addr(), uintptr(src.Len()))
+ } else {
+ n, _, e = syscall.Syscall6(unix.SYS_PWRITE64, uintptr(fd), src.Addr(), uintptr(src.Len()), uintptr(offset), 0 /* pos_h */, 0 /* unused */)
+ }
+ } else {
+ iovs := safemem.IovecsFromBlockSeq(srcs)
+ if offset == -1 {
+ n, _, e = syscall.Syscall(unix.SYS_WRITEV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)))
+ } else {
+ n, _, e = syscall.Syscall6(unix.SYS_PWRITEV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, 0 /* unused */)
+ }
+ }
+ } else {
+ iovs := safemem.IovecsFromBlockSeq(srcs)
+ n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
+ }
+ if e != 0 {
+ return 0, e
+ }
+ return uint64(n), nil
+}
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index dedf0fa15..75eedd5a2 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -24,6 +24,7 @@ go_library(
],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/abi/linux",
"//pkg/context",
"//pkg/refs",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 592650923..3d78cd48f 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -20,6 +20,7 @@ import (
"fmt"
"syscall"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -30,19 +31,6 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// Event describes the event mask that was observed and the user data to be
-// returned when one of the events occurs. It has this format to match the linux
-// format to avoid extra copying/allocation when writing events to userspace.
-type Event struct {
- // Events is the event mask containing the set of events that have been
- // observed on an entry.
- Events uint32
-
- // Data is an opaque 64-bit value provided by the caller when adding the
- // entry, and returned to the caller when the entry reports an event.
- Data [2]int32
-}
-
// EntryFlags is a bitmask that holds an entry's flags.
type EntryFlags int
@@ -227,9 +215,9 @@ func (e *EventPoll) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// ReadEvents returns up to max available events.
-func (e *EventPoll) ReadEvents(max int) []Event {
+func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent {
var local pollEntryList
- var ret []Event
+ var ret []linux.EpollEvent
e.listsMu.Lock()
@@ -251,7 +239,7 @@ func (e *EventPoll) ReadEvents(max int) []Event {
}
// Add event to the array that will be returned to caller.
- ret = append(ret, Event{
+ ret = append(ret, linux.EpollEvent{
Events: uint32(ready),
Data: entry.userData,
})
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index e5d133d6c..f48247c94 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -484,7 +484,7 @@ type Task struct {
// bit.
//
// numaPolicy and numaNodeMask are protected by mu.
- numaPolicy int32
+ numaPolicy linux.NumaPolicy
numaNodeMask uint64
// netns is the task's network namespace. netns is never nil.
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index 2ba8d7e63..d654dd997 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -96,6 +96,7 @@ func (t *Task) run(threadID uintptr) {
t.tg.liveGoroutines.Done()
t.tg.pidns.owner.liveGoroutines.Done()
t.tg.pidns.owner.runningGoroutines.Done()
+ t.p.Release()
// Keep argument alive because stack trace for dead variables may not be correct.
runtime.KeepAlive(threadID)
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
index 8b148db35..09366b60c 100644
--- a/pkg/sentry/kernel/task_sched.go
+++ b/pkg/sentry/kernel/task_sched.go
@@ -653,14 +653,14 @@ func (t *Task) SetNiceness(n int) {
}
// NumaPolicy returns t's current numa policy.
-func (t *Task) NumaPolicy() (policy int32, nodeMask uint64) {
+func (t *Task) NumaPolicy() (policy linux.NumaPolicy, nodeMask uint64) {
t.mu.Lock()
defer t.mu.Unlock()
return t.numaPolicy, t.numaNodeMask
}
// SetNumaPolicy sets t's numa policy.
-func (t *Task) SetNumaPolicy(policy int32, nodeMask uint64) {
+func (t *Task) SetNumaPolicy(policy linux.NumaPolicy, nodeMask uint64) {
t.mu.Lock()
defer t.mu.Unlock()
t.numaPolicy = policy
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index f07de2089..7d25e98f7 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -263,6 +263,19 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
if t.haveSavedSignalMask {
mask = t.savedSignalMask
}
+
+ // Set up the restorer.
+ // x86-64 should always uses SA_RESTORER, but this flag is optional on other platforms.
+ // Please see the linux code as reference:
+ // linux/arch/x86/kernel/signal.c:__setup_rt_frame()
+ // If SA_RESTORER is not configured, we can use the sigreturn trampolines
+ // the vdso provides instead.
+ // Please see the linux code as reference:
+ // linux/arch/arm64/kernel/signal.c:setup_return()
+ if act.Flags&linux.SA_RESTORER == 0 {
+ act.Restorer = t.MemoryManager().VDSOSigReturn()
+ }
+
if err := t.Arch().SignalSetup(st, &act, info, &alt, mask); err != nil {
return err
}
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index d6675b8f0..88449fe95 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -311,6 +311,15 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
m.SetAuxv(auxv)
m.SetExecutable(file)
+ symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn")
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to find rt_sigreturn in vdso: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ // Found rt_sigretrun.
+ addr := uint64(vdsoAddr) + symbolValue - vdsoPrelink
+ m.SetVDSOSigReturn(addr)
+
ac.SetIP(uintptr(loaded.entry))
ac.SetStack(uintptr(stack.Bottom))
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 161b28c2c..00977fc08 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -15,9 +15,11 @@
package loader
import (
+ "bytes"
"debug/elf"
"fmt"
"io"
+ "strings"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/context"
@@ -38,6 +40,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+const vdsoPrelink = 0xffffffffff700000
+
type fileContext struct {
context.Context
}
@@ -221,6 +225,27 @@ type VDSO struct {
phdrs []elf.ProgHeader `state:".([]elfProgHeader)"`
}
+// getSymbolValueFromVDSO returns the specific symbol value in vdso.so.
+func getSymbolValueFromVDSO(symbol string) (uint64, error) {
+ f, err := elf.NewFile(bytes.NewReader(vdsoBin))
+ if err != nil {
+ return 0, err
+ }
+ syms, err := f.Symbols()
+ if err != nil {
+ return 0, err
+ }
+
+ for _, sym := range syms {
+ if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF {
+ if strings.Contains(sym.Name, symbol) {
+ return sym.Value, nil
+ }
+ }
+ }
+ return 0, fmt.Errorf("no %v in vdso.so", symbol)
+}
+
// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
// param page for updating by the kernel.
func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 34d3bde7a..6db7c3d40 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -35,6 +35,7 @@
package mm
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fsbridge"
@@ -286,7 +287,7 @@ type vma struct {
mlockMode memmap.MLockMode
// numaPolicy is the NUMA policy for this vma set by mbind().
- numaPolicy int32
+ numaPolicy linux.NumaPolicy
// numaNodemask is the NUMA nodemask for this vma set by mbind().
numaNodemask uint64
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
index 1ab92f046..6efe5102b 100644
--- a/pkg/sentry/mm/procfs.go
+++ b/pkg/sentry/mm/procfs.go
@@ -148,7 +148,7 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI
// Do not include the guard page: fs/proc/task_mmu.c:show_map_vma() =>
// stack_guard_page_start().
- fmt.Fprintf(b, "%08x-%08x %s%s %08x %02x:%02x %d ",
+ lineLen, _ := fmt.Fprintf(b, "%08x-%08x %s%s %08x %02x:%02x %d ",
vseg.Start(), vseg.End(), vma.realPerms, private, vma.off, devMajor, devMinor, ino)
// Figure out our filename or hint.
@@ -165,7 +165,7 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI
}
if s != "" {
// Per linux, we pad until the 74th character.
- if pad := 73 - b.Len(); pad > 0 {
+ if pad := 73 - lineLen; pad > 0 {
b.WriteString(strings.Repeat(" ", pad))
}
b.WriteString(s)
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index c5dfa5972..3f496aa9f 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -974,7 +974,7 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error
}
// NumaPolicy implements the semantics of Linux's get_mempolicy(MPOL_F_ADDR).
-func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (int32, uint64, error) {
+func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (linux.NumaPolicy, uint64, error) {
mm.mappingMu.RLock()
defer mm.mappingMu.RUnlock()
vseg := mm.vmas.FindSegment(addr)
@@ -986,7 +986,7 @@ func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (int32, uint64, error) {
}
// SetNumaPolicy implements the semantics of Linux's mbind().
-func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy int32, nodemask uint64) error {
+func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy linux.NumaPolicy, nodemask uint64) error {
if !addr.IsPageAligned() {
return syserror.EINVAL
}
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
index c769ac7b4..6507121ea 100644
--- a/pkg/sentry/platform/kvm/context.go
+++ b/pkg/sentry/platform/kvm/context.go
@@ -85,3 +85,6 @@ func (c *context) Switch(as platform.AddressSpace, ac arch.Context, _ int32) (*a
func (c *context) Interrupt() {
c.interrupt.NotifyInterrupt()
}
+
+// Release implements platform.Context.Release().
+func (c *context) Release() {}
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index a9b4af43e..ae813e24e 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -191,6 +191,11 @@ func (*constructor) OpenDevice() (*os.File, error) {
return OpenDevice()
}
+// Flags implements platform.Constructor.Flags().
+func (*constructor) Requirements() platform.Requirements {
+ return platform.Requirements{}
+}
+
func init() {
platform.Register("kvm", &constructor{})
}
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index 2ca696382..171513f3f 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -148,6 +148,9 @@ type Context interface {
// Interrupt interrupts a concurrent call to Switch(), causing it to return
// ErrContextInterrupt.
Interrupt()
+
+ // Release() releases any resources associated with this context.
+ Release()
}
var (
@@ -353,10 +356,28 @@ func (fr FileRange) String() string {
return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
}
+// Requirements is used to specify platform specific requirements.
+type Requirements struct {
+ // RequiresCurrentPIDNS indicates that the sandbox has to be started in the
+ // current pid namespace.
+ RequiresCurrentPIDNS bool
+ // RequiresCapSysPtrace indicates that the sandbox has to be started with
+ // the CAP_SYS_PTRACE capability.
+ RequiresCapSysPtrace bool
+}
+
// Constructor represents a platform type.
type Constructor interface {
+ // New returns a new platform instance.
+ //
+ // Arguments:
+ //
+ // * deviceFile - the device file (e.g. /dev/kvm for the KVM platform).
New(deviceFile *os.File) (Platform, error)
OpenDevice() (*os.File, error)
+
+ // Requirements returns platform specific requirements.
+ Requirements() Requirements
}
// platforms contains all available platform types.
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index 03adb624b..08d055e05 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -177,6 +177,9 @@ func (c *context) Interrupt() {
c.interrupt.NotifyInterrupt()
}
+// Release implements platform.Context.Release().
+func (c *context) Release() {}
+
// PTrace represents a collection of ptrace subprocesses.
type PTrace struct {
platform.MMapMinAddr
@@ -248,6 +251,16 @@ func (*constructor) OpenDevice() (*os.File, error) {
return nil, nil
}
+// Flags implements platform.Constructor.Flags().
+func (*constructor) Requirements() platform.Requirements {
+ // TODO(b/75837838): Also set a new PID namespace so that we limit
+ // access to other host processes.
+ return platform.Requirements{
+ RequiresCapSysPtrace: true,
+ RequiresCurrentPIDNS: true,
+ }
+}
+
func init() {
platform.Register("ptrace", &constructor{})
}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index a644609ef..773ddb1ed 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -332,7 +332,7 @@ func (t *thread) unexpectedStubExit() {
msg, err := t.getEventMessage()
status := syscall.WaitStatus(msg)
if status.Signaled() && status.Signal() == syscall.SIGKILL {
- // SIGKILL can be only sent by an user or OOM-killer. In both
+ // SIGKILL can be only sent by a user or OOM-killer. In both
// these cases, we don't need to panic. There is no reasons to
// think that something wrong in gVisor.
log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid)
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index ff1cfd8f6..55c0f04f3 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
tcpHeader = header.TCP(pkt.TransportHeader)
} else {
// The TCP header hasn't been parsed yet. We have to do it here.
- if len(pkt.Data.First()) < header.TCPMinimumSize {
+ hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize)
+ if !ok {
// There's no valid TCP header here, so we hotdrop the
// packet.
return false, true
}
- tcpHeader = header.TCP(pkt.Data.First())
+ tcpHeader = header.TCP(hdr)
}
// Check whether the source and destination ports are within the
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 3359418c1..04d03d494 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
udpHeader = header.UDP(pkt.TransportHeader)
} else {
// The UDP header hasn't been parsed yet. We have to do it here.
- if len(pkt.Data.First()) < header.UDPMinimumSize {
+ hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
+ if !ok {
// There's no valid UDP header here, so we hotdrop the
// packet.
return false, true
}
- udpHeader = header.UDP(pkt.Data.First())
+ udpHeader = header.UDP(hdr)
}
// Check whether the source and destination ports are within the
diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go
index 87dcad18b..d9fb808c0 100644
--- a/pkg/sentry/syscalls/epoll.go
+++ b/pkg/sentry/syscalls/epoll.go
@@ -17,6 +17,7 @@ package syscalls
import (
"time"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -118,7 +119,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error {
}
// WaitEpoll implements the epoll_wait(2) linux syscall.
-func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]epoll.Event, error) {
+func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEvent, error) {
// Get epoll from the file descriptor.
epollfile := t.GetFile(fd)
if epollfile == nil {
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index 3ab93fbde..51bf205cf 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
"gvisor.dev/gvisor/pkg/sentry/syscalls"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -72,7 +71,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
var data [2]int32
if op != linux.EPOLL_CTL_DEL {
var e linux.EpollEvent
- if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ if _, err := e.CopyIn(t, eventAddr); err != nil {
return 0, nil, err
}
@@ -105,28 +104,6 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
}
-// copyOutEvents copies epoll events from the kernel to user memory.
-func copyOutEvents(t *kernel.Task, addr usermem.Addr, e []epoll.Event) error {
- const itemLen = 12
- buffLen := len(e) * itemLen
- if _, ok := addr.AddLength(uint64(buffLen)); !ok {
- return syserror.EFAULT
- }
-
- b := t.CopyScratchBuffer(buffLen)
- for i := range e {
- usermem.ByteOrder.PutUint32(b[i*itemLen:], e[i].Events)
- usermem.ByteOrder.PutUint32(b[i*itemLen+4:], uint32(e[i].Data[0]))
- usermem.ByteOrder.PutUint32(b[i*itemLen+8:], uint32(e[i].Data[1]))
- }
-
- if _, err := t.CopyOutBytes(addr, b); err != nil {
- return err
- }
-
- return nil
-}
-
// EpollWait implements the epoll_wait(2) linux syscall.
func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
epfd := args[0].Int()
@@ -140,7 +117,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
if len(r) != 0 {
- if err := copyOutEvents(t, eventsAddr, r); err != nil {
+ if _, err := linux.CopyEpollEventSliceOut(t, eventsAddr, r); err != nil {
return 0, nil, err
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go
index ac934dc6f..9b4a5c3f1 100644
--- a/pkg/sentry/syscalls/linux/sys_mempolicy.go
+++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go
@@ -162,10 +162,10 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
if err != nil {
return 0, nil, err
}
- policy = 0 // maxNodes == 1
+ policy = linux.MPOL_DEFAULT // maxNodes == 1
}
if mode != 0 {
- if _, err := t.CopyOut(mode, policy); err != nil {
+ if _, err := policy.CopyOut(t, mode); err != nil {
return 0, nil, err
}
}
@@ -199,10 +199,10 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE {
return 0, nil, syserror.EINVAL
}
- policy = 0 // maxNodes == 1
+ policy = linux.MPOL_DEFAULT // maxNodes == 1
}
if mode != 0 {
- if _, err := t.CopyOut(mode, policy); err != nil {
+ if _, err := policy.CopyOut(t, mode); err != nil {
return 0, nil, err
}
}
@@ -216,7 +216,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
// SetMempolicy implements the syscall set_mempolicy(2).
func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- modeWithFlags := args[0].Int()
+ modeWithFlags := linux.NumaPolicy(args[0].Int())
nodemask := args[1].Pointer()
maxnode := args[2].Uint()
@@ -233,7 +233,7 @@ func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
length := args[1].Uint64()
- mode := args[2].Int()
+ mode := linux.NumaPolicy(args[2].Int())
nodemask := args[3].Pointer()
maxnode := args[4].Uint()
flags := args[5].Uint()
@@ -258,9 +258,9 @@ func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
-func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags int32, nodemask usermem.Addr, maxnode uint32) (int32, uint64, error) {
- flags := modeWithFlags & linux.MPOL_MODE_FLAGS
- mode := modeWithFlags &^ linux.MPOL_MODE_FLAGS
+func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nodemask usermem.Addr, maxnode uint32) (linux.NumaPolicy, uint64, error) {
+ flags := linux.NumaPolicy(modeWithFlags & linux.MPOL_MODE_FLAGS)
+ mode := linux.NumaPolicy(modeWithFlags &^ linux.MPOL_MODE_FLAGS)
if flags == linux.MPOL_MODE_FLAGS {
// Can't specify both mode flags simultaneously.
return 0, 0, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 6ff2d84d2..f6fb0f219 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -6,7 +6,6 @@ go_library(
name = "vfs2",
srcs = [
"epoll.go",
- "epoll_unsafe.go",
"execve.go",
"fd.go",
"filesystem.go",
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
index 5a938cee2..34c90ae3e 100644
--- a/pkg/sentry/syscalls/linux/vfs2/epoll.go
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+var sizeofEpollEvent = (*linux.EpollEvent)(nil).SizeBytes()
+
// EpollCreate1 implements Linux syscall epoll_create1(2).
func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
@@ -124,7 +126,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
maxEvents := int(args[2].Int())
timeout := int(args[3].Int())
- const _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
+ var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
return 0, nil, syserror.EINVAL
}
@@ -157,7 +159,8 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
maxEvents -= n
if n != 0 {
// Copy what we read out.
- copiedEvents, err := copyOutEvents(t, eventsAddr, events[:n])
+ copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events[:n])
+ copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
total += copiedEvents
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
deleted file mode 100644
index 825f325bf..000000000
--- a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs2
-
-import (
- "reflect"
- "runtime"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/gohacks"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/usermem"
-)
-
-const sizeofEpollEvent = int(unsafe.Sizeof(linux.EpollEvent{}))
-
-func copyOutEvents(t *kernel.Task, addr usermem.Addr, events []linux.EpollEvent) (int, error) {
- if len(events) == 0 {
- return 0, nil
- }
- // Cast events to a byte slice for copying.
- var eventBytes []byte
- eventBytesHdr := (*reflect.SliceHeader)(unsafe.Pointer(&eventBytes))
- eventBytesHdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(&events[0])))
- eventBytesHdr.Len = len(events) * sizeofEpollEvent
- eventBytesHdr.Cap = len(events) * sizeofEpollEvent
- copiedBytes, err := t.CopyOutBytes(addr, eventBytes)
- runtime.KeepAlive(events)
- copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
- return copiedEvents, err
-}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
index 62e98817d..c7c7bf7ce 100644
--- a/pkg/sentry/syscalls/linux/vfs2/getdents.go
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -130,7 +130,7 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
if cb.t.Arch().Width() != 8 {
panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
}
- size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name)
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
size = (size + 7) &^ 7 // round up to multiple of sizeof(long)
if size > cb.remaining {
return syserror.EINVAL
@@ -143,11 +143,11 @@ func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
// Zero out all remaining bytes in buf, including the NUL terminator
// after dirent.Name and the zero padding byte between the name and
// dirent type.
- bufTail := buf[18+len(dirent.Name):]
+ bufTail := buf[18+len(dirent.Name) : size-1]
for i := range bufTail {
bufTail[i] = 0
}
- bufTail[2] = dirent.Type
+ buf[size-1] = dirent.Type
}
n, err := cb.t.CopyOutBytes(cb.addr, buf)
if err != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
index 21eb98444..74920f785 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
@@ -58,8 +58,8 @@ func Override(table map[uintptr]kernel.Syscall) {
table[51] = syscalls.PartiallySupported("getsockname", GetSockName, "In process of porting socket syscalls to VFS2.", nil)
table[52] = syscalls.PartiallySupported("getpeername", GetPeerName, "In process of porting socket syscalls to VFS2.", nil)
table[53] = syscalls.PartiallySupported("socketpair", SocketPair, "In process of porting socket syscalls to VFS2.", nil)
- table[54] = syscalls.PartiallySupported("getsockopt", GetSockOpt, "In process of porting socket syscalls to VFS2.", nil)
- table[55] = syscalls.PartiallySupported("setsockopt", SetSockOpt, "In process of porting socket syscalls to VFS2.", nil)
+ table[54] = syscalls.PartiallySupported("setsockopt", SetSockOpt, "In process of porting socket syscalls to VFS2.", nil)
+ table[55] = syscalls.PartiallySupported("getsockopt", GetSockOpt, "In process of porting socket syscalls to VFS2.", nil)
table[59] = syscalls.Supported("execve", Execve)
table[72] = syscalls.Supported("fcntl", Fcntl)
delete(table, 73) // flock
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 15cc091e2..418d69b96 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -111,10 +111,10 @@ type FileDescriptionOptions struct {
}
// Init must be called before first use of fd. If it succeeds, it takes
-// references on mnt and d. statusFlags is the initial file description status
-// flags, which is usually the full set of flags passed to open(2).
-func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) error {
- writable := MayWriteFileWithOpenFlags(statusFlags)
+// references on mnt and d. flags is the initial file description flags, which
+// is usually the full set of flags passed to open(2).
+func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) error {
+ writable := MayWriteFileWithOpenFlags(flags)
if writable {
if err := mnt.CheckBeginWrite(); err != nil {
return err
@@ -122,7 +122,10 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mn
}
fd.refs = 1
- fd.statusFlags = statusFlags
+
+ // Remove "file creation flags" to mirror the behavior from file.f_flags in
+ // fs/open.c:do_dentry_open
+ fd.statusFlags = flags &^ (linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC)
fd.vd = VirtualDentry{
mount: mnt,
dentry: d,
@@ -130,7 +133,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, statusFlags uint32, mn
mnt.IncRef()
d.IncRef()
fd.opts = *opts
- fd.readable = MayReadFileWithOpenFlags(statusFlags)
+ fd.readable = MayReadFileWithOpenFlags(flags)
fd.writable = writable
fd.impl = impl
return nil
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 74577bc2f..20e5bb072 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -443,8 +443,7 @@ type FilesystemImpl interface {
// Errors:
//
// - If extended attributes are not supported by the filesystem,
- // ListxattrAt returns nil. (See FileDescription.Listxattr for an
- // explanation.)
+ // ListxattrAt returns ENOTSUP.
//
// - If the size of the list (including a NUL terminating byte after every
// entry) would exceed size, ERANGE may be returned. Note that
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 534528ce6..022bac127 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -33,6 +33,25 @@ type GetDentryOptions struct {
type MkdirOptions struct {
// Mode is the file mode bits for the created directory.
Mode linux.FileMode
+
+ // If ForSyntheticMountpoint is true, FilesystemImpl.MkdirAt() may create
+ // the given directory in memory only (as opposed to persistent storage).
+ // The created directory should be able to support the creation of
+ // subdirectories with ForSyntheticMountpoint == true. It does not need to
+ // support the creation of subdirectories with ForSyntheticMountpoint ==
+ // false, or files of other types.
+ //
+ // FilesystemImpls are permitted to ignore the ForSyntheticMountpoint
+ // option.
+ //
+ // The ForSyntheticMountpoint option exists because, unlike mount(2), the
+ // OCI Runtime Specification permits the specification of mount points that
+ // do not exist, under the expectation that container runtimes will create
+ // them. (More accurately, the OCI Runtime Specification completely fails
+ // to document this feature, but it's implemented by runc.)
+ // ForSyntheticMountpoint allows such mount points to be created even when
+ // the underlying persistent filesystem is immutable.
+ ForSyntheticMountpoint bool
}
// MknodOptions contains options to VirtualFilesystem.MknodAt() and
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 8ec5d5d5c..f01217c91 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView {
return VectorisedView{views: views, size: size}
}
-// TrimFront removes the first "count" bytes of the vectorised view.
+// TrimFront removes the first "count" bytes of the vectorised view. It panics
+// if count > vv.Size().
func (vv *VectorisedView) TrimFront(count int) {
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
@@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) {
return
}
count -= len(vv.views[0])
- vv.RemoveFirst()
+ vv.removeFirst()
}
}
@@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) {
count -= len(vv.views[0])
copy(v[copied:], vv.views[0])
copied += len(vv.views[0])
- vv.RemoveFirst()
+ vv.removeFirst()
}
if copied == 0 {
return 0, io.EOF
@@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
count -= len(vv.views[0])
dstVV.AppendView(vv.views[0])
copied += len(vv.views[0])
- vv.RemoveFirst()
+ vv.removeFirst()
}
return copied
}
@@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
-// First returns the first view of the vectorised view.
-func (vv *VectorisedView) First() View {
+// PullUp returns the first "count" bytes of the vectorised view. If those
+// bytes aren't already contiguous inside the vectorised view, PullUp will
+// reallocate as needed to make them contiguous. PullUp fails and returns false
+// when count > vv.Size().
+func (vv *VectorisedView) PullUp(count int) (View, bool) {
if len(vv.views) == 0 {
- return nil
+ return nil, count == 0
+ }
+ if count <= len(vv.views[0]) {
+ return vv.views[0][:count], true
+ }
+ if count > vv.size {
+ return nil, false
}
- return vv.views[0]
-}
-// RemoveFirst removes the first view of the vectorised view.
-func (vv *VectorisedView) RemoveFirst() {
- if len(vv.views) == 0 {
- return
+ newFirst := NewView(count)
+ i := 0
+ for offset := 0; offset < count; i++ {
+ copy(newFirst[offset:], vv.views[i])
+ if count-offset < len(vv.views[i]) {
+ vv.views[i].TrimFront(count - offset)
+ break
+ }
+ offset += len(vv.views[i])
+ vv.views[i] = nil
}
- vv.size -= len(vv.views[0])
- vv.views[0] = nil
- vv.views = vv.views[1:]
+ // We're guaranteed that i > 0, since count is too large for the first
+ // view.
+ vv.views[i-1] = newFirst
+ vv.views = vv.views[i-1:]
+ return newFirst, true
}
// Size returns the size in bytes of the entire content stored in the vectorised view.
@@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader {
}
return readers
}
+
+// removeFirst panics when len(vv.views) < 1.
+func (vv *VectorisedView) removeFirst() {
+ vv.size -= len(vv.views[0])
+ vv.views[0] = nil
+ vv.views = vv.views[1:]
+}
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index 106e1994c..c56795c7b 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -16,6 +16,7 @@
package buffer
import (
+ "bytes"
"reflect"
"testing"
)
@@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) {
})
}
}
+
+var pullUpTestCases = []struct {
+ comment string
+ in VectorisedView
+ count int
+ want []byte
+ result VectorisedView
+ ok bool
+}{
+ {
+ comment: "simple case",
+ in: vv(2, "12"),
+ count: 1,
+ want: []byte("1"),
+ result: vv(2, "12"),
+ ok: true,
+ },
+ {
+ comment: "entire View",
+ in: vv(2, "1", "2"),
+ count: 1,
+ want: []byte("1"),
+ result: vv(2, "1", "2"),
+ ok: true,
+ },
+ {
+ comment: "spanning across two Views",
+ in: vv(3, "1", "23"),
+ count: 2,
+ want: []byte("12"),
+ result: vv(3, "12", "3"),
+ ok: true,
+ },
+ {
+ comment: "spanning across all Views",
+ in: vv(5, "1", "23", "45"),
+ count: 5,
+ want: []byte("12345"),
+ result: vv(5, "12345"),
+ ok: true,
+ },
+ {
+ comment: "count = 0",
+ in: vv(1, "1"),
+ count: 0,
+ want: []byte{},
+ result: vv(1, "1"),
+ ok: true,
+ },
+ {
+ comment: "count = size",
+ in: vv(1, "1"),
+ count: 1,
+ want: []byte("1"),
+ result: vv(1, "1"),
+ ok: true,
+ },
+ {
+ comment: "count too large",
+ in: vv(3, "1", "23"),
+ count: 4,
+ want: nil,
+ result: vv(3, "1", "23"),
+ ok: false,
+ },
+ {
+ comment: "empty vv",
+ in: vv(0, ""),
+ count: 1,
+ want: nil,
+ result: vv(0, ""),
+ ok: false,
+ },
+ {
+ comment: "empty vv, count = 0",
+ in: vv(0, ""),
+ count: 0,
+ want: nil,
+ result: vv(0, ""),
+ ok: true,
+ },
+ {
+ comment: "empty views",
+ in: vv(3, "", "1", "", "23"),
+ count: 2,
+ want: []byte("12"),
+ result: vv(3, "12", "3"),
+ ok: true,
+ },
+}
+
+func TestPullUp(t *testing.T) {
+ for _, c := range pullUpTestCases {
+ got, ok := c.in.PullUp(c.count)
+
+ // Is the return value right?
+ if ok != c.ok {
+ t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t",
+ c.comment, c.count, c.in, ok, c.ok)
+ }
+ if bytes.Compare(got, View(c.want)) != 0 {
+ t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v",
+ c.comment, c.count, c.in, got, c.want)
+ }
+
+ // Is the underlying structure right?
+ if !reflect.DeepEqual(c.in, c.result) {
+ t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v",
+ c.comment, c.count, c.in, c.result)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 1e2255bfa..073c84ef9 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- // Reject the packet if it's shorter than an ethernet header.
- if vv.Size() < header.EthernetMinimumSize {
+ // There should be an ethernet header at the beginning of vv.
+ hdr, ok := vv.PullUp(header.EthernetMinimumSize)
+ if !ok {
+ // Reject the packet if it's shorter than an ethernet header.
return tcpip.ErrBadAddress
}
-
- // There should be an ethernet header at the beginning of vv.
- linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize])
+ linkHeader := header.Ethernet(hdr)
vv.TrimFront(len(linkHeader))
e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{
Data: vv,
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 14b527bc2..9cc08d0e2 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -18,3 +18,10 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "rawfile_test",
+ size = "small",
+ srcs = ["rawfile_test.go"],
+ library = ":rawfile",
+)
diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go
new file mode 100644
index 000000000..8f14ba761
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_test.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "syscall"
+ "testing"
+)
+
+func TestNonBlockingWrite3ZeroLength(t *testing.T) {
+ fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0)
+ if err != nil {
+ t.Fatalf("failed to open /dev/null: %v", err)
+ }
+ defer syscall.Close(fd)
+
+ if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil {
+ t.Fatalf("failed to write: %v", err)
+ }
+}
+
+func TestNonBlockingWrite3Nil(t *testing.T) {
+ fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0)
+ if err != nil {
+ t.Fatalf("failed to open /dev/null: %v", err)
+ }
+ defer syscall.Close(fd)
+
+ if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil {
+ t.Fatalf("failed to write: %v", err)
+ }
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 44e25d475..92efd0bf8 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -76,9 +76,13 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
// We have two buffers. Build the iovec that represents them and issue
// a writev syscall.
+ var base *byte
+ if len(b1) > 0 {
+ base = &b1[0]
+ }
iovec := [3]syscall.Iovec{
{
- Base: &b1[0],
+ Base: base,
Len: uint64(len(b1)),
},
{
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 27ea3f531..33f640b85 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) {
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
c.mu.Lock()
- rcvd := []byte(c.packets[0].vv.First())
+ rcvd := []byte(c.packets[0].vv.ToView())
c.packets = c.packets[:0]
c.mu.Unlock()
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index be2537a82..0799c8f4d 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
- first := pkt.Header.View()
- if len(first) == 0 {
- first = pkt.Data.First()
- }
- logPacket(prefix, protocol, first, gso)
+ logPacket(prefix, protocol, pkt, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Header.UsedLength() + pkt.Data.Size()
@@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
// Wait implements stack.LinkEndpoint.Wait.
func (e *endpoint) Wait() { e.lower.Wait() }
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
size := uint16(0)
var fragmentOffset uint16
var moreFragments bool
+
+ // Create a clone of pkt, including any headers if present. Avoid allocating
+ // backing memory for the clone.
+ views := [8]buffer.View{}
+ vv := buffer.NewVectorisedView(0, views[:0])
+ vv.AppendView(pkt.Header.View())
+ vv.Append(pkt.Data)
+
switch protocol {
case header.IPv4ProtocolNumber:
- ipv4 := header.IPv4(b)
+ hdr, ok := vv.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return
+ }
+ ipv4 := header.IPv4(hdr)
fragmentOffset = ipv4.FragmentOffset()
moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
- b = b[ipv4.HeaderLength():]
+ vv.TrimFront(int(ipv4.HeaderLength()))
id = int(ipv4.ID())
case header.IPv6ProtocolNumber:
- ipv6 := header.IPv6(b)
+ hdr, ok := vv.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return
+ }
+ ipv6 := header.IPv6(hdr)
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
transProto = ipv6.NextHeader()
size = ipv6.PayloadLength()
- b = b[header.IPv6MinimumSize:]
+ vv.TrimFront(header.IPv6MinimumSize)
case header.ARPProtocolNumber:
- arp := header.ARP(b)
+ hdr, ok := vv.PullUp(header.ARPSize)
+ if !ok {
+ return
+ }
+ vv.TrimFront(header.ARPSize)
+ arp := header.ARP(hdr)
log.Infof(
"%s arp %v (%v) -> %v (%v) valid:%v",
prefix,
@@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
// We aren't guaranteed to have a transport header - it's possible for
// writes via raw endpoints to contain only network headers.
- if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize {
+ if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize {
log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto)
return
}
@@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
- icmp := header.ICMPv4(b)
+ hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
+ if !ok {
+ break
+ }
+ icmp := header.ICMPv4(hdr)
icmpType := "unknown"
if fragmentOffset == 0 {
switch icmp.Type() {
@@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.ICMPv6ProtocolNumber:
transName = "icmp"
- icmp := header.ICMPv6(b)
+ hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
+ if !ok {
+ break
+ }
+ icmp := header.ICMPv6(hdr)
icmpType := "unknown"
switch icmp.Type() {
case header.ICMPv6DstUnreachable:
@@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.UDPProtocolNumber:
transName = "udp"
- udp := header.UDP(b)
+ hdr, ok := vv.PullUp(header.UDPMinimumSize)
+ if !ok {
+ break
+ }
+ udp := header.UDP(hdr)
if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
@@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.TCPProtocolNumber:
transName = "tcp"
- tcp := header.TCP(b)
+ hdr, ok := vv.PullUp(header.TCPMinimumSize)
+ if !ok {
+ break
+ }
+ tcp := header.TCP(hdr)
if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize {
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7acbfa0a8..cf73a939e 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -93,7 +93,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf
}
func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- v := pkt.Data.First()
+ v, ok := pkt.Data.PullUp(header.ARPSize)
+ if !ok {
+ return
+ }
h := header.ARP(v)
if !h.IsValid() {
return
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index c4bf1ba5c..4cbefe5ab 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -25,7 +25,11 @@ import (
// used to find out which transport endpoint must be notified about the ICMP
// packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
- h := header.IPv4(pkt.Data.First())
+ h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return
+ }
+ hdr := header.IPv4(h)
// We don't use IsValid() here because ICMP only requires that the IP
// header plus 8 bytes of the transport header be included. So it's
@@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
//
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match the endpoint's address.
- if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ if hdr.SourceAddress() != e.id.LocalAddress {
return
}
- hlen := int(h.HeaderLength())
- if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 {
+ hlen := int(hdr.HeaderLength())
+ if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 {
// We won't be able to handle this if it doesn't contain the
// full IPv4 header, or if it's a fragment not at offset 0
// (because it won't have the transport header).
@@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
// Skip the ip header, then deliver control message.
pkt.Data.TrimFront(hlen)
- p := h.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ p := hdr.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
- v := pkt.Data.First()
- if len(v) < header.ICMPv4MinimumSize {
+ v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
+ if !ok {
received.Invalid.Increment()
return
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 104aafbed..17202cc7a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -328,7 +328,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
- ip := header.IPv4(pkt.Data.First())
+ h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
+ ip := header.IPv4(h)
if !ip.IsValid(pkt.Data.Size()) {
return tcpip.ErrInvalidOptionValue
}
@@ -378,7 +382,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView := pkt.Data.First()
+ headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
h := header.IPv4(headerView)
if !h.IsValid(pkt.Data.Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index b68983d10..bdf3a0d25 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -28,7 +28,11 @@ import (
// used to find out which transport endpoint must be notified about the ICMP
// packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) {
- h := header.IPv6(pkt.Data.First())
+ h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return
+ }
+ hdr := header.IPv6(h)
// We don't use IsValid() here because ICMP only requires that up to
// 1280 bytes of the original packet be included. So it's likely that it
@@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
//
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match the endpoint's address.
- if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ if hdr.SourceAddress() != e.id.LocalAddress {
return
}
// Skip the IP header, then handle the fragmentation header if there
// is one.
pkt.Data.TrimFront(header.IPv6MinimumSize)
- p := h.TransportProtocol()
+ p := hdr.TransportProtocol()
if p == header.IPv6FragmentHeader {
- f := header.IPv6Fragment(pkt.Data.First())
- if !f.IsValid() || f.FragmentOffset() != 0 {
+ f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize)
+ if !ok {
+ return
+ }
+ fragHdr := header.IPv6Fragment(f)
+ if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 {
// We can't handle fragments that aren't at offset 0
// because they don't have the transport headers.
return
@@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.
// Skip fragmentation header and find out the actual protocol
// number.
pkt.Data.TrimFront(header.IPv6FragmentHeaderSize)
- p = f.TransportProtocol()
+ p = fragHdr.TransportProtocol()
}
// Deliver the control packet to the transport endpoint.
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
- v := pkt.Data.First()
- if len(v) < header.ICMPv6MinimumSize {
+ v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
+ if !ok {
received.Invalid.Increment()
return
}
@@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
// Validate ICMPv6 checksum before processing the packet.
//
- // Only the first view in vv is accounted for by h. To account for the
- // rest of vv, a shallow copy is made and the first view is removed.
// This copy is used as extra payload during the checksum calculation.
payload := pkt.Data.Clone(nil)
- payload.RemoveFirst()
+ payload.TrimFront(len(h))
if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
received.Invalid.Increment()
return
@@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
switch h.Type() {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
- if len(v) < header.ICMPv6PacketTooBigMinimumSize {
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
+ if !ok {
received.Invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := h.MTU()
+ mtu := header.ICMPv6(hdr).MTU()
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
- if len(v) < header.ICMPv6DstUnreachableMinimumSize {
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize)
+ if !ok {
received.Invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
- switch h.Code() {
+ switch header.ICMPv6(hdr).Code() {
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
+ if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
received.Invalid.Increment()
return
}
- ns := header.NDPNeighborSolicit(h.NDPPayload())
+ // The remainder of payload must be only the neighbor solicitation, so
+ // payload.ToView() always returns the solicitation. Per RFC 6980 section 5,
+ // NDP messages cannot be fragmented. Also note that in the common case NDP
+ // datagrams are very small and ToView() will not incur allocations.
+ ns := header.NDPNeighborSolicit(payload.ToView())
it, err := ns.Options().Iter(true)
if err != nil {
// If we have a malformed NDP NS option, drop the packet.
@@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
+ if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
received.Invalid.Increment()
return
}
- na := header.NDPNeighborAdvert(h.NDPPayload())
+ // The remainder of payload must be only the neighbor advertisement, so
+ // payload.ToView() always returns the advertisement. Per RFC 6980 section
+ // 5, NDP messages cannot be fragmented. Also note that in the common case
+ // NDP datagrams are very small and ToView() will not incur allocations.
+ na := header.NDPNeighborAdvert(payload.ToView())
it, err := na.Options().Iter(true)
if err != nil {
// If we have a malformed NDP NA option, drop the packet.
@@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
- if len(v) < header.ICMPv6EchoMinimumSize {
+ icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize)
+ if !ok {
received.Invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
- copy(packet, h)
+ copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{
@@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6EchoReply:
received.EchoReply.Increment()
- if len(v) < header.ICMPv6EchoMinimumSize {
+ if pkt.Data.Size() < header.ICMPv6EchoMinimumSize {
received.Invalid.Increment()
return
}
@@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
case header.ICMPv6RouterAdvert:
received.RouterAdvert.Increment()
- p := h.NDPPayload()
- if len(p) < header.NDPRAMinimumSize || !isNDPValid() {
+ // Is the NDP payload of sufficient size to hold a Router
+ // Advertisement?
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
received.Invalid.Increment()
return
}
@@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P
return
}
- ra := header.NDPRouterAdvert(p)
+ // The remainder of payload must be only the router advertisement, so
+ // payload.ToView() always returns the advertisement. Per RFC 6980 section
+ // 5, NDP messages cannot be fragmented. Also note that in the common case
+ // NDP datagrams are very small and ToView() will not incur allocations.
+ ra := header.NDPRouterAdvert(payload.ToView())
opts := ra.Options()
// Are options valid as per the wire format?
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index bd099a7f8..d412ff688 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) {
},
{
typ: header.ICMPv6NeighborSolicit,
- size: header.ICMPv6NeighborSolicitMinimumSize},
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ },
{
typ: header.ICMPv6NeighborAdvert,
size: header.ICMPv6NeighborAdvertMinimumSize,
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 331b0817b..486725131 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
- headerView := pkt.Data.First()
+ headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
h := header.IPv6(headerView)
if !h.IsValid(pkt.Data.Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index e9c652042..c7c663498 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) {
// Consume the network header.
- b := pkt.Data.First()
+ b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
+ if !ok {
+ return
+ }
pkt.Data.TrimFront(fwdTestNetHeaderLen)
// Dispatch the packet to the transport protocol.
@@ -473,7 +476,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Header.View()
+ b := p.Pkt.Data.ToView()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@@ -517,7 +520,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Header.View()
+ b := p.Pkt.Data.ToView()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@@ -564,7 +567,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Header.View()
+ b := p.Pkt.Data.ToView()
if b[0] != 3 {
t.Fatalf("got b[0] = %d, want = 3", b[0])
}
@@ -619,7 +622,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
- b := p.Pkt.Header.View()
+ b := p.Pkt.Data.ToView()
if b[0] < 8 {
t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 6c0a4b24d..6b91159d4 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
// CheckPackets runs pkts through the rules for hook and returns a map of packets that
// should not go forward.
//
+// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+//
+// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
+// precondition.
+//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) {
@@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa
return drop
}
-// Precondition: pkt.NetworkHeader is set.
+// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
+// precondition.
func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
@@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx
return chainDrop
}
-// Precondition: pk.NetworkHeader is set.
+// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
+// precondition.
func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// If pkt.NetworkHeader hasn't been set yet, it will be contained in
- // pkt.Data.First().
+ // pkt.Data.
if pkt.NetworkHeader == nil {
- pkt.NetworkHeader = pkt.Data.First()
+ var ok bool
+ pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ // Precondition has been violated.
+ panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize))
+ }
}
// Check whether the packet matches the IP header filter.
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 7b4543caf..8be61f4b1 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
newPkt := pkt.Clone()
// Set network header.
- headerView := newPkt.Data.First()
+ headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return RuleDrop, 0
+ }
netHeader := header.IPv4(headerView)
- newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize]
+ newPkt.NetworkHeader = headerView
hlen := int(netHeader.HeaderLength())
tlen := int(netHeader.TotalLength())
@@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
if newPkt.TransportHeader != nil {
udpHeader = header.UDP(newPkt.TransportHeader)
} else {
- if len(pkt.Data.First()) < header.UDPMinimumSize {
+ if pkt.Data.Size() < header.UDPMinimumSize {
+ return RuleDrop, 0
+ }
+ hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize)
+ if !ok {
return RuleDrop, 0
}
- udpHeader = header.UDP(newPkt.Data.First())
+ udpHeader = header.UDP(hdr)
}
udpHeader.SetDestinationPort(rt.MinPort)
case header.TCPProtocolNumber:
@@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
if newPkt.TransportHeader != nil {
tcpHeader = header.TCP(newPkt.TransportHeader)
} else {
- if len(pkt.Data.First()) < header.TCPMinimumSize {
+ if pkt.Data.Size() < header.TCPMinimumSize {
return RuleDrop, 0
}
- tcpHeader = header.TCP(newPkt.TransportHeader)
+ hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize)
+ if !ok {
+ return RuleDrop, 0
+ }
+ tcpHeader = header.TCP(hdr)
}
// TODO(gvisor.dev/issue/170): Need to recompute checksum
// and implement nat connection tracking to support TCP.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 016dbe15e..0c2b1f36a 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
n.stack.stats.IP.PacketsReceived.Increment()
}
- if len(pkt.Data.First()) < netProto.MinimumPacketSize() {
+ netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize())
+ if !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
-
- src, dst := netProto.ParseAddresses(pkt.Data.First())
+ src, dst := netProto.ParseAddresses(netHeader)
if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
// The source address is one of our own, so we never should have gotten a
@@ -1289,22 +1289,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
-
- firstData := pkt.Data.First()
- pkt.Data.RemoveFirst()
-
- if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 {
- pkt.Header = buffer.NewPrependableFromView(firstData)
- } else {
- firstDataLen := len(firstData)
-
- // pkt.Header should have enough capacity to hold n.linkEP's headers.
- pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen)
-
- // TODO(b/151227689): avoid copying the packet when forwarding
- if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen {
- panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen))
- }
+ if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 {
+ pkt.Header = buffer.NewPrependable(linkHeaderLen)
}
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
@@ -1332,12 +1318,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
- if len(pkt.Data.First()) < transProto.MinimumPacketSize() {
+ transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
+ if !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
+ srcPort, dstPort, err := transProto.ParsePorts(transHeader)
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
@@ -1375,11 +1362,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
- if len(pkt.Data.First()) < 8 {
+ transHeader, ok := pkt.Data.PullUp(8)
+ if !ok {
return
}
- srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
+ srcPort, dstPort, err := transProto.ParsePorts(transHeader)
if err != nil {
return
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index dc125f25e..7d36f8e84 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -37,7 +37,13 @@ type PacketBuffer struct {
Data buffer.VectorisedView
// Header holds the headers of outbound packets. As a packet is passed
- // down the stack, each layer adds to Header.
+ // down the stack, each layer adds to Header. Note that forwarded
+ // packets don't populate Headers on their way out -- their headers and
+ // payload are never parsed out and remain in Data.
+ //
+ // TODO(gvisor.dev/issue/170): Forwarded packets don't currently
+ // populate Header, but should. This will be doable once early parsing
+ // (https://github.com/google/gvisor/pull/1995) is supported.
Header buffer.Prependable
// These fields are used by both inbound and outbound packets. They
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index c7634ceb1..d45d2cc1f 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
// Consume the network header.
- b := pkt.Data.First()
+ b, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ if !ok {
+ return
+ }
pkt.Data.TrimFront(fakeNetHeaderLen)
// Handle control packets.
if b[2] == uint8(fakeControlProtocol) {
- nb := pkt.Data.First()
- if len(nb) < fakeNetHeaderLen {
+ nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ if !ok {
return
}
-
pkt.Data.TrimFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
return
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 3084e6593..a611e44ab 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- if dst := p.Pkt.Header.View()[0]; dst != 3 {
+ hdrs := p.Pkt.Data.ToView()
+ if dst := hdrs[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := p.Pkt.Header.View()[1]; src != 1 {
+ if src := hdrs[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index feef8dca0..b1d820372 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h := header.ICMPv4(pkt.Data.First())
- if h.Type() != header.ICMPv4EchoReply {
+ h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
+ if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h := header.ICMPv6(pkt.Data.First())
- if h.Type() != header.ICMPv6EchoReply {
+ h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
+ if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 61426623c..f2aa69069 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -105,8 +105,8 @@ go_test(
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp/testing/context",
+ "//pkg/test/testutil",
"//pkg/waiter",
- "//runsc/testutil",
],
)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 45f2aa78b..07d3e64c8 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2158,8 +2158,6 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
//
// By not removing this endpoint from the demuxer mapping, we
// ensure that any other bind to the same port fails, as on Linux.
- // TODO(gvisor.dev/issue/2468): We need to enable applications to
- // start listening on this endpoint again similar to Linux.
e.rcvListMu.Lock()
e.rcvClosed = true
e.rcvListMu.Unlock()
@@ -2188,26 +2186,31 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
e.LockUser()
defer e.UnlockUser()
- // Allow the backlog to be adjusted if the endpoint is not shutting down.
- // When the endpoint shuts down, it sets workerCleanup to true, and from
- // that point onward, acceptedChan is the responsibility of the cleanup()
- // method (and should not be touched anywhere else, including here).
- if e.EndpointState() == StateListen && !e.workerCleanup {
- // Adjust the size of the channel iff we can fix existing
- // pending connections into the new one.
+ if e.EndpointState() == StateListen && !e.closed {
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- if len(e.acceptedChan) > backlog {
- return tcpip.ErrInvalidEndpointState
- }
- if cap(e.acceptedChan) == backlog {
- return nil
- }
- origChan := e.acceptedChan
- e.acceptedChan = make(chan *endpoint, backlog)
- close(origChan)
- for ep := range origChan {
- e.acceptedChan <- ep
+ if e.acceptedChan == nil {
+ // listen is called after shutdown.
+ e.acceptedChan = make(chan *endpoint, backlog)
+ e.shutdownFlags = 0
+ e.rcvListMu.Lock()
+ e.rcvClosed = false
+ e.rcvListMu.Unlock()
+ } else {
+ // Adjust the size of the channel iff we can fix
+ // existing pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return tcpip.ErrInvalidEndpointState
+ }
+ if cap(e.acceptedChan) == backlog {
+ return nil
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *endpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
}
// Notify any blocked goroutines that they can attempt to
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index c3c692555..8b7562396 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -247,6 +247,11 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
+ e.LockUser()
+ if e.shutdownFlags != 0 {
+ e.shutdownLocked(e.shutdownFlags)
+ }
+ e.UnlockUser()
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 40461fd31..7712ce652 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size {
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
func (s *segment) parse() bool {
- h := header.TCP(s.data.First())
+ h, ok := s.data.PullUp(header.TCPMinimumSize)
+ if !ok {
+ return false
+ }
+ hdr := header.TCP(h)
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
@@ -156,12 +160,16 @@ func (s *segment) parse() bool {
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
- offset := int(h.DataOffset())
- if offset < header.TCPMinimumSize || offset > len(h) {
+ offset := int(hdr.DataOffset())
+ if offset < header.TCPMinimumSize {
+ return false
+ }
+ hdrWithOpts, ok := s.data.PullUp(offset)
+ if !ok {
return false
}
- s.options = []byte(h[header.TCPMinimumSize:offset])
+ s.options = []byte(hdrWithOpts[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
// Query the link capabilities to decide if checksum validation is
@@ -173,18 +181,19 @@ func (s *segment) parse() bool {
s.data.TrimFront(offset)
}
if verifyChecksum {
- s.csum = h.Checksum()
+ hdr = header.TCP(hdrWithOpts)
+ s.csum = hdr.Checksum()
xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
- xsum = h.CalculateChecksum(xsum)
+ xsum = hdr.CalculateChecksum(xsum)
s.data.TrimFront(offset)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
}
- s.sequenceNumber = seqnum.Value(h.SequenceNumber())
- s.ackNumber = seqnum.Value(h.AckNumber())
- s.flags = h.Flags()
- s.window = seqnum.Size(h.WindowSize())
+ s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
+ s.ackNumber = seqnum.Value(hdr.AckNumber())
+ s.flags = hdr.Flags()
+ s.window = seqnum.Size(hdr.WindowSize())
return true
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 359a75e73..5fe23113b 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -31,7 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
- "gvisor.dev/gvisor/runsc/testutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
func TestFastRecovery(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index ab1014c7f..286c66cf5 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- tcpbuf := vv.First()[header.IPv4MinimumSize:]
+ tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
c.SendSegment(vv)
@@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- tcpbuf := vv.First()[header.IPv4MinimumSize:]
+ tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
// Overwrite a byte in the payload which should cause checksum
// verification to fail.
tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index edb54f0be..756ab913a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// Get the header then trim it from the view.
- hdr := header.UDP(pkt.Data.First())
- if int(hdr.Length()) > pkt.Data.Size() {
+ hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
+ if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
- Port: hdr.SourcePort(),
+ Port: header.UDP(hdr).SourcePort(),
},
}
packet.data = pkt.Data
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 6e31a9bac..52af6de22 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// that don't match any existing endpoint.
func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
// Get the header then trim it from the view.
- hdr := header.UDP(pkt.Data.First())
- if int(hdr.Length()) > pkt.Data.Size() {
+ h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
+ if !ok {
+ // Malformed packet.
+ r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ return true
+ }
+ if int(header.UDP(h).Length()) > pkt.Data.Size() {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true
diff --git a/pkg/test/criutil/BUILD b/pkg/test/criutil/BUILD
new file mode 100644
index 000000000..a7b082cee
--- /dev/null
+++ b/pkg/test/criutil/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "criutil",
+ testonly = 1,
+ srcs = ["criutil.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
+)
diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go
new file mode 100644
index 000000000..bebebb48e
--- /dev/null
+++ b/pkg/test/criutil/criutil.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package criutil contains utility functions for interacting with the
+// Container Runtime Interface (CRI), principally via the crictl command line
+// tool. This requires critools to be installed on the local system.
+package criutil
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "os/exec"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Crictl contains information required to run the crictl utility.
+type Crictl struct {
+ logger testutil.Logger
+ endpoint string
+ cleanup []func()
+}
+
+// resolvePath attempts to find binary paths. It may set the path to invalid,
+// which will cause the execution to fail with a sensible error.
+func resolvePath(executable string) string {
+ guess, err := exec.LookPath(executable)
+ if err != nil {
+ guess = fmt.Sprintf("/usr/local/bin/%s", executable)
+ }
+ return guess
+}
+
+// NewCrictl returns a Crictl configured with a timeout and an endpoint over
+// which it will talk to containerd.
+func NewCrictl(logger testutil.Logger, endpoint string) *Crictl {
+ // Attempt to find the executable, but don't bother propagating the
+ // error at this point. The first command executed will return with a
+ // binary not found error.
+ return &Crictl{
+ logger: logger,
+ endpoint: endpoint,
+ }
+}
+
+// CleanUp executes cleanup functions.
+func (cc *Crictl) CleanUp() {
+ for _, c := range cc.cleanup {
+ c()
+ }
+ cc.cleanup = nil
+}
+
+// RunPod creates a sandbox. It corresponds to `crictl runp`.
+func (cc *Crictl) RunPod(sbSpecFile string) (string, error) {
+ podID, err := cc.run("runp", sbSpecFile)
+ if err != nil {
+ return "", fmt.Errorf("runp failed: %v", err)
+ }
+ // Strip the trailing newline from crictl output.
+ return strings.TrimSpace(podID), nil
+}
+
+// Create creates a container within a sandbox. It corresponds to `crictl
+// create`.
+func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) {
+ podID, err := cc.run("create", podID, contSpecFile, sbSpecFile)
+ if err != nil {
+ return "", fmt.Errorf("create failed: %v", err)
+ }
+ // Strip the trailing newline from crictl output.
+ return strings.TrimSpace(podID), nil
+}
+
+// Start starts a container. It corresponds to `crictl start`.
+func (cc *Crictl) Start(contID string) (string, error) {
+ output, err := cc.run("start", contID)
+ if err != nil {
+ return "", fmt.Errorf("start failed: %v", err)
+ }
+ return output, nil
+}
+
+// Stop stops a container. It corresponds to `crictl stop`.
+func (cc *Crictl) Stop(contID string) error {
+ _, err := cc.run("stop", contID)
+ return err
+}
+
+// Exec execs a program inside a container. It corresponds to `crictl exec`.
+func (cc *Crictl) Exec(contID string, args ...string) (string, error) {
+ a := []string{"exec", contID}
+ a = append(a, args...)
+ output, err := cc.run(a...)
+ if err != nil {
+ return "", fmt.Errorf("exec failed: %v", err)
+ }
+ return output, nil
+}
+
+// Rm removes a container. It corresponds to `crictl rm`.
+func (cc *Crictl) Rm(contID string) error {
+ _, err := cc.run("rm", contID)
+ return err
+}
+
+// StopPod stops a pod. It corresponds to `crictl stopp`.
+func (cc *Crictl) StopPod(podID string) error {
+ _, err := cc.run("stopp", podID)
+ return err
+}
+
+// containsConfig is a minimal copy of
+// https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/apis/cri/runtime/v1alpha2/api.proto
+// It only contains fields needed for testing.
+type containerConfig struct {
+ Status containerStatus
+}
+
+type containerStatus struct {
+ Network containerNetwork
+}
+
+type containerNetwork struct {
+ IP string
+}
+
+// PodIP returns a pod's IP address.
+func (cc *Crictl) PodIP(podID string) (string, error) {
+ output, err := cc.run("inspectp", podID)
+ if err != nil {
+ return "", err
+ }
+ conf := &containerConfig{}
+ if err := json.Unmarshal([]byte(output), conf); err != nil {
+ return "", fmt.Errorf("failed to unmarshal JSON: %v, %s", err, output)
+ }
+ if conf.Status.Network.IP == "" {
+ return "", fmt.Errorf("no IP found in config: %s", output)
+ }
+ return conf.Status.Network.IP, nil
+}
+
+// RmPod removes a container. It corresponds to `crictl rmp`.
+func (cc *Crictl) RmPod(podID string) error {
+ _, err := cc.run("rmp", podID)
+ return err
+}
+
+// Import imports the given container from the local Docker instance.
+func (cc *Crictl) Import(image string) error {
+ // Note that we provide a 10 minute timeout after connect because we may
+ // be pushing a lot of bytes in order to import the image. The connect
+ // timeout stays the same and is inherited from the Crictl instance.
+ cmd := testutil.Command(cc.logger,
+ resolvePath("ctr"),
+ fmt.Sprintf("--connect-timeout=%s", 30*time.Second),
+ fmt.Sprintf("--address=%s", cc.endpoint),
+ "-n", "k8s.io", "images", "import", "-")
+ cmd.Stderr = os.Stderr // Pass through errors.
+
+ // Create a pipe and start the program.
+ w, err := cmd.StdinPipe()
+ if err != nil {
+ return err
+ }
+ if err := cmd.Start(); err != nil {
+ return err
+ }
+
+ // Save the image on the other end.
+ if err := dockerutil.Save(cc.logger, image, w); err != nil {
+ cmd.Wait()
+ return err
+ }
+
+ // Close our pipe reference & see if it was loaded.
+ if err := w.Close(); err != nil {
+ return w.Close()
+ }
+
+ return cmd.Wait()
+}
+
+// StartContainer pulls the given image ands starts the container in the
+// sandbox with the given podID.
+//
+// Note that the image will always be imported from the local docker daemon.
+func (cc *Crictl) StartContainer(podID, image, sbSpec, contSpec string) (string, error) {
+ if err := cc.Import(image); err != nil {
+ return "", err
+ }
+
+ // Write the specs to files that can be read by crictl.
+ sbSpecFile, cleanup, err := testutil.WriteTmpFile("sbSpec", sbSpec)
+ if err != nil {
+ return "", fmt.Errorf("failed to write sandbox spec: %v", err)
+ }
+ cc.cleanup = append(cc.cleanup, cleanup)
+ contSpecFile, cleanup, err := testutil.WriteTmpFile("contSpec", contSpec)
+ if err != nil {
+ return "", fmt.Errorf("failed to write container spec: %v", err)
+ }
+ cc.cleanup = append(cc.cleanup, cleanup)
+
+ return cc.startContainer(podID, image, sbSpecFile, contSpecFile)
+}
+
+func (cc *Crictl) startContainer(podID, image, sbSpecFile, contSpecFile string) (string, error) {
+ contID, err := cc.Create(podID, contSpecFile, sbSpecFile)
+ if err != nil {
+ return "", fmt.Errorf("failed to create container in pod %q: %v", podID, err)
+ }
+
+ if _, err := cc.Start(contID); err != nil {
+ return "", fmt.Errorf("failed to start container %q in pod %q: %v", contID, podID, err)
+ }
+
+ return contID, nil
+}
+
+// StopContainer stops and deletes the container with the given container ID.
+func (cc *Crictl) StopContainer(contID string) error {
+ if err := cc.Stop(contID); err != nil {
+ return fmt.Errorf("failed to stop container %q: %v", contID, err)
+ }
+
+ if err := cc.Rm(contID); err != nil {
+ return fmt.Errorf("failed to remove container %q: %v", contID, err)
+ }
+
+ return nil
+}
+
+// StartPodAndContainer starts a sandbox and container in that sandbox. It
+// returns the pod ID and container ID.
+func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) {
+ if err := cc.Import(image); err != nil {
+ return "", "", err
+ }
+
+ // Write the specs to files that can be read by crictl.
+ sbSpecFile, cleanup, err := testutil.WriteTmpFile("sbSpec", sbSpec)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to write sandbox spec: %v", err)
+ }
+ cc.cleanup = append(cc.cleanup, cleanup)
+ contSpecFile, cleanup, err := testutil.WriteTmpFile("contSpec", contSpec)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to write container spec: %v", err)
+ }
+ cc.cleanup = append(cc.cleanup, cleanup)
+
+ podID, err := cc.RunPod(sbSpecFile)
+ if err != nil {
+ return "", "", err
+ }
+
+ contID, err := cc.startContainer(podID, image, sbSpecFile, contSpecFile)
+
+ return podID, contID, err
+}
+
+// StopPodAndContainer stops a container and pod.
+func (cc *Crictl) StopPodAndContainer(podID, contID string) error {
+ if err := cc.StopContainer(contID); err != nil {
+ return fmt.Errorf("failed to stop container %q in pod %q: %v", contID, podID, err)
+ }
+
+ if err := cc.StopPod(podID); err != nil {
+ return fmt.Errorf("failed to stop pod %q: %v", podID, err)
+ }
+
+ if err := cc.RmPod(podID); err != nil {
+ return fmt.Errorf("failed to remove pod %q: %v", podID, err)
+ }
+
+ return nil
+}
+
+// run runs crictl with the given args.
+func (cc *Crictl) run(args ...string) (string, error) {
+ defaultArgs := []string{
+ resolvePath("crictl"),
+ "--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
+ "--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
+ }
+ fullArgs := append(defaultArgs, args...)
+ out, err := testutil.Command(cc.logger, fullArgs...).CombinedOutput()
+ return string(out), err
+}
diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD
new file mode 100644
index 000000000..7c8758e35
--- /dev/null
+++ b/pkg/test/dockerutil/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "dockerutil",
+ testonly = 1,
+ srcs = ["dockerutil.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/testutil",
+ "@com_github_kr_pty//:go_default_library",
+ ],
+)
diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go
new file mode 100644
index 000000000..baa8fc2f2
--- /dev/null
+++ b/pkg/test/dockerutil/dockerutil.go
@@ -0,0 +1,581 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package dockerutil is a collection of utility functions.
+package dockerutil
+
+import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "os"
+ "os/exec"
+ "path"
+ "regexp"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/kr/pty"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+var (
+ // runtime is the runtime to use for tests. This will be applied to all
+ // containers. Note that the default here ("runsc") corresponds to the
+ // default used by the installations. This is important, because the
+ // default installer for vm_tests (in tools/installers:head, invoked
+ // via tools/vm:defs.bzl) will install with this name. So without
+ // changing anything, tests should have a runsc runtime available to
+ // them. Otherwise installers should update the existing runtime
+ // instead of installing a new one.
+ runtime = flag.String("runtime", "runsc", "specify which runtime to use")
+
+ // config is the default Docker daemon configuration path.
+ config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths")
+)
+
+// EnsureSupportedDockerVersion checks if correct docker is installed.
+//
+// This logs directly to stderr, as it is typically called from a Main wrapper.
+func EnsureSupportedDockerVersion() {
+ cmd := exec.Command("docker", "version")
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ log.Fatalf("error running %q: %v", "docker version", err)
+ }
+ re := regexp.MustCompile(`Version:\s+(\d+)\.(\d+)\.\d.*`)
+ matches := re.FindStringSubmatch(string(out))
+ if len(matches) != 3 {
+ log.Fatalf("Invalid docker output: %s", out)
+ }
+ major, _ := strconv.Atoi(matches[1])
+ minor, _ := strconv.Atoi(matches[2])
+ if major < 17 || (major == 17 && minor < 9) {
+ log.Fatalf("Docker version 17.09.0 or greater is required, found: %02d.%02d", major, minor)
+ }
+}
+
+// RuntimePath returns the binary path for the current runtime.
+func RuntimePath() (string, error) {
+ // Read the configuration data; the file must exist.
+ configBytes, err := ioutil.ReadFile(*config)
+ if err != nil {
+ return "", err
+ }
+
+ // Unmarshal the configuration.
+ c := make(map[string]interface{})
+ if err := json.Unmarshal(configBytes, &c); err != nil {
+ return "", err
+ }
+
+ // Decode the expected configuration.
+ r, ok := c["runtimes"]
+ if !ok {
+ return "", fmt.Errorf("no runtimes declared: %v", c)
+ }
+ rs, ok := r.(map[string]interface{})
+ if !ok {
+ // The runtimes are not a map.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ r, ok = rs[*runtime]
+ if !ok {
+ // The expected runtime is not declared.
+ return "", fmt.Errorf("runtime %q not found: %v", *runtime, c)
+ }
+ rs, ok = r.(map[string]interface{})
+ if !ok {
+ // The runtime is not a map.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ p, ok := rs["path"].(string)
+ if !ok {
+ // The runtime does not declare a path.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ return p, nil
+}
+
+// Save exports a container image to the given Writer.
+//
+// Note that the writer should be actively consuming the output, otherwise it
+// is not guaranteed that the Save will make any progress and the call may
+// stall indefinitely.
+//
+// This is called by criutil in order to import imports.
+func Save(logger testutil.Logger, image string, w io.Writer) error {
+ cmd := testutil.Command(logger, "docker", "save", testutil.ImageByName(image))
+ cmd.Stdout = w // Send directly to the writer.
+ return cmd.Run()
+}
+
+// MountMode describes if the mount should be ro or rw.
+type MountMode int
+
+const (
+ // ReadOnly is what the name says.
+ ReadOnly MountMode = iota
+ // ReadWrite is what the name says.
+ ReadWrite
+)
+
+// String returns the mount mode argument for this MountMode.
+func (m MountMode) String() string {
+ switch m {
+ case ReadOnly:
+ return "ro"
+ case ReadWrite:
+ return "rw"
+ }
+ panic(fmt.Sprintf("invalid mode: %d", m))
+}
+
+// Docker contains the name and the runtime of a docker container.
+type Docker struct {
+ logger testutil.Logger
+ Runtime string
+ Name string
+ copyErr error
+ mounts []string
+ cleanups []func()
+}
+
+// MakeDocker sets up the struct for a Docker container.
+//
+// Names of containers will be unique.
+func MakeDocker(logger testutil.Logger) *Docker {
+ return &Docker{
+ logger: logger,
+ Name: testutil.RandomID(logger.Name()),
+ Runtime: *runtime,
+ }
+}
+
+// Mount mounts the given source and makes it available in the container.
+func (d *Docker) Mount(target, source string, mode MountMode) {
+ d.mounts = append(d.mounts, fmt.Sprintf("-v=%s:%s:%v", source, target, mode))
+}
+
+// CopyFiles copies in and mounts the given files. They are always ReadOnly.
+func (d *Docker) CopyFiles(target string, sources ...string) {
+ dir, err := ioutil.TempDir("", d.Name)
+ if err != nil {
+ d.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err)
+ return
+ }
+ d.cleanups = append(d.cleanups, func() { os.RemoveAll(dir) })
+ if err := os.Chmod(dir, 0755); err != nil {
+ d.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err)
+ return
+ }
+ for _, name := range sources {
+ src, err := testutil.FindFile(name)
+ if err != nil {
+ d.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err)
+ return
+ }
+ dst := path.Join(dir, path.Base(name))
+ if err := testutil.Copy(src, dst); err != nil {
+ d.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
+ return
+ }
+ d.logger.Logf("copy: %s -> %s", src, dst)
+ }
+ d.Mount(target, dir, ReadOnly)
+}
+
+// Link links the given target.
+func (d *Docker) Link(target string, source *Docker) {
+ d.mounts = append(d.mounts, fmt.Sprintf("--link=%s:%s", source.Name, target))
+}
+
+// RunOpts are options for running a container.
+type RunOpts struct {
+ // Image is the image relative to images/. This will be mangled
+ // appropriately, to ensure that only first-party images are used.
+ Image string
+
+ // Memory is the memory limit in kB.
+ Memory int
+
+ // Ports are the ports to be allocated.
+ Ports []int
+
+ // WorkDir sets the working directory.
+ WorkDir string
+
+ // ReadOnly sets the read-only flag.
+ ReadOnly bool
+
+ // Env are additional environment variables.
+ Env []string
+
+ // User is the user to use.
+ User string
+
+ // Privileged enables privileged mode.
+ Privileged bool
+
+ // CapAdd are the extra set of capabilities to add.
+ CapAdd []string
+
+ // CapDrop are the extra set of capabilities to drop.
+ CapDrop []string
+
+ // Pty indicates that a pty will be allocated. If this is non-nil, then
+ // this will run after start-up with the *exec.Command and Pty file
+ // passed in to the function.
+ Pty func(*exec.Cmd, *os.File)
+
+ // Foreground indicates that the container should be run in the
+ // foreground. If this is true, then the output will be available as a
+ // return value from the Run function.
+ Foreground bool
+
+ // Extra are extra arguments that may be passed.
+ Extra []string
+}
+
+// args returns common arguments.
+//
+// Note that this does not define the complete behavior.
+func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) {
+ isExec := command == "exec"
+ isRun := command == "run"
+
+ if isRun || isExec {
+ rv = append(rv, "-i")
+ }
+ if r.Pty != nil {
+ rv = append(rv, "-t")
+ }
+ if r.User != "" {
+ rv = append(rv, fmt.Sprintf("--user=%s", r.User))
+ }
+ if r.Privileged {
+ rv = append(rv, "--privileged")
+ }
+ for _, c := range r.CapAdd {
+ rv = append(rv, fmt.Sprintf("--cap-add=%s", c))
+ }
+ for _, c := range r.CapDrop {
+ rv = append(rv, fmt.Sprintf("--cap-drop=%s", c))
+ }
+ for _, e := range r.Env {
+ rv = append(rv, fmt.Sprintf("--env=%s", e))
+ }
+ if r.WorkDir != "" {
+ rv = append(rv, fmt.Sprintf("--workdir=%s", r.WorkDir))
+ }
+ if !isExec {
+ if r.Memory != 0 {
+ rv = append(rv, fmt.Sprintf("--memory=%dk", r.Memory))
+ }
+ for _, p := range r.Ports {
+ rv = append(rv, fmt.Sprintf("--publish=%d", p))
+ }
+ if r.ReadOnly {
+ rv = append(rv, fmt.Sprintf("--read-only"))
+ }
+ if len(p) > 0 {
+ rv = append(rv, "--entrypoint=")
+ }
+ }
+
+ // Always attach the test environment & Extra.
+ rv = append(rv, fmt.Sprintf("--env=RUNSC_TEST_NAME=%s", d.Name))
+ rv = append(rv, r.Extra...)
+
+ // Attach necessary bits.
+ if isExec {
+ rv = append(rv, d.Name)
+ } else {
+ rv = append(rv, d.mounts...)
+ rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime))
+ rv = append(rv, fmt.Sprintf("--name=%s", d.Name))
+ rv = append(rv, testutil.ImageByName(r.Image))
+ }
+
+ // Attach other arguments.
+ rv = append(rv, p...)
+ return rv
+}
+
+// run runs a complete command.
+func (d *Docker) run(r RunOpts, command string, p ...string) (string, error) {
+ if d.copyErr != nil {
+ return "", d.copyErr
+ }
+ basicArgs := []string{"docker"}
+ if command == "spawn" {
+ command = "run"
+ basicArgs = append(basicArgs, command)
+ basicArgs = append(basicArgs, "-d")
+ } else {
+ basicArgs = append(basicArgs, command)
+ }
+ customArgs := d.argsFor(&r, command, p)
+ cmd := testutil.Command(d.logger, append(basicArgs, customArgs...)...)
+ if r.Pty != nil {
+ // If allocating a terminal, then we just ignore the output
+ // from the command.
+ ptmx, err := pty.Start(cmd.Cmd)
+ if err != nil {
+ return "", err
+ }
+ defer cmd.Wait() // Best effort.
+ r.Pty(cmd.Cmd, ptmx)
+ } else {
+ // Can't support PTY or streaming.
+ out, err := cmd.CombinedOutput()
+ return string(out), err
+ }
+ return "", nil
+}
+
+// Create calls 'docker create' with the arguments provided.
+func (d *Docker) Create(r RunOpts, args ...string) error {
+ _, err := d.run(r, "create", args...)
+ return err
+}
+
+// Start calls 'docker start'.
+func (d *Docker) Start() error {
+ return testutil.Command(d.logger, "docker", "start", d.Name).Run()
+}
+
+// Stop calls 'docker stop'.
+func (d *Docker) Stop() error {
+ return testutil.Command(d.logger, "docker", "stop", d.Name).Run()
+}
+
+// Run calls 'docker run' with the arguments provided.
+func (d *Docker) Run(r RunOpts, args ...string) (string, error) {
+ return d.run(r, "run", args...)
+}
+
+// Spawn starts the container and detaches.
+func (d *Docker) Spawn(r RunOpts, args ...string) error {
+ _, err := d.run(r, "spawn", args...)
+ return err
+}
+
+// Logs calls 'docker logs'.
+func (d *Docker) Logs() (string, error) {
+ // Don't capture the output; since it will swamp the logs.
+ out, err := exec.Command("docker", "logs", d.Name).CombinedOutput()
+ return string(out), err
+}
+
+// Exec calls 'docker exec' with the arguments provided.
+func (d *Docker) Exec(r RunOpts, args ...string) (string, error) {
+ return d.run(r, "exec", args...)
+}
+
+// Pause calls 'docker pause'.
+func (d *Docker) Pause() error {
+ return testutil.Command(d.logger, "docker", "pause", d.Name).Run()
+}
+
+// Unpause calls 'docker pause'.
+func (d *Docker) Unpause() error {
+ return testutil.Command(d.logger, "docker", "unpause", d.Name).Run()
+}
+
+// Checkpoint calls 'docker checkpoint'.
+func (d *Docker) Checkpoint(name string) error {
+ return testutil.Command(d.logger, "docker", "checkpoint", "create", d.Name, name).Run()
+}
+
+// Restore calls 'docker start --checkname [name]'.
+func (d *Docker) Restore(name string) error {
+ return testutil.Command(d.logger, "docker", "start", fmt.Sprintf("--checkpoint=%s", name), d.Name).Run()
+}
+
+// Kill calls 'docker kill'.
+func (d *Docker) Kill() error {
+ // Skip logging this command, it will likely be an error.
+ out, err := exec.Command("docker", "kill", d.Name).CombinedOutput()
+ if err != nil && !strings.Contains(string(out), "is not running") {
+ return err
+ }
+ return nil
+}
+
+// Remove calls 'docker rm'.
+func (d *Docker) Remove() error {
+ return testutil.Command(d.logger, "docker", "rm", d.Name).Run()
+}
+
+// CleanUp kills and deletes the container (best effort).
+func (d *Docker) CleanUp() {
+ // Kill the container.
+ if err := d.Kill(); err != nil {
+ // Just log; can't do anything here.
+ d.logger.Logf("error killing container %q: %v", d.Name, err)
+ }
+ // Remove the image.
+ if err := d.Remove(); err != nil {
+ d.logger.Logf("error removing container %q: %v", d.Name, err)
+ }
+ // Forget all mounts.
+ d.mounts = nil
+ // Execute all cleanups.
+ for _, c := range d.cleanups {
+ c()
+ }
+ d.cleanups = nil
+}
+
+// FindPort returns the host port that is mapped to 'sandboxPort'. This calls
+// docker to allocate a free port in the host and prevent conflicts.
+func (d *Docker) FindPort(sandboxPort int) (int, error) {
+ format := fmt.Sprintf(`{{ (index (index .NetworkSettings.Ports "%d/tcp") 0).HostPort }}`, sandboxPort)
+ out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput()
+ if err != nil {
+ return -1, fmt.Errorf("error retrieving port: %v", err)
+ }
+ port, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
+ if err != nil {
+ return -1, fmt.Errorf("error parsing port %q: %v", out, err)
+ }
+ return port, nil
+}
+
+// FindIP returns the IP address of the container.
+func (d *Docker) FindIP() (net.IP, error) {
+ const format = `{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}`
+ out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput()
+ if err != nil {
+ return net.IP{}, fmt.Errorf("error retrieving IP: %v", err)
+ }
+ ip := net.ParseIP(strings.TrimSpace(string(out)))
+ if ip == nil {
+ return net.IP{}, fmt.Errorf("invalid IP: %q", string(out))
+ }
+ return ip, nil
+}
+
+// SandboxPid returns the PID to the sandbox process.
+func (d *Docker) SandboxPid() (int, error) {
+ out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput()
+ if err != nil {
+ return -1, fmt.Errorf("error retrieving pid: %v", err)
+ }
+ pid, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
+ if err != nil {
+ return -1, fmt.Errorf("error parsing pid %q: %v", out, err)
+ }
+ return pid, nil
+}
+
+// ID returns the container ID.
+func (d *Docker) ID() (string, error) {
+ out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.Id}}", d.Name).CombinedOutput()
+ if err != nil {
+ return "", fmt.Errorf("error retrieving ID: %v", err)
+ }
+ return strings.TrimSpace(string(out)), nil
+}
+
+// Wait waits for container to exit, up to the given timeout. Returns error if
+// wait fails or timeout is hit. Returns the application return code otherwise.
+// Note that the application may have failed even if err == nil, always check
+// the exit code.
+func (d *Docker) Wait(timeout time.Duration) (syscall.WaitStatus, error) {
+ timeoutChan := time.After(timeout)
+ waitChan := make(chan (syscall.WaitStatus))
+ errChan := make(chan (error))
+
+ go func() {
+ out, err := testutil.Command(d.logger, "docker", "wait", d.Name).CombinedOutput()
+ if err != nil {
+ errChan <- fmt.Errorf("error waiting for container %q: %v", d.Name, err)
+ }
+ exit, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
+ if err != nil {
+ errChan <- fmt.Errorf("error parsing exit code %q: %v", out, err)
+ }
+ waitChan <- syscall.WaitStatus(uint32(exit))
+ }()
+
+ select {
+ case ws := <-waitChan:
+ return ws, nil
+ case err := <-errChan:
+ return syscall.WaitStatus(1), err
+ case <-timeoutChan:
+ return syscall.WaitStatus(1), fmt.Errorf("timeout waiting for container %q", d.Name)
+ }
+}
+
+// WaitForOutput calls 'docker logs' to retrieve containers output and searches
+// for the given pattern.
+func (d *Docker) WaitForOutput(pattern string, timeout time.Duration) (string, error) {
+ matches, err := d.WaitForOutputSubmatch(pattern, timeout)
+ if err != nil {
+ return "", err
+ }
+ if len(matches) == 0 {
+ return "", nil
+ }
+ return matches[0], nil
+}
+
+// WaitForOutputSubmatch calls 'docker logs' to retrieve containers output and
+// searches for the given pattern. It returns any regexp submatches as well.
+func (d *Docker) WaitForOutputSubmatch(pattern string, timeout time.Duration) ([]string, error) {
+ re := regexp.MustCompile(pattern)
+ var (
+ lastOut string
+ stopped bool
+ )
+ for exp := time.Now().Add(timeout); time.Now().Before(exp); {
+ out, err := d.Logs()
+ if err != nil {
+ return nil, err
+ }
+ if out != lastOut {
+ if lastOut == "" {
+ d.logger.Logf("output (start): %s", out)
+ } else if strings.HasPrefix(out, lastOut) {
+ d.logger.Logf("output (contn): %s", out[len(lastOut):])
+ } else {
+ d.logger.Logf("output (trunc): %s", out)
+ }
+ lastOut = out // Save for future.
+ if matches := re.FindStringSubmatch(lastOut); matches != nil {
+ return matches, nil // Success!
+ }
+ } else if stopped {
+ // The sandbox stopped and we looked at the
+ // logs at least once since determining that.
+ return nil, fmt.Errorf("no longer running: %v", err)
+ } else if pid, err := d.SandboxPid(); pid == 0 || err != nil {
+ // The sandbox may have stopped, but it's
+ // possible that it has emitted the terminal
+ // line between the last call to Logs and here.
+ stopped = true
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+ return nil, fmt.Errorf("timeout waiting for output %q: %s", re.String(), lastOut)
+}
diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD
new file mode 100644
index 000000000..03b1b4677
--- /dev/null
+++ b/pkg/test/testutil/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "testutil",
+ testonly = 1,
+ srcs = [
+ "testutil.go",
+ "testutil_runfiles.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/sync",
+ "//runsc/boot",
+ "//runsc/specutils",
+ "@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ ],
+)
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
new file mode 100644
index 000000000..d75ceca3d
--- /dev/null
+++ b/pkg/test/testutil/testutil.go
@@ -0,0 +1,550 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil contains utility functions for runsc tests.
+package testutil
+
+import (
+ "bufio"
+ "context"
+ "debug/elf"
+ "encoding/base32"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "math"
+ "math/rand"
+ "net/http"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+var (
+ checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+)
+
+// IsCheckpointSupported returns the relevant command line flag.
+func IsCheckpointSupported() bool {
+ return *checkpoint
+}
+
+// nameToActual is used by ImageByName (for now).
+var nameToActual = map[string]string{
+ "basic/alpine": "alpine",
+ "basic/busybox": "busybox:1.31.1",
+ "basic/httpd": "httpd",
+ "basic/mysql": "mysql",
+ "basic/nginx": "nginx",
+ "basic/python": "gcr.io/gvisor-presubmit/python-hello",
+ "basic/resolv": "k8s.gcr.io/busybox",
+ "basic/ruby": "ruby",
+ "basic/tomcat": "tomcat:8.0",
+ "basic/ubuntu": "ubuntu:trusty",
+ "iptables": "gcr.io/gvisor-presubmit/iptables-test",
+ "packetdrill": "gcr.io/gvisor-presubmit/packetdrill",
+ "packetimpact": "gcr.io/gvisor-presubmit/packetimpact",
+ "runtimes/go1.12": "gcr.io/gvisor-presubmit/go1.12",
+ "runtimes/java11": "gcr.io/gvisor-presubmit/java11",
+ "runtimes/nodejs12.4.0": "gcr.io/gvisor-presubmit/nodejs12.4.0",
+ "runtimes/php7.3.6": "gcr.io/gvisor-presubmit/php7.3.6",
+ "runtimes/python3.7.3": "gcr.io/gvisor-presubmit/python3.7.3",
+}
+
+// ImageByName mangles the image name used locally.
+//
+// For now, this is implemented as a static lookup table. In a subsequent
+// change, this will be used to reference a locally-generated image.
+func ImageByName(name string) string {
+ actual, ok := nameToActual[name]
+ if !ok {
+ panic(fmt.Sprintf("unknown image: %v", name))
+ }
+ // A terrible hack, for now execute a manual pull.
+ if out, err := exec.Command("docker", "pull", actual).CombinedOutput(); err != nil {
+ panic(fmt.Sprintf("error pulling image %q -> %q: %v, out: %s", name, actual, err, string(out)))
+ }
+ return actual
+}
+
+// ConfigureExePath configures the executable for runsc in the test environment.
+func ConfigureExePath() error {
+ path, err := FindFile("runsc/runsc")
+ if err != nil {
+ return err
+ }
+ specutils.ExePath = path
+ return nil
+}
+
+// TmpDir returns the absolute path to a writable directory that can be used as
+// scratch by the test.
+func TmpDir() string {
+ dir := os.Getenv("TEST_TMPDIR")
+ if dir == "" {
+ dir = "/tmp"
+ }
+ return dir
+}
+
+// Logger is a simple logging wrapper.
+//
+// This is designed to be implemented by *testing.T.
+type Logger interface {
+ Name() string
+ Logf(fmt string, args ...interface{})
+}
+
+// DefaultLogger logs using the log package.
+type DefaultLogger string
+
+// Name implements Logger.Name.
+func (d DefaultLogger) Name() string {
+ return string(d)
+}
+
+// Logf implements Logger.Logf.
+func (d DefaultLogger) Logf(fmt string, args ...interface{}) {
+ log.Printf(fmt, args...)
+}
+
+// Cmd is a simple wrapper.
+type Cmd struct {
+ logger Logger
+ *exec.Cmd
+}
+
+// CombinedOutput returns the output and logs.
+func (c *Cmd) CombinedOutput() ([]byte, error) {
+ out, err := c.Cmd.CombinedOutput()
+ if len(out) > 0 {
+ c.logger.Logf("output: %s", string(out))
+ }
+ if err != nil {
+ c.logger.Logf("error: %v", err)
+ }
+ return out, err
+}
+
+// Command is a simple wrapper around exec.Command, that logs.
+func Command(logger Logger, args ...string) *Cmd {
+ logger.Logf("command: %s", strings.Join(args, " "))
+ return &Cmd{
+ logger: logger,
+ Cmd: exec.Command(args[0], args[1:]...),
+ }
+}
+
+// TestConfig returns the default configuration to use in tests. Note that
+// 'RootDir' must be set by caller if required.
+func TestConfig(t *testing.T) *boot.Config {
+ logDir := os.TempDir()
+ if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
+ logDir = dir + "/"
+ }
+ return &boot.Config{
+ Debug: true,
+ DebugLog: path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%"),
+ LogFormat: "text",
+ DebugLogFormat: "text",
+ LogPackets: true,
+ Network: boot.NetworkNone,
+ Strace: true,
+ Platform: "ptrace",
+ FileAccess: boot.FileAccessExclusive,
+ NumNetworkChannels: 1,
+
+ TestOnlyAllowRunAsCurrentUserWithoutChroot: true,
+ }
+}
+
+// NewSpecWithArgs creates a simple spec with the given args suitable for use
+// in tests.
+func NewSpecWithArgs(args ...string) *specs.Spec {
+ return &specs.Spec{
+ // The host filesystem root is the container root.
+ Root: &specs.Root{
+ Path: "/",
+ Readonly: true,
+ },
+ Process: &specs.Process{
+ Args: args,
+ Env: []string{
+ "PATH=" + os.Getenv("PATH"),
+ },
+ Capabilities: specutils.AllCapabilities(),
+ },
+ Mounts: []specs.Mount{
+ // Hide the host /etc to avoid any side-effects.
+ // For example, bash reads /etc/passwd and if it is
+ // very big, tests can fail by timeout.
+ {
+ Type: "tmpfs",
+ Destination: "/etc",
+ },
+ // Root is readonly, but many tests want to write to tmpdir.
+ // This creates a writable mount inside the root. Also, when tmpdir points
+ // to "/tmp", it makes the the actual /tmp to be mounted and not a tmpfs
+ // inside the sentry.
+ {
+ Type: "bind",
+ Destination: TmpDir(),
+ Source: TmpDir(),
+ },
+ },
+ Hostname: "runsc-test-hostname",
+ }
+}
+
+// SetupRootDir creates a root directory for containers.
+func SetupRootDir() (string, func(), error) {
+ rootDir, err := ioutil.TempDir(TmpDir(), "containers")
+ if err != nil {
+ return "", nil, fmt.Errorf("error creating root dir: %v", err)
+ }
+ return rootDir, func() { os.RemoveAll(rootDir) }, nil
+}
+
+// SetupContainer creates a bundle and root dir for the container, generates a
+// test config, and writes the spec to config.json in the bundle dir.
+func SetupContainer(spec *specs.Spec, conf *boot.Config) (rootDir, bundleDir string, cleanup func(), err error) {
+ rootDir, rootCleanup, err := SetupRootDir()
+ if err != nil {
+ return "", "", nil, err
+ }
+ conf.RootDir = rootDir
+ bundleDir, bundleCleanup, err := SetupBundleDir(spec)
+ if err != nil {
+ rootCleanup()
+ return "", "", nil, err
+ }
+ return rootDir, bundleDir, func() {
+ bundleCleanup()
+ rootCleanup()
+ }, err
+}
+
+// SetupBundleDir creates a bundle dir and writes the spec to config.json.
+func SetupBundleDir(spec *specs.Spec) (string, func(), error) {
+ bundleDir, err := ioutil.TempDir(TmpDir(), "bundle")
+ if err != nil {
+ return "", nil, fmt.Errorf("error creating bundle dir: %v", err)
+ }
+ cleanup := func() { os.RemoveAll(bundleDir) }
+ if err := writeSpec(bundleDir, spec); err != nil {
+ cleanup()
+ return "", nil, fmt.Errorf("error writing spec: %v", err)
+ }
+ return bundleDir, cleanup, nil
+}
+
+// writeSpec writes the spec to disk in the given directory.
+func writeSpec(dir string, spec *specs.Spec) error {
+ b, err := json.Marshal(spec)
+ if err != nil {
+ return err
+ }
+ return ioutil.WriteFile(filepath.Join(dir, "config.json"), b, 0755)
+}
+
+// RandomID returns 20 random bytes following the given prefix.
+func RandomID(prefix string) string {
+ // Read 20 random bytes.
+ b := make([]byte, 20)
+ // "[Read] always returns len(p) and a nil error." --godoc
+ if _, err := rand.Read(b); err != nil {
+ panic("rand.Read failed: " + err.Error())
+ }
+ return fmt.Sprintf("%s-%s", prefix, base32.StdEncoding.EncodeToString(b))
+}
+
+// RandomContainerID generates a random container id for each test.
+//
+// The container id is used to create an abstract unix domain socket, which
+// must be unique. While the container forbids creating two containers with the
+// same name, sometimes between test runs the socket does not get cleaned up
+// quickly enough, causing container creation to fail.
+func RandomContainerID() string {
+ return RandomID("test-container-")
+}
+
+// Copy copies file from src to dst.
+func Copy(src, dst string) error {
+ in, err := os.Open(src)
+ if err != nil {
+ return err
+ }
+ defer in.Close()
+
+ st, err := in.Stat()
+ if err != nil {
+ return err
+ }
+
+ out, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, st.Mode().Perm())
+ if err != nil {
+ return err
+ }
+ defer out.Close()
+
+ // Mirror the local user's permissions across all users. This is
+ // because as we inject things into the container, the UID/GID will
+ // change. Also, the build system may generate artifacts with different
+ // modes. At the top-level (volume mapping) we have a big read-only
+ // knob that can be applied to prevent modifications.
+ //
+ // Note that this must be done via a separate Chmod call, otherwise the
+ // current process's umask will get in the way.
+ var mode os.FileMode
+ if st.Mode()&0100 != 0 {
+ mode |= 0111
+ }
+ if st.Mode()&0200 != 0 {
+ mode |= 0222
+ }
+ if st.Mode()&0400 != 0 {
+ mode |= 0444
+ }
+ if err := os.Chmod(dst, mode); err != nil {
+ return err
+ }
+
+ _, err = io.Copy(out, in)
+ return err
+}
+
+// Poll is a shorthand function to poll for something with given timeout.
+func Poll(cb func() error, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
+ return backoff.Retry(cb, b)
+}
+
+// WaitForHTTP tries GET requests on a port until the call succeeds or timeout.
+func WaitForHTTP(port int, timeout time.Duration) error {
+ cb := func() error {
+ c := &http.Client{
+ // Calculate timeout to be able to do minimum 5 attempts.
+ Timeout: timeout / 5,
+ }
+ url := fmt.Sprintf("http://localhost:%d/", port)
+ resp, err := c.Get(url)
+ if err != nil {
+ log.Printf("Waiting %s: %v", url, err)
+ return err
+ }
+ resp.Body.Close()
+ return nil
+ }
+ return Poll(cb, timeout)
+}
+
+// Reaper reaps child processes.
+type Reaper struct {
+ // mu protects ch, which will be nil if the reaper is not running.
+ mu sync.Mutex
+ ch chan os.Signal
+}
+
+// Start starts reaping child processes.
+func (r *Reaper) Start() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.ch != nil {
+ panic("reaper.Start called on a running reaper")
+ }
+
+ r.ch = make(chan os.Signal, 1)
+ signal.Notify(r.ch, syscall.SIGCHLD)
+
+ go func() {
+ for {
+ r.mu.Lock()
+ ch := r.ch
+ r.mu.Unlock()
+ if ch == nil {
+ return
+ }
+
+ _, ok := <-ch
+ if !ok {
+ // Channel closed.
+ return
+ }
+ for {
+ cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil)
+ if cpid < 1 {
+ break
+ }
+ }
+ }
+ }()
+}
+
+// Stop stops reaping child processes.
+func (r *Reaper) Stop() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.ch == nil {
+ panic("reaper.Stop called on a stopped reaper")
+ }
+
+ signal.Stop(r.ch)
+ close(r.ch)
+ r.ch = nil
+}
+
+// StartReaper is a helper that starts a new Reaper and returns a function to
+// stop it.
+func StartReaper() func() {
+ r := &Reaper{}
+ r.Start()
+ return r.Stop
+}
+
+// WaitUntilRead reads from the given reader until the wanted string is found
+// or until timeout.
+func WaitUntilRead(r io.Reader, want string, split bufio.SplitFunc, timeout time.Duration) error {
+ sc := bufio.NewScanner(r)
+ if split != nil {
+ sc.Split(split)
+ }
+ // done must be accessed atomically. A value greater than 0 indicates
+ // that the read loop can exit.
+ var done uint32
+ doneCh := make(chan struct{})
+ go func() {
+ for sc.Scan() {
+ t := sc.Text()
+ if strings.Contains(t, want) {
+ atomic.StoreUint32(&done, 1)
+ close(doneCh)
+ break
+ }
+ if atomic.LoadUint32(&done) > 0 {
+ break
+ }
+ }
+ }()
+ select {
+ case <-time.After(timeout):
+ atomic.StoreUint32(&done, 1)
+ return fmt.Errorf("timeout waiting to read %q", want)
+ case <-doneCh:
+ return nil
+ }
+}
+
+// KillCommand kills the process running cmd unless it hasn't been started. It
+// returns an error if it cannot kill the process unless the reason is that the
+// process has already exited.
+//
+// KillCommand will also reap the process.
+func KillCommand(cmd *exec.Cmd) error {
+ if cmd.Process == nil {
+ return nil
+ }
+ if err := cmd.Process.Kill(); err != nil {
+ if !strings.Contains(err.Error(), "process already finished") {
+ return fmt.Errorf("failed to kill process %v: %v", cmd, err)
+ }
+ }
+ return cmd.Wait()
+}
+
+// WriteTmpFile writes text to a temporary file, closes the file, and returns
+// the name of the file. A cleanup function is also returned.
+func WriteTmpFile(pattern, text string) (string, func(), error) {
+ file, err := ioutil.TempFile(TmpDir(), pattern)
+ if err != nil {
+ return "", nil, err
+ }
+ defer file.Close()
+ if _, err := file.Write([]byte(text)); err != nil {
+ return "", nil, err
+ }
+ return file.Name(), func() { os.RemoveAll(file.Name()) }, nil
+}
+
+// IsStatic returns true iff the given file is a static binary.
+func IsStatic(filename string) (bool, error) {
+ f, err := elf.Open(filename)
+ if err != nil {
+ return false, err
+ }
+ for _, prog := range f.Progs {
+ if prog.Type == elf.PT_INTERP {
+ return false, nil // Has interpreter.
+ }
+ }
+ return true, nil
+}
+
+// TestIndicesForShard returns indices for this test shard based on the
+// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars.
+//
+// If either of the env vars are not present, then the function will return all
+// tests. If there are more shards than there are tests, then the returned list
+// may be empty.
+func TestIndicesForShard(numTests int) ([]int, error) {
+ var (
+ shardIndex = 0
+ shardTotal = 1
+ )
+
+ indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
+ if indexStr != "" && totalStr != "" {
+ // Parse index and total to ints.
+ var err error
+ shardIndex, err = strconv.Atoi(indexStr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
+ }
+ shardTotal, err = strconv.Atoi(totalStr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
+ }
+ }
+
+ // Calculate!
+ var indices []int
+ numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal)))
+ for i := 0; i < numBlocks; i++ {
+ pick := i*shardTotal + shardIndex
+ if pick < numTests {
+ indices = append(indices, pick)
+ }
+ }
+ return indices, nil
+}
diff --git a/pkg/test/testutil/testutil_runfiles.go b/pkg/test/testutil/testutil_runfiles.go
new file mode 100644
index 000000000..ece9ea9a1
--- /dev/null
+++ b/pkg/test/testutil/testutil_runfiles.go
@@ -0,0 +1,75 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+)
+
+// FindFile searchs for a file inside the test run environment. It returns the
+// full path to the file. It fails if none or more than one file is found.
+func FindFile(path string) (string, error) {
+ wd, err := os.Getwd()
+ if err != nil {
+ return "", err
+ }
+
+ // The test root is demarcated by a path element called "__main__". Search for
+ // it backwards from the working directory.
+ root := wd
+ for {
+ dir, name := filepath.Split(root)
+ if name == "__main__" {
+ break
+ }
+ if len(dir) == 0 {
+ return "", fmt.Errorf("directory __main__ not found in %q", wd)
+ }
+ // Remove ending slash to loop around.
+ root = dir[:len(dir)-1]
+ }
+
+ // Annoyingly, bazel adds the build type to the directory path for go
+ // binaries, but not for c++ binaries. We use two different patterns to
+ // to find our file.
+ patterns := []string{
+ // Try the obvious path first.
+ filepath.Join(root, path),
+ // If it was a go binary, use a wildcard to match the build
+ // type. The pattern is: /test-path/__main__/directories/*/file.
+ filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)),
+ }
+
+ for _, p := range patterns {
+ matches, err := filepath.Glob(p)
+ if err != nil {
+ // "The only possible returned error is ErrBadPattern,
+ // when pattern is malformed." -godoc
+ return "", fmt.Errorf("error globbing %q: %v", p, err)
+ }
+ switch len(matches) {
+ case 0:
+ // Try the next pattern.
+ case 1:
+ // We found it.
+ return matches[0], nil
+ default:
+ return "", fmt.Errorf("more than one match found for %q: %s", path, matches)
+ }
+ }
+ return "", fmt.Errorf("file %q not found", path)
+}